[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Conversion / LLVMCommon / TypeConverter.cpp
blobce91424e7a577e25fe7eab8509f15b9c096531c5
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
20 using namespace mlir;
22 SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
24 // Most of the time, the entry already exists in the map.
25 std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26 std::defer_lock);
27 if (getContext().isMultithreadingEnabled())
28 lock.lock();
29 auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30 if (recursiveStack != conversionCallStack.end())
31 return *recursiveStack->second;
34 // First time this thread gets here, we have to get an exclusive access to
35 // inset in the map
36 std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37 auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38 llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39 return *recursiveStackInserted.first->second;
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
43 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
44 const DataLayoutAnalysis *analysis)
45 : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
47 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
48 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
49 const LowerToLLVMOptions &options,
50 const DataLayoutAnalysis *analysis)
51 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
52 dataLayoutAnalysis(analysis) {
53 assert(llvmDialect && "LLVM IR dialect is not registered");
55 // Register conversions for the builtin types.
56 addConversion([&](ComplexType type) { return convertComplexType(type); });
57 addConversion([&](FloatType type) { return convertFloatType(type); });
58 addConversion([&](FunctionType type) { return convertFunctionType(type); });
59 addConversion([&](IndexType type) { return convertIndexType(type); });
60 addConversion([&](IntegerType type) { return convertIntegerType(type); });
61 addConversion([&](MemRefType type) { return convertMemRefType(type); });
62 addConversion(
63 [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
64 addConversion([&](VectorType type) -> std::optional<Type> {
65 FailureOr<Type> llvmType = convertVectorType(type);
66 if (failed(llvmType))
67 return std::nullopt;
68 return llvmType;
69 });
71 // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72 // before the conversions below since conversions are attempted in reverse
73 // order and those should take priority.
74 addConversion([](Type type) {
75 return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
76 : std::nullopt;
77 });
79 addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
80 -> std::optional<LogicalResult> {
81 // Fastpath for types that won't be converted by this callback anyway.
82 if (LLVM::isCompatibleType(type)) {
83 results.push_back(type);
84 return success();
87 if (type.isIdentified()) {
88 auto convertedType = LLVM::LLVMStructType::getIdentified(
89 type.getContext(), ("_Converted." + type.getName()).str());
91 SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack();
92 if (llvm::count(recursiveStack, type)) {
93 results.push_back(convertedType);
94 return success();
96 recursiveStack.push_back(type);
97 auto popConversionCallStack = llvm::make_scope_exit(
98 [&recursiveStack]() { recursiveStack.pop_back(); });
100 SmallVector<Type> convertedElemTypes;
101 convertedElemTypes.reserve(type.getBody().size());
102 if (failed(convertTypes(type.getBody(), convertedElemTypes)))
103 return std::nullopt;
105 // If the converted type has not been initialized yet, just set its body
106 // to be the converted arguments and return.
107 if (!convertedType.isInitialized()) {
108 if (failed(
109 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
110 return failure();
112 results.push_back(convertedType);
113 return success();
116 // If it has been initialized, has the same body and packed bit, just use
117 // it. This ensures that recursive structs keep being recursive rather
118 // than including a non-updated name.
119 if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
120 convertedType.isPacked() == type.isPacked()) {
121 results.push_back(convertedType);
122 return success();
125 return failure();
128 SmallVector<Type> convertedSubtypes;
129 convertedSubtypes.reserve(type.getBody().size());
130 if (failed(convertTypes(type.getBody(), convertedSubtypes)))
131 return std::nullopt;
133 results.push_back(LLVM::LLVMStructType::getLiteral(
134 type.getContext(), convertedSubtypes, type.isPacked()));
135 return success();
137 addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
138 if (auto element = convertType(type.getElementType()))
139 return LLVM::LLVMArrayType::get(element, type.getNumElements());
140 return std::nullopt;
142 addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
143 Type convertedResType = convertType(type.getReturnType());
144 if (!convertedResType)
145 return std::nullopt;
147 SmallVector<Type> convertedArgTypes;
148 convertedArgTypes.reserve(type.getNumParams());
149 if (failed(convertTypes(type.getParams(), convertedArgTypes)))
150 return std::nullopt;
152 return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
153 type.isVarArg());
156 // Argument materializations convert from the new block argument types
157 // (multiple SSA values that make up a memref descriptor) back to the
158 // original block argument type. The dialect conversion framework will then
159 // insert a target materialization from the original block argument type to
160 // a legal type.
161 addArgumentMaterialization([&](OpBuilder &builder,
162 UnrankedMemRefType resultType,
163 ValueRange inputs, Location loc) {
164 if (inputs.size() == 1) {
165 // Bare pointers are not supported for unranked memrefs because a
166 // memref descriptor cannot be built just from a bare pointer.
167 return Value();
169 Value desc =
170 UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
171 // An argument materialization must return a value of type
172 // `resultType`, so insert a cast from the memref descriptor type
173 // (!llvm.struct) to the original memref type.
174 return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175 .getResult(0);
177 addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178 ValueRange inputs, Location loc) {
179 Value desc;
180 if (inputs.size() == 1) {
181 // This is a bare pointer. We allow bare pointers only for function entry
182 // blocks.
183 BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
184 if (!barePtr)
185 return Value();
186 Block *block = barePtr.getOwner();
187 if (!block->isEntryBlock() ||
188 !isa<FunctionOpInterface>(block->getParentOp()))
189 return Value();
190 desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
191 inputs[0]);
192 } else {
193 desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
195 // An argument materialization must return a value of type `resultType`,
196 // so insert a cast from the memref descriptor type (!llvm.struct) to the
197 // original memref type.
198 return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
199 .getResult(0);
201 // Add generic source and target materializations to handle cases where
202 // non-LLVM types persist after an LLVM conversion.
203 addSourceMaterialization([&](OpBuilder &builder, Type resultType,
204 ValueRange inputs, Location loc) {
205 if (inputs.size() != 1)
206 return Value();
208 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
209 .getResult(0);
211 addTargetMaterialization([&](OpBuilder &builder, Type resultType,
212 ValueRange inputs, Location loc) {
213 if (inputs.size() != 1)
214 return Value();
216 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
217 .getResult(0);
220 // Integer memory spaces map to themselves.
221 addTypeAttributeConversion(
222 [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
225 /// Returns the MLIR context.
226 MLIRContext &LLVMTypeConverter::getContext() const {
227 return *getDialect()->getContext();
230 Type LLVMTypeConverter::getIndexType() const {
231 return IntegerType::get(&getContext(), getIndexTypeBitwidth());
234 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
235 return options.dataLayout.getPointerSizeInBits(addressSpace);
238 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
239 return getIndexType();
242 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
243 return IntegerType::get(&getContext(), type.getWidth());
246 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
247 if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
248 type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
249 type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
250 type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
251 type.isFloat8E8M0FNU())
252 return IntegerType::get(&getContext(), type.getWidth());
253 return type;
256 // Convert a `ComplexType` to an LLVM type. The result is a complex number
257 // struct with entries for the
258 // 1. real part and for the
259 // 2. imaginary part.
260 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
261 auto elementType = convertType(type.getElementType());
262 return LLVM::LLVMStructType::getLiteral(&getContext(),
263 {elementType, elementType});
266 // Except for signatures, MLIR function types are converted into LLVM
267 // pointer-to-function types.
268 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
269 return LLVM::LLVMPointerType::get(type.getContext());
272 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
273 /// function arguments. Returns an empty container if none of these attributes
274 /// are found in any of the arguments.
275 static void
276 filterByValRefArgAttrs(FunctionOpInterface funcOp,
277 SmallVectorImpl<std::optional<NamedAttribute>> &result) {
278 assert(result.empty() && "Unexpected non-empty output");
279 result.resize(funcOp.getNumArguments(), std::nullopt);
280 bool foundByValByRefAttrs = false;
281 for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
282 for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
283 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
284 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
285 foundByValByRefAttrs = true;
286 result[argIdx] = namedAttr;
287 break;
292 if (!foundByValByRefAttrs)
293 result.clear();
296 // Function types are converted to LLVM Function types by recursively converting
297 // argument and result types. If MLIR Function has zero results, the LLVM
298 // Function has one VoidType result. If MLIR Function has more than one result,
299 // they are into an LLVM StructType in their order of appearance.
300 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
301 // `llvm.byref` function arguments which are not LLVM pointers are overridden
302 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
303 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
304 Type LLVMTypeConverter::convertFunctionSignatureImpl(
305 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
306 LLVMTypeConverter::SignatureConversion &result,
307 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
308 // Select the argument converter depending on the calling convention.
309 useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
310 auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
311 : structFuncArgTypeConverter;
312 // Convert argument types one by one and check for errors.
313 for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
314 SmallVector<Type, 8> converted;
315 if (failed(funcArgConverter(*this, type, converted)))
316 return {};
318 // Rewrite converted type of `llvm.byval` or `llvm.byref` function
319 // argument that was not converted to an LLVM pointer types.
320 if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
321 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
322 // If the argument was already converted to an LLVM pointer type, we stop
323 // tracking it as it doesn't need more processing.
324 if (isa<LLVM::LLVMPointerType>(converted[0]))
325 (*byValRefNonPtrAttrs)[idx] = std::nullopt;
326 else
327 converted[0] = LLVM::LLVMPointerType::get(&getContext());
330 result.addInputs(idx, converted);
333 // If function does not return anything, create the void result type,
334 // if it returns on element, convert it, otherwise pack the result types into
335 // a struct.
336 Type resultType =
337 funcTy.getNumResults() == 0
338 ? LLVM::LLVMVoidType::get(&getContext())
339 : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
340 if (!resultType)
341 return {};
342 return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
343 isVariadic);
346 Type LLVMTypeConverter::convertFunctionSignature(
347 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
348 LLVMTypeConverter::SignatureConversion &result) const {
349 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
350 result,
351 /*byValRefNonPtrAttrs=*/nullptr);
354 Type LLVMTypeConverter::convertFunctionSignature(
355 FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
356 LLVMTypeConverter::SignatureConversion &result,
357 SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
358 // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
359 // that were not converted to LLVM pointer types will be returned for further
360 // processing.
361 filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
362 auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
363 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
364 result, &byValRefNonPtrAttrs);
367 /// Converts the function type to a C-compatible format, in particular using
368 /// pointers to memref descriptors for arguments.
369 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
370 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
371 SmallVector<Type, 4> inputs;
373 Type resultType = type.getNumResults() == 0
374 ? LLVM::LLVMVoidType::get(&getContext())
375 : packFunctionResults(type.getResults());
376 if (!resultType)
377 return {};
379 auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
380 auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
381 if (structType) {
382 // Struct types cannot be safely returned via C interface. Make this a
383 // pointer argument, instead.
384 inputs.push_back(ptrType);
385 resultType = LLVM::LLVMVoidType::get(&getContext());
388 for (Type t : type.getInputs()) {
389 auto converted = convertType(t);
390 if (!converted || !LLVM::isCompatibleType(converted))
391 return {};
392 if (isa<MemRefType, UnrankedMemRefType>(t))
393 converted = ptrType;
394 inputs.push_back(converted);
397 return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
400 /// Convert a memref type into a list of LLVM IR types that will form the
401 /// memref descriptor. The result contains the following types:
402 /// 1. The pointer to the allocated data buffer, followed by
403 /// 2. The pointer to the aligned data buffer, followed by
404 /// 3. A lowered `index`-type integer containing the distance between the
405 /// beginning of the buffer and the first element to be accessed through the
406 /// view, followed by
407 /// 4. An array containing as many `index`-type integers as the rank of the
408 /// MemRef: the array represents the size, in number of elements, of the memref
409 /// along the given dimension. For constant MemRef dimensions, the
410 /// corresponding size entry is a constant whose runtime value must match the
411 /// static value, followed by
412 /// 5. A second array containing as many `index`-type integers as the rank of
413 /// the MemRef: the second array represents the "stride" (in tensor abstraction
414 /// sense), i.e. the number of consecutive elements of the underlying buffer.
415 /// TODO: add assertions for the static cases.
417 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
418 /// are expanded into individual index-type elements.
420 /// template <typename Elem, typename Index, size_t Rank>
421 /// struct {
422 /// Elem *allocatedPtr;
423 /// Elem *alignedPtr;
424 /// Index offset;
425 /// Index sizes[Rank]; // omitted when rank == 0
426 /// Index strides[Rank]; // omitted when rank == 0
427 /// };
428 SmallVector<Type, 5>
429 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
430 bool unpackAggregates) const {
431 if (!isStrided(type)) {
432 emitError(
433 UnknownLoc::get(type.getContext()),
434 "conversion to strided form failed either due to non-strided layout "
435 "maps (which should have been normalized away) or other reasons");
436 return {};
439 Type elementType = convertType(type.getElementType());
440 if (!elementType)
441 return {};
443 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
444 if (failed(addressSpace)) {
445 emitError(UnknownLoc::get(type.getContext()),
446 "conversion of memref memory space ")
447 << type.getMemorySpace()
448 << " to integer address space "
449 "failed. Consider adding memory space conversions.";
450 return {};
452 auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
454 auto indexTy = getIndexType();
456 SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
457 auto rank = type.getRank();
458 if (rank == 0)
459 return results;
461 if (unpackAggregates)
462 results.insert(results.end(), 2 * rank, indexTy);
463 else
464 results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
465 return results;
468 unsigned
469 LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
470 const DataLayout &layout) const {
471 // Compute the descriptor size given that of its components indicated above.
472 unsigned space = *getMemRefAddressSpace(type);
473 return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
474 (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
477 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
478 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
479 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
480 // When converting a MemRefType to a struct with descriptor fields, do not
481 // unpack the `sizes` and `strides` arrays.
482 SmallVector<Type, 5> types =
483 getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
484 if (types.empty())
485 return {};
486 return LLVM::LLVMStructType::getLiteral(&getContext(), types);
489 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
490 /// that will form the unranked memref descriptor. In particular, the fields
491 /// for an unranked memref descriptor are:
492 /// 1. index-typed rank, the dynamic rank of this MemRef
493 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
494 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
495 /// be unranked.
496 SmallVector<Type, 2>
497 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
498 return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())};
501 unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
502 UnrankedMemRefType type, const DataLayout &layout) const {
503 // Compute the descriptor size given that of its components indicated above.
504 unsigned space = *getMemRefAddressSpace(type);
505 return layout.getTypeSize(getIndexType()) +
506 llvm::divideCeil(getPointerBitwidth(space), 8);
509 Type LLVMTypeConverter::convertUnrankedMemRefType(
510 UnrankedMemRefType type) const {
511 if (!convertType(type.getElementType()))
512 return {};
513 return LLVM::LLVMStructType::getLiteral(&getContext(),
514 getUnrankedMemRefDescriptorFields());
517 FailureOr<unsigned>
518 LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
519 if (!type.getMemorySpace()) // Default memory space -> 0.
520 return 0;
521 std::optional<Attribute> converted =
522 convertTypeAttribute(type, type.getMemorySpace());
523 if (!converted)
524 return failure();
525 if (!(*converted)) // Conversion to default is 0.
526 return 0;
527 if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
528 if (explicitSpace.getType().isIndex() ||
529 explicitSpace.getType().isSignlessInteger())
530 return explicitSpace.getInt();
532 return failure();
535 // Check if a memref type can be converted to a bare pointer.
536 bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
537 if (isa<UnrankedMemRefType>(type))
538 // Unranked memref is not supported in the bare pointer calling convention.
539 return false;
541 // Check that the memref has static shape, strides and offset. Otherwise, it
542 // cannot be lowered to a bare pointer.
543 auto memrefTy = cast<MemRefType>(type);
544 if (!memrefTy.hasStaticShape())
545 return false;
547 int64_t offset = 0;
548 SmallVector<int64_t, 4> strides;
549 if (failed(getStridesAndOffset(memrefTy, strides, offset)))
550 return false;
552 for (int64_t stride : strides)
553 if (ShapedType::isDynamic(stride))
554 return false;
556 return !ShapedType::isDynamic(offset);
559 /// Convert a memref type to a bare pointer to the memref element type.
560 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
561 if (!canConvertToBarePtr(type))
562 return {};
563 Type elementType = convertType(type.getElementType());
564 if (!elementType)
565 return {};
566 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
567 if (failed(addressSpace))
568 return {};
569 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
572 /// Convert an n-D vector type to an LLVM vector type:
573 /// * 0-D `vector<T>` are converted to vector<1xT>
574 /// * 1-D `vector<axT>` remains as is while,
575 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
576 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
577 /// As LLVM supports arrays of scalable vectors, this method will also convert
578 /// n-D scalable vectors provided that only the trailing dim is scalable.
579 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
580 auto elementType = convertType(type.getElementType());
581 if (!elementType)
582 return {};
583 if (type.getShape().empty())
584 return VectorType::get({1}, elementType);
585 Type vectorType = VectorType::get(type.getShape().back(), elementType,
586 type.getScalableDims().back());
587 assert(LLVM::isCompatibleVectorType(vectorType) &&
588 "expected vector type compatible with the LLVM dialect");
589 // For n-D vector types for which a _non-trailing_ dim is scalable,
590 // return a failure. Supporting such cases would require LLVM
591 // to support something akin "scalable arrays" of vectors.
592 if (llvm::is_contained(type.getScalableDims().drop_back(), true))
593 return failure();
594 auto shape = type.getShape();
595 for (int i = shape.size() - 2; i >= 0; --i)
596 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
597 return vectorType;
600 /// Convert a type in the context of the default or bare pointer calling
601 /// convention. Calling convention sensitive types, such as MemRefType and
602 /// UnrankedMemRefType, are converted following the specific rules for the
603 /// calling convention. Calling convention independent types are converted
604 /// following the default LLVM type conversions.
605 Type LLVMTypeConverter::convertCallingConventionType(
606 Type type, bool useBarePtrCallConv) const {
607 if (useBarePtrCallConv)
608 if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
609 return convertMemRefToBarePtr(memrefTy);
611 return convertType(type);
614 /// Promote the bare pointers in 'values' that resulted from memrefs to
615 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
616 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
617 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
618 ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
619 SmallVectorImpl<Value> &values) const {
620 assert(stdTypes.size() == values.size() &&
621 "The number of types and values doesn't match");
622 for (unsigned i = 0, end = values.size(); i < end; ++i)
623 if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
624 values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
625 memrefTy, values[i]);
628 /// Convert a non-empty list of types of values produced by an operation into an
629 /// LLVM-compatible type. In particular, if more than one value is
630 /// produced, create a literal structure with elements that correspond to each
631 /// of the types converted with `convertType`.
632 Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
633 assert(!types.empty() && "expected non-empty list of type");
634 if (types.size() == 1)
635 return convertType(types[0]);
637 SmallVector<Type> resultTypes;
638 resultTypes.reserve(types.size());
639 for (Type type : types) {
640 Type converted = convertType(type);
641 if (!converted || !LLVM::isCompatibleType(converted))
642 return {};
643 resultTypes.push_back(converted);
646 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
649 /// Convert a non-empty list of types to be returned from a function into an
650 /// LLVM-compatible type. In particular, if more than one value is returned,
651 /// create an LLVM dialect structure type with elements that correspond to each
652 /// of the types converted with `convertCallingConventionType`.
653 Type LLVMTypeConverter::packFunctionResults(TypeRange types,
654 bool useBarePtrCallConv) const {
655 assert(!types.empty() && "expected non-empty list of type");
657 useBarePtrCallConv |= options.useBarePtrCallConv;
658 if (types.size() == 1)
659 return convertCallingConventionType(types.front(), useBarePtrCallConv);
661 SmallVector<Type> resultTypes;
662 resultTypes.reserve(types.size());
663 for (auto t : types) {
664 auto converted = convertCallingConventionType(t, useBarePtrCallConv);
665 if (!converted || !LLVM::isCompatibleType(converted))
666 return {};
667 resultTypes.push_back(converted);
670 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
673 Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
674 OpBuilder &builder) const {
675 // Alloca with proper alignment. We do not expect optimizations of this
676 // alloca op and so we omit allocating at the entry block.
677 auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
678 Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
679 builder.getIndexAttr(1));
680 Value allocated =
681 builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
682 // Store into the alloca'ed descriptor.
683 builder.create<LLVM::StoreOp>(loc, operand, allocated);
684 return allocated;
687 SmallVector<Value, 4>
688 LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
689 ValueRange operands, OpBuilder &builder,
690 bool useBarePtrCallConv) const {
691 SmallVector<Value, 4> promotedOperands;
692 promotedOperands.reserve(operands.size());
693 useBarePtrCallConv |= options.useBarePtrCallConv;
694 for (auto it : llvm::zip(opOperands, operands)) {
695 auto operand = std::get<0>(it);
696 auto llvmOperand = std::get<1>(it);
698 if (useBarePtrCallConv) {
699 // For the bare-ptr calling convention, we only have to extract the
700 // aligned pointer of a memref.
701 if (dyn_cast<MemRefType>(operand.getType())) {
702 MemRefDescriptor desc(llvmOperand);
703 llvmOperand = desc.alignedPtr(builder, loc);
704 } else if (isa<UnrankedMemRefType>(operand.getType())) {
705 llvm_unreachable("Unranked memrefs are not supported");
707 } else {
708 if (isa<UnrankedMemRefType>(operand.getType())) {
709 UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
710 promotedOperands);
711 continue;
713 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
714 MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
715 promotedOperands);
716 continue;
720 promotedOperands.push_back(llvmOperand);
722 return promotedOperands;
725 /// Callback to convert function argument types. It converts a MemRef function
726 /// argument to a list of non-aggregate types containing descriptor
727 /// information, and an UnrankedmemRef function argument to a list containing
728 /// the rank and a pointer to a descriptor struct.
729 LogicalResult
730 mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
731 SmallVectorImpl<Type> &result) {
732 if (auto memref = dyn_cast<MemRefType>(type)) {
733 // In signatures, Memref descriptors are expanded into lists of
734 // non-aggregate values.
735 auto converted =
736 converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
737 if (converted.empty())
738 return failure();
739 result.append(converted.begin(), converted.end());
740 return success();
742 if (isa<UnrankedMemRefType>(type)) {
743 auto converted = converter.getUnrankedMemRefDescriptorFields();
744 if (converted.empty())
745 return failure();
746 result.append(converted.begin(), converted.end());
747 return success();
749 auto converted = converter.convertType(type);
750 if (!converted)
751 return failure();
752 result.push_back(converted);
753 return success();
756 /// Callback to convert function argument types. It converts MemRef function
757 /// arguments to bare pointers to the MemRef element type.
758 LogicalResult
759 mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
760 SmallVectorImpl<Type> &result) {
761 auto llvmTy = converter.convertCallingConventionType(
762 type, /*useBarePointerCallConv=*/true);
763 if (!llvmTy)
764 return failure();
766 result.push_back(llvmTy);
767 return success();