[mlir][scf]: Add value bound between scf for loop yield and result (#123200)
[llvm-project.git] / mlir / lib / Target / LLVMIR / TypeFromLLVM.cpp
blobea990ca7aefbe030173c4898c5fb84945e5dbdd6
1 //===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Target/LLVMIR/TypeFromLLVM.h"
10 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
11 #include "mlir/IR/BuiltinTypes.h"
12 #include "mlir/IR/MLIRContext.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/IR/DataLayout.h"
16 #include "llvm/IR/DerivedTypes.h"
17 #include "llvm/IR/Type.h"
19 using namespace mlir;
21 namespace mlir {
22 namespace LLVM {
23 namespace detail {
24 /// Support for translating LLVM IR types to MLIR LLVM dialect types.
25 class TypeFromLLVMIRTranslatorImpl {
26 public:
27 /// Constructs a class creating types in the given MLIR context.
28 TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
30 /// Translates the given type.
31 Type translateType(llvm::Type *type) {
32 if (knownTranslations.count(type))
33 return knownTranslations.lookup(type);
35 Type translated =
36 llvm::TypeSwitch<llvm::Type *, Type>(type)
37 .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
38 llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
39 llvm::ScalableVectorType, llvm::TargetExtType>(
40 [this](auto *type) { return this->translate(type); })
41 .Default([this](llvm::Type *type) {
42 return translatePrimitiveType(type);
43 });
44 knownTranslations.try_emplace(type, translated);
45 return translated;
48 private:
49 /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
50 /// type.
51 Type translatePrimitiveType(llvm::Type *type) {
52 if (type->isVoidTy())
53 return LLVM::LLVMVoidType::get(&context);
54 if (type->isHalfTy())
55 return Float16Type::get(&context);
56 if (type->isBFloatTy())
57 return BFloat16Type::get(&context);
58 if (type->isFloatTy())
59 return Float32Type::get(&context);
60 if (type->isDoubleTy())
61 return Float64Type::get(&context);
62 if (type->isFP128Ty())
63 return Float128Type::get(&context);
64 if (type->isX86_FP80Ty())
65 return Float80Type::get(&context);
66 if (type->isX86_AMXTy())
67 return LLVM::LLVMX86AMXType::get(&context);
68 if (type->isPPC_FP128Ty())
69 return LLVM::LLVMPPCFP128Type::get(&context);
70 if (type->isLabelTy())
71 return LLVM::LLVMLabelType::get(&context);
72 if (type->isMetadataTy())
73 return LLVM::LLVMMetadataType::get(&context);
74 if (type->isTokenTy())
75 return LLVM::LLVMTokenType::get(&context);
76 llvm_unreachable("not a primitive type");
79 /// Translates the given array type.
80 Type translate(llvm::ArrayType *type) {
81 return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
82 type->getNumElements());
85 /// Translates the given function type.
86 Type translate(llvm::FunctionType *type) {
87 SmallVector<Type, 8> paramTypes;
88 translateTypes(type->params(), paramTypes);
89 return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
90 paramTypes, type->isVarArg());
93 /// Translates the given integer type.
94 Type translate(llvm::IntegerType *type) {
95 return IntegerType::get(&context, type->getBitWidth());
98 /// Translates the given pointer type.
99 Type translate(llvm::PointerType *type) {
100 return LLVM::LLVMPointerType::get(&context, type->getAddressSpace());
103 /// Translates the given structure type.
104 Type translate(llvm::StructType *type) {
105 SmallVector<Type, 8> subtypes;
106 if (type->isLiteral()) {
107 translateTypes(type->subtypes(), subtypes);
108 return LLVM::LLVMStructType::getLiteral(&context, subtypes,
109 type->isPacked());
112 if (type->isOpaque())
113 return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
115 // With opaque pointers, types in LLVM can't be recursive anymore. Note that
116 // using getIdentified is not possible, as type names in LLVM are not
117 // guaranteed to be unique.
118 translateTypes(type->subtypes(), subtypes);
119 LLVM::LLVMStructType translated = LLVM::LLVMStructType::getNewIdentified(
120 &context, type->getName(), subtypes, type->isPacked());
121 knownTranslations.try_emplace(type, translated);
122 return translated;
125 /// Translates the given fixed-vector type.
126 Type translate(llvm::FixedVectorType *type) {
127 return LLVM::getFixedVectorType(translateType(type->getElementType()),
128 type->getNumElements());
131 /// Translates the given scalable-vector type.
132 Type translate(llvm::ScalableVectorType *type) {
133 return LLVM::LLVMScalableVectorType::get(
134 translateType(type->getElementType()), type->getMinNumElements());
137 /// Translates the given target extension type.
138 Type translate(llvm::TargetExtType *type) {
139 SmallVector<Type> typeParams;
140 translateTypes(type->type_params(), typeParams);
142 return LLVM::LLVMTargetExtType::get(&context, type->getName(), typeParams,
143 type->int_params());
146 /// Translates a list of types.
147 void translateTypes(ArrayRef<llvm::Type *> types,
148 SmallVectorImpl<Type> &result) {
149 result.reserve(result.size() + types.size());
150 for (llvm::Type *type : types)
151 result.push_back(translateType(type));
154 /// Map of known translations. Serves as a cache and as recursion stopper for
155 /// translating recursive structs.
156 llvm::DenseMap<llvm::Type *, Type> knownTranslations;
158 /// The context in which MLIR types are created.
159 MLIRContext &context;
162 } // namespace detail
163 } // namespace LLVM
164 } // namespace mlir
166 LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
167 : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {}
169 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() = default;
171 Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
172 return impl->translateType(type);