[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / FuncToLLVM / FuncToLLVM.cpp
blobc046ea1b824fc85ee27b9d95f700aaf7bfa2c870
1 //===- FuncToLLVM.cpp - Func to LLVM dialect conversion -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR Func and builtin dialects
10 // into the LLVM IR dialect.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16 #include "mlir/Analysis/DataLayoutAnalysis.h"
17 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
18 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
19 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
20 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
21 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
22 #include "mlir/Conversion/LLVMCommon/Pattern.h"
23 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
28 #include "mlir/Dialect/Utils/StaticValueUtils.h"
29 #include "mlir/IR/Attributes.h"
30 #include "mlir/IR/Builders.h"
31 #include "mlir/IR/BuiltinAttributeInterfaces.h"
32 #include "mlir/IR/BuiltinAttributes.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/IRMapping.h"
35 #include "mlir/IR/PatternMatch.h"
36 #include "mlir/IR/SymbolTable.h"
37 #include "mlir/IR/TypeUtilities.h"
38 #include "mlir/Transforms/DialectConversion.h"
39 #include "mlir/Transforms/Passes.h"
40 #include "llvm/ADT/SmallVector.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/IR/DerivedTypes.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Type.h"
45 #include "llvm/Support/Casting.h"
46 #include "llvm/Support/CommandLine.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include <algorithm>
49 #include <functional>
50 #include <optional>
52 namespace mlir {
53 #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
54 #define GEN_PASS_DEF_SETLLVMMODULEDATALAYOUTPASS
55 #include "mlir/Conversion/Passes.h.inc"
56 } // namespace mlir
58 using namespace mlir;
60 #define PASS_NAME "convert-func-to-llvm"
62 static constexpr StringRef varargsAttrName = "func.varargs";
63 static constexpr StringRef linkageAttrName = "llvm.linkage";
64 static constexpr StringRef barePtrAttrName = "llvm.bareptr";
66 /// Return `true` if the `op` should use bare pointer calling convention.
67 static bool shouldUseBarePtrCallConv(Operation *op,
68 const LLVMTypeConverter *typeConverter) {
69 return (op && op->hasAttr(barePtrAttrName)) ||
70 typeConverter->getOptions().useBarePtrCallConv;
73 /// Only retain those attributes that are not constructed by
74 /// `LLVMFuncOp::build`.
75 static void filterFuncAttributes(FunctionOpInterface func,
76 SmallVectorImpl<NamedAttribute> &result) {
77 for (const NamedAttribute &attr : func->getDiscardableAttrs()) {
78 if (attr.getName() == linkageAttrName ||
79 attr.getName() == varargsAttrName ||
80 attr.getName() == LLVM::LLVMDialect::getReadnoneAttrName())
81 continue;
82 result.push_back(attr);
86 /// Propagate argument/results attributes.
87 static void propagateArgResAttrs(OpBuilder &builder, bool resultStructType,
88 FunctionOpInterface funcOp,
89 LLVM::LLVMFuncOp wrapperFuncOp) {
90 auto argAttrs = funcOp.getAllArgAttrs();
91 if (!resultStructType) {
92 if (auto resAttrs = funcOp.getAllResultAttrs())
93 wrapperFuncOp.setAllResultAttrs(resAttrs);
94 if (argAttrs)
95 wrapperFuncOp.setAllArgAttrs(argAttrs);
96 } else {
97 SmallVector<Attribute> argAttributes;
98 // Only modify the argument and result attributes when the result is now
99 // an argument.
100 if (argAttrs) {
101 argAttributes.push_back(builder.getDictionaryAttr({}));
102 argAttributes.append(argAttrs.begin(), argAttrs.end());
103 wrapperFuncOp.setAllArgAttrs(argAttributes);
106 cast<FunctionOpInterface>(wrapperFuncOp.getOperation())
107 .setVisibility(funcOp.getVisibility());
110 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
111 /// arguments instead of unpacked arguments. This function can be called from C
112 /// by passing a pointer to a C struct corresponding to a memref descriptor.
113 /// Similarly, returned memrefs are passed via pointers to a C struct that is
114 /// passed as additional argument.
115 /// Internally, the auxiliary function unpacks the descriptor into individual
116 /// components and forwards them to `newFuncOp` and forwards the results to
117 /// the extra arguments.
118 static void wrapForExternalCallers(OpBuilder &rewriter, Location loc,
119 const LLVMTypeConverter &typeConverter,
120 FunctionOpInterface funcOp,
121 LLVM::LLVMFuncOp newFuncOp) {
122 auto type = cast<FunctionType>(funcOp.getFunctionType());
123 auto [wrapperFuncType, resultStructType] =
124 typeConverter.convertFunctionTypeCWrapper(type);
126 SmallVector<NamedAttribute> attributes;
127 filterFuncAttributes(funcOp, attributes);
129 auto wrapperFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
130 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
131 wrapperFuncType, LLVM::Linkage::External, /*dsoLocal=*/false,
132 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
133 propagateArgResAttrs(rewriter, !!resultStructType, funcOp, wrapperFuncOp);
135 OpBuilder::InsertionGuard guard(rewriter);
136 rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock(rewriter));
138 SmallVector<Value, 8> args;
139 size_t argOffset = resultStructType ? 1 : 0;
140 for (auto [index, argType] : llvm::enumerate(type.getInputs())) {
141 Value arg = wrapperFuncOp.getArgument(index + argOffset);
142 if (auto memrefType = dyn_cast<MemRefType>(argType)) {
143 Value loaded = rewriter.create<LLVM::LoadOp>(
144 loc, typeConverter.convertType(memrefType), arg);
145 MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args);
146 continue;
148 if (isa<UnrankedMemRefType>(argType)) {
149 Value loaded = rewriter.create<LLVM::LoadOp>(
150 loc, typeConverter.convertType(argType), arg);
151 UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args);
152 continue;
155 args.push_back(arg);
158 auto call = rewriter.create<LLVM::CallOp>(loc, newFuncOp, args);
160 if (resultStructType) {
161 rewriter.create<LLVM::StoreOp>(loc, call.getResult(),
162 wrapperFuncOp.getArgument(0));
163 rewriter.create<LLVM::ReturnOp>(loc, ValueRange{});
164 } else {
165 rewriter.create<LLVM::ReturnOp>(loc, call.getResults());
169 /// Creates an auxiliary function with pointer-to-memref-descriptor-struct
170 /// arguments instead of unpacked arguments. Creates a body for the (external)
171 /// `newFuncOp` that allocates a memref descriptor on stack, packs the
172 /// individual arguments into this descriptor and passes a pointer to it into
173 /// the auxiliary function. If the result of the function cannot be directly
174 /// returned, we write it to a special first argument that provides a pointer
175 /// to a corresponding struct. This auxiliary external function is now
176 /// compatible with functions defined in C using pointers to C structs
177 /// corresponding to a memref descriptor.
178 static void wrapExternalFunction(OpBuilder &builder, Location loc,
179 const LLVMTypeConverter &typeConverter,
180 FunctionOpInterface funcOp,
181 LLVM::LLVMFuncOp newFuncOp) {
182 OpBuilder::InsertionGuard guard(builder);
184 auto [wrapperType, resultStructType] =
185 typeConverter.convertFunctionTypeCWrapper(
186 cast<FunctionType>(funcOp.getFunctionType()));
187 // This conversion can only fail if it could not convert one of the argument
188 // types. But since it has been applied to a non-wrapper function before, it
189 // should have failed earlier and not reach this point at all.
190 assert(wrapperType && "unexpected type conversion failure");
192 SmallVector<NamedAttribute, 4> attributes;
193 filterFuncAttributes(funcOp, attributes);
195 // Create the auxiliary function.
196 auto wrapperFunc = builder.create<LLVM::LLVMFuncOp>(
197 loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(),
198 wrapperType, LLVM::Linkage::External, /*dsoLocal=*/false,
199 /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr, attributes);
200 propagateArgResAttrs(builder, !!resultStructType, funcOp, wrapperFunc);
202 // The wrapper that we synthetize here should only be visible in this module.
203 newFuncOp.setLinkage(LLVM::Linkage::Private);
204 builder.setInsertionPointToStart(newFuncOp.addEntryBlock(builder));
206 // Get a ValueRange containing arguments.
207 FunctionType type = cast<FunctionType>(funcOp.getFunctionType());
208 SmallVector<Value, 8> args;
209 args.reserve(type.getNumInputs());
210 ValueRange wrapperArgsRange(newFuncOp.getArguments());
212 if (resultStructType) {
213 // Allocate the struct on the stack and pass the pointer.
214 Type resultType = cast<LLVM::LLVMFunctionType>(wrapperType).getParamType(0);
215 Value one = builder.create<LLVM::ConstantOp>(
216 loc, typeConverter.convertType(builder.getIndexType()),
217 builder.getIntegerAttr(builder.getIndexType(), 1));
218 Value result =
219 builder.create<LLVM::AllocaOp>(loc, resultType, resultStructType, one);
220 args.push_back(result);
223 // Iterate over the inputs of the original function and pack values into
224 // memref descriptors if the original type is a memref.
225 for (Type input : type.getInputs()) {
226 Value arg;
227 int numToDrop = 1;
228 auto memRefType = dyn_cast<MemRefType>(input);
229 auto unrankedMemRefType = dyn_cast<UnrankedMemRefType>(input);
230 if (memRefType || unrankedMemRefType) {
231 numToDrop = memRefType
232 ? MemRefDescriptor::getNumUnpackedValues(memRefType)
233 : UnrankedMemRefDescriptor::getNumUnpackedValues();
234 Value packed =
235 memRefType
236 ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType,
237 wrapperArgsRange.take_front(numToDrop))
238 : UnrankedMemRefDescriptor::pack(
239 builder, loc, typeConverter, unrankedMemRefType,
240 wrapperArgsRange.take_front(numToDrop));
242 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
243 Value one = builder.create<LLVM::ConstantOp>(
244 loc, typeConverter.convertType(builder.getIndexType()),
245 builder.getIntegerAttr(builder.getIndexType(), 1));
246 Value allocated = builder.create<LLVM::AllocaOp>(
247 loc, ptrTy, packed.getType(), one, /*alignment=*/0);
248 builder.create<LLVM::StoreOp>(loc, packed, allocated);
249 arg = allocated;
250 } else {
251 arg = wrapperArgsRange[0];
254 args.push_back(arg);
255 wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop);
257 assert(wrapperArgsRange.empty() && "did not map some of the arguments");
259 auto call = builder.create<LLVM::CallOp>(loc, wrapperFunc, args);
261 if (resultStructType) {
262 Value result =
263 builder.create<LLVM::LoadOp>(loc, resultStructType, args.front());
264 builder.create<LLVM::ReturnOp>(loc, result);
265 } else {
266 builder.create<LLVM::ReturnOp>(loc, call.getResults());
270 /// Inserts `llvm.load` ops in the function body to restore the expected pointee
271 /// value from `llvm.byval`/`llvm.byref` function arguments that were converted
272 /// to LLVM pointer types.
273 static void restoreByValRefArgumentType(
274 ConversionPatternRewriter &rewriter, const LLVMTypeConverter &typeConverter,
275 ArrayRef<std::optional<NamedAttribute>> byValRefNonPtrAttrs,
276 ArrayRef<BlockArgument> oldBlockArgs, LLVM::LLVMFuncOp funcOp) {
277 // Nothing to do for function declarations.
278 if (funcOp.isExternal())
279 return;
281 ConversionPatternRewriter::InsertionGuard guard(rewriter);
282 rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
284 for (const auto &[arg, oldArg, byValRefAttr] :
285 llvm::zip(funcOp.getArguments(), oldBlockArgs, byValRefNonPtrAttrs)) {
286 // Skip argument if no `llvm.byval` or `llvm.byref` attribute.
287 if (!byValRefAttr)
288 continue;
290 // Insert load to retrieve the actual argument passed by value/reference.
291 assert(isa<LLVM::LLVMPointerType>(arg.getType()) &&
292 "Expected LLVM pointer type for argument with "
293 "`llvm.byval`/`llvm.byref` attribute");
294 Type resTy = typeConverter.convertType(
295 cast<TypeAttr>(byValRefAttr->getValue()).getValue());
297 auto valueArg = rewriter.create<LLVM::LoadOp>(arg.getLoc(), resTy, arg);
298 rewriter.replaceUsesOfBlockArgument(oldArg, valueArg);
302 FailureOr<LLVM::LLVMFuncOp>
303 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
304 ConversionPatternRewriter &rewriter,
305 const LLVMTypeConverter &converter) {
306 // Check the funcOp has `FunctionType`.
307 auto funcTy = dyn_cast<FunctionType>(funcOp.getFunctionType());
308 if (!funcTy)
309 return rewriter.notifyMatchFailure(
310 funcOp, "Only support FunctionOpInterface with FunctionType");
312 // Keep track of the entry block arguments. They will be needed later.
313 SmallVector<BlockArgument> oldBlockArgs =
314 llvm::to_vector(funcOp.getArguments());
316 // Convert the original function arguments. They are converted using the
317 // LLVMTypeConverter provided to this legalization pattern.
318 auto varargsAttr = funcOp->getAttrOfType<BoolAttr>(varargsAttrName);
319 // Gather `llvm.byval` and `llvm.byref` arguments whose type convertion was
320 // overriden with an LLVM pointer type for later processing.
321 SmallVector<std::optional<NamedAttribute>> byValRefNonPtrAttrs;
322 TypeConverter::SignatureConversion result(funcOp.getNumArguments());
323 auto llvmType = converter.convertFunctionSignature(
324 funcOp, varargsAttr && varargsAttr.getValue(),
325 shouldUseBarePtrCallConv(funcOp, &converter), result,
326 byValRefNonPtrAttrs);
327 if (!llvmType)
328 return rewriter.notifyMatchFailure(funcOp, "signature conversion failed");
330 // Create an LLVM function, use external linkage by default until MLIR
331 // functions have linkage.
332 LLVM::Linkage linkage = LLVM::Linkage::External;
333 if (funcOp->hasAttr(linkageAttrName)) {
334 auto attr =
335 dyn_cast<mlir::LLVM::LinkageAttr>(funcOp->getAttr(linkageAttrName));
336 if (!attr) {
337 funcOp->emitError() << "Contains " << linkageAttrName
338 << " attribute not of type LLVM::LinkageAttr";
339 return rewriter.notifyMatchFailure(
340 funcOp, "Contains linkage attribute not of type LLVM::LinkageAttr");
342 linkage = attr.getLinkage();
345 SmallVector<NamedAttribute, 4> attributes;
346 filterFuncAttributes(funcOp, attributes);
347 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
348 funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
349 /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C, /*comdat=*/nullptr,
350 attributes);
351 cast<FunctionOpInterface>(newFuncOp.getOperation())
352 .setVisibility(funcOp.getVisibility());
354 // Create a memory effect attribute corresponding to readnone.
355 StringRef readnoneAttrName = LLVM::LLVMDialect::getReadnoneAttrName();
356 if (funcOp->hasAttr(readnoneAttrName)) {
357 auto attr = funcOp->getAttrOfType<UnitAttr>(readnoneAttrName);
358 if (!attr) {
359 funcOp->emitError() << "Contains " << readnoneAttrName
360 << " attribute not of type UnitAttr";
361 return rewriter.notifyMatchFailure(
362 funcOp, "Contains readnone attribute not of type UnitAttr");
364 auto memoryAttr = LLVM::MemoryEffectsAttr::get(
365 rewriter.getContext(),
366 {LLVM::ModRefInfo::NoModRef, LLVM::ModRefInfo::NoModRef,
367 LLVM::ModRefInfo::NoModRef});
368 newFuncOp.setMemoryEffectsAttr(memoryAttr);
371 // Propagate argument/result attributes to all converted arguments/result
372 // obtained after converting a given original argument/result.
373 if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) {
374 assert(!resAttrDicts.empty() && "expected array to be non-empty");
375 if (funcOp.getNumResults() == 1)
376 newFuncOp.setAllResultAttrs(resAttrDicts);
378 if (ArrayAttr argAttrDicts = funcOp.getAllArgAttrs()) {
379 SmallVector<Attribute> newArgAttrs(
380 cast<LLVM::LLVMFunctionType>(llvmType).getNumParams());
381 for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
382 // Some LLVM IR attribute have a type attached to them. During FuncOp ->
383 // LLVMFuncOp conversion these types may have changed. Account for that
384 // change by converting attributes' types as well.
385 SmallVector<NamedAttribute, 4> convertedAttrs;
386 auto attrsDict = cast<DictionaryAttr>(argAttrDicts[i]);
387 convertedAttrs.reserve(attrsDict.size());
388 for (const NamedAttribute &attr : attrsDict) {
389 const auto convert = [&](const NamedAttribute &attr) {
390 return TypeAttr::get(converter.convertType(
391 cast<TypeAttr>(attr.getValue()).getValue()));
393 if (attr.getName().getValue() ==
394 LLVM::LLVMDialect::getByValAttrName()) {
395 convertedAttrs.push_back(rewriter.getNamedAttr(
396 LLVM::LLVMDialect::getByValAttrName(), convert(attr)));
397 } else if (attr.getName().getValue() ==
398 LLVM::LLVMDialect::getByRefAttrName()) {
399 convertedAttrs.push_back(rewriter.getNamedAttr(
400 LLVM::LLVMDialect::getByRefAttrName(), convert(attr)));
401 } else if (attr.getName().getValue() ==
402 LLVM::LLVMDialect::getStructRetAttrName()) {
403 convertedAttrs.push_back(rewriter.getNamedAttr(
404 LLVM::LLVMDialect::getStructRetAttrName(), convert(attr)));
405 } else if (attr.getName().getValue() ==
406 LLVM::LLVMDialect::getInAllocaAttrName()) {
407 convertedAttrs.push_back(rewriter.getNamedAttr(
408 LLVM::LLVMDialect::getInAllocaAttrName(), convert(attr)));
409 } else {
410 convertedAttrs.push_back(attr);
413 auto mapping = result.getInputMapping(i);
414 assert(mapping && "unexpected deletion of function argument");
415 // Only attach the new argument attributes if there is a one-to-one
416 // mapping from old to new types. Otherwise, attributes might be
417 // attached to types that they do not support.
418 if (mapping->size == 1) {
419 newArgAttrs[mapping->inputNo] =
420 DictionaryAttr::get(rewriter.getContext(), convertedAttrs);
421 continue;
423 // TODO: Implement custom handling for types that expand to multiple
424 // function arguments.
425 for (size_t j = 0; j < mapping->size; ++j)
426 newArgAttrs[mapping->inputNo + j] =
427 DictionaryAttr::get(rewriter.getContext(), {});
429 if (!newArgAttrs.empty())
430 newFuncOp.setAllArgAttrs(rewriter.getArrayAttr(newArgAttrs));
433 rewriter.inlineRegionBefore(funcOp.getFunctionBody(), newFuncOp.getBody(),
434 newFuncOp.end());
435 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), converter,
436 &result))) {
437 return rewriter.notifyMatchFailure(funcOp,
438 "region types conversion failed");
441 // Fix the type mismatch between the materialized `llvm.ptr` and the expected
442 // pointee type in the function body when converting `llvm.byval`/`llvm.byref`
443 // function arguments.
444 restoreByValRefArgumentType(rewriter, converter, byValRefNonPtrAttrs,
445 oldBlockArgs, newFuncOp);
447 if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
448 if (funcOp->getAttrOfType<UnitAttr>(
449 LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
450 if (newFuncOp.isVarArg())
451 return funcOp.emitError("C interface for variadic functions is not "
452 "supported yet.");
454 if (newFuncOp.isExternal())
455 wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
456 newFuncOp);
457 else
458 wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
459 newFuncOp);
463 return newFuncOp;
466 namespace {
468 /// FuncOp legalization pattern that converts MemRef arguments to pointers to
469 /// MemRef descriptors (LLVM struct data types) containing all the MemRef type
470 /// information.
471 struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
472 FuncOpConversion(const LLVMTypeConverter &converter)
473 : ConvertOpToLLVMPattern(converter) {}
475 LogicalResult
476 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
477 ConversionPatternRewriter &rewriter) const override {
478 FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
479 cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
480 *getTypeConverter());
481 if (failed(newFuncOp))
482 return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");
484 rewriter.eraseOp(funcOp);
485 return success();
489 struct ConstantOpLowering : public ConvertOpToLLVMPattern<func::ConstantOp> {
490 using ConvertOpToLLVMPattern<func::ConstantOp>::ConvertOpToLLVMPattern;
492 LogicalResult
493 matchAndRewrite(func::ConstantOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter) const override {
495 auto type = typeConverter->convertType(op.getResult().getType());
496 if (!type || !LLVM::isCompatibleType(type))
497 return rewriter.notifyMatchFailure(op, "failed to convert result type");
499 auto newOp =
500 rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type, op.getValue());
501 for (const NamedAttribute &attr : op->getAttrs()) {
502 if (attr.getName().strref() == "value")
503 continue;
504 newOp->setAttr(attr.getName(), attr.getValue());
506 rewriter.replaceOp(op, newOp->getResults());
507 return success();
511 // A CallOp automatically promotes MemRefType to a sequence of alloca/store and
512 // passes the pointer to the MemRef across function boundaries.
513 template <typename CallOpType>
514 struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
515 using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
516 using Super = CallOpInterfaceLowering<CallOpType>;
517 using Base = ConvertOpToLLVMPattern<CallOpType>;
519 LogicalResult matchAndRewriteImpl(CallOpType callOp,
520 typename CallOpType::Adaptor adaptor,
521 ConversionPatternRewriter &rewriter,
522 bool useBarePtrCallConv = false) const {
523 // Pack the result types into a struct.
524 Type packedResult = nullptr;
525 unsigned numResults = callOp.getNumResults();
526 auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
528 if (numResults != 0) {
529 if (!(packedResult = this->getTypeConverter()->packFunctionResults(
530 resultTypes, useBarePtrCallConv)))
531 return failure();
534 if (useBarePtrCallConv) {
535 for (auto it : callOp->getOperands()) {
536 Type operandType = it.getType();
537 if (isa<UnrankedMemRefType>(operandType)) {
538 // Unranked memref is not supported in the bare pointer calling
539 // convention.
540 return failure();
544 auto promoted = this->getTypeConverter()->promoteOperands(
545 callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
546 adaptor.getOperands(), rewriter, useBarePtrCallConv);
547 auto newOp = rewriter.create<LLVM::CallOp>(
548 callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
549 promoted, callOp->getAttrs());
551 newOp.getProperties().operandSegmentSizes = {
552 static_cast<int32_t>(promoted.size()), 0};
553 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
555 SmallVector<Value, 4> results;
556 if (numResults < 2) {
557 // If < 2 results, packing did not do anything and we can just return.
558 results.append(newOp.result_begin(), newOp.result_end());
559 } else {
560 // Otherwise, it had been converted to an operation producing a structure.
561 // Extract individual results from the structure and return them as list.
562 results.reserve(numResults);
563 for (unsigned i = 0; i < numResults; ++i) {
564 results.push_back(rewriter.create<LLVM::ExtractValueOp>(
565 callOp.getLoc(), newOp->getResult(0), i));
569 if (useBarePtrCallConv) {
570 // For the bare-ptr calling convention, promote memref results to
571 // descriptors.
572 assert(results.size() == resultTypes.size() &&
573 "The number of arguments and types doesn't match");
574 this->getTypeConverter()->promoteBarePtrsToDescriptors(
575 rewriter, callOp.getLoc(), resultTypes, results);
576 } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
577 resultTypes, results,
578 /*toDynamic=*/false))) {
579 return failure();
582 rewriter.replaceOp(callOp, results);
583 return success();
587 class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
588 public:
589 CallOpLowering(const LLVMTypeConverter &typeConverter,
590 // Can be nullptr.
591 const SymbolTable *symbolTable, PatternBenefit benefit = 1)
592 : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
593 symbolTable(symbolTable) {}
595 LogicalResult
596 matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
597 ConversionPatternRewriter &rewriter) const override {
598 bool useBarePtrCallConv = false;
599 if (getTypeConverter()->getOptions().useBarePtrCallConv) {
600 useBarePtrCallConv = true;
601 } else if (symbolTable != nullptr) {
602 // Fast lookup.
603 Operation *callee =
604 symbolTable->lookup(callOp.getCalleeAttr().getValue());
605 useBarePtrCallConv =
606 callee != nullptr && callee->hasAttr(barePtrAttrName);
607 } else {
608 // Warning: This is a linear lookup.
609 Operation *callee =
610 SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
611 useBarePtrCallConv =
612 callee != nullptr && callee->hasAttr(barePtrAttrName);
614 return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
617 private:
618 const SymbolTable *symbolTable = nullptr;
621 struct CallIndirectOpLowering
622 : public CallOpInterfaceLowering<func::CallIndirectOp> {
623 using Super::Super;
625 LogicalResult
626 matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
627 ConversionPatternRewriter &rewriter) const override {
628 return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
632 struct UnrealizedConversionCastOpLowering
633 : public ConvertOpToLLVMPattern<UnrealizedConversionCastOp> {
634 using ConvertOpToLLVMPattern<
635 UnrealizedConversionCastOp>::ConvertOpToLLVMPattern;
637 LogicalResult
638 matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor,
639 ConversionPatternRewriter &rewriter) const override {
640 SmallVector<Type> convertedTypes;
641 if (succeeded(typeConverter->convertTypes(op.getOutputs().getTypes(),
642 convertedTypes)) &&
643 convertedTypes == adaptor.getInputs().getTypes()) {
644 rewriter.replaceOp(op, adaptor.getInputs());
645 return success();
648 convertedTypes.clear();
649 if (succeeded(typeConverter->convertTypes(adaptor.getInputs().getTypes(),
650 convertedTypes)) &&
651 convertedTypes == op.getOutputs().getType()) {
652 rewriter.replaceOp(op, adaptor.getInputs());
653 return success();
655 return failure();
659 // Special lowering pattern for `ReturnOps`. Unlike all other operations,
660 // `ReturnOp` interacts with the function signature and must have as many
661 // operands as the function has return values. Because in LLVM IR, functions
662 // can only return 0 or 1 value, we pack multiple values into a structure type.
663 // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if
664 // necessary before returning it
665 struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
666 using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
668 LogicalResult
669 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
670 ConversionPatternRewriter &rewriter) const override {
671 Location loc = op.getLoc();
672 unsigned numArguments = op.getNumOperands();
673 SmallVector<Value, 4> updatedOperands;
675 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
676 bool useBarePtrCallConv =
677 shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
678 if (useBarePtrCallConv) {
679 // For the bare-ptr calling convention, extract the aligned pointer to
680 // be returned from the memref descriptor.
681 for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
682 Type oldTy = std::get<0>(it).getType();
683 Value newOperand = std::get<1>(it);
684 if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
685 cast<BaseMemRefType>(oldTy))) {
686 MemRefDescriptor memrefDesc(newOperand);
687 newOperand = memrefDesc.allocatedPtr(rewriter, loc);
688 } else if (isa<UnrankedMemRefType>(oldTy)) {
689 // Unranked memref is not supported in the bare pointer calling
690 // convention.
691 return failure();
693 updatedOperands.push_back(newOperand);
695 } else {
696 updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
697 (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
698 updatedOperands,
699 /*toDynamic=*/true);
702 // If ReturnOp has 0 or 1 operand, create it and return immediately.
703 if (numArguments <= 1) {
704 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
705 op, TypeRange(), updatedOperands, op->getAttrs());
706 return success();
709 // Otherwise, we need to pack the arguments into an LLVM struct type before
710 // returning.
711 auto packedType = getTypeConverter()->packFunctionResults(
712 op.getOperandTypes(), useBarePtrCallConv);
713 if (!packedType) {
714 return rewriter.notifyMatchFailure(op, "could not convert result types");
717 Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
718 for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
719 packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
721 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
722 op->getAttrs());
723 return success();
726 } // namespace
728 void mlir::populateFuncToLLVMFuncOpConversionPattern(
729 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
730 patterns.add<FuncOpConversion>(converter);
733 void mlir::populateFuncToLLVMConversionPatterns(
734 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
735 const SymbolTable *symbolTable) {
736 populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
737 patterns.add<CallIndirectOpLowering>(converter);
738 patterns.add<CallOpLowering>(converter, symbolTable);
739 patterns.add<ConstantOpLowering>(converter);
740 patterns.add<ReturnOpLowering>(converter);
743 namespace {
744 /// A pass converting Func operations into the LLVM IR dialect.
745 struct ConvertFuncToLLVMPass
746 : public impl::ConvertFuncToLLVMPassBase<ConvertFuncToLLVMPass> {
747 using Base::Base;
749 /// Run the dialect converter on the module.
750 void runOnOperation() override {
751 ModuleOp m = getOperation();
752 StringRef dataLayout;
753 auto dataLayoutAttr = dyn_cast_or_null<StringAttr>(
754 m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()));
755 if (dataLayoutAttr)
756 dataLayout = dataLayoutAttr.getValue();
758 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
759 dataLayout, [this](const Twine &message) {
760 getOperation().emitError() << message.str();
761 }))) {
762 signalPassFailure();
763 return;
766 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
768 LowerToLLVMOptions options(&getContext(),
769 dataLayoutAnalysis.getAtOrAbove(m));
770 options.useBarePtrCallConv = useBarePtrCallConv;
771 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
772 options.overrideIndexBitwidth(indexBitwidth);
773 options.dataLayout = llvm::DataLayout(dataLayout);
775 LLVMTypeConverter typeConverter(&getContext(), options,
776 &dataLayoutAnalysis);
778 std::optional<SymbolTable> optSymbolTable = std::nullopt;
779 const SymbolTable *symbolTable = nullptr;
780 if (!options.useBarePtrCallConv) {
781 optSymbolTable.emplace(m);
782 symbolTable = &optSymbolTable.value();
785 RewritePatternSet patterns(&getContext());
786 populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
788 // TODO(https://github.com/llvm/llvm-project/issues/70982): Remove these in
789 // favor of their dedicated conversion passes.
790 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
791 cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
793 LLVMConversionTarget target(getContext());
794 if (failed(applyPartialConversion(m, target, std::move(patterns))))
795 signalPassFailure();
799 struct SetLLVMModuleDataLayoutPass
800 : public impl::SetLLVMModuleDataLayoutPassBase<
801 SetLLVMModuleDataLayoutPass> {
802 using Base::Base;
804 /// Run the dialect converter on the module.
805 void runOnOperation() override {
806 if (failed(LLVM::LLVMDialect::verifyDataLayoutString(
807 this->dataLayout, [this](const Twine &message) {
808 getOperation().emitError() << message.str();
809 }))) {
810 signalPassFailure();
811 return;
813 ModuleOp m = getOperation();
814 m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
815 StringAttr::get(m.getContext(), this->dataLayout));
818 } // namespace
820 //===----------------------------------------------------------------------===//
821 // ConvertToLLVMPatternInterface implementation
822 //===----------------------------------------------------------------------===//
824 namespace {
825 /// Implement the interface to convert Func to LLVM.
826 struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
827 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
828 /// Hook for derived dialect interface to provide conversion patterns
829 /// and mark dialect legal for the conversion target.
830 void populateConvertToLLVMConversionPatterns(
831 ConversionTarget &target, LLVMTypeConverter &typeConverter,
832 RewritePatternSet &patterns) const final {
833 populateFuncToLLVMConversionPatterns(typeConverter, patterns);
836 } // namespace
838 void mlir::registerConvertFuncToLLVMInterface(DialectRegistry &registry) {
839 registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
840 dialect->addInterfaces<FuncToLLVMDialectInterface>();