Break circular dependency between FIR dialect and utilities
[llvm-project.git] / flang / lib / Optimizer / CodeGen / TargetRewrite.cpp
blob0957e399f4e56998e93aa0f1efcad88b473f2174
1 //===-- TargetRewrite.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest.
10 // LLVM expects different lowering idioms to be used for distinct target
11 // triples. These distinctions are handled by this pass.
13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
15 //===----------------------------------------------------------------------===//
17 #include "flang/Optimizer/CodeGen/CodeGen.h"
19 #include "Target.h"
20 #include "flang/Optimizer/Builder/Character.h"
21 #include "flang/Optimizer/Builder/FIRBuilder.h"
22 #include "flang/Optimizer/Builder/Todo.h"
23 #include "flang/Optimizer/Dialect/FIRDialect.h"
24 #include "flang/Optimizer/Dialect/FIROps.h"
25 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
26 #include "flang/Optimizer/Dialect/FIRType.h"
27 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 #include <optional>
34 namespace fir {
35 #define GEN_PASS_DEF_TARGETREWRITEPASS
36 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
37 } // namespace fir
39 #define DEBUG_TYPE "flang-target-rewrite"
41 namespace {
43 /// Fixups for updating a FuncOp's arguments and return values.
44 struct FixupTy {
45 enum class Codes {
46 ArgumentAsLoad,
47 ArgumentType,
48 CharPair,
49 ReturnAsStore,
50 ReturnType,
51 Split,
52 Trailing,
53 TrailingCharProc
56 FixupTy(Codes code, std::size_t index, std::size_t second = 0)
57 : code{code}, index{index}, second{second} {}
58 FixupTy(Codes code, std::size_t index,
59 std::function<void(mlir::func::FuncOp)> &&finalizer)
60 : code{code}, index{index}, finalizer{finalizer} {}
61 FixupTy(Codes code, std::size_t index, std::size_t second,
62 std::function<void(mlir::func::FuncOp)> &&finalizer)
63 : code{code}, index{index}, second{second}, finalizer{finalizer} {}
65 Codes code;
66 std::size_t index;
67 std::size_t second{};
68 std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
69 }; // namespace
71 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
72 /// generation that traverses the FIR and modifies types and operations to a
73 /// form that is appropriate for the specific target. LLVM IR has specific
74 /// idioms that are used for distinct target processor and ABI combinations.
75 class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
76 public:
77 TargetRewrite(const fir::TargetRewriteOptions &options) {
78 noCharacterConversion = options.noCharacterConversion;
79 noComplexConversion = options.noComplexConversion;
82 void runOnOperation() override final {
83 auto &context = getContext();
84 mlir::OpBuilder rewriter(&context);
86 auto mod = getModule();
87 if (!forcedTargetTriple.empty())
88 fir::setTargetTriple(mod, forcedTargetTriple);
90 auto specifics = fir::CodeGenSpecifics::get(
91 mod.getContext(), fir::getTargetTriple(mod), fir::getKindMapping(mod));
92 setMembers(specifics.get(), &rewriter);
94 // Perform type conversion on signatures and call sites.
95 if (mlir::failed(convertTypes(mod))) {
96 mlir::emitError(mlir::UnknownLoc::get(&context),
97 "error in converting types to target abi");
98 signalPassFailure();
101 // Convert ops in target-specific patterns.
102 mod.walk([&](mlir::Operation *op) {
103 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
104 if (!hasPortableSignature(call.getFunctionType(), op))
105 convertCallOp(call);
106 } else if (auto dispatch = mlir::dyn_cast<fir::DispatchOp>(op)) {
107 if (!hasPortableSignature(dispatch.getFunctionType(), op))
108 convertCallOp(dispatch);
109 } else if (auto addr = mlir::dyn_cast<fir::AddrOfOp>(op)) {
110 if (addr.getType().isa<mlir::FunctionType>() &&
111 !hasPortableSignature(addr.getType(), op))
112 convertAddrOp(addr);
116 clearMembers();
119 mlir::ModuleOp getModule() { return getOperation(); }
121 template <typename A, typename B, typename C>
122 std::optional<std::function<mlir::Value(mlir::Operation *)>>
123 rewriteCallComplexResultType(mlir::Location loc, A ty, B &newResTys,
124 B &newInTys, C &newOpers) {
125 if (noComplexConversion) {
126 newResTys.push_back(ty);
127 return std::nullopt;
129 auto m = specifics->complexReturnType(loc, ty.getElementType());
130 // Currently targets mandate COMPLEX is a single aggregate or packed
131 // scalar, including the sret case.
132 assert(m.size() == 1 && "target of complex return not supported");
133 auto resTy = std::get<mlir::Type>(m[0]);
134 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
135 if (attr.isSRet()) {
136 assert(fir::isa_ref_type(resTy) && "must be a memory reference type");
137 mlir::Value stack =
138 rewriter->create<fir::AllocaOp>(loc, fir::dyn_cast_ptrEleTy(resTy));
139 newInTys.push_back(resTy);
140 newOpers.push_back(stack);
141 return [=](mlir::Operation *) -> mlir::Value {
142 auto memTy = fir::ReferenceType::get(ty);
143 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, stack);
144 return rewriter->create<fir::LoadOp>(loc, cast);
147 newResTys.push_back(resTy);
148 return [=](mlir::Operation *call) -> mlir::Value {
149 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
150 rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem);
151 auto memTy = fir::ReferenceType::get(ty);
152 auto cast = rewriter->create<fir::ConvertOp>(loc, memTy, mem);
153 return rewriter->create<fir::LoadOp>(loc, cast);
157 template <typename A, typename B, typename C>
158 void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys,
159 C &newOpers) {
160 if (noComplexConversion) {
161 newInTys.push_back(ty);
162 newOpers.push_back(oper);
163 return;
166 auto *ctx = ty.getContext();
167 mlir::Location loc = mlir::UnknownLoc::get(ctx);
168 if (auto *op = oper.getDefiningOp())
169 loc = op->getLoc();
170 auto m = specifics->complexArgumentType(loc, ty.getElementType());
171 if (m.size() == 1) {
172 // COMPLEX is a single aggregate
173 auto resTy = std::get<mlir::Type>(m[0]);
174 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
175 auto oldRefTy = fir::ReferenceType::get(ty);
176 if (attr.isByVal()) {
177 auto mem = rewriter->create<fir::AllocaOp>(loc, ty);
178 rewriter->create<fir::StoreOp>(loc, oper, mem);
179 newOpers.push_back(rewriter->create<fir::ConvertOp>(loc, resTy, mem));
180 } else {
181 auto mem = rewriter->create<fir::AllocaOp>(loc, resTy);
182 auto cast = rewriter->create<fir::ConvertOp>(loc, oldRefTy, mem);
183 rewriter->create<fir::StoreOp>(loc, oper, cast);
184 newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem));
186 newInTys.push_back(resTy);
187 } else {
188 assert(m.size() == 2);
189 // COMPLEX is split into 2 separate arguments
190 auto iTy = rewriter->getIntegerType(32);
191 for (auto e : llvm::enumerate(m)) {
192 auto &tup = e.value();
193 auto ty = std::get<mlir::Type>(tup);
194 auto index = e.index();
195 auto idx = rewriter->getIntegerAttr(iTy, index);
196 auto val = rewriter->create<fir::ExtractValueOp>(
197 loc, ty, oper, rewriter->getArrayAttr(idx));
198 newInTys.push_back(ty);
199 newOpers.push_back(val);
204 // Convert fir.call and fir.dispatch Ops.
205 template <typename A>
206 void convertCallOp(A callOp) {
207 auto fnTy = callOp.getFunctionType();
208 auto loc = callOp.getLoc();
209 rewriter->setInsertionPoint(callOp);
210 llvm::SmallVector<mlir::Type> newResTys;
211 llvm::SmallVector<mlir::Type> newInTys;
212 llvm::SmallVector<mlir::Value> newOpers;
214 // If the call is indirect, the first argument must still be the function
215 // to call.
216 int dropFront = 0;
217 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
218 if (!callOp.getCallee()) {
219 newInTys.push_back(fnTy.getInput(0));
220 newOpers.push_back(callOp.getOperand(0));
221 dropFront = 1;
223 } else {
224 dropFront = 1; // First operand is the polymorphic object.
227 // Determine the rewrite function, `wrap`, for the result value.
228 std::optional<std::function<mlir::Value(mlir::Operation *)>> wrap;
229 if (fnTy.getResults().size() == 1) {
230 mlir::Type ty = fnTy.getResult(0);
231 llvm::TypeSwitch<mlir::Type>(ty)
232 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
233 wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys,
234 newOpers);
236 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
237 wrap = rewriteCallComplexResultType(loc, cmplx, newResTys, newInTys,
238 newOpers);
240 .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
241 } else if (fnTy.getResults().size() > 1) {
242 TODO(loc, "multiple results not supported yet");
245 llvm::SmallVector<mlir::Type> trailingInTys;
246 llvm::SmallVector<mlir::Value> trailingOpers;
247 unsigned passArgShift = 0;
248 for (auto e : llvm::enumerate(
249 llvm::zip(fnTy.getInputs().drop_front(dropFront),
250 callOp.getOperands().drop_front(dropFront)))) {
251 mlir::Type ty = std::get<0>(e.value());
252 mlir::Value oper = std::get<1>(e.value());
253 unsigned index = e.index();
254 llvm::TypeSwitch<mlir::Type>(ty)
255 .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
256 bool sret;
257 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
258 if (noCharacterConversion) {
259 newInTys.push_back(boxTy);
260 newOpers.push_back(oper);
261 return;
263 sret = callOp.getCallee() &&
264 functionArgIsSRet(
265 index, getModule().lookupSymbol<mlir::func::FuncOp>(
266 *callOp.getCallee()));
267 } else {
268 // TODO: dispatch case; how do we put arguments on a call?
269 // We cannot put both an sret and the dispatch object first.
270 sret = false;
271 TODO(loc, "dispatch + sret not supported yet");
273 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret);
274 auto unbox = rewriter->create<fir::UnboxCharOp>(
275 loc, std::get<mlir::Type>(m[0]), std::get<mlir::Type>(m[1]),
276 oper);
277 // unboxed CHARACTER arguments
278 for (auto e : llvm::enumerate(m)) {
279 unsigned idx = e.index();
280 auto attr =
281 std::get<fir::CodeGenSpecifics::Attributes>(e.value());
282 auto argTy = std::get<mlir::Type>(e.value());
283 if (attr.isAppend()) {
284 trailingInTys.push_back(argTy);
285 trailingOpers.push_back(unbox.getResult(idx));
286 } else {
287 newInTys.push_back(argTy);
288 newOpers.push_back(unbox.getResult(idx));
292 .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
293 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
295 .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
296 rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers);
298 .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
299 if (fir::isCharacterProcedureTuple(tuple)) {
300 mlir::ModuleOp module = getModule();
301 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
302 if (callOp.getCallee()) {
303 llvm::StringRef charProcAttr =
304 fir::getCharacterProcedureDummyAttrName();
305 // The charProcAttr attribute is only used as a safety to
306 // confirm that this is a dummy procedure and should be split.
307 // It cannot be used to match because attributes are not
308 // available in case of indirect calls.
309 auto funcOp = module.lookupSymbol<mlir::func::FuncOp>(
310 *callOp.getCallee());
311 if (funcOp &&
312 !funcOp.template getArgAttrOfType<mlir::UnitAttr>(
313 index, charProcAttr))
314 mlir::emitError(loc, "tuple argument will be split even "
315 "though it does not have the `" +
316 charProcAttr + "` attribute");
319 mlir::Type funcPointerType = tuple.getType(0);
320 mlir::Type lenType = tuple.getType(1);
321 fir::KindMapping kindMap = fir::getKindMapping(module);
322 fir::FirOpBuilder builder(*rewriter, kindMap);
323 auto [funcPointer, len] =
324 fir::factory::extractCharacterProcedureTuple(builder, loc,
325 oper);
326 newInTys.push_back(funcPointerType);
327 newOpers.push_back(funcPointer);
328 trailingInTys.push_back(lenType);
329 trailingOpers.push_back(len);
330 } else {
331 newInTys.push_back(tuple);
332 newOpers.push_back(oper);
335 .Default([&](mlir::Type ty) {
336 if constexpr (std::is_same_v<std::decay_t<A>, fir::DispatchOp>) {
337 if (callOp.getPassArgPos() && *callOp.getPassArgPos() == index)
338 passArgShift = newOpers.size() - *callOp.getPassArgPos();
340 newInTys.push_back(ty);
341 newOpers.push_back(oper);
344 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
345 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());
346 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) {
347 fir::CallOp newCall;
348 if (callOp.getCallee()) {
349 newCall =
350 rewriter->create<A>(loc, *callOp.getCallee(), newResTys, newOpers);
351 } else {
352 // Force new type on the input operand.
353 newOpers[0].setType(mlir::FunctionType::get(
354 callOp.getContext(),
355 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys));
356 newCall = rewriter->create<A>(loc, newResTys, newOpers);
358 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n');
359 if (wrap)
360 replaceOp(callOp, (*wrap)(newCall.getOperation()));
361 else
362 replaceOp(callOp, newCall.getResults());
363 } else {
364 fir::DispatchOp dispatchOp = rewriter->create<A>(
365 loc, newResTys, rewriter->getStringAttr(callOp.getMethod()),
366 callOp.getOperands()[0], newOpers,
367 rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift));
368 if (wrap)
369 replaceOp(callOp, (*wrap)(dispatchOp.getOperation()));
370 else
371 replaceOp(callOp, dispatchOp.getResults());
375 // Result type fixup for fir::ComplexType and mlir::ComplexType
376 template <typename A, typename B>
377 void lowerComplexSignatureRes(mlir::Location loc, A cmplx, B &newResTys,
378 B &newInTys) {
379 if (noComplexConversion) {
380 newResTys.push_back(cmplx);
381 } else {
382 for (auto &tup :
383 specifics->complexReturnType(loc, cmplx.getElementType())) {
384 auto argTy = std::get<mlir::Type>(tup);
385 if (std::get<fir::CodeGenSpecifics::Attributes>(tup).isSRet())
386 newInTys.push_back(argTy);
387 else
388 newResTys.push_back(argTy);
393 // Argument type fixup for fir::ComplexType and mlir::ComplexType
394 template <typename A, typename B>
395 void lowerComplexSignatureArg(mlir::Location loc, A cmplx, B &newInTys) {
396 if (noComplexConversion)
397 newInTys.push_back(cmplx);
398 else
399 for (auto &tup :
400 specifics->complexArgumentType(loc, cmplx.getElementType()))
401 newInTys.push_back(std::get<mlir::Type>(tup));
404 /// Taking the address of a function. Modify the signature as needed.
405 void convertAddrOp(fir::AddrOfOp addrOp) {
406 rewriter->setInsertionPoint(addrOp);
407 auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
408 llvm::SmallVector<mlir::Type> newResTys;
409 llvm::SmallVector<mlir::Type> newInTys;
410 auto loc = addrOp.getLoc();
411 for (mlir::Type ty : addrTy.getResults()) {
412 llvm::TypeSwitch<mlir::Type>(ty)
413 .Case<fir::ComplexType>([&](fir::ComplexType ty) {
414 lowerComplexSignatureRes(loc, ty, newResTys, newInTys);
416 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
417 lowerComplexSignatureRes(loc, ty, newResTys, newInTys);
419 .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
421 llvm::SmallVector<mlir::Type> trailingInTys;
422 for (mlir::Type ty : addrTy.getInputs()) {
423 llvm::TypeSwitch<mlir::Type>(ty)
424 .Case<fir::BoxCharType>([&](auto box) {
425 if (noCharacterConversion) {
426 newInTys.push_back(box);
427 } else {
428 for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
429 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
430 auto argTy = std::get<mlir::Type>(tup);
431 llvm::SmallVector<mlir::Type> &vec =
432 attr.isAppend() ? trailingInTys : newInTys;
433 vec.push_back(argTy);
437 .Case<fir::ComplexType>([&](fir::ComplexType ty) {
438 lowerComplexSignatureArg(loc, ty, newInTys);
440 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
441 lowerComplexSignatureArg(loc, ty, newInTys);
443 .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
444 if (fir::isCharacterProcedureTuple(tuple)) {
445 newInTys.push_back(tuple.getType(0));
446 trailingInTys.push_back(tuple.getType(1));
447 } else {
448 newInTys.push_back(ty);
451 .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
453 // append trailing input types
454 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
455 // replace this op with a new one with the updated signature
456 auto newTy = rewriter->getFunctionType(newInTys, newResTys);
457 auto newOp = rewriter->create<fir::AddrOfOp>(addrOp.getLoc(), newTy,
458 addrOp.getSymbol());
459 replaceOp(addrOp, newOp.getResult());
462 /// Convert the type signatures on all the functions present in the module.
463 /// As the type signature is being changed, this must also update the
464 /// function itself to use any new arguments, etc.
465 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) {
466 for (auto fn : mod.getOps<mlir::func::FuncOp>())
467 convertSignature(fn);
468 return mlir::success();
471 // Returns true if the function should be interoperable with C.
472 static bool isFuncWithCCallingConvention(mlir::Operation *op) {
473 auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>(op);
474 if (!funcOp)
475 return false;
476 return op->hasAttrOfType<mlir::UnitAttr>(
477 fir::FIROpsDialect::getFirRuntimeAttrName()) ||
478 op->hasAttrOfType<mlir::StringAttr>(fir::getSymbolAttrName());
481 /// If the signature does not need any special target-specific conversions,
482 /// then it is considered portable for any target, and this function will
483 /// return `true`. Otherwise, the signature is not portable and `false` is
484 /// returned.
485 bool hasPortableSignature(mlir::Type signature, mlir::Operation *op) {
486 assert(signature.isa<mlir::FunctionType>());
487 auto func = signature.dyn_cast<mlir::FunctionType>();
488 bool hasCCallingConv = isFuncWithCCallingConvention(op);
489 for (auto ty : func.getResults())
490 if ((ty.isa<fir::BoxCharType>() && !noCharacterConversion) ||
491 (fir::isa_complex(ty) && !noComplexConversion) ||
492 (ty.isa<mlir::IntegerType>() && hasCCallingConv)) {
493 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
494 return false;
496 for (auto ty : func.getInputs())
497 if (((ty.isa<fir::BoxCharType>() || fir::isCharacterProcedureTuple(ty)) &&
498 !noCharacterConversion) ||
499 (fir::isa_complex(ty) && !noComplexConversion) ||
500 (ty.isa<mlir::IntegerType>() && hasCCallingConv)) {
501 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n");
502 return false;
504 return true;
507 /// Determine if the signature has host associations. The host association
508 /// argument may need special target specific rewriting.
509 static bool hasHostAssociations(mlir::func::FuncOp func) {
510 std::size_t end = func.getFunctionType().getInputs().size();
511 for (std::size_t i = 0; i < end; ++i)
512 if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
513 return true;
514 return false;
517 /// Rewrite the signatures and body of the `FuncOp`s in the module for
518 /// the immediately subsequent target code gen.
519 void convertSignature(mlir::func::FuncOp func) {
520 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
521 if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
522 return;
523 llvm::SmallVector<mlir::Type> newResTys;
524 llvm::SmallVector<mlir::Type> newInTys;
525 llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> savedAttrs;
526 llvm::SmallVector<std::pair<unsigned, mlir::NamedAttribute>> extraAttrs;
527 llvm::SmallVector<FixupTy> fixups;
528 llvm::SmallVector<std::pair<unsigned, mlir::NamedAttrList>, 1> resultAttrs;
530 // Save argument attributes in case there is a shift so we can replace them
531 // correctly.
532 for (auto e : llvm::enumerate(funcTy.getInputs())) {
533 unsigned index = e.index();
534 llvm::ArrayRef<mlir::NamedAttribute> attrs =
535 mlir::function_interface_impl::getArgAttrs(func, index);
536 for (mlir::NamedAttribute attr : attrs) {
537 savedAttrs.push_back({index, attr});
541 // Convert return value(s)
542 for (auto ty : funcTy.getResults())
543 llvm::TypeSwitch<mlir::Type>(ty)
544 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
545 if (noComplexConversion)
546 newResTys.push_back(cmplx);
547 else
548 doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
550 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
551 if (noComplexConversion)
552 newResTys.push_back(cmplx);
553 else
554 doComplexReturn(func, cmplx, newResTys, newInTys, fixups);
556 .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
557 auto m = specifics->integerArgumentType(func.getLoc(), intTy);
558 assert(m.size() == 1);
559 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
560 auto retTy = std::get<mlir::Type>(m[0]);
561 std::size_t resId = newResTys.size();
562 llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName();
563 if (!extensionAttrName.empty() &&
564 isFuncWithCCallingConvention(func))
565 resultAttrs.emplace_back(
566 resId, rewriter->getNamedAttr(extensionAttrName,
567 rewriter->getUnitAttr()));
568 newResTys.push_back(retTy);
570 .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
572 // Saved potential shift in argument. Handling of result can add arguments
573 // at the beginning of the function signature.
574 unsigned argumentShift = newInTys.size();
576 // Convert arguments
577 llvm::SmallVector<mlir::Type> trailingTys;
578 for (auto e : llvm::enumerate(funcTy.getInputs())) {
579 auto ty = e.value();
580 unsigned index = e.index();
581 llvm::TypeSwitch<mlir::Type>(ty)
582 .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
583 if (noCharacterConversion) {
584 newInTys.push_back(boxTy);
585 } else {
586 // Convert a CHARACTER argument type. This can involve separating
587 // the pointer and the LEN into two arguments and moving the LEN
588 // argument to the end of the arg list.
589 bool sret = functionArgIsSRet(index, func);
590 for (auto e : llvm::enumerate(specifics->boxcharArgumentType(
591 boxTy.getEleTy(), sret))) {
592 auto &tup = e.value();
593 auto index = e.index();
594 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
595 auto argTy = std::get<mlir::Type>(tup);
596 if (attr.isAppend()) {
597 trailingTys.push_back(argTy);
598 } else {
599 if (sret) {
600 fixups.emplace_back(FixupTy::Codes::CharPair,
601 newInTys.size(), index);
602 } else {
603 fixups.emplace_back(FixupTy::Codes::Trailing,
604 newInTys.size(), trailingTys.size());
606 newInTys.push_back(argTy);
611 .Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
612 if (noComplexConversion)
613 newInTys.push_back(cmplx);
614 else
615 doComplexArg(func, cmplx, newInTys, fixups);
617 .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
618 if (noComplexConversion)
619 newInTys.push_back(cmplx);
620 else
621 doComplexArg(func, cmplx, newInTys, fixups);
623 .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
624 if (fir::isCharacterProcedureTuple(tuple)) {
625 fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
626 newInTys.size(), trailingTys.size());
627 newInTys.push_back(tuple.getType(0));
628 trailingTys.push_back(tuple.getType(1));
629 } else {
630 newInTys.push_back(ty);
633 .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
634 auto m = specifics->integerArgumentType(func.getLoc(), intTy);
635 assert(m.size() == 1);
636 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
637 auto argTy = std::get<mlir::Type>(m[0]);
638 auto argNo = newInTys.size();
639 llvm::StringRef extensionAttrName = attr.getIntExtensionAttrName();
640 if (!extensionAttrName.empty() &&
641 isFuncWithCCallingConvention(func))
642 fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
643 [=](mlir::func::FuncOp func) {
644 func.setArgAttr(
645 argNo, extensionAttrName,
646 mlir::UnitAttr::get(func.getContext()));
649 newInTys.push_back(argTy);
651 .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
653 if (func.getArgAttrOfType<mlir::UnitAttr>(index,
654 fir::getHostAssocAttrName())) {
655 extraAttrs.push_back(
656 {newInTys.size() - 1,
657 rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
661 if (!func.empty()) {
662 // If the function has a body, then apply the fixups to the arguments and
663 // return ops as required. These fixups are done in place.
664 auto loc = func.getLoc();
665 const auto fixupSize = fixups.size();
666 const auto oldArgTys = func.getFunctionType().getInputs();
667 int offset = 0;
668 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) {
669 const auto &fixup = fixups[i];
670 switch (fixup.code) {
671 case FixupTy::Codes::ArgumentAsLoad: {
672 // Argument was pass-by-value, but is now pass-by-reference and
673 // possibly with a different element type.
674 auto newArg = func.front().insertArgument(fixup.index,
675 newInTys[fixup.index], loc);
676 rewriter->setInsertionPointToStart(&func.front());
677 auto oldArgTy =
678 fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
679 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, newArg);
680 auto load = rewriter->create<fir::LoadOp>(loc, cast);
681 func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
682 func.front().eraseArgument(fixup.index + 1);
683 } break;
684 case FixupTy::Codes::ArgumentType: {
685 // Argument is pass-by-value, but its type has likely been modified to
686 // suit the target ABI convention.
687 auto oldArgTy =
688 fir::ReferenceType::get(oldArgTys[fixup.index - offset]);
689 // If type did not change, keep the original argument.
690 if (newInTys[fixup.index] == oldArgTy)
691 break;
693 auto newArg = func.front().insertArgument(fixup.index,
694 newInTys[fixup.index], loc);
695 rewriter->setInsertionPointToStart(&func.front());
696 auto mem =
697 rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]);
698 rewriter->create<fir::StoreOp>(loc, newArg, mem);
699 auto cast = rewriter->create<fir::ConvertOp>(loc, oldArgTy, mem);
700 mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast);
701 func.getArgument(fixup.index + 1).replaceAllUsesWith(load);
702 func.front().eraseArgument(fixup.index + 1);
703 LLVM_DEBUG(llvm::dbgs()
704 << "old argument: " << oldArgTy.getEleTy()
705 << ", repl: " << load << ", new argument: "
706 << func.getArgument(fixup.index).getType() << '\n');
707 } break;
708 case FixupTy::Codes::CharPair: {
709 // The FIR boxchar argument has been split into a pair of distinct
710 // arguments that are in juxtaposition to each other.
711 auto newArg = func.front().insertArgument(fixup.index,
712 newInTys[fixup.index], loc);
713 if (fixup.second == 1) {
714 rewriter->setInsertionPointToStart(&func.front());
715 auto boxTy = oldArgTys[fixup.index - offset - fixup.second];
716 auto box = rewriter->create<fir::EmboxCharOp>(
717 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg);
718 func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
719 func.front().eraseArgument(fixup.index + 1);
720 offset++;
722 } break;
723 case FixupTy::Codes::ReturnAsStore: {
724 // The value being returned is now being returned in memory (callee
725 // stack space) through a hidden reference argument.
726 auto newArg = func.front().insertArgument(fixup.index,
727 newInTys[fixup.index], loc);
728 offset++;
729 func.walk([&](mlir::func::ReturnOp ret) {
730 rewriter->setInsertionPoint(ret);
731 auto oldOper = ret.getOperand(0);
732 auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
733 auto cast =
734 rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
735 rewriter->create<fir::StoreOp>(loc, oldOper, cast);
736 rewriter->create<mlir::func::ReturnOp>(loc);
737 ret.erase();
739 } break;
740 case FixupTy::Codes::ReturnType: {
741 // The function is still returning a value, but its type has likely
742 // changed to suit the target ABI convention.
743 func.walk([&](mlir::func::ReturnOp ret) {
744 rewriter->setInsertionPoint(ret);
745 auto oldOper = ret.getOperand(0);
746 auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
747 auto mem =
748 rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]);
749 auto cast = rewriter->create<fir::ConvertOp>(loc, oldOperTy, mem);
750 rewriter->create<fir::StoreOp>(loc, oldOper, cast);
751 mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem);
752 rewriter->create<mlir::func::ReturnOp>(loc, load);
753 ret.erase();
755 } break;
756 case FixupTy::Codes::Split: {
757 // The FIR argument has been split into a pair of distinct arguments
758 // that are in juxtaposition to each other. (For COMPLEX value.)
759 auto newArg = func.front().insertArgument(fixup.index,
760 newInTys[fixup.index], loc);
761 if (fixup.second == 1) {
762 rewriter->setInsertionPointToStart(&func.front());
763 auto cplxTy = oldArgTys[fixup.index - offset - fixup.second];
764 auto undef = rewriter->create<fir::UndefOp>(loc, cplxTy);
765 auto iTy = rewriter->getIntegerType(32);
766 auto zero = rewriter->getIntegerAttr(iTy, 0);
767 auto one = rewriter->getIntegerAttr(iTy, 1);
768 auto cplx1 = rewriter->create<fir::InsertValueOp>(
769 loc, cplxTy, undef, func.front().getArgument(fixup.index - 1),
770 rewriter->getArrayAttr(zero));
771 auto cplx = rewriter->create<fir::InsertValueOp>(
772 loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one));
773 func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx);
774 func.front().eraseArgument(fixup.index + 1);
775 offset++;
777 } break;
778 case FixupTy::Codes::Trailing: {
779 // The FIR argument has been split into a pair of distinct arguments.
780 // The first part of the pair appears in the original argument
781 // position. The second part of the pair is appended after all the
782 // original arguments. (Boxchar arguments.)
783 auto newBufArg = func.front().insertArgument(
784 fixup.index, newInTys[fixup.index], loc);
785 auto newLenArg =
786 func.front().addArgument(trailingTys[fixup.second], loc);
787 auto boxTy = oldArgTys[fixup.index - offset];
788 rewriter->setInsertionPointToStart(&func.front());
789 auto box = rewriter->create<fir::EmboxCharOp>(loc, boxTy, newBufArg,
790 newLenArg);
791 func.getArgument(fixup.index + 1).replaceAllUsesWith(box);
792 func.front().eraseArgument(fixup.index + 1);
793 } break;
794 case FixupTy::Codes::TrailingCharProc: {
795 // The FIR character procedure argument tuple must be split into a
796 // pair of distinct arguments. The first part of the pair appears in
797 // the original argument position. The second part of the pair is
798 // appended after all the original arguments.
799 auto newProcPointerArg = func.front().insertArgument(
800 fixup.index, newInTys[fixup.index], loc);
801 auto newLenArg =
802 func.front().addArgument(trailingTys[fixup.second], loc);
803 auto tupleType = oldArgTys[fixup.index - offset];
804 rewriter->setInsertionPointToStart(&func.front());
805 fir::KindMapping kindMap = fir::getKindMapping(getModule());
806 fir::FirOpBuilder builder(*rewriter, kindMap);
807 auto tuple = fir::factory::createCharacterProcedureTuple(
808 builder, loc, tupleType, newProcPointerArg, newLenArg);
809 func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple);
810 func.front().eraseArgument(fixup.index + 1);
811 } break;
816 // Set the new type and finalize the arguments, etc.
817 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end());
818 auto newFuncTy =
819 mlir::FunctionType::get(func.getContext(), newInTys, newResTys);
820 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n');
821 func.setType(newFuncTy);
823 for (std::pair<unsigned, mlir::NamedAttribute> extraAttr : extraAttrs)
824 func.setArgAttr(extraAttr.first, extraAttr.second.getName(),
825 extraAttr.second.getValue());
827 for (auto [resId, resAttrList] : resultAttrs)
828 for (mlir::NamedAttribute resAttr : resAttrList)
829 func.setResultAttr(resId, resAttr.getName(), resAttr.getValue());
831 // Replace attributes to the correct argument if there was an argument shift
832 // to the right.
833 if (argumentShift > 0) {
834 for (std::pair<unsigned, mlir::NamedAttribute> savedAttr : savedAttrs) {
835 func.removeArgAttr(savedAttr.first, savedAttr.second.getName());
836 func.setArgAttr(savedAttr.first + argumentShift,
837 savedAttr.second.getName(),
838 savedAttr.second.getValue());
842 for (auto &fixup : fixups)
843 if (fixup.finalizer)
844 (*fixup.finalizer)(func);
847 inline bool functionArgIsSRet(unsigned index, mlir::func::FuncOp func) {
848 if (auto attr = func.getArgAttrOfType<mlir::TypeAttr>(index, "llvm.sret"))
849 return true;
850 return false;
853 /// Convert a complex return value. This can involve converting the return
854 /// value to a "hidden" first argument or packing the complex into a wide
855 /// GPR.
856 template <typename A, typename B, typename C>
857 void doComplexReturn(mlir::func::FuncOp func, A cmplx, B &newResTys,
858 B &newInTys, C &fixups) {
859 if (noComplexConversion) {
860 newResTys.push_back(cmplx);
861 return;
863 auto m =
864 specifics->complexReturnType(func.getLoc(), cmplx.getElementType());
865 assert(m.size() == 1);
866 auto &tup = m[0];
867 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
868 auto argTy = std::get<mlir::Type>(tup);
869 if (attr.isSRet()) {
870 unsigned argNo = newInTys.size();
871 if (auto align = attr.getAlignment())
872 fixups.emplace_back(
873 FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
874 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
875 func.getFunctionType().getInput(argNo));
876 func.setArgAttr(argNo, "llvm.sret",
877 mlir::TypeAttr::get(elemType));
878 func.setArgAttr(argNo, "llvm.align",
879 rewriter->getIntegerAttr(
880 rewriter->getIntegerType(32), align));
882 else
883 fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
884 [=](mlir::func::FuncOp func) {
885 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
886 func.getFunctionType().getInput(argNo));
887 func.setArgAttr(argNo, "llvm.sret",
888 mlir::TypeAttr::get(elemType));
890 newInTys.push_back(argTy);
891 return;
892 } else {
893 if (auto align = attr.getAlignment())
894 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size(),
895 [=](mlir::func::FuncOp func) {
896 func.setArgAttr(
897 newResTys.size(), "llvm.align",
898 rewriter->getIntegerAttr(
899 rewriter->getIntegerType(32), align));
901 else
902 fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size());
904 newResTys.push_back(argTy);
907 /// Convert a complex argument value. This can involve storing the value to
908 /// a temporary memory location or factoring the value into two distinct
909 /// arguments.
910 template <typename A, typename B, typename C>
911 void doComplexArg(mlir::func::FuncOp func, A cmplx, B &newInTys, C &fixups) {
912 if (noComplexConversion) {
913 newInTys.push_back(cmplx);
914 return;
916 auto m =
917 specifics->complexArgumentType(func.getLoc(), cmplx.getElementType());
918 const auto fixupCode =
919 m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType;
920 for (auto e : llvm::enumerate(m)) {
921 auto &tup = e.value();
922 auto index = e.index();
923 auto attr = std::get<fir::CodeGenSpecifics::Attributes>(tup);
924 auto argTy = std::get<mlir::Type>(tup);
925 auto argNo = newInTys.size();
926 if (attr.isByVal()) {
927 if (auto align = attr.getAlignment())
928 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
929 [=](mlir::func::FuncOp func) {
930 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
931 func.getFunctionType().getInput(argNo));
932 func.setArgAttr(argNo, "llvm.byval",
933 mlir::TypeAttr::get(elemType));
934 func.setArgAttr(
935 argNo, "llvm.align",
936 rewriter->getIntegerAttr(
937 rewriter->getIntegerType(32), align));
939 else
940 fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(),
941 [=](mlir::func::FuncOp func) {
942 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
943 func.getFunctionType().getInput(argNo));
944 func.setArgAttr(argNo, "llvm.byval",
945 mlir::TypeAttr::get(elemType));
947 } else {
948 if (auto align = attr.getAlignment())
949 fixups.emplace_back(
950 fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
951 func.setArgAttr(argNo, "llvm.align",
952 rewriter->getIntegerAttr(
953 rewriter->getIntegerType(32), align));
955 else
956 fixups.emplace_back(fixupCode, argNo, index);
958 newInTys.push_back(argTy);
962 private:
963 // Replace `op` and remove it.
964 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) {
965 op->replaceAllUsesWith(newValues);
966 op->dropAllReferences();
967 op->erase();
970 inline void setMembers(fir::CodeGenSpecifics *s, mlir::OpBuilder *r) {
971 specifics = s;
972 rewriter = r;
975 inline void clearMembers() { setMembers(nullptr, nullptr); }
977 fir::CodeGenSpecifics *specifics = nullptr;
978 mlir::OpBuilder *rewriter = nullptr;
979 }; // namespace
980 } // namespace
982 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
983 fir::createFirTargetRewritePass(const fir::TargetRewriteOptions &options) {
984 return std::make_unique<TargetRewrite>(options);