1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16 #include "mlir/Analysis/DataLayoutAnalysis.h"
17 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
18 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
22 #include "mlir/Conversion/LLVMCommon/Pattern.h"
23 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
28 #include "mlir/Dialect/Utils/StaticValueUtils.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinAttributeInterfaces.h"
32 #include "mlir/IR/BuiltinAttributes.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/SymbolTable.h"
37 #include "mlir/IR/TypeUtilities.h"
38 #include "mlir/Transforms/DialectConversion.h"
39 #include "mlir/Transforms/Passes.h"
40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Type.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/FormatVariadic.h"
53 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
54 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
55 #include "mlir/Conversion/Passes.h.inc"
60 #define PASS_NAME "convert-func-to-llvm"
62 static constexpr StringRef varargsAttrName
= "func.varargs";
63 static constexpr StringRef linkageAttrName
= "llvm.linkage";
64 static constexpr StringRef barePtrAttrName
= "llvm.bareptr";
66 /// Return `true` if the `op` should use bare pointer calling convention.
67 static bool shouldUseBarePtrCallConv(Operation
*op
,
68 const LLVMTypeConverter
*typeConverter
) {
69 return (op
&& op
->hasAttr(barePtrAttrName
)) ||
70 typeConverter
->getOptions().useBarePtrCallConv
;
73 /// Only retain those attributes that are not constructed by
74 /// `LLVMFuncOp::build`.
75 static void filterFuncAttributes(FunctionOpInterface func
,
76 SmallVectorImpl
<NamedAttribute
> &result
) {
77 for (const NamedAttribute
&attr
: func
->getDiscardableAttrs()) {
78 if (attr
.getName() == linkageAttrName
||
79 attr
.getName() == varargsAttrName
||
80 attr
.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
82 result
.push_back(attr
);
86 /// Propagate argument/results attributes.
87 static void propagateArgResAttrs(OpBuilder
&builder
, bool resultStructType
,
88 FunctionOpInterface funcOp
,
89 LLVM::LLVMFuncOp wrapperFuncOp
) {
90 auto argAttrs
= funcOp
.getAllArgAttrs();
91 if (!resultStructType
) {
92 if (auto resAttrs
= funcOp
.getAllResultAttrs())
93 wrapperFuncOp
.setAllResultAttrs(resAttrs
);
95 wrapperFuncOp
.setAllArgAttrs(argAttrs
);
97 SmallVector
<Attribute
> argAttributes
;
98 // Only modify the argument and result attributes when the result is now
101 argAttributes
.push_back(builder
.getDictionaryAttr({}));
102 argAttributes
.append(argAttrs
.begin(), argAttrs
.end());
103 wrapperFuncOp
.setAllArgAttrs(argAttributes
);
106 cast
<FunctionOpInterface
>(wrapperFuncOp
.getOperation())
107 .setVisibility(funcOp
.getVisibility());
110 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
111 /// arguments instead of unpacked arguments. This function can be called from C
112 /// by passing a pointer to a C struct corresponding to a memref descriptor.
113 /// Similarly, returned memrefs are passed via pointers to a C struct that is
114 /// passed as additional argument.
115 /// Internally, the auxiliary function unpacks the descriptor into individual
116 /// components and forwards them to `newFuncOp` and forwards the results to
117 /// the extra arguments.
118 static void wrapForExternalCallers(OpBuilder
&rewriter
, Location loc
,
119 const LLVMTypeConverter
&typeConverter
,
120 FunctionOpInterface funcOp
,
121 LLVM::LLVMFuncOp newFuncOp
) {
122 auto type
= cast
<FunctionType
>(funcOp
.getFunctionType());
123 auto [wrapperFuncType
, resultStructType
] =
124 typeConverter
.convertFunctionTypeCWrapper(type
);
126 SmallVector
<NamedAttribute
> attributes
;
127 filterFuncAttributes(funcOp
, attributes
);
129 auto wrapperFuncOp
= rewriter
.create
<LLVM::LLVMFuncOp
>(
130 loc
, llvm::formatv("_mlir_ciface_{0}", funcOp
.getName()).str(),
131 wrapperFuncType
, LLVM::Linkage::External
, /*dsoLocal=*/false,
132 /*cconv=*/LLVM::CConv::C
, /*comdat=*/nullptr, attributes
);
133 propagateArgResAttrs(rewriter
, !!resultStructType
, funcOp
, wrapperFuncOp
);
135 OpBuilder::InsertionGuard
guard(rewriter
);
136 rewriter
.setInsertionPointToStart(wrapperFuncOp
.addEntryBlock(rewriter
));
138 SmallVector
<Value
, 8> args
;
139 size_t argOffset
= resultStructType
? 1 : 0;
140 for (auto [index
, argType
] : llvm::enumerate(type
.getInputs())) {
141 Value arg
= wrapperFuncOp
.getArgument(index
+ argOffset
);
142 if (auto memrefType
= dyn_cast
<MemRefType
>(argType
)) {
143 Value loaded
= rewriter
.create
<LLVM::LoadOp
>(
144 loc
, typeConverter
.convertType(memrefType
), arg
);
145 MemRefDescriptor::unpack(rewriter
, loc
, loaded
, memrefType
, args
);
148 if (isa
<UnrankedMemRefType
>(argType
)) {
149 Value loaded
= rewriter
.create
<LLVM::LoadOp
>(
150 loc
, typeConverter
.convertType(argType
), arg
);
151 UnrankedMemRefDescriptor::unpack(rewriter
, loc
, loaded
, args
);
158 auto call
= rewriter
.create
<LLVM::CallOp
>(loc
, newFuncOp
, args
);
160 if (resultStructType
) {
161 rewriter
.create
<LLVM::StoreOp
>(loc
, call
.getResult(),
162 wrapperFuncOp
.getArgument(0));
163 rewriter
.create
<LLVM::ReturnOp
>(loc
, ValueRange
{});
165 rewriter
.create
<LLVM::ReturnOp
>(loc
, call
.getResults());
169 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
170 /// arguments instead of unpacked arguments. Creates a body for the (external)
171 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
172 /// individual arguments into this descriptor and passes a pointer to it into
173 /// the auxiliary function. If the result of the function cannot be directly
174 /// returned, we write it to a special first argument that provides a pointer
175 /// to a corresponding struct. This auxiliary external function is now
176 /// compatible with functions defined in C using pointers to C structs
177 /// corresponding to a memref descriptor.
178 static void wrapExternalFunction(OpBuilder
&builder
, Location loc
,
179 const LLVMTypeConverter
&typeConverter
,
180 FunctionOpInterface funcOp
,
181 LLVM::LLVMFuncOp newFuncOp
) {
182 OpBuilder::InsertionGuard
guard(builder
);
184 auto [wrapperType
, resultStructType
] =
185 typeConverter
.convertFunctionTypeCWrapper(
186 cast
<FunctionType
>(funcOp
.getFunctionType()));
187 // This conversion can only fail if it could not convert one of the argument
188 // types. But since it has been applied to a non-wrapper function before, it
189 // should have failed earlier and not reach this point at all.
190 assert(wrapperType
&& "unexpected type conversion failure");
192 SmallVector
<NamedAttribute
, 4> attributes
;
193 filterFuncAttributes(funcOp
, attributes
);
195 // Create the auxiliary function.
196 auto wrapperFunc
= builder
.create
<LLVM::LLVMFuncOp
>(
197 loc
, llvm::formatv("_mlir_ciface_{0}", funcOp
.getName()).str(),
198 wrapperType
, LLVM::Linkage::External
, /*dsoLocal=*/false,
199 /*cconv=*/LLVM::CConv::C
, /*comdat=*/nullptr, attributes
);
200 propagateArgResAttrs(builder
, !!resultStructType
, funcOp
, wrapperFunc
);
202 // The wrapper that we synthetize here should only be visible in this module.
203 newFuncOp
.setLinkage(LLVM::Linkage::Private
);
204 builder
.setInsertionPointToStart(newFuncOp
.addEntryBlock(builder
));
206 // Get a ValueRange containing arguments.
207 FunctionType type
= cast
<FunctionType
>(funcOp
.getFunctionType());
208 SmallVector
<Value
, 8> args
;
209 args
.reserve(type
.getNumInputs());
210 ValueRange
wrapperArgsRange(newFuncOp
.getArguments());
212 if (resultStructType
) {
213 // Allocate the struct on the stack and pass the pointer.
214 Type resultType
= cast
<LLVM::LLVMFunctionType
>(wrapperType
).getParamType(0);
215 Value one
= builder
.create
<LLVM::ConstantOp
>(
216 loc
, typeConverter
.convertType(builder
.getIndexType()),
217 builder
.getIntegerAttr(builder
.getIndexType(), 1));
219 builder
.create
<LLVM::AllocaOp
>(loc
, resultType
, resultStructType
, one
);
220 args
.push_back(result
);
223 // Iterate over the inputs of the original function and pack values into
224 // memref descriptors if the original type is a memref.
225 for (Type input
: type
.getInputs()) {
228 auto memRefType
= dyn_cast
<MemRefType
>(input
);
229 auto unrankedMemRefType
= dyn_cast
<UnrankedMemRefType
>(input
);
230 if (memRefType
|| unrankedMemRefType
) {
231 numToDrop
= memRefType
232 ? MemRefDescriptor::getNumUnpackedValues(memRefType
)
233 : UnrankedMemRefDescriptor::getNumUnpackedValues();
236 ? MemRefDescriptor::pack(builder
, loc
, typeConverter
, memRefType
,
237 wrapperArgsRange
.take_front(numToDrop
))
238 : UnrankedMemRefDescriptor::pack(
239 builder
, loc
, typeConverter
, unrankedMemRefType
,
240 wrapperArgsRange
.take_front(numToDrop
));
242 auto ptrTy
= LLVM::LLVMPointerType::get(builder
.getContext());
243 Value one
= builder
.create
<LLVM::ConstantOp
>(
244 loc
, typeConverter
.convertType(builder
.getIndexType()),
245 builder
.getIntegerAttr(builder
.getIndexType(), 1));
246 Value allocated
= builder
.create
<LLVM::AllocaOp
>(
247 loc
, ptrTy
, packed
.getType(), one
, /*alignment=*/0);
248 builder
.create
<LLVM::StoreOp
>(loc
, packed
, allocated
);
251 arg
= wrapperArgsRange
[0];
255 wrapperArgsRange
= wrapperArgsRange
.drop_front(numToDrop
);
257 assert(wrapperArgsRange
.empty() && "did not map some of the arguments");
259 auto call
= builder
.create
<LLVM::CallOp
>(loc
, wrapperFunc
, args
);
261 if (resultStructType
) {
263 builder
.create
<LLVM::LoadOp
>(loc
, resultStructType
, args
.front());
264 builder
.create
<LLVM::ReturnOp
>(loc
, result
);
266 builder
.create
<LLVM::ReturnOp
>(loc
, call
.getResults());
270 /// Inserts `llvm.load` ops in the function body to restore the expected pointee
271 /// value from `llvm.byval`/`llvm.byref` function arguments that were converted
272 /// to LLVM pointer types.
273 static void restoreByValRefArgumentType(
274 ConversionPatternRewriter
&rewriter
, const LLVMTypeConverter
&typeConverter
,
275 ArrayRef
<std::optional
<NamedAttribute
>> byValRefNonPtrAttrs
,
276 ArrayRef
<BlockArgument
> oldBlockArgs
, LLVM::LLVMFuncOp funcOp
) {
277 // Nothing to do for function declarations.
278 if (funcOp
.isExternal())
281 ConversionPatternRewriter::InsertionGuard
guard(rewriter
);
282 rewriter
.setInsertionPointToStart(&funcOp
.getFunctionBody().front());
284 for (const auto &[arg
, oldArg
, byValRefAttr
] :
285 llvm::zip(funcOp
.getArguments(), oldBlockArgs
, byValRefNonPtrAttrs
)) {
286 // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
290 // Insert load to retrieve the actual argument passed by value/reference.
291 assert(isa
<LLVM::LLVMPointerType
>(arg
.getType()) &&
292 "Expected LLVM pointer type for argument with "
293 "`llvm.byval`/`llvm.byref` attribute");
294 Type resTy
= typeConverter
.convertType(
295 cast
<TypeAttr
>(byValRefAttr
->getValue()).getValue());
297 auto valueArg
= rewriter
.create
<LLVM::LoadOp
>(arg
.getLoc(), resTy
, arg
);
298 rewriter
.replaceUsesOfBlockArgument(oldArg
, valueArg
);
302 FailureOr
<LLVM::LLVMFuncOp
>
303 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp
,
304 ConversionPatternRewriter
&rewriter
,
305 const LLVMTypeConverter
&converter
) {
306 // Check the funcOp has `FunctionType`.
307 auto funcTy
= dyn_cast
<FunctionType
>(funcOp
.getFunctionType());
309 return rewriter
.notifyMatchFailure(
310 funcOp
, "Only support FunctionOpInterface with FunctionType");
312 // Keep track of the entry block arguments. They will be needed later.
313 SmallVector
<BlockArgument
> oldBlockArgs
=
314 llvm::to_vector(funcOp
.getArguments());
316 // Convert the original function arguments. They are converted using the
317 // LLVMTypeConverter provided to this legalization pattern.
318 auto varargsAttr
= funcOp
->getAttrOfType
<BoolAttr
>(varargsAttrName
);
319 // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
320 // overriden with an LLVM pointer type for later processing.
321 SmallVector
<std::optional
<NamedAttribute
>> byValRefNonPtrAttrs
;
322 TypeConverter::SignatureConversion
result(funcOp
.getNumArguments());
323 auto llvmType
= converter
.convertFunctionSignature(
324 funcOp
, varargsAttr
&& varargsAttr
.getValue(),
325 shouldUseBarePtrCallConv(funcOp
, &converter
), result
,
326 byValRefNonPtrAttrs
);
328 return rewriter
.notifyMatchFailure(funcOp
, "signature conversion failed");
330 // Create an LLVM function, use external linkage by default until MLIR
331 // functions have linkage.
332 LLVM::Linkage linkage
= LLVM::Linkage::External
;
333 if (funcOp
->hasAttr(linkageAttrName
)) {
335 dyn_cast
<mlir::LLVM::LinkageAttr
>(funcOp
->getAttr(linkageAttrName
));
337 funcOp
->emitError() << "Contains " << linkageAttrName
338 << " attribute not of type LLVM::LinkageAttr";
339 return rewriter
.notifyMatchFailure(
340 funcOp
, "Contains linkage attribute not of type LLVM::LinkageAttr");
342 linkage
= attr
.getLinkage();
345 SmallVector
<NamedAttribute
, 4> attributes
;
346 filterFuncAttributes(funcOp
, attributes
);
347 auto newFuncOp
= rewriter
.create
<LLVM::LLVMFuncOp
>(
348 funcOp
.getLoc(), funcOp
.getName(), llvmType
, linkage
,
349 /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C
, /*comdat=*/nullptr,
351 cast
<FunctionOpInterface
>(newFuncOp
.getOperation())
352 .setVisibility(funcOp
.getVisibility());
354 // Create a memory effect attribute corresponding to readnone.
355 StringRef readnoneAttrName
= LLVM::LLVMDialect::getReadnoneAttrName();
356 if (funcOp
->hasAttr(readnoneAttrName
)) {
357 auto attr
= funcOp
->getAttrOfType
<UnitAttr
>(readnoneAttrName
);
359 funcOp
->emitError() << "Contains " << readnoneAttrName
360 << " attribute not of type UnitAttr";
361 return rewriter
.notifyMatchFailure(
362 funcOp
, "Contains readnone attribute not of type UnitAttr");
364 auto memoryAttr
= LLVM::MemoryEffectsAttr::get(
365 rewriter
.getContext(),
366 {LLVM::ModRefInfo::NoModRef
, LLVM::ModRefInfo::NoModRef
,
367 LLVM::ModRefInfo::NoModRef
});
368 newFuncOp
.setMemoryEffectsAttr(memoryAttr
);
371 // Propagate argument/result attributes to all converted arguments/result
372 // obtained after converting a given original argument/result.
373 if (ArrayAttr resAttrDicts
= funcOp
.getAllResultAttrs()) {
374 assert(!resAttrDicts
.empty() && "expected array to be non-empty");
375 if (funcOp
.getNumResults() == 1)
376 newFuncOp
.setAllResultAttrs(resAttrDicts
);
378 if (ArrayAttr argAttrDicts
= funcOp
.getAllArgAttrs()) {
379 SmallVector
<Attribute
> newArgAttrs(
380 cast
<LLVM::LLVMFunctionType
>(llvmType
).getNumParams());
381 for (unsigned i
= 0, e
= funcOp
.getNumArguments(); i
< e
; ++i
) {
382 // Some LLVM IR attribute have a type attached to them. During FuncOp ->
383 // LLVMFuncOp conversion these types may have changed. Account for that
384 // change by converting attributes' types as well.
385 SmallVector
<NamedAttribute
, 4> convertedAttrs
;
386 auto attrsDict
= cast
<DictionaryAttr
>(argAttrDicts
[i
]);
387 convertedAttrs
.reserve(attrsDict
.size());
388 for (const NamedAttribute
&attr
: attrsDict
) {
389 const auto convert
= [&](const NamedAttribute
&attr
) {
390 return TypeAttr::get(converter
.convertType(
391 cast
<TypeAttr
>(attr
.getValue()).getValue()));
393 if (attr
.getName().getValue() ==
394 LLVM::LLVMDialect::getByValAttrName()) {
395 convertedAttrs
.push_back(rewriter
.getNamedAttr(
396 LLVM::LLVMDialect::getByValAttrName(), convert(attr
)));
397 } else if (attr
.getName().getValue() ==
398 LLVM::LLVMDialect::getByRefAttrName()) {
399 convertedAttrs
.push_back(rewriter
.getNamedAttr(
400 LLVM::LLVMDialect::getByRefAttrName(), convert(attr
)));
401 } else if (attr
.getName().getValue() ==
402 LLVM::LLVMDialect::getStructRetAttrName()) {
403 convertedAttrs
.push_back(rewriter
.getNamedAttr(
404 LLVM::LLVMDialect::getStructRetAttrName(), convert(attr
)));
405 } else if (attr
.getName().getValue() ==
406 LLVM::LLVMDialect::getInAllocaAttrName()) {
407 convertedAttrs
.push_back(rewriter
.getNamedAttr(
408 LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr
)));
410 convertedAttrs
.push_back(attr
);
413 auto mapping
= result
.getInputMapping(i
);
414 assert(mapping
&& "unexpected deletion of function argument");
415 // Only attach the new argument attributes if there is a one-to-one
416 // mapping from old to new types. Otherwise, attributes might be
417 // attached to types that they do not support.
418 if (mapping
->size
== 1) {
419 newArgAttrs
[mapping
->inputNo
] =
420 DictionaryAttr::get(rewriter
.getContext(), convertedAttrs
);
423 // TODO: Implement custom handling for types that expand to multiple
424 // function arguments.
425 for (size_t j
= 0; j
< mapping
->size
; ++j
)
426 newArgAttrs
[mapping
->inputNo
+ j
] =
427 DictionaryAttr::get(rewriter
.getContext(), {});
429 if (!newArgAttrs
.empty())
430 newFuncOp
.setAllArgAttrs(rewriter
.getArrayAttr(newArgAttrs
));
433 rewriter
.inlineRegionBefore(funcOp
.getFunctionBody(), newFuncOp
.getBody(),
435 if (failed(rewriter
.convertRegionTypes(&newFuncOp
.getBody(), converter
,
437 return rewriter
.notifyMatchFailure(funcOp
,
438 "region types conversion failed");
441 // Fix the type mismatch between the materialized `llvm.ptr` and the expected
442 // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
443 // function arguments.
444 restoreByValRefArgumentType(rewriter
, converter
, byValRefNonPtrAttrs
,
445 oldBlockArgs
, newFuncOp
);
447 if (!shouldUseBarePtrCallConv(funcOp
, &converter
)) {
448 if (funcOp
->getAttrOfType
<UnitAttr
>(
449 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
450 if (newFuncOp
.isVarArg())
451 return funcOp
.emitError("C interface for variadic functions is not "
454 if (newFuncOp
.isExternal())
455 wrapExternalFunction(rewriter
, funcOp
->getLoc(), converter
, funcOp
,
458 wrapForExternalCallers(rewriter
, funcOp
->getLoc(), converter
, funcOp
,
468 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
469 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
471 struct FuncOpConversion
: public ConvertOpToLLVMPattern
<func::FuncOp
> {
472 FuncOpConversion(const LLVMTypeConverter
&converter
)
473 : ConvertOpToLLVMPattern(converter
) {}
476 matchAndRewrite(func::FuncOp funcOp
, OpAdaptor adaptor
,
477 ConversionPatternRewriter
&rewriter
) const override
{
478 FailureOr
<LLVM::LLVMFuncOp
> newFuncOp
= mlir::convertFuncOpToLLVMFuncOp(
479 cast
<FunctionOpInterface
>(funcOp
.getOperation()), rewriter
,
480 *getTypeConverter());
481 if (failed(newFuncOp
))
482 return rewriter
.notifyMatchFailure(funcOp
, "Could not convert funcop");
484 rewriter
.eraseOp(funcOp
);
489 struct ConstantOpLowering
: public ConvertOpToLLVMPattern
<func::ConstantOp
> {
490 using ConvertOpToLLVMPattern
<func::ConstantOp
>::ConvertOpToLLVMPattern
;
493 matchAndRewrite(func::ConstantOp op
, OpAdaptor adaptor
,
494 ConversionPatternRewriter
&rewriter
) const override
{
495 auto type
= typeConverter
->convertType(op
.getResult().getType());
496 if (!type
|| !LLVM::isCompatibleType(type
))
497 return rewriter
.notifyMatchFailure(op
, "failed to convert result type");
500 rewriter
.create
<LLVM::AddressOfOp
>(op
.getLoc(), type
, op
.getValue());
501 for (const NamedAttribute
&attr
: op
->getAttrs()) {
502 if (attr
.getName().strref() == "value")
504 newOp
->setAttr(attr
.getName(), attr
.getValue());
506 rewriter
.replaceOp(op
, newOp
->getResults());
511 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
512 // passes the pointer to the MemRef across function boundaries.
513 template <typename CallOpType
>
514 struct CallOpInterfaceLowering
: public ConvertOpToLLVMPattern
<CallOpType
> {
515 using ConvertOpToLLVMPattern
<CallOpType
>::ConvertOpToLLVMPattern
;
516 using Super
= CallOpInterfaceLowering
<CallOpType
>;
517 using Base
= ConvertOpToLLVMPattern
<CallOpType
>;
519 LogicalResult
matchAndRewriteImpl(CallOpType callOp
,
520 typename
CallOpType::Adaptor adaptor
,
521 ConversionPatternRewriter
&rewriter
,
522 bool useBarePtrCallConv
= false) const {
523 // Pack the result types into a struct.
524 Type packedResult
= nullptr;
525 unsigned numResults
= callOp
.getNumResults();
526 auto resultTypes
= llvm::to_vector
<4>(callOp
.getResultTypes());
528 if (numResults
!= 0) {
529 if (!(packedResult
= this->getTypeConverter()->packFunctionResults(
530 resultTypes
, useBarePtrCallConv
)))
534 if (useBarePtrCallConv
) {
535 for (auto it
: callOp
->getOperands()) {
536 Type operandType
= it
.getType();
537 if (isa
<UnrankedMemRefType
>(operandType
)) {
538 // Unranked memref is not supported in the bare pointer calling
544 auto promoted
= this->getTypeConverter()->promoteOperands(
545 callOp
.getLoc(), /*opOperands=*/callOp
->getOperands(),
546 adaptor
.getOperands(), rewriter
, useBarePtrCallConv
);
547 auto newOp
= rewriter
.create
<LLVM::CallOp
>(
548 callOp
.getLoc(), packedResult
? TypeRange(packedResult
) : TypeRange(),
549 promoted
, callOp
->getAttrs());
551 newOp
.getProperties().operandSegmentSizes
= {
552 static_cast<int32_t>(promoted
.size()), 0};
553 newOp
.getProperties().op_bundle_sizes
= rewriter
.getDenseI32ArrayAttr({});
555 SmallVector
<Value
, 4> results
;
556 if (numResults
< 2) {
557 // If < 2 results, packing did not do anything and we can just return.
558 results
.append(newOp
.result_begin(), newOp
.result_end());
560 // Otherwise, it had been converted to an operation producing a structure.
561 // Extract individual results from the structure and return them as list.
562 results
.reserve(numResults
);
563 for (unsigned i
= 0; i
< numResults
; ++i
) {
564 results
.push_back(rewriter
.create
<LLVM::ExtractValueOp
>(
565 callOp
.getLoc(), newOp
->getResult(0), i
));
569 if (useBarePtrCallConv
) {
570 // For the bare-ptr calling convention, promote memref results to
572 assert(results
.size() == resultTypes
.size() &&
573 "The number of arguments and types doesn't match");
574 this->getTypeConverter()->promoteBarePtrsToDescriptors(
575 rewriter
, callOp
.getLoc(), resultTypes
, results
);
576 } else if (failed(this->copyUnrankedDescriptors(rewriter
, callOp
.getLoc(),
577 resultTypes
, results
,
578 /*toDynamic=*/false))) {
582 rewriter
.replaceOp(callOp
, results
);
587 class CallOpLowering
: public CallOpInterfaceLowering
<func::CallOp
> {
589 CallOpLowering(const LLVMTypeConverter
&typeConverter
,
591 const SymbolTable
*symbolTable
, PatternBenefit benefit
= 1)
592 : CallOpInterfaceLowering
<func::CallOp
>(typeConverter
, benefit
),
593 symbolTable(symbolTable
) {}
596 matchAndRewrite(func::CallOp callOp
, OpAdaptor adaptor
,
597 ConversionPatternRewriter
&rewriter
) const override
{
598 bool useBarePtrCallConv
= false;
599 if (getTypeConverter()->getOptions().useBarePtrCallConv
) {
600 useBarePtrCallConv
= true;
601 } else if (symbolTable
!= nullptr) {
604 symbolTable
->lookup(callOp
.getCalleeAttr().getValue());
606 callee
!= nullptr && callee
->hasAttr(barePtrAttrName
);
608 // Warning: This is a linear lookup.
610 SymbolTable::lookupNearestSymbolFrom(callOp
, callOp
.getCalleeAttr());
612 callee
!= nullptr && callee
->hasAttr(barePtrAttrName
);
614 return matchAndRewriteImpl(callOp
, adaptor
, rewriter
, useBarePtrCallConv
);
618 const SymbolTable
*symbolTable
= nullptr;
621 struct CallIndirectOpLowering
622 : public CallOpInterfaceLowering
<func::CallIndirectOp
> {
626 matchAndRewrite(func::CallIndirectOp callIndirectOp
, OpAdaptor adaptor
,
627 ConversionPatternRewriter
&rewriter
) const override
{
628 return matchAndRewriteImpl(callIndirectOp
, adaptor
, rewriter
);
632 struct UnrealizedConversionCastOpLowering
633 : public ConvertOpToLLVMPattern
<UnrealizedConversionCastOp
> {
634 using ConvertOpToLLVMPattern
<
635 UnrealizedConversionCastOp
>::ConvertOpToLLVMPattern
;
638 matchAndRewrite(UnrealizedConversionCastOp op
, OpAdaptor adaptor
,
639 ConversionPatternRewriter
&rewriter
) const override
{
640 SmallVector
<Type
> convertedTypes
;
641 if (succeeded(typeConverter
->convertTypes(op
.getOutputs().getTypes(),
643 convertedTypes
== adaptor
.getInputs().getTypes()) {
644 rewriter
.replaceOp(op
, adaptor
.getInputs());
648 convertedTypes
.clear();
649 if (succeeded(typeConverter
->convertTypes(adaptor
.getInputs().getTypes(),
651 convertedTypes
== op
.getOutputs().getType()) {
652 rewriter
.replaceOp(op
, adaptor
.getInputs());
659 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
660 // `ReturnOp` interacts with the function signature and must have as many
661 // operands as the function has return values. Because in LLVM IR, functions
662 // can only return 0 or 1 value, we pack multiple values into a structure type.
663 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
664 // necessary before returning it
665 struct ReturnOpLowering
: public ConvertOpToLLVMPattern
<func::ReturnOp
> {
666 using ConvertOpToLLVMPattern
<func::ReturnOp
>::ConvertOpToLLVMPattern
;
669 matchAndRewrite(func::ReturnOp op
, OpAdaptor adaptor
,
670 ConversionPatternRewriter
&rewriter
) const override
{
671 Location loc
= op
.getLoc();
672 unsigned numArguments
= op
.getNumOperands();
673 SmallVector
<Value
, 4> updatedOperands
;
675 auto funcOp
= op
->getParentOfType
<LLVM::LLVMFuncOp
>();
676 bool useBarePtrCallConv
=
677 shouldUseBarePtrCallConv(funcOp
, this->getTypeConverter());
678 if (useBarePtrCallConv
) {
679 // For the bare-ptr calling convention, extract the aligned pointer to
680 // be returned from the memref descriptor.
681 for (auto it
: llvm::zip(op
->getOperands(), adaptor
.getOperands())) {
682 Type oldTy
= std::get
<0>(it
).getType();
683 Value newOperand
= std::get
<1>(it
);
684 if (isa
<MemRefType
>(oldTy
) && getTypeConverter()->canConvertToBarePtr(
685 cast
<BaseMemRefType
>(oldTy
))) {
686 MemRefDescriptor
memrefDesc(newOperand
);
687 newOperand
= memrefDesc
.allocatedPtr(rewriter
, loc
);
688 } else if (isa
<UnrankedMemRefType
>(oldTy
)) {
689 // Unranked memref is not supported in the bare pointer calling
693 updatedOperands
.push_back(newOperand
);
696 updatedOperands
= llvm::to_vector
<4>(adaptor
.getOperands());
697 (void)copyUnrankedDescriptors(rewriter
, loc
, op
.getOperands().getTypes(),
702 // If ReturnOp has 0 or 1 operand, create it and return immediately.
703 if (numArguments
<= 1) {
704 rewriter
.replaceOpWithNewOp
<LLVM::ReturnOp
>(
705 op
, TypeRange(), updatedOperands
, op
->getAttrs());
709 // Otherwise, we need to pack the arguments into an LLVM struct type before
711 auto packedType
= getTypeConverter()->packFunctionResults(
712 op
.getOperandTypes(), useBarePtrCallConv
);
714 return rewriter
.notifyMatchFailure(op
, "could not convert result types");
717 Value packed
= rewriter
.create
<LLVM::UndefOp
>(loc
, packedType
);
718 for (auto [idx
, operand
] : llvm::enumerate(updatedOperands
)) {
719 packed
= rewriter
.create
<LLVM::InsertValueOp
>(loc
, packed
, operand
, idx
);
721 rewriter
.replaceOpWithNewOp
<LLVM::ReturnOp
>(op
, TypeRange(), packed
,
728 void mlir::populateFuncToLLVMFuncOpConversionPattern(
729 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
) {
730 patterns
.add
<FuncOpConversion
>(converter
);
733 void mlir::populateFuncToLLVMConversionPatterns(
734 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
,
735 const SymbolTable
*symbolTable
) {
736 populateFuncToLLVMFuncOpConversionPattern(converter
, patterns
);
737 patterns
.add
<CallIndirectOpLowering
>(converter
);
738 patterns
.add
<CallOpLowering
>(converter
, symbolTable
);
739 patterns
.add
<ConstantOpLowering
>(converter
);
740 patterns
.add
<ReturnOpLowering
>(converter
);
744 /// A pass converting Func operations into the LLVM IR dialect.
745 struct ConvertFuncToLLVMPass
746 : public impl::ConvertFuncToLLVMPassBase
<ConvertFuncToLLVMPass
> {
749 /// Run the dialect converter on the module.
750 void runOnOperation() override
{
751 ModuleOp m
= getOperation();
752 StringRef dataLayout
;
753 auto dataLayoutAttr
= dyn_cast_or_null
<StringAttr
>(
754 m
->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
756 dataLayout
= dataLayoutAttr
.getValue();
758 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
759 dataLayout
, [this](const Twine
&message
) {
760 getOperation().emitError() << message
.str();
766 const auto &dataLayoutAnalysis
= getAnalysis
<DataLayoutAnalysis
>();
768 LowerToLLVMOptions
options(&getContext(),
769 dataLayoutAnalysis
.getAtOrAbove(m
));
770 options
.useBarePtrCallConv
= useBarePtrCallConv
;
771 if (indexBitwidth
!= kDeriveIndexBitwidthFromDataLayout
)
772 options
.overrideIndexBitwidth(indexBitwidth
);
773 options
.dataLayout
= llvm::DataLayout(dataLayout
);
775 LLVMTypeConverter
typeConverter(&getContext(), options
,
776 &dataLayoutAnalysis
);
778 std::optional
<SymbolTable
> optSymbolTable
= std::nullopt
;
779 const SymbolTable
*symbolTable
= nullptr;
780 if (!options
.useBarePtrCallConv
) {
781 optSymbolTable
.emplace(m
);
782 symbolTable
= &optSymbolTable
.value();
785 RewritePatternSet
patterns(&getContext());
786 populateFuncToLLVMConversionPatterns(typeConverter
, patterns
, symbolTable
);
788 // TODO(https://github.com/llvm/llvm-project/issues/70982): Remove these in
789 // favor of their dedicated conversion passes.
790 arith::populateArithToLLVMConversionPatterns(typeConverter
, patterns
);
791 cf::populateControlFlowToLLVMConversionPatterns(typeConverter
, patterns
);
793 LLVMConversionTarget
target(getContext());
794 if (failed(applyPartialConversion(m
, target
, std::move(patterns
))))
799 struct SetLLVMModuleDataLayoutPass
800 : public impl::SetLLVMModuleDataLayoutPassBase
<
801 SetLLVMModuleDataLayoutPass
> {
804 /// Run the dialect converter on the module.
805 void runOnOperation() override
{
806 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
807 this->dataLayout
, [this](const Twine
&message
) {
808 getOperation().emitError() << message
.str();
813 ModuleOp m
= getOperation();
814 m
->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
815 StringAttr::get(m
.getContext(), this->dataLayout
));
820 //===----------------------------------------------------------------------===//
821 // ConvertToLLVMPatternInterface implementation
822 //===----------------------------------------------------------------------===//
825 /// Implement the interface to convert Func to LLVM.
826 struct FuncToLLVMDialectInterface
: public ConvertToLLVMPatternInterface
{
827 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface
;
828 /// Hook for derived dialect interface to provide conversion patterns
829 /// and mark dialect legal for the conversion target.
830 void populateConvertToLLVMConversionPatterns(
831 ConversionTarget
&target
, LLVMTypeConverter
&typeConverter
,
832 RewritePatternSet
&patterns
) const final
{
833 populateFuncToLLVMConversionPatterns(typeConverter
, patterns
);
838 void mlir::registerConvertFuncToLLVMInterface(DialectRegistry
®istry
) {
839 registry
.addExtension(+[](MLIRContext
*ctx
, func::FuncDialect
*dialect
) {
840 dialect
->addInterfaces
<FuncToLLVMDialectInterface
>();