1 //===-- TargetRewrite.cpp -------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10 // LLVM expects different lowering idioms to be used for distinct target
11 // triples. These distinctions are handled by this pass.
13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
15 //===----------------------------------------------------------------------===//
17 #include "flang/Optimizer/CodeGen/CodeGen.h"
20 #include "flang/Optimizer/Builder/Character.h"
21 #include "flang/Optimizer/Builder/FIRBuilder.h"
22 #include "flang/Optimizer/Builder/Todo.h"
23 #include "flang/Optimizer/Dialect/FIRDialect.h"
24 #include "flang/Optimizer/Dialect/FIROps.h"
25 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
26 #include "flang/Optimizer/Dialect/FIRType.h"
27 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
35 #define GEN_PASS_DEF_TARGETREWRITEPASS
36 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
39 #define DEBUG_TYPE "flang-target-rewrite"
43 /// Fixups for updating a FuncOp's arguments and return values.
56 FixupTy(Codes code
, std::size_t index
, std::size_t second
= 0)
57 : code
{code
}, index
{index
}, second
{second
} {}
58 FixupTy(Codes code
, std::size_t index
,
59 std::function
<void(mlir::func::FuncOp
)> &&finalizer
)
60 : code
{code
}, index
{index
}, finalizer
{finalizer
} {}
61 FixupTy(Codes code
, std::size_t index
, std::size_t second
,
62 std::function
<void(mlir::func::FuncOp
)> &&finalizer
)
63 : code
{code
}, index
{index
}, second
{second
}, finalizer
{finalizer
} {}
68 std::optional
<std::function
<void(mlir::func::FuncOp
)>> finalizer
{};
71 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
72 /// generation that traverses the FIR and modifies types and operations to a
73 /// form that is appropriate for the specific target. LLVM IR has specific
74 /// idioms that are used for distinct target processor and ABI combinations.
75 class TargetRewrite
: public fir::impl::TargetRewritePassBase
<TargetRewrite
> {
77 TargetRewrite(const fir::TargetRewriteOptions
&options
) {
78 noCharacterConversion
= options
.noCharacterConversion
;
79 noComplexConversion
= options
.noComplexConversion
;
82 void runOnOperation() override final
{
83 auto &context
= getContext();
84 mlir::OpBuilder
rewriter(&context
);
86 auto mod
= getModule();
87 if (!forcedTargetTriple
.empty())
88 fir::setTargetTriple(mod
, forcedTargetTriple
);
90 auto specifics
= fir::CodeGenSpecifics::get(
91 mod
.getContext(), fir::getTargetTriple(mod
), fir::getKindMapping(mod
));
92 setMembers(specifics
.get(), &rewriter
);
94 // Perform type conversion on signatures and call sites.
95 if (mlir::failed(convertTypes(mod
))) {
96 mlir::emitError(mlir::UnknownLoc::get(&context
),
97 "error in converting types to target abi");
101 // Convert ops in target-specific patterns.
102 mod
.walk([&](mlir::Operation
*op
) {
103 if (auto call
= mlir::dyn_cast
<fir::CallOp
>(op
)) {
104 if (!hasPortableSignature(call
.getFunctionType(), op
))
106 } else if (auto dispatch
= mlir::dyn_cast
<fir::DispatchOp
>(op
)) {
107 if (!hasPortableSignature(dispatch
.getFunctionType(), op
))
108 convertCallOp(dispatch
);
109 } else if (auto addr
= mlir::dyn_cast
<fir::AddrOfOp
>(op
)) {
110 if (addr
.getType().isa
<mlir::FunctionType
>() &&
111 !hasPortableSignature(addr
.getType(), op
))
119 mlir::ModuleOp
getModule() { return getOperation(); }
121 template <typename A
, typename B
, typename C
>
122 std::optional
<std::function
<mlir::Value(mlir::Operation
*)>>
123 rewriteCallComplexResultType(mlir::Location loc
, A ty
, B
&newResTys
,
124 B
&newInTys
, C
&newOpers
) {
125 if (noComplexConversion
) {
126 newResTys
.push_back(ty
);
129 auto m
= specifics
->complexReturnType(loc
, ty
.getElementType());
130 // Currently targets mandate COMPLEX is a single aggregate or packed
131 // scalar, including the sret case.
132 assert(m
.size() == 1 && "target of complex return not supported");
133 auto resTy
= std::get
<mlir::Type
>(m
[0]);
134 auto attr
= std::get
<fir::CodeGenSpecifics::Attributes
>(m
[0]);
136 assert(fir::isa_ref_type(resTy
) && "must be a memory reference type");
138 rewriter
->create
<fir::AllocaOp
>(loc
, fir::dyn_cast_ptrEleTy(resTy
));
139 newInTys
.push_back(resTy
);
140 newOpers
.push_back(stack
);
141 return [=](mlir::Operation
*) -> mlir::Value
{
142 auto memTy
= fir::ReferenceType::get(ty
);
143 auto cast
= rewriter
->create
<fir::ConvertOp
>(loc
, memTy
, stack
);
144 return rewriter
->create
<fir::LoadOp
>(loc
, cast
);
147 newResTys
.push_back(resTy
);
148 return [=](mlir::Operation
*call
) -> mlir::Value
{
149 auto mem
= rewriter
->create
<fir::AllocaOp
>(loc
, resTy
);
150 rewriter
->create
<fir::StoreOp
>(loc
, call
->getResult(0), mem
);
151 auto memTy
= fir::ReferenceType::get(ty
);
152 auto cast
= rewriter
->create
<fir::ConvertOp
>(loc
, memTy
, mem
);
153 return rewriter
->create
<fir::LoadOp
>(loc
, cast
);
157 template <typename A
, typename B
, typename C
>
158 void rewriteCallComplexInputType(A ty
, mlir::Value oper
, B
&newInTys
,
160 if (noComplexConversion
) {
161 newInTys
.push_back(ty
);
162 newOpers
.push_back(oper
);
166 auto *ctx
= ty
.getContext();
167 mlir::Location loc
= mlir::UnknownLoc::get(ctx
);
168 if (auto *op
= oper
.getDefiningOp())
170 auto m
= specifics
->complexArgumentType(loc
, ty
.getElementType());
172 // COMPLEX is a single aggregate
173 auto resTy
= std::get
<mlir::Type
>(m
[0]);
174 auto attr
= std::get
<fir::CodeGenSpecifics::Attributes
>(m
[0]);
175 auto oldRefTy
= fir::ReferenceType::get(ty
);
176 if (attr
.isByVal()) {
177 auto mem
= rewriter
->create
<fir::AllocaOp
>(loc
, ty
);
178 rewriter
->create
<fir::StoreOp
>(loc
, oper
, mem
);
179 newOpers
.push_back(rewriter
->create
<fir::ConvertOp
>(loc
, resTy
, mem
));
181 auto mem
= rewriter
->create
<fir::AllocaOp
>(loc
, resTy
);
182 auto cast
= rewriter
->create
<fir::ConvertOp
>(loc
, oldRefTy
, mem
);
183 rewriter
->create
<fir::StoreOp
>(loc
, oper
, cast
);
184 newOpers
.push_back(rewriter
->create
<fir::LoadOp
>(loc
, mem
));
186 newInTys
.push_back(resTy
);
188 assert(m
.size() == 2);
189 // COMPLEX is split into 2 separate arguments
190 auto iTy
= rewriter
->getIntegerType(32);
191 for (auto e
: llvm::enumerate(m
)) {
192 auto &tup
= e
.value();
193 auto ty
= std::get
<mlir::Type
>(tup
);
194 auto index
= e
.index();
195 auto idx
= rewriter
->getIntegerAttr(iTy
, index
);
196 auto val
= rewriter
->create
<fir::ExtractValueOp
>(
197 loc
, ty
, oper
, rewriter
->getArrayAttr(idx
));
198 newInTys
.push_back(ty
);
199 newOpers
.push_back(val
);
204 // Convert fir.call and fir.dispatch Ops.
205 template <typename A
>
206 void convertCallOp(A callOp
) {
207 auto fnTy
= callOp
.getFunctionType();
208 auto loc
= callOp
.getLoc();
209 rewriter
->setInsertionPoint(callOp
);
210 llvm::SmallVector
<mlir::Type
> newResTys
;
211 llvm::SmallVector
<mlir::Type
> newInTys
;
212 llvm::SmallVector
<mlir::Value
> newOpers
;
214 // If the call is indirect, the first argument must still be the function
217 if constexpr (std::is_same_v
<std::decay_t
<A
>, fir::CallOp
>) {
218 if (!callOp
.getCallee()) {
219 newInTys
.push_back(fnTy
.getInput(0));
220 newOpers
.push_back(callOp
.getOperand(0));
224 dropFront
= 1; // First operand is the polymorphic object.
227 // Determine the rewrite function, `wrap`, for the result value.
228 std::optional
<std::function
<mlir::Value(mlir::Operation
*)>> wrap
;
229 if (fnTy
.getResults().size() == 1) {
230 mlir::Type ty
= fnTy
.getResult(0);
231 llvm::TypeSwitch
<mlir::Type
>(ty
)
232 .template Case
<fir::ComplexType
>([&](fir::ComplexType cmplx
) {
233 wrap
= rewriteCallComplexResultType(loc
, cmplx
, newResTys
, newInTys
,
236 .template Case
<mlir::ComplexType
>([&](mlir::ComplexType cmplx
) {
237 wrap
= rewriteCallComplexResultType(loc
, cmplx
, newResTys
, newInTys
,
240 .Default([&](mlir::Type ty
) { newResTys
.push_back(ty
); });
241 } else if (fnTy
.getResults().size() > 1) {
242 TODO(loc
, "multiple results not supported yet");
245 llvm::SmallVector
<mlir::Type
> trailingInTys
;
246 llvm::SmallVector
<mlir::Value
> trailingOpers
;
247 unsigned passArgShift
= 0;
248 for (auto e
: llvm::enumerate(
249 llvm::zip(fnTy
.getInputs().drop_front(dropFront
),
250 callOp
.getOperands().drop_front(dropFront
)))) {
251 mlir::Type ty
= std::get
<0>(e
.value());
252 mlir::Value oper
= std::get
<1>(e
.value());
253 unsigned index
= e
.index();
254 llvm::TypeSwitch
<mlir::Type
>(ty
)
255 .template Case
<fir::BoxCharType
>([&](fir::BoxCharType boxTy
) {
257 if constexpr (std::is_same_v
<std::decay_t
<A
>, fir::CallOp
>) {
258 if (noCharacterConversion
) {
259 newInTys
.push_back(boxTy
);
260 newOpers
.push_back(oper
);
263 sret
= callOp
.getCallee() &&
265 index
, getModule().lookupSymbol
<mlir::func::FuncOp
>(
266 *callOp
.getCallee()));
268 // TODO: dispatch case; how do we put arguments on a call?
269 // We cannot put both an sret and the dispatch object first.
271 TODO(loc
, "dispatch + sret not supported yet");
273 auto m
= specifics
->boxcharArgumentType(boxTy
.getEleTy(), sret
);
274 auto unbox
= rewriter
->create
<fir::UnboxCharOp
>(
275 loc
, std::get
<mlir::Type
>(m
[0]), std::get
<mlir::Type
>(m
[1]),
277 // unboxed CHARACTER arguments
278 for (auto e
: llvm::enumerate(m
)) {
279 unsigned idx
= e
.index();
281 std::get
<fir::CodeGenSpecifics::Attributes
>(e
.value());
282 auto argTy
= std::get
<mlir::Type
>(e
.value());
283 if (attr
.isAppend()) {
284 trailingInTys
.push_back(argTy
);
285 trailingOpers
.push_back(unbox
.getResult(idx
));
287 newInTys
.push_back(argTy
);
288 newOpers
.push_back(unbox
.getResult(idx
));
292 .template Case
<fir::ComplexType
>([&](fir::ComplexType cmplx
) {
293 rewriteCallComplexInputType(cmplx
, oper
, newInTys
, newOpers
);
295 .template Case
<mlir::ComplexType
>([&](mlir::ComplexType cmplx
) {
296 rewriteCallComplexInputType(cmplx
, oper
, newInTys
, newOpers
);
298 .template Case
<mlir::TupleType
>([&](mlir::TupleType tuple
) {
299 if (fir::isCharacterProcedureTuple(tuple
)) {
300 mlir::ModuleOp module
= getModule();
301 if constexpr (std::is_same_v
<std::decay_t
<A
>, fir::CallOp
>) {
302 if (callOp
.getCallee()) {
303 llvm::StringRef charProcAttr
=
304 fir::getCharacterProcedureDummyAttrName();
305 // The charProcAttr attribute is only used as a safety to
306 // confirm that this is a dummy procedure and should be split.
307 // It cannot be used to match because attributes are not
308 // available in case of indirect calls.
309 auto funcOp
= module
.lookupSymbol
<mlir::func::FuncOp
>(
310 *callOp
.getCallee());
312 !funcOp
.template getArgAttrOfType
<mlir::UnitAttr
>(
313 index
, charProcAttr
))
314 mlir::emitError(loc
, "tuple argument will be split even "
315 "though it does not have the `" +
316 charProcAttr
+ "` attribute");
319 mlir::Type funcPointerType
= tuple
.getType(0);
320 mlir::Type lenType
= tuple
.getType(1);
321 fir::KindMapping kindMap
= fir::getKindMapping(module
);
322 fir::FirOpBuilder
builder(*rewriter
, kindMap
);
323 auto [funcPointer
, len
] =
324 fir::factory::extractCharacterProcedureTuple(builder
, loc
,
326 newInTys
.push_back(funcPointerType
);
327 newOpers
.push_back(funcPointer
);
328 trailingInTys
.push_back(lenType
);
329 trailingOpers
.push_back(len
);
331 newInTys
.push_back(tuple
);
332 newOpers
.push_back(oper
);
335 .Default([&](mlir::Type ty
) {
336 if constexpr (std::is_same_v
<std::decay_t
<A
>, fir::DispatchOp
>) {
337 if (callOp
.getPassArgPos() && *callOp
.getPassArgPos() == index
)
338 passArgShift
= newOpers
.size() - *callOp
.getPassArgPos();
340 newInTys
.push_back(ty
);
341 newOpers
.push_back(oper
);
344 newInTys
.insert(newInTys
.end(), trailingInTys
.begin(), trailingInTys
.end());
345 newOpers
.insert(newOpers
.end(), trailingOpers
.begin(), trailingOpers
.end());
346 if constexpr (std::is_same_v
<std::decay_t
<A
>, fir::CallOp
>) {
348 if (callOp
.getCallee()) {
350 rewriter
->create
<A
>(loc
, *callOp
.getCallee(), newResTys
, newOpers
);
352 // Force new type on the input operand.
353 newOpers
[0].setType(mlir::FunctionType::get(
355 mlir::TypeRange
{newInTys
}.drop_front(dropFront
), newResTys
));
356 newCall
= rewriter
->create
<A
>(loc
, newResTys
, newOpers
);
358 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall
<< '\n');
360 replaceOp(callOp
, (*wrap
)(newCall
.getOperation()));
362 replaceOp(callOp
, newCall
.getResults());
364 fir::DispatchOp dispatchOp
= rewriter
->create
<A
>(
365 loc
, newResTys
, rewriter
->getStringAttr(callOp
.getMethod()),
366 callOp
.getOperands()[0], newOpers
,
367 rewriter
->getI32IntegerAttr(*callOp
.getPassArgPos() + passArgShift
));
369 replaceOp(callOp
, (*wrap
)(dispatchOp
.getOperation()));
371 replaceOp(callOp
, dispatchOp
.getResults());
375 // Result type fixup for fir::ComplexType and mlir::ComplexType
376 template <typename A
, typename B
>
377 void lowerComplexSignatureRes(mlir::Location loc
, A cmplx
, B
&newResTys
,
379 if (noComplexConversion
) {
380 newResTys
.push_back(cmplx
);
383 specifics
->complexReturnType(loc
, cmplx
.getElementType())) {
384 auto argTy
= std::get
<mlir::Type
>(tup
);
385 if (std::get
<fir::CodeGenSpecifics::Attributes
>(tup
).isSRet())
386 newInTys
.push_back(argTy
);
388 newResTys
.push_back(argTy
);
393 // Argument type fixup for fir::ComplexType and mlir::ComplexType
394 template <typename A
, typename B
>
395 void lowerComplexSignatureArg(mlir::Location loc
, A cmplx
, B
&newInTys
) {
396 if (noComplexConversion
)
397 newInTys
.push_back(cmplx
);
400 specifics
->complexArgumentType(loc
, cmplx
.getElementType()))
401 newInTys
.push_back(std::get
<mlir::Type
>(tup
));
404 /// Taking the address of a function. Modify the signature as needed.
405 void convertAddrOp(fir::AddrOfOp addrOp
) {
406 rewriter
->setInsertionPoint(addrOp
);
407 auto addrTy
= addrOp
.getType().cast
<mlir::FunctionType
>();
408 llvm::SmallVector
<mlir::Type
> newResTys
;
409 llvm::SmallVector
<mlir::Type
> newInTys
;
410 auto loc
= addrOp
.getLoc();
411 for (mlir::Type ty
: addrTy
.getResults()) {
412 llvm::TypeSwitch
<mlir::Type
>(ty
)
413 .Case
<fir::ComplexType
>([&](fir::ComplexType ty
) {
414 lowerComplexSignatureRes(loc
, ty
, newResTys
, newInTys
);
416 .Case
<mlir::ComplexType
>([&](mlir::ComplexType ty
) {
417 lowerComplexSignatureRes(loc
, ty
, newResTys
, newInTys
);
419 .Default([&](mlir::Type ty
) { newResTys
.push_back(ty
); });
421 llvm::SmallVector
<mlir::Type
> trailingInTys
;
422 for (mlir::Type ty
: addrTy
.getInputs()) {
423 llvm::TypeSwitch
<mlir::Type
>(ty
)
424 .Case
<fir::BoxCharType
>([&](auto box
) {
425 if (noCharacterConversion
) {
426 newInTys
.push_back(box
);
428 for (auto &tup
: specifics
->boxcharArgumentType(box
.getEleTy())) {
429 auto attr
= std::get
<fir::CodeGenSpecifics::Attributes
>(tup
);
430 auto argTy
= std::get
<mlir::Type
>(tup
);
431 llvm::SmallVector
<mlir::Type
> &vec
=
432 attr
.isAppend() ? trailingInTys
: newInTys
;
433 vec
.push_back(argTy
);
437 .Case
<fir::ComplexType
>([&](fir::ComplexType ty
) {
438 lowerComplexSignatureArg(loc
, ty
, newInTys
);
440 .Case
<mlir::ComplexType
>([&](mlir::ComplexType ty
) {
441 lowerComplexSignatureArg(loc
, ty
, newInTys
);
443 .Case
<mlir::TupleType
>([&](mlir::TupleType tuple
) {
444 if (fir::isCharacterProcedureTuple(tuple
)) {
445 newInTys
.push_back(tuple
.getType(0));
446 trailingInTys
.push_back(tuple
.getType(1));
448 newInTys
.push_back(ty
);
451 .Default([&](mlir::Type ty
) { newInTys
.push_back(ty
); });
453 // append trailing input types
454 newInTys
.insert(newInTys
.end(), trailingInTys
.begin(), trailingInTys
.end());
455 // replace this op with a new one with the updated signature
456 auto newTy
= rewriter
->getFunctionType(newInTys
, newResTys
);
457 auto newOp
= rewriter
->create
<fir::AddrOfOp
>(addrOp
.getLoc(), newTy
,
459 replaceOp(addrOp
, newOp
.getResult());
462 /// Convert the type signatures on all the functions present in the module.
463 /// As the type signature is being changed, this must also update the
464 /// function itself to use any new arguments, etc.
465 mlir::LogicalResult
convertTypes(mlir::ModuleOp mod
) {
466 for (auto fn
: mod
.getOps
<mlir::func::FuncOp
>())
467 convertSignature(fn
);
468 return mlir::success();
471 // Returns true if the function should be interoperable with C.
472 static bool isFuncWithCCallingConvention(mlir::Operation
*op
) {
473 auto funcOp
= mlir::dyn_cast
<mlir::func::FuncOp
>(op
);
476 return op
->hasAttrOfType
<mlir::UnitAttr
>(
477 fir::FIROpsDialect::getFirRuntimeAttrName()) ||
478 op
->hasAttrOfType
<mlir::StringAttr
>(fir::getSymbolAttrName());
481 /// If the signature does not need any special target-specific conversions,
482 /// then it is considered portable for any target, and this function will
483 /// return `true`. Otherwise, the signature is not portable and `false` is
485 bool hasPortableSignature(mlir::Type signature
, mlir::Operation
*op
) {
486 assert(signature
.isa
<mlir::FunctionType
>());
487 auto func
= signature
.dyn_cast
<mlir::FunctionType
>();
488 bool hasCCallingConv
= isFuncWithCCallingConvention(op
);
489 for (auto ty
: func
.getResults())
490 if ((ty
.isa
<fir::BoxCharType
>() && !noCharacterConversion
) ||
491 (fir::isa_complex(ty
) && !noComplexConversion
) ||
492 (ty
.isa
<mlir::IntegerType
>() && hasCCallingConv
)) {
493 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature
<< " for target\n");
496 for (auto ty
: func
.getInputs())
497 if (((ty
.isa
<fir::BoxCharType
>() || fir::isCharacterProcedureTuple(ty
)) &&
498 !noCharacterConversion
) ||
499 (fir::isa_complex(ty
) && !noComplexConversion
) ||
500 (ty
.isa
<mlir::IntegerType
>() && hasCCallingConv
)) {
501 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature
<< " for target\n");
507 /// Determine if the signature has host associations. The host association
508 /// argument may need special target specific rewriting.
509 static bool hasHostAssociations(mlir::func::FuncOp func
) {
510 std::size_t end
= func
.getFunctionType().getInputs().size();
511 for (std::size_t i
= 0; i
< end
; ++i
)
512 if (func
.getArgAttrOfType
<mlir::UnitAttr
>(i
, fir::getHostAssocAttrName()))
517 /// Rewrite the signatures and body of the `FuncOp`s in the module for
518 /// the immediately subsequent target code gen.
519 void convertSignature(mlir::func::FuncOp func
) {
520 auto funcTy
= func
.getFunctionType().cast
<mlir::FunctionType
>();
521 if (hasPortableSignature(funcTy
, func
) && !hasHostAssociations(func
))
523 llvm::SmallVector
<mlir::Type
> newResTys
;
524 llvm::SmallVector
<mlir::Type
> newInTys
;
525 llvm::SmallVector
<std::pair
<unsigned, mlir::NamedAttribute
>> savedAttrs
;
526 llvm::SmallVector
<std::pair
<unsigned, mlir::NamedAttribute
>> extraAttrs
;
527 llvm::SmallVector
<FixupTy
> fixups
;
528 llvm::SmallVector
<std::pair
<unsigned, mlir::NamedAttrList
>, 1> resultAttrs
;
530 // Save argument attributes in case there is a shift so we can replace them
532 for (auto e
: llvm::enumerate(funcTy
.getInputs())) {
533 unsigned index
= e
.index();
534 llvm::ArrayRef
<mlir::NamedAttribute
> attrs
=
535 mlir::function_interface_impl::getArgAttrs(func
, index
);
536 for (mlir::NamedAttribute attr
: attrs
) {
537 savedAttrs
.push_back({index
, attr
});
541 // Convert return value(s)
542 for (auto ty
: funcTy
.getResults())
543 llvm::TypeSwitch
<mlir::Type
>(ty
)
544 .Case
<fir::ComplexType
>([&](fir::ComplexType cmplx
) {
545 if (noComplexConversion
)
546 newResTys
.push_back(cmplx
);
548 doComplexReturn(func
, cmplx
, newResTys
, newInTys
, fixups
);
550 .Case
<mlir::ComplexType
>([&](mlir::ComplexType cmplx
) {
551 if (noComplexConversion
)
552 newResTys
.push_back(cmplx
);
554 doComplexReturn(func
, cmplx
, newResTys
, newInTys
, fixups
);
556 .Case
<mlir::IntegerType
>([&](mlir::IntegerType intTy
) {
557 auto m
= specifics
->integerArgumentType(func
.getLoc(), intTy
);
558 assert(m
.size() == 1);
559 auto attr
= std::get
<fir::CodeGenSpecifics::Attributes
>(m
[0]);
560 auto retTy
= std::get
<mlir::Type
>(m
[0]);
561 std::size_t resId
= newResTys
.size();
562 llvm::StringRef extensionAttrName
= attr
.getIntExtensionAttrName();
563 if (!extensionAttrName
.empty() &&
564 isFuncWithCCallingConvention(func
))
565 resultAttrs
.emplace_back(
566 resId
, rewriter
->getNamedAttr(extensionAttrName
,
567 rewriter
->getUnitAttr()));
568 newResTys
.push_back(retTy
);
570 .Default([&](mlir::Type ty
) { newResTys
.push_back(ty
); });
572 // Saved potential shift in argument. Handling of result can add arguments
573 // at the beginning of the function signature.
574 unsigned argumentShift
= newInTys
.size();
577 llvm::SmallVector
<mlir::Type
> trailingTys
;
578 for (auto e
: llvm::enumerate(funcTy
.getInputs())) {
580 unsigned index
= e
.index();
581 llvm::TypeSwitch
<mlir::Type
>(ty
)
582 .Case
<fir::BoxCharType
>([&](fir::BoxCharType boxTy
) {
583 if (noCharacterConversion
) {
584 newInTys
.push_back(boxTy
);
586 // Convert a CHARACTER argument type. This can involve separating
587 // the pointer and the LEN into two arguments and moving the LEN
588 // argument to the end of the arg list.
589 bool sret
= functionArgIsSRet(index
, func
);
590 for (auto e
: llvm::enumerate(specifics
->boxcharArgumentType(
591 boxTy
.getEleTy(), sret
))) {
592 auto &tup
= e
.value();
593 auto index
= e
.index();
594 auto attr
= std::get
<fir::CodeGenSpecifics::Attributes
>(tup
);
595 auto argTy
= std::get
<mlir::Type
>(tup
);
596 if (attr
.isAppend()) {
597 trailingTys
.push_back(argTy
);
600 fixups
.emplace_back(FixupTy::Codes::CharPair
,
601 newInTys
.size(), index
);
603 fixups
.emplace_back(FixupTy::Codes::Trailing
,
604 newInTys
.size(), trailingTys
.size());
606 newInTys
.push_back(argTy
);
611 .Case
<fir::ComplexType
>([&](fir::ComplexType cmplx
) {
612 if (noComplexConversion
)
613 newInTys
.push_back(cmplx
);
615 doComplexArg(func
, cmplx
, newInTys
, fixups
);
617 .Case
<mlir::ComplexType
>([&](mlir::ComplexType cmplx
) {
618 if (noComplexConversion
)
619 newInTys
.push_back(cmplx
);
621 doComplexArg(func
, cmplx
, newInTys
, fixups
);
623 .Case
<mlir::TupleType
>([&](mlir::TupleType tuple
) {
624 if (fir::isCharacterProcedureTuple(tuple
)) {
625 fixups
.emplace_back(FixupTy::Codes::TrailingCharProc
,
626 newInTys
.size(), trailingTys
.size());
627 newInTys
.push_back(tuple
.getType(0));
628 trailingTys
.push_back(tuple
.getType(1));
630 newInTys
.push_back(ty
);
633 .Case
<mlir::IntegerType
>([&](mlir::IntegerType intTy
) {
634 auto m
= specifics
->integerArgumentType(func
.getLoc(), intTy
);
635 assert(m
.size() == 1);
636 auto attr
= std::get
<fir::CodeGenSpecifics::Attributes
>(m
[0]);
637 auto argTy
= std::get
<mlir::Type
>(m
[0]);
638 auto argNo
= newInTys
.size();
639 llvm::StringRef extensionAttrName
= attr
.getIntExtensionAttrName();
640 if (!extensionAttrName
.empty() &&
641 isFuncWithCCallingConvention(func
))
642 fixups
.emplace_back(FixupTy::Codes::ArgumentType
, argNo
,
643 [=](mlir::func::FuncOp func
) {
645 argNo
, extensionAttrName
,
646 mlir::UnitAttr::get(func
.getContext()));
649 newInTys
.push_back(argTy
);
651 .Default([&](mlir::Type ty
) { newInTys
.push_back(ty
); });
653 if (func
.getArgAttrOfType
<mlir::UnitAttr
>(index
,
654 fir::getHostAssocAttrName())) {
655 extraAttrs
.push_back(
656 {newInTys
.size() - 1,
657 rewriter
->getNamedAttr("llvm.nest", rewriter
->getUnitAttr())});
662 // If the function has a body, then apply the fixups to the arguments and
663 // return ops as required. These fixups are done in place.
664 auto loc
= func
.getLoc();
665 const auto fixupSize
= fixups
.size();
666 const auto oldArgTys
= func
.getFunctionType().getInputs();
668 for (std::remove_const_t
<decltype(fixupSize
)> i
= 0; i
< fixupSize
; ++i
) {
669 const auto &fixup
= fixups
[i
];
670 switch (fixup
.code
) {
671 case FixupTy::Codes::ArgumentAsLoad
: {
672 // Argument was pass-by-value, but is now pass-by-reference and
673 // possibly with a different element type.
674 auto newArg
= func
.front().insertArgument(fixup
.index
,
675 newInTys
[fixup
.index
], loc
);
676 rewriter
->setInsertionPointToStart(&func
.front());
678 fir::ReferenceType::get(oldArgTys
[fixup
.index
- offset
]);
679 auto cast
= rewriter
->create
<fir::ConvertOp
>(loc
, oldArgTy
, newArg
);
680 auto load
= rewriter
->create
<fir::LoadOp
>(loc
, cast
);
681 func
.getArgument(fixup
.index
+ 1).replaceAllUsesWith(load
);
682 func
.front().eraseArgument(fixup
.index
+ 1);
684 case FixupTy::Codes::ArgumentType
: {
685 // Argument is pass-by-value, but its type has likely been modified to
686 // suit the target ABI convention.
688 fir::ReferenceType::get(oldArgTys
[fixup
.index
- offset
]);
689 // If type did not change, keep the original argument.
690 if (newInTys
[fixup
.index
] == oldArgTy
)
693 auto newArg
= func
.front().insertArgument(fixup
.index
,
694 newInTys
[fixup
.index
], loc
);
695 rewriter
->setInsertionPointToStart(&func
.front());
697 rewriter
->create
<fir::AllocaOp
>(loc
, newInTys
[fixup
.index
]);
698 rewriter
->create
<fir::StoreOp
>(loc
, newArg
, mem
);
699 auto cast
= rewriter
->create
<fir::ConvertOp
>(loc
, oldArgTy
, mem
);
700 mlir::Value load
= rewriter
->create
<fir::LoadOp
>(loc
, cast
);
701 func
.getArgument(fixup
.index
+ 1).replaceAllUsesWith(load
);
702 func
.front().eraseArgument(fixup
.index
+ 1);
703 LLVM_DEBUG(llvm::dbgs()
704 << "old argument: " << oldArgTy
.getEleTy()
705 << ", repl: " << load
<< ", new argument: "
706 << func
.getArgument(fixup
.index
).getType() << '\n');
708 case FixupTy::Codes::CharPair
: {
709 // The FIR boxchar argument has been split into a pair of distinct
710 // arguments that are in juxtaposition to each other.
711 auto newArg
= func
.front().insertArgument(fixup
.index
,
712 newInTys
[fixup
.index
], loc
);
713 if (fixup
.second
== 1) {
714 rewriter
->setInsertionPointToStart(&func
.front());
715 auto boxTy
= oldArgTys
[fixup
.index
- offset
- fixup
.second
];
716 auto box
= rewriter
->create
<fir::EmboxCharOp
>(
717 loc
, boxTy
, func
.front().getArgument(fixup
.index
- 1), newArg
);
718 func
.getArgument(fixup
.index
+ 1).replaceAllUsesWith(box
);
719 func
.front().eraseArgument(fixup
.index
+ 1);
723 case FixupTy::Codes::ReturnAsStore
: {
724 // The value being returned is now being returned in memory (callee
725 // stack space) through a hidden reference argument.
726 auto newArg
= func
.front().insertArgument(fixup
.index
,
727 newInTys
[fixup
.index
], loc
);
729 func
.walk([&](mlir::func::ReturnOp ret
) {
730 rewriter
->setInsertionPoint(ret
);
731 auto oldOper
= ret
.getOperand(0);
732 auto oldOperTy
= fir::ReferenceType::get(oldOper
.getType());
734 rewriter
->create
<fir::ConvertOp
>(loc
, oldOperTy
, newArg
);
735 rewriter
->create
<fir::StoreOp
>(loc
, oldOper
, cast
);
736 rewriter
->create
<mlir::func::ReturnOp
>(loc
);
740 case FixupTy::Codes::ReturnType
: {
741 // The function is still returning a value, but its type has likely
742 // changed to suit the target ABI convention.
743 func
.walk([&](mlir::func::ReturnOp ret
) {
744 rewriter
->setInsertionPoint(ret
);
745 auto oldOper
= ret
.getOperand(0);
746 auto oldOperTy
= fir::ReferenceType::get(oldOper
.getType());
748 rewriter
->create
<fir::AllocaOp
>(loc
, newResTys
[fixup
.index
]);
749 auto cast
= rewriter
->create
<fir::ConvertOp
>(loc
, oldOperTy
, mem
);
750 rewriter
->create
<fir::StoreOp
>(loc
, oldOper
, cast
);
751 mlir::Value load
= rewriter
->create
<fir::LoadOp
>(loc
, mem
);
752 rewriter
->create
<mlir::func::ReturnOp
>(loc
, load
);
756 case FixupTy::Codes::Split
: {
757 // The FIR argument has been split into a pair of distinct arguments
758 // that are in juxtaposition to each other. (For COMPLEX value.)
759 auto newArg
= func
.front().insertArgument(fixup
.index
,
760 newInTys
[fixup
.index
], loc
);
761 if (fixup
.second
== 1) {
762 rewriter
->setInsertionPointToStart(&func
.front());
763 auto cplxTy
= oldArgTys
[fixup
.index
- offset
- fixup
.second
];
764 auto undef
= rewriter
->create
<fir::UndefOp
>(loc
, cplxTy
);
765 auto iTy
= rewriter
->getIntegerType(32);
766 auto zero
= rewriter
->getIntegerAttr(iTy
, 0);
767 auto one
= rewriter
->getIntegerAttr(iTy
, 1);
768 auto cplx1
= rewriter
->create
<fir::InsertValueOp
>(
769 loc
, cplxTy
, undef
, func
.front().getArgument(fixup
.index
- 1),
770 rewriter
->getArrayAttr(zero
));
771 auto cplx
= rewriter
->create
<fir::InsertValueOp
>(
772 loc
, cplxTy
, cplx1
, newArg
, rewriter
->getArrayAttr(one
));
773 func
.getArgument(fixup
.index
+ 1).replaceAllUsesWith(cplx
);
774 func
.front().eraseArgument(fixup
.index
+ 1);
778 case FixupTy::Codes::Trailing
: {
779 // The FIR argument has been split into a pair of distinct arguments.
780 // The first part of the pair appears in the original argument
781 // position. The second part of the pair is appended after all the
782 // original arguments. (Boxchar arguments.)
783 auto newBufArg
= func
.front().insertArgument(
784 fixup
.index
, newInTys
[fixup
.index
], loc
);
786 func
.front().addArgument(trailingTys
[fixup
.second
], loc
);
787 auto boxTy
= oldArgTys
[fixup
.index
- offset
];
788 rewriter
->setInsertionPointToStart(&func
.front());
789 auto box
= rewriter
->create
<fir::EmboxCharOp
>(loc
, boxTy
, newBufArg
,
791 func
.getArgument(fixup
.index
+ 1).replaceAllUsesWith(box
);
792 func
.front().eraseArgument(fixup
.index
+ 1);
794 case FixupTy::Codes::TrailingCharProc
: {
795 // The FIR character procedure argument tuple must be split into a
796 // pair of distinct arguments. The first part of the pair appears in
797 // the original argument position. The second part of the pair is
798 // appended after all the original arguments.
799 auto newProcPointerArg
= func
.front().insertArgument(
800 fixup
.index
, newInTys
[fixup
.index
], loc
);
802 func
.front().addArgument(trailingTys
[fixup
.second
], loc
);
803 auto tupleType
= oldArgTys
[fixup
.index
- offset
];
804 rewriter
->setInsertionPointToStart(&func
.front());
805 fir::KindMapping kindMap
= fir::getKindMapping(getModule());
806 fir::FirOpBuilder
builder(*rewriter
, kindMap
);
807 auto tuple
= fir::factory::createCharacterProcedureTuple(
808 builder
, loc
, tupleType
, newProcPointerArg
, newLenArg
);
809 func
.getArgument(fixup
.index
+ 1).replaceAllUsesWith(tuple
);
810 func
.front().eraseArgument(fixup
.index
+ 1);
816 // Set the new type and finalize the arguments, etc.
817 newInTys
.insert(newInTys
.end(), trailingTys
.begin(), trailingTys
.end());
819 mlir::FunctionType::get(func
.getContext(), newInTys
, newResTys
);
820 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy
<< '\n');
821 func
.setType(newFuncTy
);
823 for (std::pair
<unsigned, mlir::NamedAttribute
> extraAttr
: extraAttrs
)
824 func
.setArgAttr(extraAttr
.first
, extraAttr
.second
.getName(),
825 extraAttr
.second
.getValue());
827 for (auto [resId
, resAttrList
] : resultAttrs
)
828 for (mlir::NamedAttribute resAttr
: resAttrList
)
829 func
.setResultAttr(resId
, resAttr
.getName(), resAttr
.getValue());
831 // Replace attributes to the correct argument if there was an argument shift
833 if (argumentShift
> 0) {
834 for (std::pair
<unsigned, mlir::NamedAttribute
> savedAttr
: savedAttrs
) {
835 func
.removeArgAttr(savedAttr
.first
, savedAttr
.second
.getName());
836 func
.setArgAttr(savedAttr
.first
+ argumentShift
,
837 savedAttr
.second
.getName(),
838 savedAttr
.second
.getValue());
842 for (auto &fixup
: fixups
)
844 (*fixup
.finalizer
)(func
);
847 inline bool functionArgIsSRet(unsigned index
, mlir::func::FuncOp func
) {
848 if (auto attr
= func
.getArgAttrOfType
<mlir::TypeAttr
>(index
, "llvm.sret"))
853 /// Convert a complex return value. This can involve converting the return
854 /// value to a "hidden" first argument or packing the complex into a wide
856 template <typename A
, typename B
, typename C
>
857 void doComplexReturn(mlir::func::FuncOp func
, A cmplx
, B
&newResTys
,
858 B
&newInTys
, C
&fixups
) {
859 if (noComplexConversion
) {
860 newResTys
.push_back(cmplx
);
864 specifics
->complexReturnType(func
.getLoc(), cmplx
.getElementType());
865 assert(m
.size() == 1);
867 auto attr
= std::get
<fir::CodeGenSpecifics::Attributes
>(tup
);
868 auto argTy
= std::get
<mlir::Type
>(tup
);
870 unsigned argNo
= newInTys
.size();
871 if (auto align
= attr
.getAlignment())
873 FixupTy::Codes::ReturnAsStore
, argNo
, [=](mlir::func::FuncOp func
) {
874 auto elemType
= fir::dyn_cast_ptrOrBoxEleTy(
875 func
.getFunctionType().getInput(argNo
));
876 func
.setArgAttr(argNo
, "llvm.sret",
877 mlir::TypeAttr::get(elemType
));
878 func
.setArgAttr(argNo
, "llvm.align",
879 rewriter
->getIntegerAttr(
880 rewriter
->getIntegerType(32), align
));
883 fixups
.emplace_back(FixupTy::Codes::ReturnAsStore
, argNo
,
884 [=](mlir::func::FuncOp func
) {
885 auto elemType
= fir::dyn_cast_ptrOrBoxEleTy(
886 func
.getFunctionType().getInput(argNo
));
887 func
.setArgAttr(argNo
, "llvm.sret",
888 mlir::TypeAttr::get(elemType
));
890 newInTys
.push_back(argTy
);
893 if (auto align
= attr
.getAlignment())
894 fixups
.emplace_back(FixupTy::Codes::ReturnType
, newResTys
.size(),
895 [=](mlir::func::FuncOp func
) {
897 newResTys
.size(), "llvm.align",
898 rewriter
->getIntegerAttr(
899 rewriter
->getIntegerType(32), align
));
902 fixups
.emplace_back(FixupTy::Codes::ReturnType
, newResTys
.size());
904 newResTys
.push_back(argTy
);
907 /// Convert a complex argument value. This can involve storing the value to
908 /// a temporary memory location or factoring the value into two distinct
910 template <typename A
, typename B
, typename C
>
911 void doComplexArg(mlir::func::FuncOp func
, A cmplx
, B
&newInTys
, C
&fixups
) {
912 if (noComplexConversion
) {
913 newInTys
.push_back(cmplx
);
917 specifics
->complexArgumentType(func
.getLoc(), cmplx
.getElementType());
918 const auto fixupCode
=
919 m
.size() > 1 ? FixupTy::Codes::Split
: FixupTy::Codes::ArgumentType
;
920 for (auto e
: llvm::enumerate(m
)) {
921 auto &tup
= e
.value();
922 auto index
= e
.index();
923 auto attr
= std::get
<fir::CodeGenSpecifics::Attributes
>(tup
);
924 auto argTy
= std::get
<mlir::Type
>(tup
);
925 auto argNo
= newInTys
.size();
926 if (attr
.isByVal()) {
927 if (auto align
= attr
.getAlignment())
928 fixups
.emplace_back(FixupTy::Codes::ArgumentAsLoad
, argNo
,
929 [=](mlir::func::FuncOp func
) {
930 auto elemType
= fir::dyn_cast_ptrOrBoxEleTy(
931 func
.getFunctionType().getInput(argNo
));
932 func
.setArgAttr(argNo
, "llvm.byval",
933 mlir::TypeAttr::get(elemType
));
936 rewriter
->getIntegerAttr(
937 rewriter
->getIntegerType(32), align
));
940 fixups
.emplace_back(FixupTy::Codes::ArgumentAsLoad
, newInTys
.size(),
941 [=](mlir::func::FuncOp func
) {
942 auto elemType
= fir::dyn_cast_ptrOrBoxEleTy(
943 func
.getFunctionType().getInput(argNo
));
944 func
.setArgAttr(argNo
, "llvm.byval",
945 mlir::TypeAttr::get(elemType
));
948 if (auto align
= attr
.getAlignment())
950 fixupCode
, argNo
, index
, [=](mlir::func::FuncOp func
) {
951 func
.setArgAttr(argNo
, "llvm.align",
952 rewriter
->getIntegerAttr(
953 rewriter
->getIntegerType(32), align
));
956 fixups
.emplace_back(fixupCode
, argNo
, index
);
958 newInTys
.push_back(argTy
);
963 // Replace `op` and remove it.
964 void replaceOp(mlir::Operation
*op
, mlir::ValueRange newValues
) {
965 op
->replaceAllUsesWith(newValues
);
966 op
->dropAllReferences();
970 inline void setMembers(fir::CodeGenSpecifics
*s
, mlir::OpBuilder
*r
) {
975 inline void clearMembers() { setMembers(nullptr, nullptr); }
977 fir::CodeGenSpecifics
*specifics
= nullptr;
978 mlir::OpBuilder
*rewriter
= nullptr;
982 std::unique_ptr
<mlir::OperationPass
<mlir::ModuleOp
>>
983 fir::createFirTargetRewritePass(const fir::TargetRewriteOptions
&options
) {
984 return std::make_unique
<TargetRewrite
>(options
);