1 //===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Target/LLVMIR/TypeFromLLVM.h"
10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/MLIRContext.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/IR/DataLayout.h"
16 #include "llvm/IR/DerivedTypes.h"
17 #include "llvm/IR/Type.h"
24 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
25 class TypeFromLLVMIRTranslatorImpl
{
27 /// Constructs a class creating types in the given MLIR context.
28 TypeFromLLVMIRTranslatorImpl(MLIRContext
&context
) : context(context
) {}
30 /// Translates the given type.
31 Type
translateType(llvm::Type
*type
) {
32 if (knownTranslations
.count(type
))
33 return knownTranslations
.lookup(type
);
36 llvm::TypeSwitch
<llvm::Type
*, Type
>(type
)
37 .Case
<llvm::ArrayType
, llvm::FunctionType
, llvm::IntegerType
,
38 llvm::PointerType
, llvm::StructType
, llvm::FixedVectorType
,
39 llvm::ScalableVectorType
, llvm::TargetExtType
>(
40 [this](auto *type
) { return this->translate(type
); })
41 .Default([this](llvm::Type
*type
) {
42 return translatePrimitiveType(type
);
44 knownTranslations
.try_emplace(type
, translated
);
49 /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
51 Type
translatePrimitiveType(llvm::Type
*type
) {
53 return LLVM::LLVMVoidType::get(&context
);
55 return Float16Type::get(&context
);
56 if (type
->isBFloatTy())
57 return BFloat16Type::get(&context
);
58 if (type
->isFloatTy())
59 return Float32Type::get(&context
);
60 if (type
->isDoubleTy())
61 return Float64Type::get(&context
);
62 if (type
->isFP128Ty())
63 return Float128Type::get(&context
);
64 if (type
->isX86_FP80Ty())
65 return Float80Type::get(&context
);
66 if (type
->isX86_AMXTy())
67 return LLVM::LLVMX86AMXType::get(&context
);
68 if (type
->isPPC_FP128Ty())
69 return LLVM::LLVMPPCFP128Type::get(&context
);
70 if (type
->isLabelTy())
71 return LLVM::LLVMLabelType::get(&context
);
72 if (type
->isMetadataTy())
73 return LLVM::LLVMMetadataType::get(&context
);
74 if (type
->isTokenTy())
75 return LLVM::LLVMTokenType::get(&context
);
76 llvm_unreachable("not a primitive type");
79 /// Translates the given array type.
80 Type
translate(llvm::ArrayType
*type
) {
81 return LLVM::LLVMArrayType::get(translateType(type
->getElementType()),
82 type
->getNumElements());
85 /// Translates the given function type.
86 Type
translate(llvm::FunctionType
*type
) {
87 SmallVector
<Type
, 8> paramTypes
;
88 translateTypes(type
->params(), paramTypes
);
89 return LLVM::LLVMFunctionType::get(translateType(type
->getReturnType()),
90 paramTypes
, type
->isVarArg());
93 /// Translates the given integer type.
94 Type
translate(llvm::IntegerType
*type
) {
95 return IntegerType::get(&context
, type
->getBitWidth());
98 /// Translates the given pointer type.
99 Type
translate(llvm::PointerType
*type
) {
100 return LLVM::LLVMPointerType::get(&context
, type
->getAddressSpace());
103 /// Translates the given structure type.
104 Type
translate(llvm::StructType
*type
) {
105 SmallVector
<Type
, 8> subtypes
;
106 if (type
->isLiteral()) {
107 translateTypes(type
->subtypes(), subtypes
);
108 return LLVM::LLVMStructType::getLiteral(&context
, subtypes
,
112 if (type
->isOpaque())
113 return LLVM::LLVMStructType::getOpaque(type
->getName(), &context
);
115 // With opaque pointers, types in LLVM can't be recursive anymore. Note that
116 // using getIdentified is not possible, as type names in LLVM are not
117 // guaranteed to be unique.
118 translateTypes(type
->subtypes(), subtypes
);
119 LLVM::LLVMStructType translated
= LLVM::LLVMStructType::getNewIdentified(
120 &context
, type
->getName(), subtypes
, type
->isPacked());
121 knownTranslations
.try_emplace(type
, translated
);
125 /// Translates the given fixed-vector type.
126 Type
translate(llvm::FixedVectorType
*type
) {
127 return LLVM::getFixedVectorType(translateType(type
->getElementType()),
128 type
->getNumElements());
131 /// Translates the given scalable-vector type.
132 Type
translate(llvm::ScalableVectorType
*type
) {
133 return LLVM::LLVMScalableVectorType::get(
134 translateType(type
->getElementType()), type
->getMinNumElements());
137 /// Translates the given target extension type.
138 Type
translate(llvm::TargetExtType
*type
) {
139 SmallVector
<Type
> typeParams
;
140 translateTypes(type
->type_params(), typeParams
);
142 return LLVM::LLVMTargetExtType::get(&context
, type
->getName(), typeParams
,
146 /// Translates a list of types.
147 void translateTypes(ArrayRef
<llvm::Type
*> types
,
148 SmallVectorImpl
<Type
> &result
) {
149 result
.reserve(result
.size() + types
.size());
150 for (llvm::Type
*type
: types
)
151 result
.push_back(translateType(type
));
154 /// Map of known translations. Serves as a cache and as recursion stopper for
155 /// translating recursive structs.
156 llvm::DenseMap
<llvm::Type
*, Type
> knownTranslations
;
158 /// The context in which MLIR types are created.
159 MLIRContext
&context
;
162 } // namespace detail
166 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext
&context
)
167 : impl(new detail::TypeFromLLVMIRTranslatorImpl(context
)) {}
169 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() = default;
171 Type
LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type
*type
) {
172 return impl
->translateType(type
);