1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
22 SmallVector
<Type
> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
24 // Most of the time, the entry already exists in the map.
25 std::shared_lock
<decltype(callStackMutex
)> lock(callStackMutex
,
27 if (getContext().isMultithreadingEnabled())
29 auto recursiveStack
= conversionCallStack
.find(llvm::get_threadid());
30 if (recursiveStack
!= conversionCallStack
.end())
31 return *recursiveStack
->second
;
34 // First time this thread gets here, we have to get an exclusive access to
36 std::unique_lock
<decltype(callStackMutex
)> lock(callStackMutex
);
37 auto recursiveStackInserted
= conversionCallStack
.insert(std::make_pair(
38 llvm::get_threadid(), std::make_unique
<SmallVector
<Type
>>()));
39 return *recursiveStackInserted
.first
->second
;
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
43 LLVMTypeConverter::LLVMTypeConverter(MLIRContext
*ctx
,
44 const DataLayoutAnalysis
*analysis
)
45 : LLVMTypeConverter(ctx
, LowerToLLVMOptions(ctx
), analysis
) {}
47 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
48 LLVMTypeConverter::LLVMTypeConverter(MLIRContext
*ctx
,
49 const LowerToLLVMOptions
&options
,
50 const DataLayoutAnalysis
*analysis
)
51 : llvmDialect(ctx
->getOrLoadDialect
<LLVM::LLVMDialect
>()), options(options
),
52 dataLayoutAnalysis(analysis
) {
53 assert(llvmDialect
&& "LLVM IR dialect is not registered");
55 // Register conversions for the builtin types.
56 addConversion([&](ComplexType type
) { return convertComplexType(type
); });
57 addConversion([&](FloatType type
) { return convertFloatType(type
); });
58 addConversion([&](FunctionType type
) { return convertFunctionType(type
); });
59 addConversion([&](IndexType type
) { return convertIndexType(type
); });
60 addConversion([&](IntegerType type
) { return convertIntegerType(type
); });
61 addConversion([&](MemRefType type
) { return convertMemRefType(type
); });
63 [&](UnrankedMemRefType type
) { return convertUnrankedMemRefType(type
); });
64 addConversion([&](VectorType type
) -> std::optional
<Type
> {
65 FailureOr
<Type
> llvmType
= convertVectorType(type
);
71 // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72 // before the conversions below since conversions are attempted in reverse
73 // order and those should take priority.
74 addConversion([](Type type
) {
75 return LLVM::isCompatibleType(type
) ? std::optional
<Type
>(type
)
79 addConversion([&](LLVM::LLVMStructType type
, SmallVectorImpl
<Type
> &results
)
80 -> std::optional
<LogicalResult
> {
81 // Fastpath for types that won't be converted by this callback anyway.
82 if (LLVM::isCompatibleType(type
)) {
83 results
.push_back(type
);
87 if (type
.isIdentified()) {
88 auto convertedType
= LLVM::LLVMStructType::getIdentified(
89 type
.getContext(), ("_Converted." + type
.getName()).str());
91 SmallVectorImpl
<Type
> &recursiveStack
= getCurrentThreadRecursiveStack();
92 if (llvm::count(recursiveStack
, type
)) {
93 results
.push_back(convertedType
);
96 recursiveStack
.push_back(type
);
97 auto popConversionCallStack
= llvm::make_scope_exit(
98 [&recursiveStack
]() { recursiveStack
.pop_back(); });
100 SmallVector
<Type
> convertedElemTypes
;
101 convertedElemTypes
.reserve(type
.getBody().size());
102 if (failed(convertTypes(type
.getBody(), convertedElemTypes
)))
105 // If the converted type has not been initialized yet, just set its body
106 // to be the converted arguments and return.
107 if (!convertedType
.isInitialized()) {
109 convertedType
.setBody(convertedElemTypes
, type
.isPacked()))) {
112 results
.push_back(convertedType
);
116 // If it has been initialized, has the same body and packed bit, just use
117 // it. This ensures that recursive structs keep being recursive rather
118 // than including a non-updated name.
119 if (TypeRange(convertedType
.getBody()) == TypeRange(convertedElemTypes
) &&
120 convertedType
.isPacked() == type
.isPacked()) {
121 results
.push_back(convertedType
);
128 SmallVector
<Type
> convertedSubtypes
;
129 convertedSubtypes
.reserve(type
.getBody().size());
130 if (failed(convertTypes(type
.getBody(), convertedSubtypes
)))
133 results
.push_back(LLVM::LLVMStructType::getLiteral(
134 type
.getContext(), convertedSubtypes
, type
.isPacked()));
137 addConversion([&](LLVM::LLVMArrayType type
) -> std::optional
<Type
> {
138 if (auto element
= convertType(type
.getElementType()))
139 return LLVM::LLVMArrayType::get(element
, type
.getNumElements());
142 addConversion([&](LLVM::LLVMFunctionType type
) -> std::optional
<Type
> {
143 Type convertedResType
= convertType(type
.getReturnType());
144 if (!convertedResType
)
147 SmallVector
<Type
> convertedArgTypes
;
148 convertedArgTypes
.reserve(type
.getNumParams());
149 if (failed(convertTypes(type
.getParams(), convertedArgTypes
)))
152 return LLVM::LLVMFunctionType::get(convertedResType
, convertedArgTypes
,
156 // Helper function that checks if the given value range is a bare pointer.
157 auto isBarePointer
= [](ValueRange values
) {
158 return values
.size() == 1 &&
159 isa
<LLVM::LLVMPointerType
>(values
.front().getType());
162 // Argument materializations convert from the new block argument types
163 // (multiple SSA values that make up a memref descriptor) back to the
164 // original block argument type. The dialect conversion framework will then
165 // insert a target materialization from the original block argument type to
167 addArgumentMaterialization([&](OpBuilder
&builder
,
168 UnrankedMemRefType resultType
,
169 ValueRange inputs
, Location loc
) {
170 // Note: Bare pointers are not supported for unranked memrefs because a
171 // memref descriptor cannot be built just from a bare pointer.
172 if (TypeRange(inputs
) != getUnrankedMemRefDescriptorFields())
175 UnrankedMemRefDescriptor::pack(builder
, loc
, *this, resultType
, inputs
);
176 // An argument materialization must return a value of type
177 // `resultType`, so insert a cast from the memref descriptor type
178 // (!llvm.struct) to the original memref type.
179 return builder
.create
<UnrealizedConversionCastOp
>(loc
, resultType
, desc
)
182 addArgumentMaterialization([&](OpBuilder
&builder
, MemRefType resultType
,
183 ValueRange inputs
, Location loc
) {
185 if (isBarePointer(inputs
)) {
186 desc
= MemRefDescriptor::fromStaticShape(builder
, loc
, *this, resultType
,
188 } else if (TypeRange(inputs
) ==
189 getMemRefDescriptorFields(resultType
,
190 /*unpackAggregates=*/true)) {
191 desc
= MemRefDescriptor::pack(builder
, loc
, *this, resultType
, inputs
);
193 // The inputs are neither a bare pointer nor an unpacked memref
194 // descriptor. This materialization function cannot be used.
197 // An argument materialization must return a value of type `resultType`,
198 // so insert a cast from the memref descriptor type (!llvm.struct) to the
199 // original memref type.
200 return builder
.create
<UnrealizedConversionCastOp
>(loc
, resultType
, desc
)
203 // Add generic source and target materializations to handle cases where
204 // non-LLVM types persist after an LLVM conversion.
205 addSourceMaterialization([&](OpBuilder
&builder
, Type resultType
,
206 ValueRange inputs
, Location loc
) {
207 if (inputs
.size() != 1)
210 return builder
.create
<UnrealizedConversionCastOp
>(loc
, resultType
, inputs
)
213 addTargetMaterialization([&](OpBuilder
&builder
, Type resultType
,
214 ValueRange inputs
, Location loc
) {
215 if (inputs
.size() != 1)
218 return builder
.create
<UnrealizedConversionCastOp
>(loc
, resultType
, inputs
)
222 // Integer memory spaces map to themselves.
223 addTypeAttributeConversion(
224 [](BaseMemRefType memref
, IntegerAttr addrspace
) { return addrspace
; });
227 /// Returns the MLIR context.
228 MLIRContext
&LLVMTypeConverter::getContext() const {
229 return *getDialect()->getContext();
232 Type
LLVMTypeConverter::getIndexType() const {
233 return IntegerType::get(&getContext(), getIndexTypeBitwidth());
236 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace
) const {
237 return options
.dataLayout
.getPointerSizeInBits(addressSpace
);
240 Type
LLVMTypeConverter::convertIndexType(IndexType type
) const {
241 return getIndexType();
244 Type
LLVMTypeConverter::convertIntegerType(IntegerType type
) const {
245 return IntegerType::get(&getContext(), type
.getWidth());
248 Type
LLVMTypeConverter::convertFloatType(FloatType type
) const {
249 if (type
.isFloat8E5M2() || type
.isFloat8E4M3() || type
.isFloat8E4M3FN() ||
250 type
.isFloat8E5M2FNUZ() || type
.isFloat8E4M3FNUZ() ||
251 type
.isFloat8E4M3B11FNUZ() || type
.isFloat8E3M4() ||
252 type
.isFloat4E2M1FN() || type
.isFloat6E2M3FN() || type
.isFloat6E3M2FN() ||
253 type
.isFloat8E8M0FNU())
254 return IntegerType::get(&getContext(), type
.getWidth());
258 // Convert a `ComplexType` to an LLVM type. The result is a complex number
259 // struct with entries for the
260 // 1. real part and for the
261 // 2. imaginary part.
262 Type
LLVMTypeConverter::convertComplexType(ComplexType type
) const {
263 auto elementType
= convertType(type
.getElementType());
264 return LLVM::LLVMStructType::getLiteral(&getContext(),
265 {elementType
, elementType
});
268 // Except for signatures, MLIR function types are converted into LLVM
269 // pointer-to-function types.
270 Type
LLVMTypeConverter::convertFunctionType(FunctionType type
) const {
271 return LLVM::LLVMPointerType::get(type
.getContext());
274 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
275 /// function arguments. Returns an empty container if none of these attributes
276 /// are found in any of the arguments.
278 filterByValRefArgAttrs(FunctionOpInterface funcOp
,
279 SmallVectorImpl
<std::optional
<NamedAttribute
>> &result
) {
280 assert(result
.empty() && "Unexpected non-empty output");
281 result
.resize(funcOp
.getNumArguments(), std::nullopt
);
282 bool foundByValByRefAttrs
= false;
283 for (int argIdx
: llvm::seq(funcOp
.getNumArguments())) {
284 for (NamedAttribute namedAttr
: funcOp
.getArgAttrs(argIdx
)) {
285 if ((namedAttr
.getName() == LLVM::LLVMDialect::getByValAttrName() ||
286 namedAttr
.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
287 foundByValByRefAttrs
= true;
288 result
[argIdx
] = namedAttr
;
294 if (!foundByValByRefAttrs
)
298 // Function types are converted to LLVM Function types by recursively converting
299 // argument and result types. If MLIR Function has zero results, the LLVM
300 // Function has one VoidType result. If MLIR Function has more than one result,
301 // they are into an LLVM StructType in their order of appearance.
302 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
303 // `llvm.byref` function arguments which are not LLVM pointers are overridden
304 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
305 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
306 Type
LLVMTypeConverter::convertFunctionSignatureImpl(
307 FunctionType funcTy
, bool isVariadic
, bool useBarePtrCallConv
,
308 LLVMTypeConverter::SignatureConversion
&result
,
309 SmallVectorImpl
<std::optional
<NamedAttribute
>> *byValRefNonPtrAttrs
) const {
310 // Select the argument converter depending on the calling convention.
311 useBarePtrCallConv
= useBarePtrCallConv
|| options
.useBarePtrCallConv
;
312 auto funcArgConverter
= useBarePtrCallConv
? barePtrFuncArgTypeConverter
313 : structFuncArgTypeConverter
;
314 // Convert argument types one by one and check for errors.
315 for (auto [idx
, type
] : llvm::enumerate(funcTy
.getInputs())) {
316 SmallVector
<Type
, 8> converted
;
317 if (failed(funcArgConverter(*this, type
, converted
)))
320 // Rewrite converted type of `llvm.byval` or `llvm.byref` function
321 // argument that was not converted to an LLVM pointer types.
322 if (byValRefNonPtrAttrs
!= nullptr && !byValRefNonPtrAttrs
->empty() &&
323 converted
.size() == 1 && (*byValRefNonPtrAttrs
)[idx
].has_value()) {
324 // If the argument was already converted to an LLVM pointer type, we stop
325 // tracking it as it doesn't need more processing.
326 if (isa
<LLVM::LLVMPointerType
>(converted
[0]))
327 (*byValRefNonPtrAttrs
)[idx
] = std::nullopt
;
329 converted
[0] = LLVM::LLVMPointerType::get(&getContext());
332 result
.addInputs(idx
, converted
);
335 // If function does not return anything, create the void result type,
336 // if it returns on element, convert it, otherwise pack the result types into
339 funcTy
.getNumResults() == 0
340 ? LLVM::LLVMVoidType::get(&getContext())
341 : packFunctionResults(funcTy
.getResults(), useBarePtrCallConv
);
344 return LLVM::LLVMFunctionType::get(resultType
, result
.getConvertedTypes(),
348 Type
LLVMTypeConverter::convertFunctionSignature(
349 FunctionType funcTy
, bool isVariadic
, bool useBarePtrCallConv
,
350 LLVMTypeConverter::SignatureConversion
&result
) const {
351 return convertFunctionSignatureImpl(funcTy
, isVariadic
, useBarePtrCallConv
,
353 /*byValRefNonPtrAttrs=*/nullptr);
356 Type
LLVMTypeConverter::convertFunctionSignature(
357 FunctionOpInterface funcOp
, bool isVariadic
, bool useBarePtrCallConv
,
358 LLVMTypeConverter::SignatureConversion
&result
,
359 SmallVectorImpl
<std::optional
<NamedAttribute
>> &byValRefNonPtrAttrs
) const {
360 // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
361 // that were not converted to LLVM pointer types will be returned for further
363 filterByValRefArgAttrs(funcOp
, byValRefNonPtrAttrs
);
364 auto funcTy
= cast
<FunctionType
>(funcOp
.getFunctionType());
365 return convertFunctionSignatureImpl(funcTy
, isVariadic
, useBarePtrCallConv
,
366 result
, &byValRefNonPtrAttrs
);
369 /// Converts the function type to a C-compatible format, in particular using
370 /// pointers to memref descriptors for arguments.
371 std::pair
<LLVM::LLVMFunctionType
, LLVM::LLVMStructType
>
372 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type
) const {
373 SmallVector
<Type
, 4> inputs
;
375 Type resultType
= type
.getNumResults() == 0
376 ? LLVM::LLVMVoidType::get(&getContext())
377 : packFunctionResults(type
.getResults());
381 auto ptrType
= LLVM::LLVMPointerType::get(type
.getContext());
382 auto structType
= dyn_cast
<LLVM::LLVMStructType
>(resultType
);
384 // Struct types cannot be safely returned via C interface. Make this a
385 // pointer argument, instead.
386 inputs
.push_back(ptrType
);
387 resultType
= LLVM::LLVMVoidType::get(&getContext());
390 for (Type t
: type
.getInputs()) {
391 auto converted
= convertType(t
);
392 if (!converted
|| !LLVM::isCompatibleType(converted
))
394 if (isa
<MemRefType
, UnrankedMemRefType
>(t
))
396 inputs
.push_back(converted
);
399 return {LLVM::LLVMFunctionType::get(resultType
, inputs
), structType
};
402 /// Convert a memref type into a list of LLVM IR types that will form the
403 /// memref descriptor. The result contains the following types:
404 /// 1. The pointer to the allocated data buffer, followed by
405 /// 2. The pointer to the aligned data buffer, followed by
406 /// 3. A lowered `index`-type integer containing the distance between the
407 /// beginning of the buffer and the first element to be accessed through the
408 /// view, followed by
409 /// 4. An array containing as many `index`-type integers as the rank of the
410 /// MemRef: the array represents the size, in number of elements, of the memref
411 /// along the given dimension. For constant MemRef dimensions, the
412 /// corresponding size entry is a constant whose runtime value must match the
413 /// static value, followed by
414 /// 5. A second array containing as many `index`-type integers as the rank of
415 /// the MemRef: the second array represents the "stride" (in tensor abstraction
416 /// sense), i.e. the number of consecutive elements of the underlying buffer.
417 /// TODO: add assertions for the static cases.
419 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
420 /// are expanded into individual index-type elements.
422 /// template <typename Elem, typename Index, size_t Rank>
424 /// Elem *allocatedPtr;
425 /// Elem *alignedPtr;
427 /// Index sizes[Rank]; // omitted when rank == 0
428 /// Index strides[Rank]; // omitted when rank == 0
431 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type
,
432 bool unpackAggregates
) const {
433 if (!isStrided(type
)) {
435 UnknownLoc::get(type
.getContext()),
436 "conversion to strided form failed either due to non-strided layout "
437 "maps (which should have been normalized away) or other reasons");
441 Type elementType
= convertType(type
.getElementType());
445 FailureOr
<unsigned> addressSpace
= getMemRefAddressSpace(type
);
446 if (failed(addressSpace
)) {
447 emitError(UnknownLoc::get(type
.getContext()),
448 "conversion of memref memory space ")
449 << type
.getMemorySpace()
450 << " to integer address space "
451 "failed. Consider adding memory space conversions.";
454 auto ptrTy
= LLVM::LLVMPointerType::get(type
.getContext(), *addressSpace
);
456 auto indexTy
= getIndexType();
458 SmallVector
<Type
, 5> results
= {ptrTy
, ptrTy
, indexTy
};
459 auto rank
= type
.getRank();
463 if (unpackAggregates
)
464 results
.insert(results
.end(), 2 * rank
, indexTy
);
466 results
.insert(results
.end(), 2, LLVM::LLVMArrayType::get(indexTy
, rank
));
471 LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type
,
472 const DataLayout
&layout
) const {
473 // Compute the descriptor size given that of its components indicated above.
474 unsigned space
= *getMemRefAddressSpace(type
);
475 return 2 * llvm::divideCeil(getPointerBitwidth(space
), 8) +
476 (1 + 2 * type
.getRank()) * layout
.getTypeSize(getIndexType());
479 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
480 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
481 Type
LLVMTypeConverter::convertMemRefType(MemRefType type
) const {
482 // When converting a MemRefType to a struct with descriptor fields, do not
483 // unpack the `sizes` and `strides` arrays.
484 SmallVector
<Type
, 5> types
=
485 getMemRefDescriptorFields(type
, /*unpackAggregates=*/false);
488 return LLVM::LLVMStructType::getLiteral(&getContext(), types
);
491 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
492 /// that will form the unranked memref descriptor. In particular, the fields
493 /// for an unranked memref descriptor are:
494 /// 1. index-typed rank, the dynamic rank of this MemRef
495 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
496 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
499 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
500 return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())};
503 unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
504 UnrankedMemRefType type
, const DataLayout
&layout
) const {
505 // Compute the descriptor size given that of its components indicated above.
506 unsigned space
= *getMemRefAddressSpace(type
);
507 return layout
.getTypeSize(getIndexType()) +
508 llvm::divideCeil(getPointerBitwidth(space
), 8);
511 Type
LLVMTypeConverter::convertUnrankedMemRefType(
512 UnrankedMemRefType type
) const {
513 if (!convertType(type
.getElementType()))
515 return LLVM::LLVMStructType::getLiteral(&getContext(),
516 getUnrankedMemRefDescriptorFields());
520 LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type
) const {
521 if (!type
.getMemorySpace()) // Default memory space -> 0.
523 std::optional
<Attribute
> converted
=
524 convertTypeAttribute(type
, type
.getMemorySpace());
527 if (!(*converted
)) // Conversion to default is 0.
529 if (auto explicitSpace
= dyn_cast_if_present
<IntegerAttr
>(*converted
)) {
530 if (explicitSpace
.getType().isIndex() ||
531 explicitSpace
.getType().isSignlessInteger())
532 return explicitSpace
.getInt();
537 // Check if a memref type can be converted to a bare pointer.
538 bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type
) {
539 if (isa
<UnrankedMemRefType
>(type
))
540 // Unranked memref is not supported in the bare pointer calling convention.
543 // Check that the memref has static shape, strides and offset. Otherwise, it
544 // cannot be lowered to a bare pointer.
545 auto memrefTy
= cast
<MemRefType
>(type
);
546 if (!memrefTy
.hasStaticShape())
550 SmallVector
<int64_t, 4> strides
;
551 if (failed(getStridesAndOffset(memrefTy
, strides
, offset
)))
554 for (int64_t stride
: strides
)
555 if (ShapedType::isDynamic(stride
))
558 return !ShapedType::isDynamic(offset
);
561 /// Convert a memref type to a bare pointer to the memref element type.
562 Type
LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type
) const {
563 if (!canConvertToBarePtr(type
))
565 Type elementType
= convertType(type
.getElementType());
568 FailureOr
<unsigned> addressSpace
= getMemRefAddressSpace(type
);
569 if (failed(addressSpace
))
571 return LLVM::LLVMPointerType::get(type
.getContext(), *addressSpace
);
574 /// Convert an n-D vector type to an LLVM vector type:
575 /// * 0-D `vector<T>` are converted to vector<1xT>
576 /// * 1-D `vector<axT>` remains as is while,
577 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
578 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
579 /// As LLVM supports arrays of scalable vectors, this method will also convert
580 /// n-D scalable vectors provided that only the trailing dim is scalable.
581 FailureOr
<Type
> LLVMTypeConverter::convertVectorType(VectorType type
) const {
582 auto elementType
= convertType(type
.getElementType());
585 if (type
.getShape().empty())
586 return VectorType::get({1}, elementType
);
587 Type vectorType
= VectorType::get(type
.getShape().back(), elementType
,
588 type
.getScalableDims().back());
589 assert(LLVM::isCompatibleVectorType(vectorType
) &&
590 "expected vector type compatible with the LLVM dialect");
591 // For n-D vector types for which a _non-trailing_ dim is scalable,
592 // return a failure. Supporting such cases would require LLVM
593 // to support something akin "scalable arrays" of vectors.
594 if (llvm::is_contained(type
.getScalableDims().drop_back(), true))
596 auto shape
= type
.getShape();
597 for (int i
= shape
.size() - 2; i
>= 0; --i
)
598 vectorType
= LLVM::LLVMArrayType::get(vectorType
, shape
[i
]);
602 /// Convert a type in the context of the default or bare pointer calling
603 /// convention. Calling convention sensitive types, such as MemRefType and
604 /// UnrankedMemRefType, are converted following the specific rules for the
605 /// calling convention. Calling convention independent types are converted
606 /// following the default LLVM type conversions.
607 Type
LLVMTypeConverter::convertCallingConventionType(
608 Type type
, bool useBarePtrCallConv
) const {
609 if (useBarePtrCallConv
)
610 if (auto memrefTy
= dyn_cast
<BaseMemRefType
>(type
))
611 return convertMemRefToBarePtr(memrefTy
);
613 return convertType(type
);
616 /// Promote the bare pointers in 'values' that resulted from memrefs to
617 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
618 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
619 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
620 ConversionPatternRewriter
&rewriter
, Location loc
, ArrayRef
<Type
> stdTypes
,
621 SmallVectorImpl
<Value
> &values
) const {
622 assert(stdTypes
.size() == values
.size() &&
623 "The number of types and values doesn't match");
624 for (unsigned i
= 0, end
= values
.size(); i
< end
; ++i
)
625 if (auto memrefTy
= dyn_cast
<MemRefType
>(stdTypes
[i
]))
626 values
[i
] = MemRefDescriptor::fromStaticShape(rewriter
, loc
, *this,
627 memrefTy
, values
[i
]);
630 /// Convert a non-empty list of types of values produced by an operation into an
631 /// LLVM-compatible type. In particular, if more than one value is
632 /// produced, create a literal structure with elements that correspond to each
633 /// of the types converted with `convertType`.
634 Type
LLVMTypeConverter::packOperationResults(TypeRange types
) const {
635 assert(!types
.empty() && "expected non-empty list of type");
636 if (types
.size() == 1)
637 return convertType(types
[0]);
639 SmallVector
<Type
> resultTypes
;
640 resultTypes
.reserve(types
.size());
641 for (Type type
: types
) {
642 Type converted
= convertType(type
);
643 if (!converted
|| !LLVM::isCompatibleType(converted
))
645 resultTypes
.push_back(converted
);
648 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes
);
651 /// Convert a non-empty list of types to be returned from a function into an
652 /// LLVM-compatible type. In particular, if more than one value is returned,
653 /// create an LLVM dialect structure type with elements that correspond to each
654 /// of the types converted with `convertCallingConventionType`.
655 Type
LLVMTypeConverter::packFunctionResults(TypeRange types
,
656 bool useBarePtrCallConv
) const {
657 assert(!types
.empty() && "expected non-empty list of type");
659 useBarePtrCallConv
|= options
.useBarePtrCallConv
;
660 if (types
.size() == 1)
661 return convertCallingConventionType(types
.front(), useBarePtrCallConv
);
663 SmallVector
<Type
> resultTypes
;
664 resultTypes
.reserve(types
.size());
665 for (auto t
: types
) {
666 auto converted
= convertCallingConventionType(t
, useBarePtrCallConv
);
667 if (!converted
|| !LLVM::isCompatibleType(converted
))
669 resultTypes
.push_back(converted
);
672 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes
);
675 Value
LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc
, Value operand
,
676 OpBuilder
&builder
) const {
677 // Alloca with proper alignment. We do not expect optimizations of this
678 // alloca op and so we omit allocating at the entry block.
679 auto ptrType
= LLVM::LLVMPointerType::get(builder
.getContext());
680 Value one
= builder
.create
<LLVM::ConstantOp
>(loc
, builder
.getI64Type(),
681 builder
.getIndexAttr(1));
683 builder
.create
<LLVM::AllocaOp
>(loc
, ptrType
, operand
.getType(), one
);
684 // Store into the alloca'ed descriptor.
685 builder
.create
<LLVM::StoreOp
>(loc
, operand
, allocated
);
689 SmallVector
<Value
, 4>
690 LLVMTypeConverter::promoteOperands(Location loc
, ValueRange opOperands
,
691 ValueRange operands
, OpBuilder
&builder
,
692 bool useBarePtrCallConv
) const {
693 SmallVector
<Value
, 4> promotedOperands
;
694 promotedOperands
.reserve(operands
.size());
695 useBarePtrCallConv
|= options
.useBarePtrCallConv
;
696 for (auto it
: llvm::zip(opOperands
, operands
)) {
697 auto operand
= std::get
<0>(it
);
698 auto llvmOperand
= std::get
<1>(it
);
700 if (useBarePtrCallConv
) {
701 // For the bare-ptr calling convention, we only have to extract the
702 // aligned pointer of a memref.
703 if (dyn_cast
<MemRefType
>(operand
.getType())) {
704 MemRefDescriptor
desc(llvmOperand
);
705 llvmOperand
= desc
.alignedPtr(builder
, loc
);
706 } else if (isa
<UnrankedMemRefType
>(operand
.getType())) {
707 llvm_unreachable("Unranked memrefs are not supported");
710 if (isa
<UnrankedMemRefType
>(operand
.getType())) {
711 UnrankedMemRefDescriptor::unpack(builder
, loc
, llvmOperand
,
715 if (auto memrefType
= dyn_cast
<MemRefType
>(operand
.getType())) {
716 MemRefDescriptor::unpack(builder
, loc
, llvmOperand
, memrefType
,
722 promotedOperands
.push_back(llvmOperand
);
724 return promotedOperands
;
727 /// Callback to convert function argument types. It converts a MemRef function
728 /// argument to a list of non-aggregate types containing descriptor
729 /// information, and an UnrankedmemRef function argument to a list containing
730 /// the rank and a pointer to a descriptor struct.
732 mlir::structFuncArgTypeConverter(const LLVMTypeConverter
&converter
, Type type
,
733 SmallVectorImpl
<Type
> &result
) {
734 if (auto memref
= dyn_cast
<MemRefType
>(type
)) {
735 // In signatures, Memref descriptors are expanded into lists of
736 // non-aggregate values.
738 converter
.getMemRefDescriptorFields(memref
, /*unpackAggregates=*/true);
739 if (converted
.empty())
741 result
.append(converted
.begin(), converted
.end());
744 if (isa
<UnrankedMemRefType
>(type
)) {
745 auto converted
= converter
.getUnrankedMemRefDescriptorFields();
746 if (converted
.empty())
748 result
.append(converted
.begin(), converted
.end());
751 auto converted
= converter
.convertType(type
);
754 result
.push_back(converted
);
758 /// Callback to convert function argument types. It converts MemRef function
759 /// arguments to bare pointers to the MemRef element type.
761 mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter
&converter
, Type type
,
762 SmallVectorImpl
<Type
> &result
) {
763 auto llvmTy
= converter
.convertCallingConventionType(
764 type
, /*useBarePointerCallConv=*/true);
768 result
.push_back(llvmTy
);