[analyzer] Avoid out-of-order node traversal on void return (#117863)
[llvm-project.git] / mlir / lib / Conversion / LLVMCommon / TypeConverter.cpp
blob59b0f5c9b09bcdd7d097b22757b0204a2647744e
1 //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10 #include "MemRefDescriptor.h"
11 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
12 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14 #include "llvm/ADT/ScopeExit.h"
15 #include "llvm/Support/Threading.h"
16 #include <memory>
17 #include <mutex>
18 #include <optional>
20 using namespace mlir;
22 SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
24 // Most of the time, the entry already exists in the map.
25 std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26 std::defer_lock);
27 if (getContext().isMultithreadingEnabled())
28 lock.lock();
29 auto recursiveStack = conversionCallStack.find(llvm::get_threadid());
30 if (recursiveStack != conversionCallStack.end())
31 return *recursiveStack->second;
34 // First time this thread gets here, we have to get an exclusive access to
35 // inset in the map
36 std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37 auto recursiveStackInserted = conversionCallStack.insert(std::make_pair(
38 llvm::get_threadid(), std::make_unique<SmallVector<Type>>()));
39 return *recursiveStackInserted.first->second;
42 /// Create an LLVMTypeConverter using default LowerToLLVMOptions.
43 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
44 const DataLayoutAnalysis *analysis)
45 : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
47 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
48 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
49 const LowerToLLVMOptions &options,
50 const DataLayoutAnalysis *analysis)
51 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
52 dataLayoutAnalysis(analysis) {
53 assert(llvmDialect && "LLVM IR dialect is not registered");
55 // Register conversions for the builtin types.
56 addConversion([&](ComplexType type) { return convertComplexType(type); });
57 addConversion([&](FloatType type) { return convertFloatType(type); });
58 addConversion([&](FunctionType type) { return convertFunctionType(type); });
59 addConversion([&](IndexType type) { return convertIndexType(type); });
60 addConversion([&](IntegerType type) { return convertIntegerType(type); });
61 addConversion([&](MemRefType type) { return convertMemRefType(type); });
62 addConversion(
63 [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
64 addConversion([&](VectorType type) -> std::optional<Type> {
65 FailureOr<Type> llvmType = convertVectorType(type);
66 if (failed(llvmType))
67 return std::nullopt;
68 return llvmType;
69 });
71 // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72 // before the conversions below since conversions are attempted in reverse
73 // order and those should take priority.
74 addConversion([](Type type) {
75 return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
76 : std::nullopt;
77 });
79 addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
80 -> std::optional<LogicalResult> {
81 // Fastpath for types that won't be converted by this callback anyway.
82 if (LLVM::isCompatibleType(type)) {
83 results.push_back(type);
84 return success();
87 if (type.isIdentified()) {
88 auto convertedType = LLVM::LLVMStructType::getIdentified(
89 type.getContext(), ("_Converted." + type.getName()).str());
91 SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack();
92 if (llvm::count(recursiveStack, type)) {
93 results.push_back(convertedType);
94 return success();
96 recursiveStack.push_back(type);
97 auto popConversionCallStack = llvm::make_scope_exit(
98 [&recursiveStack]() { recursiveStack.pop_back(); });
100 SmallVector<Type> convertedElemTypes;
101 convertedElemTypes.reserve(type.getBody().size());
102 if (failed(convertTypes(type.getBody(), convertedElemTypes)))
103 return std::nullopt;
105 // If the converted type has not been initialized yet, just set its body
106 // to be the converted arguments and return.
107 if (!convertedType.isInitialized()) {
108 if (failed(
109 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
110 return failure();
112 results.push_back(convertedType);
113 return success();
116 // If it has been initialized, has the same body and packed bit, just use
117 // it. This ensures that recursive structs keep being recursive rather
118 // than including a non-updated name.
119 if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
120 convertedType.isPacked() == type.isPacked()) {
121 results.push_back(convertedType);
122 return success();
125 return failure();
128 SmallVector<Type> convertedSubtypes;
129 convertedSubtypes.reserve(type.getBody().size());
130 if (failed(convertTypes(type.getBody(), convertedSubtypes)))
131 return std::nullopt;
133 results.push_back(LLVM::LLVMStructType::getLiteral(
134 type.getContext(), convertedSubtypes, type.isPacked()));
135 return success();
137 addConversion([&](LLVM::LLVMArrayType type) -> std::optional<Type> {
138 if (auto element = convertType(type.getElementType()))
139 return LLVM::LLVMArrayType::get(element, type.getNumElements());
140 return std::nullopt;
142 addConversion([&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
143 Type convertedResType = convertType(type.getReturnType());
144 if (!convertedResType)
145 return std::nullopt;
147 SmallVector<Type> convertedArgTypes;
148 convertedArgTypes.reserve(type.getNumParams());
149 if (failed(convertTypes(type.getParams(), convertedArgTypes)))
150 return std::nullopt;
152 return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
153 type.isVarArg());
156 // Helper function that checks if the given value range is a bare pointer.
157 auto isBarePointer = [](ValueRange values) {
158 return values.size() == 1 &&
159 isa<LLVM::LLVMPointerType>(values.front().getType());
162 // Argument materializations convert from the new block argument types
163 // (multiple SSA values that make up a memref descriptor) back to the
164 // original block argument type. The dialect conversion framework will then
165 // insert a target materialization from the original block argument type to
166 // a legal type.
167 addArgumentMaterialization([&](OpBuilder &builder,
168 UnrankedMemRefType resultType,
169 ValueRange inputs, Location loc) {
170 // Note: Bare pointers are not supported for unranked memrefs because a
171 // memref descriptor cannot be built just from a bare pointer.
172 if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
173 return Value();
174 Value desc =
175 UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
176 // An argument materialization must return a value of type
177 // `resultType`, so insert a cast from the memref descriptor type
178 // (!llvm.struct) to the original memref type.
179 return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
180 .getResult(0);
182 addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
183 ValueRange inputs, Location loc) {
184 Value desc;
185 if (isBarePointer(inputs)) {
186 desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
187 inputs[0]);
188 } else if (TypeRange(inputs) ==
189 getMemRefDescriptorFields(resultType,
190 /*unpackAggregates=*/true)) {
191 desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
192 } else {
193 // The inputs are neither a bare pointer nor an unpacked memref
194 // descriptor. This materialization function cannot be used.
195 return Value();
197 // An argument materialization must return a value of type `resultType`,
198 // so insert a cast from the memref descriptor type (!llvm.struct) to the
199 // original memref type.
200 return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
201 .getResult(0);
203 // Add generic source and target materializations to handle cases where
204 // non-LLVM types persist after an LLVM conversion.
205 addSourceMaterialization([&](OpBuilder &builder, Type resultType,
206 ValueRange inputs, Location loc) {
207 if (inputs.size() != 1)
208 return Value();
210 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
211 .getResult(0);
213 addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214 ValueRange inputs, Location loc) {
215 if (inputs.size() != 1)
216 return Value();
218 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
219 .getResult(0);
222 // Integer memory spaces map to themselves.
223 addTypeAttributeConversion(
224 [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
227 /// Returns the MLIR context.
228 MLIRContext &LLVMTypeConverter::getContext() const {
229 return *getDialect()->getContext();
232 Type LLVMTypeConverter::getIndexType() const {
233 return IntegerType::get(&getContext(), getIndexTypeBitwidth());
236 unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
237 return options.dataLayout.getPointerSizeInBits(addressSpace);
240 Type LLVMTypeConverter::convertIndexType(IndexType type) const {
241 return getIndexType();
244 Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
245 return IntegerType::get(&getContext(), type.getWidth());
248 Type LLVMTypeConverter::convertFloatType(FloatType type) const {
249 if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
250 type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
251 type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
252 type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
253 type.isFloat8E8M0FNU())
254 return IntegerType::get(&getContext(), type.getWidth());
255 return type;
258 // Convert a `ComplexType` to an LLVM type. The result is a complex number
259 // struct with entries for the
260 // 1. real part and for the
261 // 2. imaginary part.
262 Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
263 auto elementType = convertType(type.getElementType());
264 return LLVM::LLVMStructType::getLiteral(&getContext(),
265 {elementType, elementType});
268 // Except for signatures, MLIR function types are converted into LLVM
269 // pointer-to-function types.
270 Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
271 return LLVM::LLVMPointerType::get(type.getContext());
274 /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
275 /// function arguments. Returns an empty container if none of these attributes
276 /// are found in any of the arguments.
277 static void
278 filterByValRefArgAttrs(FunctionOpInterface funcOp,
279 SmallVectorImpl<std::optional<NamedAttribute>> &result) {
280 assert(result.empty() && "Unexpected non-empty output");
281 result.resize(funcOp.getNumArguments(), std::nullopt);
282 bool foundByValByRefAttrs = false;
283 for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
284 for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
285 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
286 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
287 foundByValByRefAttrs = true;
288 result[argIdx] = namedAttr;
289 break;
294 if (!foundByValByRefAttrs)
295 result.clear();
298 // Function types are converted to LLVM Function types by recursively converting
299 // argument and result types. If MLIR Function has zero results, the LLVM
300 // Function has one VoidType result. If MLIR Function has more than one result,
301 // they are into an LLVM StructType in their order of appearance.
302 // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
303 // `llvm.byref` function arguments which are not LLVM pointers are overridden
304 // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
305 // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
306 Type LLVMTypeConverter::convertFunctionSignatureImpl(
307 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
308 LLVMTypeConverter::SignatureConversion &result,
309 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
310 // Select the argument converter depending on the calling convention.
311 useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
312 auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
313 : structFuncArgTypeConverter;
314 // Convert argument types one by one and check for errors.
315 for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
316 SmallVector<Type, 8> converted;
317 if (failed(funcArgConverter(*this, type, converted)))
318 return {};
320 // Rewrite converted type of `llvm.byval` or `llvm.byref` function
321 // argument that was not converted to an LLVM pointer types.
322 if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
323 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
324 // If the argument was already converted to an LLVM pointer type, we stop
325 // tracking it as it doesn't need more processing.
326 if (isa<LLVM::LLVMPointerType>(converted[0]))
327 (*byValRefNonPtrAttrs)[idx] = std::nullopt;
328 else
329 converted[0] = LLVM::LLVMPointerType::get(&getContext());
332 result.addInputs(idx, converted);
335 // If function does not return anything, create the void result type,
336 // if it returns on element, convert it, otherwise pack the result types into
337 // a struct.
338 Type resultType =
339 funcTy.getNumResults() == 0
340 ? LLVM::LLVMVoidType::get(&getContext())
341 : packFunctionResults(funcTy.getResults(), useBarePtrCallConv);
342 if (!resultType)
343 return {};
344 return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
345 isVariadic);
348 Type LLVMTypeConverter::convertFunctionSignature(
349 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
350 LLVMTypeConverter::SignatureConversion &result) const {
351 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
352 result,
353 /*byValRefNonPtrAttrs=*/nullptr);
356 Type LLVMTypeConverter::convertFunctionSignature(
357 FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
358 LLVMTypeConverter::SignatureConversion &result,
359 SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
360 // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
361 // that were not converted to LLVM pointer types will be returned for further
362 // processing.
363 filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
364 auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
365 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
366 result, &byValRefNonPtrAttrs);
369 /// Converts the function type to a C-compatible format, in particular using
370 /// pointers to memref descriptors for arguments.
371 std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
372 LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
373 SmallVector<Type, 4> inputs;
375 Type resultType = type.getNumResults() == 0
376 ? LLVM::LLVMVoidType::get(&getContext())
377 : packFunctionResults(type.getResults());
378 if (!resultType)
379 return {};
381 auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
382 auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
383 if (structType) {
384 // Struct types cannot be safely returned via C interface. Make this a
385 // pointer argument, instead.
386 inputs.push_back(ptrType);
387 resultType = LLVM::LLVMVoidType::get(&getContext());
390 for (Type t : type.getInputs()) {
391 auto converted = convertType(t);
392 if (!converted || !LLVM::isCompatibleType(converted))
393 return {};
394 if (isa<MemRefType, UnrankedMemRefType>(t))
395 converted = ptrType;
396 inputs.push_back(converted);
399 return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
402 /// Convert a memref type into a list of LLVM IR types that will form the
403 /// memref descriptor. The result contains the following types:
404 /// 1. The pointer to the allocated data buffer, followed by
405 /// 2. The pointer to the aligned data buffer, followed by
406 /// 3. A lowered `index`-type integer containing the distance between the
407 /// beginning of the buffer and the first element to be accessed through the
408 /// view, followed by
409 /// 4. An array containing as many `index`-type integers as the rank of the
410 /// MemRef: the array represents the size, in number of elements, of the memref
411 /// along the given dimension. For constant MemRef dimensions, the
412 /// corresponding size entry is a constant whose runtime value must match the
413 /// static value, followed by
414 /// 5. A second array containing as many `index`-type integers as the rank of
415 /// the MemRef: the second array represents the "stride" (in tensor abstraction
416 /// sense), i.e. the number of consecutive elements of the underlying buffer.
417 /// TODO: add assertions for the static cases.
419 /// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
420 /// are expanded into individual index-type elements.
422 /// template <typename Elem, typename Index, size_t Rank>
423 /// struct {
424 /// Elem *allocatedPtr;
425 /// Elem *alignedPtr;
426 /// Index offset;
427 /// Index sizes[Rank]; // omitted when rank == 0
428 /// Index strides[Rank]; // omitted when rank == 0
429 /// };
430 SmallVector<Type, 5>
431 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
432 bool unpackAggregates) const {
433 if (!isStrided(type)) {
434 emitError(
435 UnknownLoc::get(type.getContext()),
436 "conversion to strided form failed either due to non-strided layout "
437 "maps (which should have been normalized away) or other reasons");
438 return {};
441 Type elementType = convertType(type.getElementType());
442 if (!elementType)
443 return {};
445 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
446 if (failed(addressSpace)) {
447 emitError(UnknownLoc::get(type.getContext()),
448 "conversion of memref memory space ")
449 << type.getMemorySpace()
450 << " to integer address space "
451 "failed. Consider adding memory space conversions.";
452 return {};
454 auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
456 auto indexTy = getIndexType();
458 SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
459 auto rank = type.getRank();
460 if (rank == 0)
461 return results;
463 if (unpackAggregates)
464 results.insert(results.end(), 2 * rank, indexTy);
465 else
466 results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
467 return results;
470 unsigned
471 LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
472 const DataLayout &layout) const {
473 // Compute the descriptor size given that of its components indicated above.
474 unsigned space = *getMemRefAddressSpace(type);
475 return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) +
476 (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType());
479 /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
480 /// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
481 Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
482 // When converting a MemRefType to a struct with descriptor fields, do not
483 // unpack the `sizes` and `strides` arrays.
484 SmallVector<Type, 5> types =
485 getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
486 if (types.empty())
487 return {};
488 return LLVM::LLVMStructType::getLiteral(&getContext(), types);
491 /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
492 /// that will form the unranked memref descriptor. In particular, the fields
493 /// for an unranked memref descriptor are:
494 /// 1. index-typed rank, the dynamic rank of this MemRef
495 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
496 /// stack allocated (alloca) copy of a MemRef descriptor that got casted to
497 /// be unranked.
498 SmallVector<Type, 2>
499 LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
500 return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())};
503 unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
504 UnrankedMemRefType type, const DataLayout &layout) const {
505 // Compute the descriptor size given that of its components indicated above.
506 unsigned space = *getMemRefAddressSpace(type);
507 return layout.getTypeSize(getIndexType()) +
508 llvm::divideCeil(getPointerBitwidth(space), 8);
511 Type LLVMTypeConverter::convertUnrankedMemRefType(
512 UnrankedMemRefType type) const {
513 if (!convertType(type.getElementType()))
514 return {};
515 return LLVM::LLVMStructType::getLiteral(&getContext(),
516 getUnrankedMemRefDescriptorFields());
519 FailureOr<unsigned>
520 LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
521 if (!type.getMemorySpace()) // Default memory space -> 0.
522 return 0;
523 std::optional<Attribute> converted =
524 convertTypeAttribute(type, type.getMemorySpace());
525 if (!converted)
526 return failure();
527 if (!(*converted)) // Conversion to default is 0.
528 return 0;
529 if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
530 if (explicitSpace.getType().isIndex() ||
531 explicitSpace.getType().isSignlessInteger())
532 return explicitSpace.getInt();
534 return failure();
537 // Check if a memref type can be converted to a bare pointer.
538 bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
539 if (isa<UnrankedMemRefType>(type))
540 // Unranked memref is not supported in the bare pointer calling convention.
541 return false;
543 // Check that the memref has static shape, strides and offset. Otherwise, it
544 // cannot be lowered to a bare pointer.
545 auto memrefTy = cast<MemRefType>(type);
546 if (!memrefTy.hasStaticShape())
547 return false;
549 int64_t offset = 0;
550 SmallVector<int64_t, 4> strides;
551 if (failed(getStridesAndOffset(memrefTy, strides, offset)))
552 return false;
554 for (int64_t stride : strides)
555 if (ShapedType::isDynamic(stride))
556 return false;
558 return !ShapedType::isDynamic(offset);
561 /// Convert a memref type to a bare pointer to the memref element type.
562 Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
563 if (!canConvertToBarePtr(type))
564 return {};
565 Type elementType = convertType(type.getElementType());
566 if (!elementType)
567 return {};
568 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
569 if (failed(addressSpace))
570 return {};
571 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
574 /// Convert an n-D vector type to an LLVM vector type:
575 /// * 0-D `vector<T>` are converted to vector<1xT>
576 /// * 1-D `vector<axT>` remains as is while,
577 /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
578 /// `!llvm.array<ax...array<jxvector<kxT>>>`.
579 /// As LLVM supports arrays of scalable vectors, this method will also convert
580 /// n-D scalable vectors provided that only the trailing dim is scalable.
581 FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
582 auto elementType = convertType(type.getElementType());
583 if (!elementType)
584 return {};
585 if (type.getShape().empty())
586 return VectorType::get({1}, elementType);
587 Type vectorType = VectorType::get(type.getShape().back(), elementType,
588 type.getScalableDims().back());
589 assert(LLVM::isCompatibleVectorType(vectorType) &&
590 "expected vector type compatible with the LLVM dialect");
591 // For n-D vector types for which a _non-trailing_ dim is scalable,
592 // return a failure. Supporting such cases would require LLVM
593 // to support something akin "scalable arrays" of vectors.
594 if (llvm::is_contained(type.getScalableDims().drop_back(), true))
595 return failure();
596 auto shape = type.getShape();
597 for (int i = shape.size() - 2; i >= 0; --i)
598 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
599 return vectorType;
602 /// Convert a type in the context of the default or bare pointer calling
603 /// convention. Calling convention sensitive types, such as MemRefType and
604 /// UnrankedMemRefType, are converted following the specific rules for the
605 /// calling convention. Calling convention independent types are converted
606 /// following the default LLVM type conversions.
607 Type LLVMTypeConverter::convertCallingConventionType(
608 Type type, bool useBarePtrCallConv) const {
609 if (useBarePtrCallConv)
610 if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
611 return convertMemRefToBarePtr(memrefTy);
613 return convertType(type);
616 /// Promote the bare pointers in 'values' that resulted from memrefs to
617 /// descriptors. 'stdTypes' holds they types of 'values' before the conversion
618 /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
619 void LLVMTypeConverter::promoteBarePtrsToDescriptors(
620 ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
621 SmallVectorImpl<Value> &values) const {
622 assert(stdTypes.size() == values.size() &&
623 "The number of types and values doesn't match");
624 for (unsigned i = 0, end = values.size(); i < end; ++i)
625 if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
626 values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
627 memrefTy, values[i]);
630 /// Convert a non-empty list of types of values produced by an operation into an
631 /// LLVM-compatible type. In particular, if more than one value is
632 /// produced, create a literal structure with elements that correspond to each
633 /// of the types converted with `convertType`.
634 Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
635 assert(!types.empty() && "expected non-empty list of type");
636 if (types.size() == 1)
637 return convertType(types[0]);
639 SmallVector<Type> resultTypes;
640 resultTypes.reserve(types.size());
641 for (Type type : types) {
642 Type converted = convertType(type);
643 if (!converted || !LLVM::isCompatibleType(converted))
644 return {};
645 resultTypes.push_back(converted);
648 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
651 /// Convert a non-empty list of types to be returned from a function into an
652 /// LLVM-compatible type. In particular, if more than one value is returned,
653 /// create an LLVM dialect structure type with elements that correspond to each
654 /// of the types converted with `convertCallingConventionType`.
655 Type LLVMTypeConverter::packFunctionResults(TypeRange types,
656 bool useBarePtrCallConv) const {
657 assert(!types.empty() && "expected non-empty list of type");
659 useBarePtrCallConv |= options.useBarePtrCallConv;
660 if (types.size() == 1)
661 return convertCallingConventionType(types.front(), useBarePtrCallConv);
663 SmallVector<Type> resultTypes;
664 resultTypes.reserve(types.size());
665 for (auto t : types) {
666 auto converted = convertCallingConventionType(t, useBarePtrCallConv);
667 if (!converted || !LLVM::isCompatibleType(converted))
668 return {};
669 resultTypes.push_back(converted);
672 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
675 Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
676 OpBuilder &builder) const {
677 // Alloca with proper alignment. We do not expect optimizations of this
678 // alloca op and so we omit allocating at the entry block.
679 auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
680 Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
681 builder.getIndexAttr(1));
682 Value allocated =
683 builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
684 // Store into the alloca'ed descriptor.
685 builder.create<LLVM::StoreOp>(loc, operand, allocated);
686 return allocated;
689 SmallVector<Value, 4>
690 LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
691 ValueRange operands, OpBuilder &builder,
692 bool useBarePtrCallConv) const {
693 SmallVector<Value, 4> promotedOperands;
694 promotedOperands.reserve(operands.size());
695 useBarePtrCallConv |= options.useBarePtrCallConv;
696 for (auto it : llvm::zip(opOperands, operands)) {
697 auto operand = std::get<0>(it);
698 auto llvmOperand = std::get<1>(it);
700 if (useBarePtrCallConv) {
701 // For the bare-ptr calling convention, we only have to extract the
702 // aligned pointer of a memref.
703 if (dyn_cast<MemRefType>(operand.getType())) {
704 MemRefDescriptor desc(llvmOperand);
705 llvmOperand = desc.alignedPtr(builder, loc);
706 } else if (isa<UnrankedMemRefType>(operand.getType())) {
707 llvm_unreachable("Unranked memrefs are not supported");
709 } else {
710 if (isa<UnrankedMemRefType>(operand.getType())) {
711 UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
712 promotedOperands);
713 continue;
715 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
716 MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
717 promotedOperands);
718 continue;
722 promotedOperands.push_back(llvmOperand);
724 return promotedOperands;
727 /// Callback to convert function argument types. It converts a MemRef function
728 /// argument to a list of non-aggregate types containing descriptor
729 /// information, and an UnrankedmemRef function argument to a list containing
730 /// the rank and a pointer to a descriptor struct.
731 LogicalResult
732 mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
733 SmallVectorImpl<Type> &result) {
734 if (auto memref = dyn_cast<MemRefType>(type)) {
735 // In signatures, Memref descriptors are expanded into lists of
736 // non-aggregate values.
737 auto converted =
738 converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true);
739 if (converted.empty())
740 return failure();
741 result.append(converted.begin(), converted.end());
742 return success();
744 if (isa<UnrankedMemRefType>(type)) {
745 auto converted = converter.getUnrankedMemRefDescriptorFields();
746 if (converted.empty())
747 return failure();
748 result.append(converted.begin(), converted.end());
749 return success();
751 auto converted = converter.convertType(type);
752 if (!converted)
753 return failure();
754 result.push_back(converted);
755 return success();
758 /// Callback to convert function argument types. It converts MemRef function
759 /// arguments to bare pointers to the MemRef element type.
760 LogicalResult
761 mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
762 SmallVectorImpl<Type> &result) {
763 auto llvmTy = converter.convertCallingConventionType(
764 type, /*useBarePointerCallConv=*/true);
765 if (!llvmTy)
766 return failure();
768 result.push_back(llvmTy);
769 return success();