1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
22 SmallVector
<Type
> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
24 // Most of the time, the entry already exists in the map.
25 std::shared_lock
<decltype(callStackMutex
)> lock(callStackMutex
,
27 if (getContext().isMultithreadingEnabled())
29 auto recursiveStack
= conversionCallStack
.find(llvm::get_threadid());
30 if (recursiveStack
!= conversionCallStack
.end())
31 return *recursiveStack
->second
;
34 // First time this thread gets here, we have to get an exclusive access to
36 std::unique_lock
<decltype(callStackMutex
)> lock(callStackMutex
);
37 auto recursiveStackInserted
= conversionCallStack
.insert(std::make_pair(
38 llvm::get_threadid(), std::make_unique
<SmallVector
<Type
>>()));
39 return *recursiveStackInserted
.first
->second
;
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
43 LLVMTypeConverter::LLVMTypeConverter(MLIRContext
*ctx
,
44 const DataLayoutAnalysis
*analysis
)
45 : LLVMTypeConverter(ctx
, LowerToLLVMOptions(ctx
), analysis
) {}
47 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
48 LLVMTypeConverter::LLVMTypeConverter(MLIRContext
*ctx
,
49 const LowerToLLVMOptions
&options
,
50 const DataLayoutAnalysis
*analysis
)
51 : llvmDialect(ctx
->getOrLoadDialect
<LLVM::LLVMDialect
>()), options(options
),
52 dataLayoutAnalysis(analysis
) {
53 assert(llvmDialect
&& "LLVM IR dialect is not registered");
55 // Register conversions for the builtin types.
56 addConversion([&](ComplexType type
) { return convertComplexType(type
); });
57 addConversion([&](FloatType type
) { return convertFloatType(type
); });
58 addConversion([&](FunctionType type
) { return convertFunctionType(type
); });
59 addConversion([&](IndexType type
) { return convertIndexType(type
); });
60 addConversion([&](IntegerType type
) { return convertIntegerType(type
); });
61 addConversion([&](MemRefType type
) { return convertMemRefType(type
); });
63 [&](UnrankedMemRefType type
) { return convertUnrankedMemRefType(type
); });
64 addConversion([&](VectorType type
) -> std::optional
<Type
> {
65 FailureOr
<Type
> llvmType
= convertVectorType(type
);
71 // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72 // before the conversions below since conversions are attempted in reverse
73 // order and those should take priority.
74 addConversion([](Type type
) {
75 return LLVM::isCompatibleType(type
) ? std::optional
<Type
>(type
)
79 addConversion([&](LLVM::LLVMStructType type
, SmallVectorImpl
<Type
> &results
)
80 -> std::optional
<LogicalResult
> {
81 // Fastpath for types that won't be converted by this callback anyway.
82 if (LLVM::isCompatibleType(type
)) {
83 results
.push_back(type
);
87 if (type
.isIdentified()) {
88 auto convertedType
= LLVM::LLVMStructType::getIdentified(
89 type
.getContext(), ("_Converted." + type
.getName()).str());
91 SmallVectorImpl
<Type
> &recursiveStack
= getCurrentThreadRecursiveStack();
92 if (llvm::count(recursiveStack
, type
)) {
93 results
.push_back(convertedType
);
96 recursiveStack
.push_back(type
);
97 auto popConversionCallStack
= llvm::make_scope_exit(
98 [&recursiveStack
]() { recursiveStack
.pop_back(); });
100 SmallVector
<Type
> convertedElemTypes
;
101 convertedElemTypes
.reserve(type
.getBody().size());
102 if (failed(convertTypes(type
.getBody(), convertedElemTypes
)))
105 // If the converted type has not been initialized yet, just set its body
106 // to be the converted arguments and return.
107 if (!convertedType
.isInitialized()) {
109 convertedType
.setBody(convertedElemTypes
, type
.isPacked()))) {
112 results
.push_back(convertedType
);
116 // If it has been initialized, has the same body and packed bit, just use
117 // it. This ensures that recursive structs keep being recursive rather
118 // than including a non-updated name.
119 if (TypeRange(convertedType
.getBody()) == TypeRange(convertedElemTypes
) &&
120 convertedType
.isPacked() == type
.isPacked()) {
121 results
.push_back(convertedType
);
128 SmallVector
<Type
> convertedSubtypes
;
129 convertedSubtypes
.reserve(type
.getBody().size());
130 if (failed(convertTypes(type
.getBody(), convertedSubtypes
)))
133 results
.push_back(LLVM::LLVMStructType::getLiteral(
134 type
.getContext(), convertedSubtypes
, type
.isPacked()));
137 addConversion([&](LLVM::LLVMArrayType type
) -> std::optional
<Type
> {
138 if (auto element
= convertType(type
.getElementType()))
139 return LLVM::LLVMArrayType::get(element
, type
.getNumElements());
142 addConversion([&](LLVM::LLVMFunctionType type
) -> std::optional
<Type
> {
143 Type convertedResType
= convertType(type
.getReturnType());
144 if (!convertedResType
)
147 SmallVector
<Type
> convertedArgTypes
;
148 convertedArgTypes
.reserve(type
.getNumParams());
149 if (failed(convertTypes(type
.getParams(), convertedArgTypes
)))
152 return LLVM::LLVMFunctionType::get(convertedResType
, convertedArgTypes
,
156 // Argument materializations convert from the new block argument types
157 // (multiple SSA values that make up a memref descriptor) back to the
158 // original block argument type. The dialect conversion framework will then
159 // insert a target materialization from the original block argument type to
161 addArgumentMaterialization([&](OpBuilder
&builder
,
162 UnrankedMemRefType resultType
,
163 ValueRange inputs
, Location loc
) {
164 if (inputs
.size() == 1) {
165 // Bare pointers are not supported for unranked memrefs because a
166 // memref descriptor cannot be built just from a bare pointer.
170 UnrankedMemRefDescriptor::pack(builder
, loc
, *this, resultType
, inputs
);
171 // An argument materialization must return a value of type
172 // `resultType`, so insert a cast from the memref descriptor type
173 // (!llvm.struct) to the original memref type.
174 return builder
.create
<UnrealizedConversionCastOp
>(loc
, resultType
, desc
)
177 addArgumentMaterialization([&](OpBuilder
&builder
, MemRefType resultType
,
178 ValueRange inputs
, Location loc
) {
180 if (inputs
.size() == 1) {
181 // This is a bare pointer. We allow bare pointers only for function entry
183 BlockArgument barePtr
= dyn_cast
<BlockArgument
>(inputs
.front());
186 Block
*block
= barePtr
.getOwner();
187 if (!block
->isEntryBlock() ||
188 !isa
<FunctionOpInterface
>(block
->getParentOp()))
190 desc
= MemRefDescriptor::fromStaticShape(builder
, loc
, *this, resultType
,
193 desc
= MemRefDescriptor::pack(builder
, loc
, *this, resultType
, inputs
);
195 // An argument materialization must return a value of type `resultType`,
196 // so insert a cast from the memref descriptor type (!llvm.struct) to the
197 // original memref type.
198 return builder
.create
<UnrealizedConversionCastOp
>(loc
, resultType
, desc
)
201 // Add generic source and target materializations to handle cases where
202 // non-LLVM types persist after an LLVM conversion.
203 addSourceMaterialization([&](OpBuilder
&builder
, Type resultType
,
204 ValueRange inputs
, Location loc
) {
205 if (inputs
.size() != 1)
208 return builder
.create
<UnrealizedConversionCastOp
>(loc
, resultType
, inputs
)
211 addTargetMaterialization([&](OpBuilder
&builder
, Type resultType
,
212 ValueRange inputs
, Location loc
) {
213 if (inputs
.size() != 1)
216 return builder
.create
<UnrealizedConversionCastOp
>(loc
, resultType
, inputs
)
220 // Integer memory spaces map to themselves.
221 addTypeAttributeConversion(
222 [](BaseMemRefType memref
, IntegerAttr addrspace
) { return addrspace
; });
225 /// Returns the MLIR context.
226 MLIRContext
&LLVMTypeConverter::getContext() const {
227 return *getDialect()->getContext();
230 Type
LLVMTypeConverter::getIndexType() const {
231 return IntegerType::get(&getContext(), getIndexTypeBitwidth());
234 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace
) const {
235 return options
.dataLayout
.getPointerSizeInBits(addressSpace
);
238 Type
LLVMTypeConverter::convertIndexType(IndexType type
) const {
239 return getIndexType();
242 Type
LLVMTypeConverter::convertIntegerType(IntegerType type
) const {
243 return IntegerType::get(&getContext(), type
.getWidth());
246 Type
LLVMTypeConverter::convertFloatType(FloatType type
) const {
247 if (type
.isFloat8E5M2() || type
.isFloat8E4M3() || type
.isFloat8E4M3FN() ||
248 type
.isFloat8E5M2FNUZ() || type
.isFloat8E4M3FNUZ() ||
249 type
.isFloat8E4M3B11FNUZ() || type
.isFloat8E3M4() ||
250 type
.isFloat4E2M1FN() || type
.isFloat6E2M3FN() || type
.isFloat6E3M2FN() ||
251 type
.isFloat8E8M0FNU())
252 return IntegerType::get(&getContext(), type
.getWidth());
256 // Convert a `ComplexType` to an LLVM type. The result is a complex number
257 // struct with entries for the
258 // 1. real part and for the
259 // 2. imaginary part.
260 Type
LLVMTypeConverter::convertComplexType(ComplexType type
) const {
261 auto elementType
= convertType(type
.getElementType());
262 return LLVM::LLVMStructType::getLiteral(&getContext(),
263 {elementType
, elementType
});
266 // Except for signatures, MLIR function types are converted into LLVM
267 // pointer-to-function types.
268 Type
LLVMTypeConverter::convertFunctionType(FunctionType type
) const {
269 return LLVM::LLVMPointerType::get(type
.getContext());
272 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
273 /// function arguments. Returns an empty container if none of these attributes
274 /// are found in any of the arguments.
276 filterByValRefArgAttrs(FunctionOpInterface funcOp
,
277 SmallVectorImpl
<std::optional
<NamedAttribute
>> &result
) {
278 assert(result
.empty() && "Unexpected non-empty output");
279 result
.resize(funcOp
.getNumArguments(), std::nullopt
);
280 bool foundByValByRefAttrs
= false;
281 for (int argIdx
: llvm::seq(funcOp
.getNumArguments())) {
282 for (NamedAttribute namedAttr
: funcOp
.getArgAttrs(argIdx
)) {
283 if ((namedAttr
.getName() == LLVM::LLVMDialect::getByValAttrName() ||
284 namedAttr
.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
285 foundByValByRefAttrs
= true;
286 result
[argIdx
] = namedAttr
;
292 if (!foundByValByRefAttrs
)
296 // Function types are converted to LLVM Function types by recursively converting
297 // argument and result types. If MLIR Function has zero results, the LLVM
298 // Function has one VoidType result. If MLIR Function has more than one result,
299 // they are into an LLVM StructType in their order of appearance.
300 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
301 // `llvm.byref` function arguments which are not LLVM pointers are overridden
302 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
303 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
304 Type
LLVMTypeConverter::convertFunctionSignatureImpl(
305 FunctionType funcTy
, bool isVariadic
, bool useBarePtrCallConv
,
306 LLVMTypeConverter::SignatureConversion
&result
,
307 SmallVectorImpl
<std::optional
<NamedAttribute
>> *byValRefNonPtrAttrs
) const {
308 // Select the argument converter depending on the calling convention.
309 useBarePtrCallConv
= useBarePtrCallConv
|| options
.useBarePtrCallConv
;
310 auto funcArgConverter
= useBarePtrCallConv
? barePtrFuncArgTypeConverter
311 : structFuncArgTypeConverter
;
312 // Convert argument types one by one and check for errors.
313 for (auto [idx
, type
] : llvm::enumerate(funcTy
.getInputs())) {
314 SmallVector
<Type
, 8> converted
;
315 if (failed(funcArgConverter(*this, type
, converted
)))
318 // Rewrite converted type of `llvm.byval` or `llvm.byref` function
319 // argument that was not converted to an LLVM pointer types.
320 if (byValRefNonPtrAttrs
!= nullptr && !byValRefNonPtrAttrs
->empty() &&
321 converted
.size() == 1 && (*byValRefNonPtrAttrs
)[idx
].has_value()) {
322 // If the argument was already converted to an LLVM pointer type, we stop
323 // tracking it as it doesn't need more processing.
324 if (isa
<LLVM::LLVMPointerType
>(converted
[0]))
325 (*byValRefNonPtrAttrs
)[idx
] = std::nullopt
;
327 converted
[0] = LLVM::LLVMPointerType::get(&getContext());
330 result
.addInputs(idx
, converted
);
333 // If function does not return anything, create the void result type,
334 // if it returns on element, convert it, otherwise pack the result types into
337 funcTy
.getNumResults() == 0
338 ? LLVM::LLVMVoidType::get(&getContext())
339 : packFunctionResults(funcTy
.getResults(), useBarePtrCallConv
);
342 return LLVM::LLVMFunctionType::get(resultType
, result
.getConvertedTypes(),
346 Type
LLVMTypeConverter::convertFunctionSignature(
347 FunctionType funcTy
, bool isVariadic
, bool useBarePtrCallConv
,
348 LLVMTypeConverter::SignatureConversion
&result
) const {
349 return convertFunctionSignatureImpl(funcTy
, isVariadic
, useBarePtrCallConv
,
351 /*byValRefNonPtrAttrs=*/nullptr);
354 Type
LLVMTypeConverter::convertFunctionSignature(
355 FunctionOpInterface funcOp
, bool isVariadic
, bool useBarePtrCallConv
,
356 LLVMTypeConverter::SignatureConversion
&result
,
357 SmallVectorImpl
<std::optional
<NamedAttribute
>> &byValRefNonPtrAttrs
) const {
358 // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
359 // that were not converted to LLVM pointer types will be returned for further
361 filterByValRefArgAttrs(funcOp
, byValRefNonPtrAttrs
);
362 auto funcTy
= cast
<FunctionType
>(funcOp
.getFunctionType());
363 return convertFunctionSignatureImpl(funcTy
, isVariadic
, useBarePtrCallConv
,
364 result
, &byValRefNonPtrAttrs
);
367 /// Converts the function type to a C-compatible format, in particular using
368 /// pointers to memref descriptors for arguments.
369 std::pair
<LLVM::LLVMFunctionType
, LLVM::LLVMStructType
>
370 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type
) const {
371 SmallVector
<Type
, 4> inputs
;
373 Type resultType
= type
.getNumResults() == 0
374 ? LLVM::LLVMVoidType::get(&getContext())
375 : packFunctionResults(type
.getResults());
379 auto ptrType
= LLVM::LLVMPointerType::get(type
.getContext());
380 auto structType
= dyn_cast
<LLVM::LLVMStructType
>(resultType
);
382 // Struct types cannot be safely returned via C interface. Make this a
383 // pointer argument, instead.
384 inputs
.push_back(ptrType
);
385 resultType
= LLVM::LLVMVoidType::get(&getContext());
388 for (Type t
: type
.getInputs()) {
389 auto converted
= convertType(t
);
390 if (!converted
|| !LLVM::isCompatibleType(converted
))
392 if (isa
<MemRefType
, UnrankedMemRefType
>(t
))
394 inputs
.push_back(converted
);
397 return {LLVM::LLVMFunctionType::get(resultType
, inputs
), structType
};
400 /// Convert a memref type into a list of LLVM IR types that will form the
401 /// memref descriptor. The result contains the following types:
402 /// 1. The pointer to the allocated data buffer, followed by
403 /// 2. The pointer to the aligned data buffer, followed by
404 /// 3. A lowered `index`-type integer containing the distance between the
405 /// beginning of the buffer and the first element to be accessed through the
406 /// view, followed by
407 /// 4. An array containing as many `index`-type integers as the rank of the
408 /// MemRef: the array represents the size, in number of elements, of the memref
409 /// along the given dimension. For constant MemRef dimensions, the
410 /// corresponding size entry is a constant whose runtime value must match the
411 /// static value, followed by
412 /// 5. A second array containing as many `index`-type integers as the rank of
413 /// the MemRef: the second array represents the "stride" (in tensor abstraction
414 /// sense), i.e. the number of consecutive elements of the underlying buffer.
415 /// TODO: add assertions for the static cases.
417 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
418 /// are expanded into individual index-type elements.
420 /// template <typename Elem, typename Index, size_t Rank>
422 /// Elem *allocatedPtr;
423 /// Elem *alignedPtr;
425 /// Index sizes[Rank]; // omitted when rank == 0
426 /// Index strides[Rank]; // omitted when rank == 0
429 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type
,
430 bool unpackAggregates
) const {
431 if (!isStrided(type
)) {
433 UnknownLoc::get(type
.getContext()),
434 "conversion to strided form failed either due to non-strided layout "
435 "maps (which should have been normalized away) or other reasons");
439 Type elementType
= convertType(type
.getElementType());
443 FailureOr
<unsigned> addressSpace
= getMemRefAddressSpace(type
);
444 if (failed(addressSpace
)) {
445 emitError(UnknownLoc::get(type
.getContext()),
446 "conversion of memref memory space ")
447 << type
.getMemorySpace()
448 << " to integer address space "
449 "failed. Consider adding memory space conversions.";
452 auto ptrTy
= LLVM::LLVMPointerType::get(type
.getContext(), *addressSpace
);
454 auto indexTy
= getIndexType();
456 SmallVector
<Type
, 5> results
= {ptrTy
, ptrTy
, indexTy
};
457 auto rank
= type
.getRank();
461 if (unpackAggregates
)
462 results
.insert(results
.end(), 2 * rank
, indexTy
);
464 results
.insert(results
.end(), 2, LLVM::LLVMArrayType::get(indexTy
, rank
));
469 LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type
,
470 const DataLayout
&layout
) const {
471 // Compute the descriptor size given that of its components indicated above.
472 unsigned space
= *getMemRefAddressSpace(type
);
473 return 2 * llvm::divideCeil(getPointerBitwidth(space
), 8) +
474 (1 + 2 * type
.getRank()) * layout
.getTypeSize(getIndexType());
477 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
478 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
479 Type
LLVMTypeConverter::convertMemRefType(MemRefType type
) const {
480 // When converting a MemRefType to a struct with descriptor fields, do not
481 // unpack the `sizes` and `strides` arrays.
482 SmallVector
<Type
, 5> types
=
483 getMemRefDescriptorFields(type
, /*unpackAggregates=*/false);
486 return LLVM::LLVMStructType::getLiteral(&getContext(), types
);
489 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
490 /// that will form the unranked memref descriptor. In particular, the fields
491 /// for an unranked memref descriptor are:
492 /// 1. index-typed rank, the dynamic rank of this MemRef
493 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
494 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
497 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
498 return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())};
501 unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
502 UnrankedMemRefType type
, const DataLayout
&layout
) const {
503 // Compute the descriptor size given that of its components indicated above.
504 unsigned space
= *getMemRefAddressSpace(type
);
505 return layout
.getTypeSize(getIndexType()) +
506 llvm::divideCeil(getPointerBitwidth(space
), 8);
509 Type
LLVMTypeConverter::convertUnrankedMemRefType(
510 UnrankedMemRefType type
) const {
511 if (!convertType(type
.getElementType()))
513 return LLVM::LLVMStructType::getLiteral(&getContext(),
514 getUnrankedMemRefDescriptorFields());
518 LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type
) const {
519 if (!type
.getMemorySpace()) // Default memory space -> 0.
521 std::optional
<Attribute
> converted
=
522 convertTypeAttribute(type
, type
.getMemorySpace());
525 if (!(*converted
)) // Conversion to default is 0.
527 if (auto explicitSpace
= dyn_cast_if_present
<IntegerAttr
>(*converted
)) {
528 if (explicitSpace
.getType().isIndex() ||
529 explicitSpace
.getType().isSignlessInteger())
530 return explicitSpace
.getInt();
535 // Check if a memref type can be converted to a bare pointer.
536 bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type
) {
537 if (isa
<UnrankedMemRefType
>(type
))
538 // Unranked memref is not supported in the bare pointer calling convention.
541 // Check that the memref has static shape, strides and offset. Otherwise, it
542 // cannot be lowered to a bare pointer.
543 auto memrefTy
= cast
<MemRefType
>(type
);
544 if (!memrefTy
.hasStaticShape())
548 SmallVector
<int64_t, 4> strides
;
549 if (failed(getStridesAndOffset(memrefTy
, strides
, offset
)))
552 for (int64_t stride
: strides
)
553 if (ShapedType::isDynamic(stride
))
556 return !ShapedType::isDynamic(offset
);
559 /// Convert a memref type to a bare pointer to the memref element type.
560 Type
LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type
) const {
561 if (!canConvertToBarePtr(type
))
563 Type elementType
= convertType(type
.getElementType());
566 FailureOr
<unsigned> addressSpace
= getMemRefAddressSpace(type
);
567 if (failed(addressSpace
))
569 return LLVM::LLVMPointerType::get(type
.getContext(), *addressSpace
);
572 /// Convert an n-D vector type to an LLVM vector type:
573 /// * 0-D `vector<T>` are converted to vector<1xT>
574 /// * 1-D `vector<axT>` remains as is while,
575 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
576 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
577 /// As LLVM supports arrays of scalable vectors, this method will also convert
578 /// n-D scalable vectors provided that only the trailing dim is scalable.
579 FailureOr
<Type
> LLVMTypeConverter::convertVectorType(VectorType type
) const {
580 auto elementType
= convertType(type
.getElementType());
583 if (type
.getShape().empty())
584 return VectorType::get({1}, elementType
);
585 Type vectorType
= VectorType::get(type
.getShape().back(), elementType
,
586 type
.getScalableDims().back());
587 assert(LLVM::isCompatibleVectorType(vectorType
) &&
588 "expected vector type compatible with the LLVM dialect");
589 // For n-D vector types for which a _non-trailing_ dim is scalable,
590 // return a failure. Supporting such cases would require LLVM
591 // to support something akin "scalable arrays" of vectors.
592 if (llvm::is_contained(type
.getScalableDims().drop_back(), true))
594 auto shape
= type
.getShape();
595 for (int i
= shape
.size() - 2; i
>= 0; --i
)
596 vectorType
= LLVM::LLVMArrayType::get(vectorType
, shape
[i
]);
600 /// Convert a type in the context of the default or bare pointer calling
601 /// convention. Calling convention sensitive types, such as MemRefType and
602 /// UnrankedMemRefType, are converted following the specific rules for the
603 /// calling convention. Calling convention independent types are converted
604 /// following the default LLVM type conversions.
605 Type
LLVMTypeConverter::convertCallingConventionType(
606 Type type
, bool useBarePtrCallConv
) const {
607 if (useBarePtrCallConv
)
608 if (auto memrefTy
= dyn_cast
<BaseMemRefType
>(type
))
609 return convertMemRefToBarePtr(memrefTy
);
611 return convertType(type
);
614 /// Promote the bare pointers in 'values' that resulted from memrefs to
615 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
616 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
617 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
618 ConversionPatternRewriter
&rewriter
, Location loc
, ArrayRef
<Type
> stdTypes
,
619 SmallVectorImpl
<Value
> &values
) const {
620 assert(stdTypes
.size() == values
.size() &&
621 "The number of types and values doesn't match");
622 for (unsigned i
= 0, end
= values
.size(); i
< end
; ++i
)
623 if (auto memrefTy
= dyn_cast
<MemRefType
>(stdTypes
[i
]))
624 values
[i
] = MemRefDescriptor::fromStaticShape(rewriter
, loc
, *this,
625 memrefTy
, values
[i
]);
628 /// Convert a non-empty list of types of values produced by an operation into an
629 /// LLVM-compatible type. In particular, if more than one value is
630 /// produced, create a literal structure with elements that correspond to each
631 /// of the types converted with `convertType`.
632 Type
LLVMTypeConverter::packOperationResults(TypeRange types
) const {
633 assert(!types
.empty() && "expected non-empty list of type");
634 if (types
.size() == 1)
635 return convertType(types
[0]);
637 SmallVector
<Type
> resultTypes
;
638 resultTypes
.reserve(types
.size());
639 for (Type type
: types
) {
640 Type converted
= convertType(type
);
641 if (!converted
|| !LLVM::isCompatibleType(converted
))
643 resultTypes
.push_back(converted
);
646 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes
);
649 /// Convert a non-empty list of types to be returned from a function into an
650 /// LLVM-compatible type. In particular, if more than one value is returned,
651 /// create an LLVM dialect structure type with elements that correspond to each
652 /// of the types converted with `convertCallingConventionType`.
653 Type
LLVMTypeConverter::packFunctionResults(TypeRange types
,
654 bool useBarePtrCallConv
) const {
655 assert(!types
.empty() && "expected non-empty list of type");
657 useBarePtrCallConv
|= options
.useBarePtrCallConv
;
658 if (types
.size() == 1)
659 return convertCallingConventionType(types
.front(), useBarePtrCallConv
);
661 SmallVector
<Type
> resultTypes
;
662 resultTypes
.reserve(types
.size());
663 for (auto t
: types
) {
664 auto converted
= convertCallingConventionType(t
, useBarePtrCallConv
);
665 if (!converted
|| !LLVM::isCompatibleType(converted
))
667 resultTypes
.push_back(converted
);
670 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes
);
673 Value
LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc
, Value operand
,
674 OpBuilder
&builder
) const {
675 // Alloca with proper alignment. We do not expect optimizations of this
676 // alloca op and so we omit allocating at the entry block.
677 auto ptrType
= LLVM::LLVMPointerType::get(builder
.getContext());
678 Value one
= builder
.create
<LLVM::ConstantOp
>(loc
, builder
.getI64Type(),
679 builder
.getIndexAttr(1));
681 builder
.create
<LLVM::AllocaOp
>(loc
, ptrType
, operand
.getType(), one
);
682 // Store into the alloca'ed descriptor.
683 builder
.create
<LLVM::StoreOp
>(loc
, operand
, allocated
);
687 SmallVector
<Value
, 4>
688 LLVMTypeConverter::promoteOperands(Location loc
, ValueRange opOperands
,
689 ValueRange operands
, OpBuilder
&builder
,
690 bool useBarePtrCallConv
) const {
691 SmallVector
<Value
, 4> promotedOperands
;
692 promotedOperands
.reserve(operands
.size());
693 useBarePtrCallConv
|= options
.useBarePtrCallConv
;
694 for (auto it
: llvm::zip(opOperands
, operands
)) {
695 auto operand
= std::get
<0>(it
);
696 auto llvmOperand
= std::get
<1>(it
);
698 if (useBarePtrCallConv
) {
699 // For the bare-ptr calling convention, we only have to extract the
700 // aligned pointer of a memref.
701 if (dyn_cast
<MemRefType
>(operand
.getType())) {
702 MemRefDescriptor
desc(llvmOperand
);
703 llvmOperand
= desc
.alignedPtr(builder
, loc
);
704 } else if (isa
<UnrankedMemRefType
>(operand
.getType())) {
705 llvm_unreachable("Unranked memrefs are not supported");
708 if (isa
<UnrankedMemRefType
>(operand
.getType())) {
709 UnrankedMemRefDescriptor::unpack(builder
, loc
, llvmOperand
,
713 if (auto memrefType
= dyn_cast
<MemRefType
>(operand
.getType())) {
714 MemRefDescriptor::unpack(builder
, loc
, llvmOperand
, memrefType
,
720 promotedOperands
.push_back(llvmOperand
);
722 return promotedOperands
;
725 /// Callback to convert function argument types. It converts a MemRef function
726 /// argument to a list of non-aggregate types containing descriptor
727 /// information, and an UnrankedmemRef function argument to a list containing
728 /// the rank and a pointer to a descriptor struct.
730 mlir::structFuncArgTypeConverter(const LLVMTypeConverter
&converter
, Type type
,
731 SmallVectorImpl
<Type
> &result
) {
732 if (auto memref
= dyn_cast
<MemRefType
>(type
)) {
733 // In signatures, Memref descriptors are expanded into lists of
734 // non-aggregate values.
736 converter
.getMemRefDescriptorFields(memref
, /*unpackAggregates=*/true);
737 if (converted
.empty())
739 result
.append(converted
.begin(), converted
.end());
742 if (isa
<UnrankedMemRefType
>(type
)) {
743 auto converted
= converter
.getUnrankedMemRefDescriptorFields();
744 if (converted
.empty())
746 result
.append(converted
.begin(), converted
.end());
749 auto converted
= converter
.convertType(type
);
752 result
.push_back(converted
);
756 /// Callback to convert function argument types. It converts MemRef function
757 /// arguments to bare pointers to the MemRef element type.
759 mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter
&converter
, Type type
,
760 SmallVectorImpl
<Type
> &result
) {
761 auto llvmTy
= converter
.convertCallingConventionType(
762 type
, /*useBarePointerCallConv=*/true);
766 result
.push_back(llvmTy
);