1 //===- ModuleImport.cpp - LLVM to MLIR conversion ---------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the import of an LLVM IR module into an LLVM dialect
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/ModuleImport.h"
15 #include "mlir/Target/LLVMIR/Import.h"
17 #include "AttrKindDetail.h"
18 #include "DataLayoutImporter.h"
19 #include "DebugImporter.h"
20 #include "LoopAnnotationImporter.h"
22 #include "mlir/Dialect/DLTI/DLTI.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/Interfaces/DataLayoutInterfaces.h"
27 #include "mlir/Tools/mlir-translate/Translation.h"
29 #include "llvm/ADT/DepthFirstIterator.h"
30 #include "llvm/ADT/PostOrderIterator.h"
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/StringSet.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/IR/Comdat.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/InlineAsm.h"
37 #include "llvm/IR/InstIterator.h"
38 #include "llvm/IR/Instructions.h"
39 #include "llvm/IR/IntrinsicInst.h"
40 #include "llvm/IR/Metadata.h"
41 #include "llvm/IR/Operator.h"
42 #include "llvm/Support/ModRef.h"
45 using namespace mlir::LLVM
;
46 using namespace mlir::LLVM::detail
;
48 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
50 // Utility to print an LLVM value as a string for passing to emitError().
51 // FIXME: Diagnostic should be able to natively handle types that have
52 // operator << (raw_ostream&) defined.
53 static std::string
diag(const llvm::Value
&value
) {
55 llvm::raw_string_ostream
os(str
);
60 // Utility to print an LLVM metadata node as a string for passing
61 // to emitError(). The module argument is needed to print the nodes
62 // canonically numbered.
63 static std::string
diagMD(const llvm::Metadata
*node
,
64 const llvm::Module
*module
) {
66 llvm::raw_string_ostream
os(str
);
67 node
->print(os
, module
, /*IsForDebug=*/true);
71 /// Returns the name of the global_ctors global variables.
72 static constexpr StringRef
getGlobalCtorsVarName() {
73 return "llvm.global_ctors";
76 /// Returns the name of the global_dtors global variables.
77 static constexpr StringRef
getGlobalDtorsVarName() {
78 return "llvm.global_dtors";
81 /// Returns the symbol name for the module-level comdat operation. It must not
82 /// conflict with the user namespace.
83 static constexpr StringRef
getGlobalComdatOpName() {
84 return "__llvm_global_comdat";
87 /// Converts the sync scope identifier of `inst` to the string representation
88 /// necessary to build an atomic LLVM dialect operation. Returns the empty
89 /// string if the operation has either no sync scope or the default system-level
90 /// sync scope attached. The atomic operations only set their sync scope
91 /// attribute if they have a non-default sync scope attached.
92 static StringRef
getLLVMSyncScope(llvm::Instruction
*inst
) {
93 std::optional
<llvm::SyncScope::ID
> syncScopeID
=
94 llvm::getAtomicSyncScopeID(inst
);
98 // Search the sync scope name for the given identifier. The default
99 // system-level sync scope thereby maps to the empty string.
100 SmallVector
<StringRef
> syncScopeName
;
101 llvm::LLVMContext
&llvmContext
= inst
->getContext();
102 llvmContext
.getSyncScopeNames(syncScopeName
);
103 auto *it
= llvm::find_if(syncScopeName
, [&](StringRef name
) {
104 return *syncScopeID
== llvmContext
.getOrInsertSyncScopeID(name
);
106 if (it
!= syncScopeName
.end())
108 llvm_unreachable("incorrect sync scope identifier");
111 /// Converts an array of unsigned indices to a signed integer position array.
112 static SmallVector
<int64_t> getPositionFromIndices(ArrayRef
<unsigned> indices
) {
113 SmallVector
<int64_t> position
;
114 llvm::append_range(position
, indices
);
118 /// Converts the LLVM instructions that have a generated MLIR builder. Using a
119 /// static implementation method called from the module import ensures the
120 /// builders have to use the `moduleImport` argument and cannot directly call
121 /// import methods. As a result, both the intrinsic and the instruction MLIR
122 /// builders have to use the `moduleImport` argument and none of them has direct
123 /// access to the private module import methods.
124 static LogicalResult
convertInstructionImpl(OpBuilder
&odsBuilder
,
125 llvm::Instruction
*inst
,
126 ModuleImport
&moduleImport
,
127 LLVMImportInterface
&iface
) {
128 // Copy the operands to an LLVM operands array reference for conversion.
129 SmallVector
<llvm::Value
*> operands(inst
->operands());
130 ArrayRef
<llvm::Value
*> llvmOperands(operands
);
132 // Convert all instructions that provide an MLIR builder.
133 if (iface
.isConvertibleInstruction(inst
->getOpcode()))
134 return iface
.convertInstruction(odsBuilder
, inst
, llvmOperands
,
136 // TODO: Implement the `convertInstruction` hooks in the
137 // `LLVMDialectLLVMIRImportInterface` and move the following include there.
138 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
142 /// Get a topologically sorted list of blocks for the given basic blocks.
143 static SetVector
<llvm::BasicBlock
*>
144 getTopologicallySortedBlocks(ArrayRef
<llvm::BasicBlock
*> basicBlocks
) {
145 SetVector
<llvm::BasicBlock
*> blocks
;
146 for (llvm::BasicBlock
*basicBlock
: basicBlocks
) {
147 if (!blocks
.contains(basicBlock
)) {
148 llvm::ReversePostOrderTraversal
<llvm::BasicBlock
*> traversal(basicBlock
);
149 blocks
.insert(traversal
.begin(), traversal
.end());
152 assert(blocks
.size() == basicBlocks
.size() && "some blocks are not sorted");
156 ModuleImport::ModuleImport(ModuleOp mlirModule
,
157 std::unique_ptr
<llvm::Module
> llvmModule
,
158 bool emitExpensiveWarnings
,
159 bool importEmptyDICompositeTypes
)
160 : builder(mlirModule
->getContext()), context(mlirModule
->getContext()),
161 mlirModule(mlirModule
), llvmModule(std::move(llvmModule
)),
162 iface(mlirModule
->getContext()),
163 typeTranslator(*mlirModule
->getContext()),
164 debugImporter(std::make_unique
<DebugImporter
>(
165 mlirModule
, importEmptyDICompositeTypes
)),
166 loopAnnotationImporter(
167 std::make_unique
<LoopAnnotationImporter
>(*this, builder
)),
168 emitExpensiveWarnings(emitExpensiveWarnings
) {
169 builder
.setInsertionPointToStart(mlirModule
.getBody());
172 ComdatOp
ModuleImport::getGlobalComdatOp() {
174 return globalComdatOp
;
176 OpBuilder::InsertionGuard
guard(builder
);
177 builder
.setInsertionPointToEnd(mlirModule
.getBody());
179 builder
.create
<ComdatOp
>(mlirModule
.getLoc(), getGlobalComdatOpName());
180 globalInsertionOp
= globalComdatOp
;
181 return globalComdatOp
;
184 LogicalResult
ModuleImport::processTBAAMetadata(const llvm::MDNode
*node
) {
185 Location loc
= mlirModule
.getLoc();
187 // If `node` is a valid TBAA root node, then return its optional identity
188 // string, otherwise return failure.
189 auto getIdentityIfRootNode
=
190 [&](const llvm::MDNode
*node
) -> FailureOr
<std::optional
<StringRef
>> {
192 // !0 = !{!"Simple C/C++ TBAA"}
194 if (node
->getNumOperands() > 1)
196 // If the operand is MDString, then assume that this is a root node.
197 if (node
->getNumOperands() == 1)
198 if (const auto *op0
= dyn_cast
<const llvm::MDString
>(node
->getOperand(0)))
199 return std::optional
<StringRef
>{op0
->getString()};
200 return std::optional
<StringRef
>{};
203 // If `node` looks like a TBAA type descriptor metadata,
204 // then return true, if it is a valid node, and false otherwise.
205 // If it does not look like a TBAA type descriptor metadata, then
206 // return std::nullopt.
207 // If `identity` and `memberTypes/Offsets` are non-null, then they will
208 // contain the converted metadata operands for a valid TBAA node (i.e. when
209 // true is returned).
210 auto isTypeDescriptorNode
= [&](const llvm::MDNode
*node
,
211 StringRef
*identity
= nullptr,
212 SmallVectorImpl
<TBAAMemberAttr
> *members
=
213 nullptr) -> std::optional
<bool> {
214 unsigned numOperands
= node
->getNumOperands();
215 // Type descriptor, e.g.:
216 // !1 = !{!"int", !0, /*optional*/i64 0} /* scalar int type */
217 // !2 = !{!"agg_t", !1, i64 0} /* struct agg_t { int x; } */
221 // TODO: support "new" format (D41501) for type descriptors,
222 // where the first operand is an MDNode.
223 const auto *identityNode
=
224 dyn_cast
<const llvm::MDString
>(node
->getOperand(0));
228 // This should be a type descriptor node.
230 *identity
= identityNode
->getString();
232 for (unsigned pairNum
= 0, e
= numOperands
/ 2; pairNum
< e
; ++pairNum
) {
233 const auto *memberNode
=
234 dyn_cast
<const llvm::MDNode
>(node
->getOperand(2 * pairNum
+ 1));
236 emitError(loc
) << "operand '" << 2 * pairNum
+ 1 << "' must be MDNode: "
237 << diagMD(node
, llvmModule
.get());
241 if (2 * pairNum
+ 2 >= numOperands
) {
242 // Allow for optional 0 offset in 2-operand nodes.
243 if (numOperands
!= 2) {
244 emitError(loc
) << "missing member offset: "
245 << diagMD(node
, llvmModule
.get());
249 auto *offsetCI
= llvm::mdconst::dyn_extract
<llvm::ConstantInt
>(
250 node
->getOperand(2 * pairNum
+ 2));
252 emitError(loc
) << "operand '" << 2 * pairNum
+ 2
253 << "' must be ConstantInt: "
254 << diagMD(node
, llvmModule
.get());
257 offset
= offsetCI
->getZExtValue();
261 members
->push_back(TBAAMemberAttr::get(
262 cast
<TBAANodeAttr
>(tbaaMapping
.lookup(memberNode
)), offset
));
268 // If `node` looks like a TBAA access tag metadata,
269 // then return true, if it is a valid node, and false otherwise.
270 // If it does not look like a TBAA access tag metadata, then
271 // return std::nullopt.
272 // If the other arguments are non-null, then they will contain
273 // the converted metadata operands for a valid TBAA node (i.e. when true is
275 auto isTagNode
= [&](const llvm::MDNode
*node
,
276 TBAATypeDescriptorAttr
*baseAttr
= nullptr,
277 TBAATypeDescriptorAttr
*accessAttr
= nullptr,
278 int64_t *offset
= nullptr,
279 bool *isConstant
= nullptr) -> std::optional
<bool> {
281 // !3 = !{!1, !1, i64 0} /* scalar int access */
282 // !4 = !{!2, !1, i64 0} /* agg_t::x access */
284 // Optional 4th argument is ConstantInt 0/1 identifying whether
285 // the location being accessed is "constant" (see for details:
286 // https://llvm.org/docs/LangRef.html#representation).
287 unsigned numOperands
= node
->getNumOperands();
288 if (numOperands
!= 3 && numOperands
!= 4)
290 const auto *baseMD
= dyn_cast
<const llvm::MDNode
>(node
->getOperand(0));
291 const auto *accessMD
= dyn_cast
<const llvm::MDNode
>(node
->getOperand(1));
293 llvm::mdconst::dyn_extract
<llvm::ConstantInt
>(node
->getOperand(2));
294 if (!baseMD
|| !accessMD
|| !offsetCI
)
296 // TODO: support "new" TBAA format, if needed (see D41501).
297 // In the "old" format the first operand of the access type
298 // metadata is MDString. We have to distinguish the formats,
299 // because access tags have the same structure, but different
300 // meaning for the operands.
301 if (accessMD
->getNumOperands() < 1 ||
302 !isa
<llvm::MDString
>(accessMD
->getOperand(0)))
304 bool isConst
= false;
305 if (numOperands
== 4) {
307 llvm::mdconst::dyn_extract
<llvm::ConstantInt
>(node
->getOperand(3));
309 emitError(loc
) << "operand '3' must be ConstantInt: "
310 << diagMD(node
, llvmModule
.get());
313 isConst
= isConstantCI
->getValue()[0];
316 *baseAttr
= cast
<TBAATypeDescriptorAttr
>(tbaaMapping
.lookup(baseMD
));
318 *accessAttr
= cast
<TBAATypeDescriptorAttr
>(tbaaMapping
.lookup(accessMD
));
320 *offset
= offsetCI
->getZExtValue();
322 *isConstant
= isConst
;
326 // Do a post-order walk over the TBAA Graph. Since a correct TBAA Graph is a
327 // DAG, a post-order walk guarantees that we convert any metadata node we
328 // depend on, prior to converting the current node.
329 DenseSet
<const llvm::MDNode
*> seen
;
330 SmallVector
<const llvm::MDNode
*> workList
;
331 workList
.push_back(node
);
332 while (!workList
.empty()) {
333 const llvm::MDNode
*current
= workList
.back();
334 if (tbaaMapping
.contains(current
)) {
335 // Already converted. Just pop from the worklist.
340 // If any child of this node is not yet converted, don't pop the current
341 // node from the worklist but push the not-yet-converted children in the
342 // front of the worklist.
343 bool anyChildNotConverted
= false;
344 for (const llvm::MDOperand
&operand
: current
->operands())
345 if (auto *childNode
= dyn_cast_or_null
<const llvm::MDNode
>(operand
.get()))
346 if (!tbaaMapping
.contains(childNode
)) {
347 workList
.push_back(childNode
);
348 anyChildNotConverted
= true;
351 if (anyChildNotConverted
) {
352 // If this is the second time we failed to convert an element in the
353 // worklist it must be because a child is dependent on it being converted
354 // and we have a cycle in the graph. Cycles are not allowed in TBAA
356 if (!seen
.insert(current
).second
)
357 return emitError(loc
) << "has cycle in TBAA graph: "
358 << diagMD(current
, llvmModule
.get());
363 // Otherwise simply import the current node.
366 FailureOr
<std::optional
<StringRef
>> rootNodeIdentity
=
367 getIdentityIfRootNode(current
);
368 if (succeeded(rootNodeIdentity
)) {
369 StringAttr stringAttr
= *rootNodeIdentity
370 ? builder
.getStringAttr(**rootNodeIdentity
)
372 // The root nodes do not have operands, so we can create
373 // the TBAARootAttr on the first walk.
374 tbaaMapping
.insert({current
, builder
.getAttr
<TBAARootAttr
>(stringAttr
)});
379 SmallVector
<TBAAMemberAttr
> members
;
380 if (std::optional
<bool> isValid
=
381 isTypeDescriptorNode(current
, &identity
, &members
)) {
382 assert(isValid
.value() && "type descriptor node must be valid");
384 tbaaMapping
.insert({current
, builder
.getAttr
<TBAATypeDescriptorAttr
>(
385 identity
, members
)});
389 TBAATypeDescriptorAttr baseAttr
, accessAttr
;
392 if (std::optional
<bool> isValid
=
393 isTagNode(current
, &baseAttr
, &accessAttr
, &offset
, &isConstant
)) {
394 assert(isValid
.value() && "access tag node must be valid");
396 {current
, builder
.getAttr
<TBAATagAttr
>(baseAttr
, accessAttr
, offset
,
401 return emitError(loc
) << "unsupported TBAA node format: "
402 << diagMD(current
, llvmModule
.get());
408 ModuleImport::processAccessGroupMetadata(const llvm::MDNode
*node
) {
409 Location loc
= mlirModule
.getLoc();
410 if (failed(loopAnnotationImporter
->translateAccessGroup(node
, loc
)))
411 return emitError(loc
) << "unsupported access group node: "
412 << diagMD(node
, llvmModule
.get());
417 ModuleImport::processAliasScopeMetadata(const llvm::MDNode
*node
) {
418 Location loc
= mlirModule
.getLoc();
419 // Helper that verifies the node has a self reference operand.
420 auto verifySelfRef
= [](const llvm::MDNode
*node
) {
421 return node
->getNumOperands() != 0 &&
422 node
== dyn_cast
<llvm::MDNode
>(node
->getOperand(0));
424 // Helper that verifies the given operand is a string or does not exist.
425 auto verifyDescription
= [](const llvm::MDNode
*node
, unsigned idx
) {
426 return idx
>= node
->getNumOperands() ||
427 isa
<llvm::MDString
>(node
->getOperand(idx
));
429 // Helper that creates an alias scope domain attribute.
430 auto createAliasScopeDomainOp
= [&](const llvm::MDNode
*aliasDomain
) {
431 StringAttr description
= nullptr;
432 if (aliasDomain
->getNumOperands() >= 2)
433 if (auto *operand
= dyn_cast
<llvm::MDString
>(aliasDomain
->getOperand(1)))
434 description
= builder
.getStringAttr(operand
->getString());
435 return builder
.getAttr
<AliasScopeDomainAttr
>(
436 DistinctAttr::create(builder
.getUnitAttr()), description
);
439 // Collect the alias scopes and domains to translate them.
440 for (const llvm::MDOperand
&operand
: node
->operands()) {
441 if (const auto *scope
= dyn_cast
<llvm::MDNode
>(operand
)) {
442 llvm::AliasScopeNode
aliasScope(scope
);
443 const llvm::MDNode
*domain
= aliasScope
.getDomain();
445 // Verify the scope node points to valid scope metadata which includes
446 // verifying its domain. Perform the verification before looking it up in
447 // the alias scope mapping since it could have been inserted as a domain
449 if (!verifySelfRef(scope
) || !domain
|| !verifyDescription(scope
, 2))
450 return emitError(loc
) << "unsupported alias scope node: "
451 << diagMD(scope
, llvmModule
.get());
452 if (!verifySelfRef(domain
) || !verifyDescription(domain
, 1))
453 return emitError(loc
) << "unsupported alias domain node: "
454 << diagMD(domain
, llvmModule
.get());
456 if (aliasScopeMapping
.contains(scope
))
459 // Convert the domain metadata node if it has not been translated before.
460 auto it
= aliasScopeMapping
.find(aliasScope
.getDomain());
461 if (it
== aliasScopeMapping
.end()) {
462 auto aliasScopeDomainOp
= createAliasScopeDomainOp(domain
);
463 it
= aliasScopeMapping
.try_emplace(domain
, aliasScopeDomainOp
).first
;
466 // Convert the scope metadata node if it has not been converted before.
467 StringAttr description
= nullptr;
468 if (!aliasScope
.getName().empty())
469 description
= builder
.getStringAttr(aliasScope
.getName());
470 auto aliasScopeOp
= builder
.getAttr
<AliasScopeAttr
>(
471 DistinctAttr::create(builder
.getUnitAttr()),
472 cast
<AliasScopeDomainAttr
>(it
->second
), description
);
473 aliasScopeMapping
.try_emplace(aliasScope
.getNode(), aliasScopeOp
);
479 FailureOr
<SmallVector
<AliasScopeAttr
>>
480 ModuleImport::lookupAliasScopeAttrs(const llvm::MDNode
*node
) const {
481 SmallVector
<AliasScopeAttr
> aliasScopes
;
482 aliasScopes
.reserve(node
->getNumOperands());
483 for (const llvm::MDOperand
&operand
: node
->operands()) {
484 auto *node
= cast
<llvm::MDNode
>(operand
.get());
485 aliasScopes
.push_back(
486 dyn_cast_or_null
<AliasScopeAttr
>(aliasScopeMapping
.lookup(node
)));
488 // Return failure if one of the alias scope lookups failed.
489 if (llvm::is_contained(aliasScopes
, nullptr))
494 void ModuleImport::addDebugIntrinsic(llvm::CallInst
*intrinsic
) {
495 debugIntrinsics
.insert(intrinsic
);
498 LogicalResult
ModuleImport::convertLinkerOptionsMetadata() {
499 for (const llvm::NamedMDNode
&named
: llvmModule
->named_metadata()) {
500 if (named
.getName() != "llvm.linker.options")
502 // llvm.linker.options operands are lists of strings.
503 for (const llvm::MDNode
*md
: named
.operands()) {
504 SmallVector
<StringRef
> options
;
505 options
.reserve(md
->getNumOperands());
506 for (const llvm::MDOperand
&option
: md
->operands())
507 options
.push_back(cast
<llvm::MDString
>(option
)->getString());
508 builder
.create
<LLVM::LinkerOptionsOp
>(mlirModule
.getLoc(),
509 builder
.getStrArrayAttr(options
));
515 LogicalResult
ModuleImport::convertMetadata() {
516 OpBuilder::InsertionGuard
guard(builder
);
517 builder
.setInsertionPointToEnd(mlirModule
.getBody());
518 for (const llvm::Function
&func
: llvmModule
->functions()) {
519 for (const llvm::Instruction
&inst
: llvm::instructions(func
)) {
520 // Convert access group metadata nodes.
521 if (llvm::MDNode
*node
=
522 inst
.getMetadata(llvm::LLVMContext::MD_access_group
))
523 if (failed(processAccessGroupMetadata(node
)))
526 // Convert alias analysis metadata nodes.
527 llvm::AAMDNodes aliasAnalysisNodes
= inst
.getAAMetadata();
528 if (!aliasAnalysisNodes
)
530 if (aliasAnalysisNodes
.TBAA
)
531 if (failed(processTBAAMetadata(aliasAnalysisNodes
.TBAA
)))
533 if (aliasAnalysisNodes
.Scope
)
534 if (failed(processAliasScopeMetadata(aliasAnalysisNodes
.Scope
)))
536 if (aliasAnalysisNodes
.NoAlias
)
537 if (failed(processAliasScopeMetadata(aliasAnalysisNodes
.NoAlias
)))
541 if (failed(convertLinkerOptionsMetadata()))
546 void ModuleImport::processComdat(const llvm::Comdat
*comdat
) {
547 if (comdatMapping
.contains(comdat
))
550 ComdatOp comdatOp
= getGlobalComdatOp();
551 OpBuilder::InsertionGuard
guard(builder
);
552 builder
.setInsertionPointToEnd(&comdatOp
.getBody().back());
553 auto selectorOp
= builder
.create
<ComdatSelectorOp
>(
554 mlirModule
.getLoc(), comdat
->getName(),
555 convertComdatFromLLVM(comdat
->getSelectionKind()));
557 SymbolRefAttr::get(builder
.getContext(), getGlobalComdatOpName(),
558 FlatSymbolRefAttr::get(selectorOp
.getSymNameAttr()));
559 comdatMapping
.try_emplace(comdat
, symbolRef
);
562 LogicalResult
ModuleImport::convertComdats() {
563 for (llvm::GlobalVariable
&globalVar
: llvmModule
->globals())
564 if (globalVar
.hasComdat())
565 processComdat(globalVar
.getComdat());
566 for (llvm::Function
&func
: llvmModule
->functions())
567 if (func
.hasComdat())
568 processComdat(func
.getComdat());
572 LogicalResult
ModuleImport::convertGlobals() {
573 for (llvm::GlobalVariable
&globalVar
: llvmModule
->globals()) {
574 if (globalVar
.getName() == getGlobalCtorsVarName() ||
575 globalVar
.getName() == getGlobalDtorsVarName()) {
576 if (failed(convertGlobalCtorsAndDtors(&globalVar
))) {
577 return emitError(UnknownLoc::get(context
))
578 << "unhandled global variable: " << diag(globalVar
);
582 if (failed(convertGlobal(&globalVar
))) {
583 return emitError(UnknownLoc::get(context
))
584 << "unhandled global variable: " << diag(globalVar
);
590 LogicalResult
ModuleImport::convertDataLayout() {
591 Location loc
= mlirModule
.getLoc();
592 DataLayoutImporter
dataLayoutImporter(context
, llvmModule
->getDataLayout());
593 if (!dataLayoutImporter
.getDataLayout())
594 return emitError(loc
, "cannot translate data layout: ")
595 << dataLayoutImporter
.getLastToken();
597 for (StringRef token
: dataLayoutImporter
.getUnhandledTokens())
598 emitWarning(loc
, "unhandled data layout token: ") << token
;
600 mlirModule
->setAttr(DLTIDialect::kDataLayoutAttrName
,
601 dataLayoutImporter
.getDataLayout());
605 LogicalResult
ModuleImport::convertFunctions() {
606 for (llvm::Function
&func
: llvmModule
->functions())
607 if (failed(processFunction(&func
)))
612 void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction
*inst
,
614 SmallVector
<std::pair
<unsigned, llvm::MDNode
*>> allMetadata
;
615 inst
->getAllMetadataOtherThanDebugLoc(allMetadata
);
616 for (auto &[kind
, node
] : allMetadata
) {
617 if (!iface
.isConvertibleMetadata(kind
))
619 if (failed(iface
.setMetadataAttrs(builder
, kind
, node
, op
, *this))) {
620 if (emitExpensiveWarnings
) {
621 Location loc
= debugImporter
->translateLoc(inst
->getDebugLoc());
622 emitWarning(loc
) << "unhandled metadata: "
623 << diagMD(node
, llvmModule
.get()) << " on "
630 void ModuleImport::setIntegerOverflowFlags(llvm::Instruction
*inst
,
631 Operation
*op
) const {
632 auto iface
= cast
<IntegerOverflowFlagsInterface
>(op
);
634 IntegerOverflowFlags value
= {};
635 value
= bitEnumSet(value
, IntegerOverflowFlags::nsw
, inst
->hasNoSignedWrap());
637 bitEnumSet(value
, IntegerOverflowFlags::nuw
, inst
->hasNoUnsignedWrap());
639 iface
.setOverflowFlags(value
);
642 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction
*inst
,
643 Operation
*op
) const {
644 auto iface
= cast
<FastmathFlagsInterface
>(op
);
646 // Even if the imported operation implements the fastmath interface, the
647 // original instruction may not have fastmath flags set. Exit if an
648 // instruction, such as a non floating-point function call, does not have
650 if (!isa
<llvm::FPMathOperator
>(inst
))
652 llvm::FastMathFlags flags
= inst
->getFastMathFlags();
654 // Set the fastmath bits flag-by-flag.
655 FastmathFlags value
= {};
656 value
= bitEnumSet(value
, FastmathFlags::nnan
, flags
.noNaNs());
657 value
= bitEnumSet(value
, FastmathFlags::ninf
, flags
.noInfs());
658 value
= bitEnumSet(value
, FastmathFlags::nsz
, flags
.noSignedZeros());
659 value
= bitEnumSet(value
, FastmathFlags::arcp
, flags
.allowReciprocal());
660 value
= bitEnumSet(value
, FastmathFlags::contract
, flags
.allowContract());
661 value
= bitEnumSet(value
, FastmathFlags::afn
, flags
.approxFunc());
662 value
= bitEnumSet(value
, FastmathFlags::reassoc
, flags
.allowReassoc());
663 FastmathFlagsAttr attr
= FastmathFlagsAttr::get(builder
.getContext(), value
);
664 iface
->setAttr(iface
.getFastmathAttrName(), attr
);
667 /// Returns if `type` is a scalar integer or floating-point type.
668 static bool isScalarType(Type type
) {
669 return isa
<IntegerType
, FloatType
>(type
);
672 /// Returns `type` if it is a builtin integer or floating-point vector type that
673 /// can be used to create an attribute or nullptr otherwise. If provided,
674 /// `arrayShape` is added to the shape of the vector to create an attribute that
675 /// matches an array of vectors.
676 static Type
getVectorTypeForAttr(Type type
, ArrayRef
<int64_t> arrayShape
= {}) {
677 if (!LLVM::isCompatibleVectorType(type
))
680 llvm::ElementCount numElements
= LLVM::getVectorNumElements(type
);
681 if (numElements
.isScalable()) {
682 emitError(UnknownLoc::get(type
.getContext()))
683 << "scalable vectors not supported";
687 // An LLVM dialect vector can only contain scalars.
688 Type elementType
= LLVM::getVectorElementType(type
);
689 if (!isScalarType(elementType
))
692 SmallVector
<int64_t> shape(arrayShape
.begin(), arrayShape
.end());
693 shape
.push_back(numElements
.getKnownMinValue());
694 return VectorType::get(shape
, elementType
);
697 Type
ModuleImport::getBuiltinTypeForAttr(Type type
) {
701 // Return builtin integer and floating-point types as is.
702 if (isScalarType(type
))
705 // Return builtin vectors of integer and floating-point types as is.
706 if (Type vectorType
= getVectorTypeForAttr(type
))
709 // Multi-dimensional array types are converted to tensors or vectors,
710 // depending on the innermost type being a scalar or a vector.
711 SmallVector
<int64_t> arrayShape
;
712 while (auto arrayType
= dyn_cast
<LLVMArrayType
>(type
)) {
713 arrayShape
.push_back(arrayType
.getNumElements());
714 type
= arrayType
.getElementType();
716 if (isScalarType(type
))
717 return RankedTensorType::get(arrayShape
, type
);
718 return getVectorTypeForAttr(type
, arrayShape
);
721 /// Returns an integer or float attribute for the provided scalar constant
722 /// `constScalar` or nullptr if the conversion fails.
723 static TypedAttr
getScalarConstantAsAttr(OpBuilder
&builder
,
724 llvm::Constant
*constScalar
) {
725 MLIRContext
*context
= builder
.getContext();
727 // Convert scalar intergers.
728 if (auto *constInt
= dyn_cast
<llvm::ConstantInt
>(constScalar
)) {
729 return builder
.getIntegerAttr(
730 IntegerType::get(context
, constInt
->getBitWidth()),
731 constInt
->getValue());
734 // Convert scalar floats.
735 if (auto *constFloat
= dyn_cast
<llvm::ConstantFP
>(constScalar
)) {
736 llvm::Type
*type
= constFloat
->getType();
737 FloatType floatType
=
739 ? FloatType::getBF16(context
)
740 : LLVM::detail::getFloatType(context
, type
->getScalarSizeInBits());
742 emitError(UnknownLoc::get(builder
.getContext()))
743 << "unexpected floating-point type";
746 return builder
.getFloatAttr(floatType
, constFloat
->getValueAPF());
751 /// Returns an integer or float attribute array for the provided constant
752 /// sequence `constSequence` or nullptr if the conversion fails.
753 static SmallVector
<Attribute
>
754 getSequenceConstantAsAttrs(OpBuilder
&builder
,
755 llvm::ConstantDataSequential
*constSequence
) {
756 SmallVector
<Attribute
> elementAttrs
;
757 elementAttrs
.reserve(constSequence
->getNumElements());
758 for (auto idx
: llvm::seq
<int64_t>(0, constSequence
->getNumElements())) {
759 llvm::Constant
*constElement
= constSequence
->getElementAsConstant(idx
);
760 elementAttrs
.push_back(getScalarConstantAsAttr(builder
, constElement
));
765 Attribute
ModuleImport::getConstantAsAttr(llvm::Constant
*constant
) {
766 // Convert scalar constants.
767 if (Attribute scalarAttr
= getScalarConstantAsAttr(builder
, constant
))
770 // Convert function references.
771 if (auto *func
= dyn_cast
<llvm::Function
>(constant
))
772 return SymbolRefAttr::get(builder
.getContext(), func
->getName());
774 // Returns the static shape of the provided type if possible.
775 auto getConstantShape
= [&](llvm::Type
*type
) {
776 return llvm::dyn_cast_if_present
<ShapedType
>(
777 getBuiltinTypeForAttr(convertType(type
)));
780 // Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte
781 // integer or half/bfloat/float/double values.
782 if (auto *constArray
= dyn_cast
<llvm::ConstantDataSequential
>(constant
)) {
783 if (constArray
->isString())
784 return builder
.getStringAttr(constArray
->getAsString());
785 auto shape
= getConstantShape(constArray
->getType());
788 // Convert splat constants to splat elements attributes.
789 auto *constVector
= dyn_cast
<llvm::ConstantDataVector
>(constant
);
790 if (constVector
&& constVector
->isSplat()) {
791 // A vector is guaranteed to have at least size one.
792 Attribute splatAttr
= getScalarConstantAsAttr(
793 builder
, constVector
->getElementAsConstant(0));
794 return SplatElementsAttr::get(shape
, splatAttr
);
796 // Convert non-splat constants to dense elements attributes.
797 SmallVector
<Attribute
> elementAttrs
=
798 getSequenceConstantAsAttrs(builder
, constArray
);
799 return DenseElementsAttr::get(shape
, elementAttrs
);
802 // Convert multi-dimensional constant aggregates that store all kinds of
803 // integer and floating-point types.
804 if (auto *constAggregate
= dyn_cast
<llvm::ConstantAggregate
>(constant
)) {
805 auto shape
= getConstantShape(constAggregate
->getType());
808 // Collect the aggregate elements in depths first order.
809 SmallVector
<Attribute
> elementAttrs
;
810 SmallVector
<llvm::Constant
*> workList
= {constAggregate
};
811 while (!workList
.empty()) {
812 llvm::Constant
*current
= workList
.pop_back_val();
813 // Append any nested aggregates in reverse order to ensure the head
814 // element of the nested aggregates is at the back of the work list.
815 if (auto *constAggregate
= dyn_cast
<llvm::ConstantAggregate
>(current
)) {
817 reverse(llvm::seq
<int64_t>(0, constAggregate
->getNumOperands())))
818 workList
.push_back(constAggregate
->getAggregateElement(idx
));
821 // Append the elements of nested constant arrays or vectors that store
822 // 1/2/4/8-byte integer or half/bfloat/float/double values.
823 if (auto *constArray
= dyn_cast
<llvm::ConstantDataSequential
>(current
)) {
824 SmallVector
<Attribute
> attrs
=
825 getSequenceConstantAsAttrs(builder
, constArray
);
826 elementAttrs
.append(attrs
.begin(), attrs
.end());
829 // Append nested scalar constants that store all kinds of integer and
830 // floating-point types.
831 if (Attribute scalarAttr
= getScalarConstantAsAttr(builder
, current
)) {
832 elementAttrs
.push_back(scalarAttr
);
835 // Bail if the aggregate contains a unsupported constant type such as a
836 // constant expression.
839 return DenseElementsAttr::get(shape
, elementAttrs
);
842 // Convert zero aggregates.
843 if (auto *constZero
= dyn_cast
<llvm::ConstantAggregateZero
>(constant
)) {
844 auto shape
= llvm::dyn_cast_if_present
<ShapedType
>(
845 getBuiltinTypeForAttr(convertType(constZero
->getType())));
848 // Convert zero aggregates with a static shape to splat elements attributes.
849 Attribute splatAttr
= builder
.getZeroAttr(shape
.getElementType());
850 assert(splatAttr
&& "expected non-null zero attribute for scalar types");
851 return SplatElementsAttr::get(shape
, splatAttr
);
856 LogicalResult
ModuleImport::convertGlobal(llvm::GlobalVariable
*globalVar
) {
857 // Insert the global after the last one or at the start of the module.
858 OpBuilder::InsertionGuard
guard(builder
);
859 if (!globalInsertionOp
)
860 builder
.setInsertionPointToStart(mlirModule
.getBody());
862 builder
.setInsertionPointAfter(globalInsertionOp
);
865 if (globalVar
->hasInitializer())
866 valueAttr
= getConstantAsAttr(globalVar
->getInitializer());
867 Type type
= convertType(globalVar
->getValueType());
869 uint64_t alignment
= 0;
870 llvm::MaybeAlign maybeAlign
= globalVar
->getAlign();
871 if (maybeAlign
.has_value()) {
872 llvm::Align align
= *maybeAlign
;
873 alignment
= align
.value();
876 // Get the global expression associated with this global variable and convert
878 DIGlobalVariableExpressionAttr globalExpressionAttr
;
879 SmallVector
<llvm::DIGlobalVariableExpression
*> globalExpressions
;
880 globalVar
->getDebugInfo(globalExpressions
);
882 // There should only be a single global expression.
883 if (!globalExpressions
.empty())
884 globalExpressionAttr
=
885 debugImporter
->translateGlobalVariableExpression(globalExpressions
[0]);
887 GlobalOp globalOp
= builder
.create
<GlobalOp
>(
888 mlirModule
.getLoc(), type
, globalVar
->isConstant(),
889 convertLinkageFromLLVM(globalVar
->getLinkage()), globalVar
->getName(),
890 valueAttr
, alignment
, /*addr_space=*/globalVar
->getAddressSpace(),
891 /*dso_local=*/globalVar
->isDSOLocal(),
892 /*thread_local=*/globalVar
->isThreadLocal(), /*comdat=*/SymbolRefAttr(),
893 /*attrs=*/ArrayRef
<NamedAttribute
>(), /*dbgExpr=*/globalExpressionAttr
);
894 globalInsertionOp
= globalOp
;
896 if (globalVar
->hasInitializer() && !valueAttr
) {
898 Block
*block
= builder
.createBlock(&globalOp
.getInitializerRegion());
899 setConstantInsertionPointToStart(block
);
900 FailureOr
<Value
> initializer
=
901 convertConstantExpr(globalVar
->getInitializer());
902 if (failed(initializer
))
904 builder
.create
<ReturnOp
>(globalOp
.getLoc(), *initializer
);
906 if (globalVar
->hasAtLeastLocalUnnamedAddr()) {
907 globalOp
.setUnnamedAddr(
908 convertUnnamedAddrFromLLVM(globalVar
->getUnnamedAddr()));
910 if (globalVar
->hasSection())
911 globalOp
.setSection(globalVar
->getSection());
912 globalOp
.setVisibility_(
913 convertVisibilityFromLLVM(globalVar
->getVisibility()));
915 if (globalVar
->hasComdat())
916 globalOp
.setComdatAttr(comdatMapping
.lookup(globalVar
->getComdat()));
922 ModuleImport::convertGlobalCtorsAndDtors(llvm::GlobalVariable
*globalVar
) {
923 if (!globalVar
->hasInitializer() || !globalVar
->hasAppendingLinkage())
926 dyn_cast
<llvm::ConstantArray
>(globalVar
->getInitializer());
930 SmallVector
<Attribute
> funcs
;
931 SmallVector
<int32_t> priorities
;
932 for (llvm::Value
*operand
: initializer
->operands()) {
933 auto *aggregate
= dyn_cast
<llvm::ConstantAggregate
>(operand
);
934 if (!aggregate
|| aggregate
->getNumOperands() != 3)
937 auto *priority
= dyn_cast
<llvm::ConstantInt
>(aggregate
->getOperand(0));
938 auto *func
= dyn_cast
<llvm::Function
>(aggregate
->getOperand(1));
939 auto *data
= dyn_cast
<llvm::Constant
>(aggregate
->getOperand(2));
940 if (!priority
|| !func
|| !data
)
943 // GlobalCtorsOps and GlobalDtorsOps do not support non-null data fields.
944 if (!data
->isNullValue())
947 funcs
.push_back(FlatSymbolRefAttr::get(context
, func
->getName()));
948 priorities
.push_back(priority
->getValue().getZExtValue());
951 OpBuilder::InsertionGuard
guard(builder
);
952 if (!globalInsertionOp
)
953 builder
.setInsertionPointToStart(mlirModule
.getBody());
955 builder
.setInsertionPointAfter(globalInsertionOp
);
957 if (globalVar
->getName() == getGlobalCtorsVarName()) {
958 globalInsertionOp
= builder
.create
<LLVM::GlobalCtorsOp
>(
959 mlirModule
.getLoc(), builder
.getArrayAttr(funcs
),
960 builder
.getI32ArrayAttr(priorities
));
963 globalInsertionOp
= builder
.create
<LLVM::GlobalDtorsOp
>(
964 mlirModule
.getLoc(), builder
.getArrayAttr(funcs
),
965 builder
.getI32ArrayAttr(priorities
));
969 SetVector
<llvm::Constant
*>
970 ModuleImport::getConstantsToConvert(llvm::Constant
*constant
) {
971 // Return the empty set if the constant has been translated before.
972 if (valueMapping
.contains(constant
))
975 // Traverse the constants in post-order and stop the traversal if a constant
976 // already has a `valueMapping` from an earlier constant translation or if the
977 // constant is traversed a second time.
978 SetVector
<llvm::Constant
*> orderedSet
;
979 SetVector
<llvm::Constant
*> workList
;
980 DenseMap
<llvm::Constant
*, SmallVector
<llvm::Constant
*>> adjacencyLists
;
981 workList
.insert(constant
);
982 while (!workList
.empty()) {
983 llvm::Constant
*current
= workList
.back();
984 // Collect all dependencies of the current constant and add them to the
985 // adjacency list if none has been computed before.
986 auto adjacencyIt
= adjacencyLists
.find(current
);
987 if (adjacencyIt
== adjacencyLists
.end()) {
988 adjacencyIt
= adjacencyLists
.try_emplace(current
).first
;
989 // Add all constant operands to the adjacency list and skip any other
990 // values such as basic block addresses.
991 for (llvm::Value
*operand
: current
->operands())
992 if (auto *constDependency
= dyn_cast
<llvm::Constant
>(operand
))
993 adjacencyIt
->getSecond().push_back(constDependency
);
994 // Use the getElementValue method to add the dependencies of zero
995 // initialized aggregate constants since they do not take any operands.
996 if (auto *constAgg
= dyn_cast
<llvm::ConstantAggregateZero
>(current
)) {
997 unsigned numElements
= constAgg
->getElementCount().getFixedValue();
998 for (unsigned i
= 0, e
= numElements
; i
!= e
; ++i
)
999 adjacencyIt
->getSecond().push_back(constAgg
->getElementValue(i
));
1002 // Add the current constant to the `orderedSet` of the traversed nodes if
1003 // all its dependencies have been traversed before. Additionally, remove the
1004 // constant from the `workList` and continue the traversal.
1005 if (adjacencyIt
->getSecond().empty()) {
1006 orderedSet
.insert(current
);
1007 workList
.pop_back();
1010 // Add the next dependency from the adjacency list to the `workList` and
1011 // continue the traversal. Remove the dependency from the adjacency list to
1012 // mark that it has been processed. Only enqueue the dependency if it has no
1013 // `valueMapping` from an earlier translation and if it has not been
1015 llvm::Constant
*dependency
= adjacencyIt
->getSecond().pop_back_val();
1016 if (valueMapping
.contains(dependency
) || workList
.contains(dependency
) ||
1017 orderedSet
.contains(dependency
))
1019 workList
.insert(dependency
);
1025 FailureOr
<Value
> ModuleImport::convertConstant(llvm::Constant
*constant
) {
1026 Location loc
= UnknownLoc::get(context
);
1028 // Convert constants that can be represented as attributes.
1029 if (Attribute attr
= getConstantAsAttr(constant
)) {
1030 Type type
= convertType(constant
->getType());
1031 if (auto symbolRef
= dyn_cast
<FlatSymbolRefAttr
>(attr
)) {
1032 return builder
.create
<AddressOfOp
>(loc
, type
, symbolRef
.getValue())
1035 return builder
.create
<ConstantOp
>(loc
, type
, attr
).getResult();
1038 // Convert null pointer constants.
1039 if (auto *nullPtr
= dyn_cast
<llvm::ConstantPointerNull
>(constant
)) {
1040 Type type
= convertType(nullPtr
->getType());
1041 return builder
.create
<ZeroOp
>(loc
, type
).getResult();
1044 // Convert none token constants.
1045 if (isa
<llvm::ConstantTokenNone
>(constant
)) {
1046 return builder
.create
<NoneTokenOp
>(loc
).getResult();
1050 if (auto *poisonVal
= dyn_cast
<llvm::PoisonValue
>(constant
)) {
1051 Type type
= convertType(poisonVal
->getType());
1052 return builder
.create
<PoisonOp
>(loc
, type
).getResult();
1056 if (auto *undefVal
= dyn_cast
<llvm::UndefValue
>(constant
)) {
1057 Type type
= convertType(undefVal
->getType());
1058 return builder
.create
<UndefOp
>(loc
, type
).getResult();
1061 // Convert global variable accesses.
1062 if (auto *globalVar
= dyn_cast
<llvm::GlobalVariable
>(constant
)) {
1063 Type type
= convertType(globalVar
->getType());
1064 auto symbolRef
= FlatSymbolRefAttr::get(context
, globalVar
->getName());
1065 return builder
.create
<AddressOfOp
>(loc
, type
, symbolRef
).getResult();
1068 // Convert constant expressions.
1069 if (auto *constExpr
= dyn_cast
<llvm::ConstantExpr
>(constant
)) {
1070 // Convert the constant expression to a temporary LLVM instruction and
1071 // translate it using the `processInstruction` method. Delete the
1072 // instruction after the translation and remove it from `valueMapping`,
1073 // since later calls to `getAsInstruction` may return the same address
1074 // resulting in a conflicting `valueMapping` entry.
1075 llvm::Instruction
*inst
= constExpr
->getAsInstruction();
1076 auto guard
= llvm::make_scope_exit([&]() {
1077 assert(!noResultOpMapping
.contains(inst
) &&
1078 "expected constant expression to return a result");
1079 valueMapping
.erase(inst
);
1080 inst
->deleteValue();
1082 // Note: `processInstruction` does not call `convertConstant` recursively
1083 // since all constant dependencies have been converted before.
1084 assert(llvm::all_of(inst
->operands(), [&](llvm::Value
*value
) {
1085 return valueMapping
.contains(value
);
1087 if (failed(processInstruction(inst
)))
1089 return lookupValue(inst
);
1092 // Convert aggregate constants.
1093 if (isa
<llvm::ConstantAggregate
>(constant
) ||
1094 isa
<llvm::ConstantAggregateZero
>(constant
)) {
1095 // Lookup the aggregate elements that have been converted before.
1096 SmallVector
<Value
> elementValues
;
1097 if (auto *constAgg
= dyn_cast
<llvm::ConstantAggregate
>(constant
)) {
1098 elementValues
.reserve(constAgg
->getNumOperands());
1099 for (llvm::Value
*operand
: constAgg
->operands())
1100 elementValues
.push_back(lookupValue(operand
));
1102 if (auto *constAgg
= dyn_cast
<llvm::ConstantAggregateZero
>(constant
)) {
1103 unsigned numElements
= constAgg
->getElementCount().getFixedValue();
1104 elementValues
.reserve(numElements
);
1105 for (unsigned i
= 0, e
= numElements
; i
!= e
; ++i
)
1106 elementValues
.push_back(lookupValue(constAgg
->getElementValue(i
)));
1108 assert(llvm::count(elementValues
, nullptr) == 0 &&
1109 "expected all elements have been converted before");
1111 // Generate an UndefOp as root value and insert the aggregate elements.
1112 Type rootType
= convertType(constant
->getType());
1113 bool isArrayOrStruct
= isa
<LLVMArrayType
, LLVMStructType
>(rootType
);
1114 assert((isArrayOrStruct
|| LLVM::isCompatibleVectorType(rootType
)) &&
1115 "unrecognized aggregate type");
1116 Value root
= builder
.create
<UndefOp
>(loc
, rootType
);
1117 for (const auto &it
: llvm::enumerate(elementValues
)) {
1118 if (isArrayOrStruct
) {
1119 root
= builder
.create
<InsertValueOp
>(loc
, root
, it
.value(), it
.index());
1121 Attribute indexAttr
= builder
.getI32IntegerAttr(it
.index());
1123 builder
.create
<ConstantOp
>(loc
, builder
.getI32Type(), indexAttr
);
1124 root
= builder
.create
<InsertElementOp
>(loc
, rootType
, root
, it
.value(),
1131 if (auto *constTargetNone
= dyn_cast
<llvm::ConstantTargetNone
>(constant
)) {
1132 LLVMTargetExtType targetExtType
=
1133 cast
<LLVMTargetExtType
>(convertType(constTargetNone
->getType()));
1134 assert(targetExtType
.hasProperty(LLVMTargetExtType::HasZeroInit
) &&
1135 "target extension type does not support zero-initialization");
1136 // Create llvm.mlir.zero operation to represent zero-initialization of
1137 // target extension type.
1138 return builder
.create
<LLVM::ZeroOp
>(loc
, targetExtType
).getRes();
1141 StringRef error
= "";
1142 if (isa
<llvm::BlockAddress
>(constant
))
1143 error
= " since blockaddress(...) is unsupported";
1145 return emitError(loc
) << "unhandled constant: " << diag(*constant
) << error
;
1148 FailureOr
<Value
> ModuleImport::convertConstantExpr(llvm::Constant
*constant
) {
1149 // Only call the function for constants that have not been translated before
1150 // since it updates the constant insertion point assuming the converted
1151 // constant has been introduced at the end of the constant section.
1152 assert(!valueMapping
.contains(constant
) &&
1153 "expected constant has not been converted before");
1154 assert(constantInsertionBlock
&&
1155 "expected the constant insertion block to be non-null");
1157 // Insert the constant after the last one or at the start of the entry block.
1158 OpBuilder::InsertionGuard
guard(builder
);
1159 if (!constantInsertionOp
)
1160 builder
.setInsertionPointToStart(constantInsertionBlock
);
1162 builder
.setInsertionPointAfter(constantInsertionOp
);
1164 // Convert all constants of the expression and add them to `valueMapping`.
1165 SetVector
<llvm::Constant
*> constantsToConvert
=
1166 getConstantsToConvert(constant
);
1167 for (llvm::Constant
*constantToConvert
: constantsToConvert
) {
1168 FailureOr
<Value
> converted
= convertConstant(constantToConvert
);
1169 if (failed(converted
))
1171 mapValue(constantToConvert
, *converted
);
1174 // Update the constant insertion point and return the converted constant.
1175 Value result
= lookupValue(constant
);
1176 constantInsertionOp
= result
.getDefiningOp();
1180 FailureOr
<Value
> ModuleImport::convertValue(llvm::Value
*value
) {
1181 assert(!isa
<llvm::MetadataAsValue
>(value
) &&
1182 "expected value to not be metadata");
1184 // Return the mapped value if it has been converted before.
1185 auto it
= valueMapping
.find(value
);
1186 if (it
!= valueMapping
.end())
1187 return it
->getSecond();
1189 // Convert constants such as immediate values that have no mapping yet.
1190 if (auto *constant
= dyn_cast
<llvm::Constant
>(value
))
1191 return convertConstantExpr(constant
);
1193 Location loc
= UnknownLoc::get(context
);
1194 if (auto *inst
= dyn_cast
<llvm::Instruction
>(value
))
1195 loc
= translateLoc(inst
->getDebugLoc());
1196 return emitError(loc
) << "unhandled value: " << diag(*value
);
1199 FailureOr
<Value
> ModuleImport::convertMetadataValue(llvm::Value
*value
) {
1200 // A value may be wrapped as metadata, for example, when passed to a debug
1201 // intrinsic. Unwrap these values before the conversion.
1202 auto *nodeAsVal
= dyn_cast
<llvm::MetadataAsValue
>(value
);
1205 auto *node
= dyn_cast
<llvm::ValueAsMetadata
>(nodeAsVal
->getMetadata());
1208 value
= node
->getValue();
1210 // Return the mapped value if it has been converted before.
1211 auto it
= valueMapping
.find(value
);
1212 if (it
!= valueMapping
.end())
1213 return it
->getSecond();
1215 // Convert constants such as immediate values that have no mapping yet.
1216 if (auto *constant
= dyn_cast
<llvm::Constant
>(value
))
1217 return convertConstantExpr(constant
);
1221 FailureOr
<SmallVector
<Value
>>
1222 ModuleImport::convertValues(ArrayRef
<llvm::Value
*> values
) {
1223 SmallVector
<Value
> remapped
;
1224 remapped
.reserve(values
.size());
1225 for (llvm::Value
*value
: values
) {
1226 FailureOr
<Value
> converted
= convertValue(value
);
1227 if (failed(converted
))
1229 remapped
.push_back(*converted
);
1234 LogicalResult
ModuleImport::convertIntrinsicArguments(
1235 ArrayRef
<llvm::Value
*> values
, ArrayRef
<unsigned> immArgPositions
,
1236 ArrayRef
<StringLiteral
> immArgAttrNames
, SmallVectorImpl
<Value
> &valuesOut
,
1237 SmallVectorImpl
<NamedAttribute
> &attrsOut
) {
1238 assert(immArgPositions
.size() == immArgAttrNames
.size() &&
1239 "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
1242 SmallVector
<llvm::Value
*> operands(values
);
1243 for (auto [immArgPos
, immArgName
] :
1244 llvm::zip(immArgPositions
, immArgAttrNames
)) {
1245 auto &value
= operands
[immArgPos
];
1246 auto *constant
= llvm::cast
<llvm::Constant
>(value
);
1247 auto attr
= getScalarConstantAsAttr(builder
, constant
);
1248 assert(attr
&& attr
.getType().isIntOrFloat() &&
1249 "expected immarg to be float or integer constant");
1250 auto nameAttr
= StringAttr::get(attr
.getContext(), immArgName
);
1251 attrsOut
.push_back({nameAttr
, attr
});
1252 // Mark matched attribute values as null (so they can be removed below).
1256 for (llvm::Value
*value
: operands
) {
1259 auto mlirValue
= convertValue(value
);
1260 if (failed(mlirValue
))
1262 valuesOut
.push_back(*mlirValue
);
1268 IntegerAttr
ModuleImport::matchIntegerAttr(llvm::Value
*value
) {
1269 IntegerAttr integerAttr
;
1270 FailureOr
<Value
> converted
= convertValue(value
);
1271 bool success
= succeeded(converted
) &&
1272 matchPattern(*converted
, m_Constant(&integerAttr
));
1273 assert(success
&& "expected a constant integer value");
1278 FloatAttr
ModuleImport::matchFloatAttr(llvm::Value
*value
) {
1279 FloatAttr floatAttr
;
1280 FailureOr
<Value
> converted
= convertValue(value
);
1282 succeeded(converted
) && matchPattern(*converted
, m_Constant(&floatAttr
));
1283 assert(success
&& "expected a constant float value");
1288 DILocalVariableAttr
ModuleImport::matchLocalVariableAttr(llvm::Value
*value
) {
1289 auto *nodeAsVal
= cast
<llvm::MetadataAsValue
>(value
);
1290 auto *node
= cast
<llvm::DILocalVariable
>(nodeAsVal
->getMetadata());
1291 return debugImporter
->translate(node
);
1294 DILabelAttr
ModuleImport::matchLabelAttr(llvm::Value
*value
) {
1295 auto *nodeAsVal
= cast
<llvm::MetadataAsValue
>(value
);
1296 auto *node
= cast
<llvm::DILabel
>(nodeAsVal
->getMetadata());
1297 return debugImporter
->translate(node
);
1300 FPExceptionBehaviorAttr
1301 ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value
*value
) {
1302 auto *metadata
= cast
<llvm::MetadataAsValue
>(value
);
1303 auto *mdstr
= cast
<llvm::MDString
>(metadata
->getMetadata());
1304 std::optional
<llvm::fp::ExceptionBehavior
> optLLVM
=
1305 llvm::convertStrToExceptionBehavior(mdstr
->getString());
1306 assert(optLLVM
&& "Expecting FP exception behavior");
1307 return builder
.getAttr
<FPExceptionBehaviorAttr
>(
1308 convertFPExceptionBehaviorFromLLVM(*optLLVM
));
1311 RoundingModeAttr
ModuleImport::matchRoundingModeAttr(llvm::Value
*value
) {
1312 auto *metadata
= cast
<llvm::MetadataAsValue
>(value
);
1313 auto *mdstr
= cast
<llvm::MDString
>(metadata
->getMetadata());
1314 std::optional
<llvm::RoundingMode
> optLLVM
=
1315 llvm::convertStrToRoundingMode(mdstr
->getString());
1316 assert(optLLVM
&& "Expecting rounding mode");
1317 return builder
.getAttr
<RoundingModeAttr
>(
1318 convertRoundingModeFromLLVM(*optLLVM
));
1321 FailureOr
<SmallVector
<AliasScopeAttr
>>
1322 ModuleImport::matchAliasScopeAttrs(llvm::Value
*value
) {
1323 auto *nodeAsVal
= cast
<llvm::MetadataAsValue
>(value
);
1324 auto *node
= cast
<llvm::MDNode
>(nodeAsVal
->getMetadata());
1325 return lookupAliasScopeAttrs(node
);
1328 Location
ModuleImport::translateLoc(llvm::DILocation
*loc
) {
1329 return debugImporter
->translateLoc(loc
);
1333 ModuleImport::convertBranchArgs(llvm::Instruction
*branch
,
1334 llvm::BasicBlock
*target
,
1335 SmallVectorImpl
<Value
> &blockArguments
) {
1336 for (auto inst
= target
->begin(); isa
<llvm::PHINode
>(inst
); ++inst
) {
1337 auto *phiInst
= cast
<llvm::PHINode
>(&*inst
);
1338 llvm::Value
*value
= phiInst
->getIncomingValueForBlock(branch
->getParent());
1339 FailureOr
<Value
> converted
= convertValue(value
);
1340 if (failed(converted
))
1342 blockArguments
.push_back(*converted
);
1348 ModuleImport::convertCallTypeAndOperands(llvm::CallBase
*callInst
,
1349 SmallVectorImpl
<Type
> &types
,
1350 SmallVectorImpl
<Value
> &operands
) {
1351 if (!callInst
->getType()->isVoidTy())
1352 types
.push_back(convertType(callInst
->getType()));
1354 if (!callInst
->getCalledFunction()) {
1355 FailureOr
<Value
> called
= convertValue(callInst
->getCalledOperand());
1358 operands
.push_back(*called
);
1360 SmallVector
<llvm::Value
*> args(callInst
->args());
1361 FailureOr
<SmallVector
<Value
>> arguments
= convertValues(args
);
1362 if (failed(arguments
))
1364 llvm::append_range(operands
, *arguments
);
1368 LogicalResult
ModuleImport::convertIntrinsic(llvm::CallInst
*inst
) {
1369 if (succeeded(iface
.convertIntrinsic(builder
, inst
, *this)))
1372 Location loc
= translateLoc(inst
->getDebugLoc());
1373 return emitError(loc
) << "unhandled intrinsic: " << diag(*inst
);
1376 LogicalResult
ModuleImport::convertInstruction(llvm::Instruction
*inst
) {
1377 // Convert all instructions that do not provide an MLIR builder.
1378 Location loc
= translateLoc(inst
->getDebugLoc());
1379 if (inst
->getOpcode() == llvm::Instruction::Br
) {
1380 auto *brInst
= cast
<llvm::BranchInst
>(inst
);
1382 SmallVector
<Block
*> succBlocks
;
1383 SmallVector
<SmallVector
<Value
>> succBlockArgs
;
1384 for (auto i
: llvm::seq
<unsigned>(0, brInst
->getNumSuccessors())) {
1385 llvm::BasicBlock
*succ
= brInst
->getSuccessor(i
);
1386 SmallVector
<Value
> blockArgs
;
1387 if (failed(convertBranchArgs(brInst
, succ
, blockArgs
)))
1389 succBlocks
.push_back(lookupBlock(succ
));
1390 succBlockArgs
.push_back(blockArgs
);
1393 if (!brInst
->isConditional()) {
1394 auto brOp
= builder
.create
<LLVM::BrOp
>(loc
, succBlockArgs
.front(),
1395 succBlocks
.front());
1396 mapNoResultOp(inst
, brOp
);
1399 FailureOr
<Value
> condition
= convertValue(brInst
->getCondition());
1400 if (failed(condition
))
1402 auto condBrOp
= builder
.create
<LLVM::CondBrOp
>(
1403 loc
, *condition
, succBlocks
.front(), succBlockArgs
.front(),
1404 succBlocks
.back(), succBlockArgs
.back());
1405 mapNoResultOp(inst
, condBrOp
);
1408 if (inst
->getOpcode() == llvm::Instruction::Switch
) {
1409 auto *swInst
= cast
<llvm::SwitchInst
>(inst
);
1410 // Process the condition value.
1411 FailureOr
<Value
> condition
= convertValue(swInst
->getCondition());
1412 if (failed(condition
))
1414 SmallVector
<Value
> defaultBlockArgs
;
1415 // Process the default case.
1416 llvm::BasicBlock
*defaultBB
= swInst
->getDefaultDest();
1417 if (failed(convertBranchArgs(swInst
, defaultBB
, defaultBlockArgs
)))
1420 // Process the cases.
1421 unsigned numCases
= swInst
->getNumCases();
1422 SmallVector
<SmallVector
<Value
>> caseOperands(numCases
);
1423 SmallVector
<ValueRange
> caseOperandRefs(numCases
);
1424 SmallVector
<APInt
> caseValues(numCases
);
1425 SmallVector
<Block
*> caseBlocks(numCases
);
1426 for (const auto &it
: llvm::enumerate(swInst
->cases())) {
1427 const llvm::SwitchInst::CaseHandle
&caseHandle
= it
.value();
1428 llvm::BasicBlock
*succBB
= caseHandle
.getCaseSuccessor();
1429 if (failed(convertBranchArgs(swInst
, succBB
, caseOperands
[it
.index()])))
1431 caseOperandRefs
[it
.index()] = caseOperands
[it
.index()];
1432 caseValues
[it
.index()] = caseHandle
.getCaseValue()->getValue();
1433 caseBlocks
[it
.index()] = lookupBlock(succBB
);
1436 auto switchOp
= builder
.create
<SwitchOp
>(
1437 loc
, *condition
, lookupBlock(defaultBB
), defaultBlockArgs
, caseValues
,
1438 caseBlocks
, caseOperandRefs
);
1439 mapNoResultOp(inst
, switchOp
);
1442 if (inst
->getOpcode() == llvm::Instruction::PHI
) {
1443 Type type
= convertType(inst
->getType());
1444 mapValue(inst
, builder
.getInsertionBlock()->addArgument(
1445 type
, translateLoc(inst
->getDebugLoc())));
1448 if (inst
->getOpcode() == llvm::Instruction::Call
) {
1449 auto *callInst
= cast
<llvm::CallInst
>(inst
);
1451 SmallVector
<Type
> types
;
1452 SmallVector
<Value
> operands
;
1453 if (failed(convertCallTypeAndOperands(callInst
, types
, operands
)))
1457 dyn_cast
<LLVMFunctionType
>(convertType(callInst
->getFunctionType()));
1463 if (llvm::Function
*callee
= callInst
->getCalledFunction()) {
1464 callOp
= builder
.create
<CallOp
>(
1465 loc
, funcTy
, SymbolRefAttr::get(context
, callee
->getName()),
1468 callOp
= builder
.create
<CallOp
>(loc
, funcTy
, operands
);
1470 callOp
.setCConv(convertCConvFromLLVM(callInst
->getCallingConv()));
1471 setFastmathFlagsAttr(inst
, callOp
);
1472 if (!callInst
->getType()->isVoidTy())
1473 mapValue(inst
, callOp
.getResult());
1475 mapNoResultOp(inst
, callOp
);
1478 if (inst
->getOpcode() == llvm::Instruction::LandingPad
) {
1479 auto *lpInst
= cast
<llvm::LandingPadInst
>(inst
);
1481 SmallVector
<Value
> operands
;
1482 operands
.reserve(lpInst
->getNumClauses());
1483 for (auto i
: llvm::seq
<unsigned>(0, lpInst
->getNumClauses())) {
1484 FailureOr
<Value
> operand
= convertValue(lpInst
->getClause(i
));
1485 if (failed(operand
))
1487 operands
.push_back(*operand
);
1490 Type type
= convertType(lpInst
->getType());
1492 builder
.create
<LandingpadOp
>(loc
, type
, lpInst
->isCleanup(), operands
);
1493 mapValue(inst
, lpOp
);
1496 if (inst
->getOpcode() == llvm::Instruction::Invoke
) {
1497 auto *invokeInst
= cast
<llvm::InvokeInst
>(inst
);
1499 SmallVector
<Type
> types
;
1500 SmallVector
<Value
> operands
;
1501 if (failed(convertCallTypeAndOperands(invokeInst
, types
, operands
)))
1504 // Check whether the invoke result is an argument to the normal destination
1506 bool invokeResultUsedInPhi
= llvm::any_of(
1507 invokeInst
->getNormalDest()->phis(), [&](const llvm::PHINode
&phi
) {
1508 return phi
.getIncomingValueForBlock(invokeInst
->getParent()) ==
1512 Block
*normalDest
= lookupBlock(invokeInst
->getNormalDest());
1513 Block
*directNormalDest
= normalDest
;
1514 if (invokeResultUsedInPhi
) {
1515 // The invoke result cannot be an argument to the normal destination
1516 // block, as that would imply using the invoke operation result in its
1517 // definition, so we need to create a dummy block to serve as an
1518 // intermediate destination.
1519 OpBuilder::InsertionGuard
g(builder
);
1520 directNormalDest
= builder
.createBlock(normalDest
);
1523 SmallVector
<Value
> unwindArgs
;
1524 if (failed(convertBranchArgs(invokeInst
, invokeInst
->getUnwindDest(),
1529 dyn_cast
<LLVMFunctionType
>(convertType(invokeInst
->getFunctionType()));
1533 // Create the invoke operation. Normal destination block arguments will be
1534 // added later on to handle the case in which the operation result is
1535 // included in this list.
1537 if (llvm::Function
*callee
= invokeInst
->getCalledFunction()) {
1538 invokeOp
= builder
.create
<InvokeOp
>(
1540 SymbolRefAttr::get(builder
.getContext(), callee
->getName()), operands
,
1541 directNormalDest
, ValueRange(),
1542 lookupBlock(invokeInst
->getUnwindDest()), unwindArgs
);
1544 invokeOp
= builder
.create
<InvokeOp
>(
1545 loc
, funcTy
, /*callee=*/nullptr, operands
, directNormalDest
,
1546 ValueRange(), lookupBlock(invokeInst
->getUnwindDest()), unwindArgs
);
1548 invokeOp
.setCConv(convertCConvFromLLVM(invokeInst
->getCallingConv()));
1549 if (!invokeInst
->getType()->isVoidTy())
1550 mapValue(inst
, invokeOp
.getResults().front());
1552 mapNoResultOp(inst
, invokeOp
);
1554 SmallVector
<Value
> normalArgs
;
1555 if (failed(convertBranchArgs(invokeInst
, invokeInst
->getNormalDest(),
1559 if (invokeResultUsedInPhi
) {
1560 // The dummy normal dest block will just host an unconditional branch
1561 // instruction to the normal destination block passing the required block
1562 // arguments (including the invoke operation's result).
1563 OpBuilder::InsertionGuard
g(builder
);
1564 builder
.setInsertionPointToStart(directNormalDest
);
1565 builder
.create
<LLVM::BrOp
>(loc
, normalArgs
, normalDest
);
1567 // If the invoke operation's result is not a block argument to the normal
1568 // destination block, just add the block arguments as usual.
1569 assert(llvm::none_of(
1571 [&](Value val
) { return val
.getDefiningOp() == invokeOp
; }) &&
1572 "An llvm.invoke operation cannot pass its result as a block "
1574 invokeOp
.getNormalDestOperandsMutable().append(normalArgs
);
1579 if (inst
->getOpcode() == llvm::Instruction::GetElementPtr
) {
1580 auto *gepInst
= cast
<llvm::GetElementPtrInst
>(inst
);
1581 Type sourceElementType
= convertType(gepInst
->getSourceElementType());
1582 FailureOr
<Value
> basePtr
= convertValue(gepInst
->getOperand(0));
1583 if (failed(basePtr
))
1586 // Treat every indices as dynamic since GEPOp::build will refine those
1587 // indices into static attributes later. One small downside of this
1588 // approach is that many unused `llvm.mlir.constant` would be emitted
1590 SmallVector
<GEPArg
> indices
;
1591 for (llvm::Value
*operand
: llvm::drop_begin(gepInst
->operand_values())) {
1592 FailureOr
<Value
> index
= convertValue(operand
);
1595 indices
.push_back(*index
);
1598 Type type
= convertType(inst
->getType());
1599 auto gepOp
= builder
.create
<GEPOp
>(loc
, type
, sourceElementType
, *basePtr
,
1600 indices
, gepInst
->isInBounds());
1601 mapValue(inst
, gepOp
);
1605 // Convert all instructions that have an mlirBuilder.
1606 if (succeeded(convertInstructionImpl(builder
, inst
, *this, iface
)))
1609 return emitError(loc
) << "unhandled instruction: " << diag(*inst
);
1612 LogicalResult
ModuleImport::processInstruction(llvm::Instruction
*inst
) {
1613 // FIXME: Support uses of SubtargetData.
1614 // FIXME: Add support for call / operand attributes.
1615 // FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch,
1616 // callbr, vaarg, catchpad, cleanuppad instructions.
1618 // Convert LLVM intrinsics calls to MLIR intrinsics.
1619 if (auto *intrinsic
= dyn_cast
<llvm::IntrinsicInst
>(inst
))
1620 return convertIntrinsic(intrinsic
);
1622 // Convert all remaining LLVM instructions to MLIR operations.
1623 return convertInstruction(inst
);
1626 FlatSymbolRefAttr
ModuleImport::getPersonalityAsAttr(llvm::Function
*f
) {
1627 if (!f
->hasPersonalityFn())
1630 llvm::Constant
*pf
= f
->getPersonalityFn();
1632 // If it directly has a name, we can use it.
1634 return SymbolRefAttr::get(builder
.getContext(), pf
->getName());
1636 // If it doesn't have a name, currently, only function pointers that are
1637 // bitcast to i8* are parsed.
1638 if (auto *ce
= dyn_cast
<llvm::ConstantExpr
>(pf
)) {
1639 if (ce
->getOpcode() == llvm::Instruction::BitCast
&&
1640 ce
->getType() == llvm::PointerType::getUnqual(f
->getContext())) {
1641 if (auto *func
= dyn_cast
<llvm::Function
>(ce
->getOperand(0)))
1642 return SymbolRefAttr::get(builder
.getContext(), func
->getName());
1645 return FlatSymbolRefAttr();
1648 static void processMemoryEffects(llvm::Function
*func
, LLVMFuncOp funcOp
) {
1649 llvm::MemoryEffects memEffects
= func
->getMemoryEffects();
1651 auto othermem
= convertModRefInfoFromLLVM(
1652 memEffects
.getModRef(llvm::MemoryEffects::Location::Other
));
1653 auto argMem
= convertModRefInfoFromLLVM(
1654 memEffects
.getModRef(llvm::MemoryEffects::Location::ArgMem
));
1655 auto inaccessibleMem
= convertModRefInfoFromLLVM(
1656 memEffects
.getModRef(llvm::MemoryEffects::Location::InaccessibleMem
));
1657 auto memAttr
= MemoryEffectsAttr::get(funcOp
.getContext(), othermem
, argMem
,
1659 // Only set the attr when it does not match the default value.
1660 if (memAttr
.isReadWrite())
1662 funcOp
.setMemoryAttr(memAttr
);
1665 // List of LLVM IR attributes that map to an explicit attribute on the MLIR
1667 static constexpr std::array ExplicitAttributes
{
1668 StringLiteral("aarch64_pstate_sm_enabled"),
1669 StringLiteral("aarch64_pstate_sm_body"),
1670 StringLiteral("aarch64_pstate_sm_compatible"),
1671 StringLiteral("aarch64_new_za"),
1672 StringLiteral("aarch64_preserves_za"),
1673 StringLiteral("aarch64_in_za"),
1674 StringLiteral("aarch64_out_za"),
1675 StringLiteral("aarch64_inout_za"),
1676 StringLiteral("vscale_range"),
1677 StringLiteral("frame-pointer"),
1678 StringLiteral("target-features"),
1679 StringLiteral("unsafe-fp-math"),
1680 StringLiteral("no-infs-fp-math"),
1681 StringLiteral("no-nans-fp-math"),
1682 StringLiteral("approx-func-fp-math"),
1683 StringLiteral("no-signed-zeros-fp-math"),
1686 static void processPassthroughAttrs(llvm::Function
*func
, LLVMFuncOp funcOp
) {
1687 MLIRContext
*context
= funcOp
.getContext();
1688 SmallVector
<Attribute
> passthroughs
;
1689 llvm::AttributeSet funcAttrs
= func
->getAttributes().getAttributes(
1690 llvm::AttributeList::AttrIndex::FunctionIndex
);
1691 for (llvm::Attribute attr
: funcAttrs
) {
1692 // Skip the memory attribute since the LLVMFuncOp has an explicit memory
1694 if (attr
.hasAttribute(llvm::Attribute::Memory
))
1697 // Skip invalid type attributes.
1698 if (attr
.isTypeAttribute()) {
1699 emitWarning(funcOp
.getLoc(),
1700 "type attributes on a function are invalid, skipping it");
1705 if (attr
.isStringAttribute())
1706 attrName
= attr
.getKindAsString();
1708 attrName
= llvm::Attribute::getNameFromAttrKind(attr
.getKindAsEnum());
1709 auto keyAttr
= StringAttr::get(context
, attrName
);
1711 // Skip attributes that map to an explicit attribute on the LLVMFuncOp.
1712 if (llvm::is_contained(ExplicitAttributes
, attrName
))
1715 if (attr
.isStringAttribute()) {
1716 StringRef val
= attr
.getValueAsString();
1718 passthroughs
.push_back(keyAttr
);
1721 passthroughs
.push_back(
1722 ArrayAttr::get(context
, {keyAttr
, StringAttr::get(context
, val
)}));
1725 if (attr
.isIntAttribute()) {
1726 auto val
= std::to_string(attr
.getValueAsInt());
1727 passthroughs
.push_back(
1728 ArrayAttr::get(context
, {keyAttr
, StringAttr::get(context
, val
)}));
1731 if (attr
.isEnumAttribute()) {
1732 passthroughs
.push_back(keyAttr
);
1736 llvm_unreachable("unexpected attribute kind");
1739 if (!passthroughs
.empty())
1740 funcOp
.setPassthroughAttr(ArrayAttr::get(context
, passthroughs
));
1743 void ModuleImport::processFunctionAttributes(llvm::Function
*func
,
1744 LLVMFuncOp funcOp
) {
1745 processMemoryEffects(func
, funcOp
);
1746 processPassthroughAttrs(func
, funcOp
);
1748 if (func
->hasFnAttribute("aarch64_pstate_sm_enabled"))
1749 funcOp
.setArmStreaming(true);
1750 else if (func
->hasFnAttribute("aarch64_pstate_sm_body"))
1751 funcOp
.setArmLocallyStreaming(true);
1752 else if (func
->hasFnAttribute("aarch64_pstate_sm_compatible"))
1753 funcOp
.setArmStreamingCompatible(true);
1755 if (func
->hasFnAttribute("aarch64_new_za"))
1756 funcOp
.setArmNewZa(true);
1757 else if (func
->hasFnAttribute("aarch64_in_za"))
1758 funcOp
.setArmInZa(true);
1759 else if (func
->hasFnAttribute("aarch64_out_za"))
1760 funcOp
.setArmOutZa(true);
1761 else if (func
->hasFnAttribute("aarch64_inout_za"))
1762 funcOp
.setArmInoutZa(true);
1763 else if (func
->hasFnAttribute("aarch64_preserves_za"))
1764 funcOp
.setArmPreservesZa(true);
1766 llvm::Attribute attr
= func
->getFnAttribute(llvm::Attribute::VScaleRange
);
1767 if (attr
.isValid()) {
1768 MLIRContext
*context
= funcOp
.getContext();
1769 auto intTy
= IntegerType::get(context
, 32);
1770 funcOp
.setVscaleRangeAttr(LLVM::VScaleRangeAttr::get(
1771 context
, IntegerAttr::get(intTy
, attr
.getVScaleRangeMin()),
1772 IntegerAttr::get(intTy
, attr
.getVScaleRangeMax().value_or(0))));
1775 // Process frame-pointer attribute.
1776 if (func
->hasFnAttribute("frame-pointer")) {
1777 StringRef stringRefFramePointerKind
=
1778 func
->getFnAttribute("frame-pointer").getValueAsString();
1779 funcOp
.setFramePointerAttr(LLVM::FramePointerKindAttr::get(
1780 funcOp
.getContext(), LLVM::framePointerKind::symbolizeFramePointerKind(
1781 stringRefFramePointerKind
)
1785 if (llvm::Attribute attr
= func
->getFnAttribute("target-cpu");
1786 attr
.isStringAttribute())
1787 funcOp
.setTargetCpuAttr(StringAttr::get(context
, attr
.getValueAsString()));
1789 if (llvm::Attribute attr
= func
->getFnAttribute("target-features");
1790 attr
.isStringAttribute())
1791 funcOp
.setTargetFeaturesAttr(
1792 LLVM::TargetFeaturesAttr::get(context
, attr
.getValueAsString()));
1794 if (llvm::Attribute attr
= func
->getFnAttribute("unsafe-fp-math");
1795 attr
.isStringAttribute())
1796 funcOp
.setUnsafeFpMath(attr
.getValueAsBool());
1798 if (llvm::Attribute attr
= func
->getFnAttribute("no-infs-fp-math");
1799 attr
.isStringAttribute())
1800 funcOp
.setNoInfsFpMath(attr
.getValueAsBool());
1802 if (llvm::Attribute attr
= func
->getFnAttribute("no-nans-fp-math");
1803 attr
.isStringAttribute())
1804 funcOp
.setNoNansFpMath(attr
.getValueAsBool());
1806 if (llvm::Attribute attr
= func
->getFnAttribute("approx-func-fp-math");
1807 attr
.isStringAttribute())
1808 funcOp
.setApproxFuncFpMath(attr
.getValueAsBool());
1810 if (llvm::Attribute attr
= func
->getFnAttribute("no-signed-zeros-fp-math");
1811 attr
.isStringAttribute())
1812 funcOp
.setNoSignedZerosFpMath(attr
.getValueAsBool());
1816 ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs
,
1817 OpBuilder
&builder
) {
1818 SmallVector
<NamedAttribute
> paramAttrs
;
1819 for (auto [llvmKind
, mlirName
] : getAttrKindToNameMapping()) {
1820 auto llvmAttr
= llvmParamAttrs
.getAttribute(llvmKind
);
1821 // Skip attributes that are not attached.
1822 if (!llvmAttr
.isValid())
1825 if (llvmAttr
.isTypeAttribute())
1826 mlirAttr
= TypeAttr::get(convertType(llvmAttr
.getValueAsType()));
1827 else if (llvmAttr
.isIntAttribute())
1828 mlirAttr
= builder
.getI64IntegerAttr(llvmAttr
.getValueAsInt());
1829 else if (llvmAttr
.isEnumAttribute())
1830 mlirAttr
= builder
.getUnitAttr();
1832 llvm_unreachable("unexpected parameter attribute kind");
1833 paramAttrs
.push_back(builder
.getNamedAttr(mlirName
, mlirAttr
));
1836 return builder
.getDictionaryAttr(paramAttrs
);
1839 void ModuleImport::convertParameterAttributes(llvm::Function
*func
,
1841 OpBuilder
&builder
) {
1842 auto llvmAttrs
= func
->getAttributes();
1843 for (size_t i
= 0, e
= funcOp
.getNumArguments(); i
< e
; ++i
) {
1844 llvm::AttributeSet llvmArgAttrs
= llvmAttrs
.getParamAttrs(i
);
1845 funcOp
.setArgAttrs(i
, convertParameterAttribute(llvmArgAttrs
, builder
));
1847 // Convert the result attributes and attach them wrapped in an ArrayAttribute
1849 llvm::AttributeSet llvmResAttr
= llvmAttrs
.getRetAttrs();
1850 if (!llvmResAttr
.hasAttributes())
1852 funcOp
.setResAttrsAttr(
1853 builder
.getArrayAttr(convertParameterAttribute(llvmResAttr
, builder
)));
1856 LogicalResult
ModuleImport::processFunction(llvm::Function
*func
) {
1860 dyn_cast
<LLVMFunctionType
>(convertType(func
->getFunctionType()));
1861 if (func
->isIntrinsic() &&
1862 iface
.isConvertibleIntrinsic(func
->getIntrinsicID()))
1865 bool dsoLocal
= func
->hasLocalLinkage();
1866 CConv cconv
= convertCConvFromLLVM(func
->getCallingConv());
1868 // Insert the function at the end of the module.
1869 OpBuilder::InsertionGuard
guard(builder
);
1870 builder
.setInsertionPoint(mlirModule
.getBody(), mlirModule
.getBody()->end());
1872 Location loc
= debugImporter
->translateFuncLocation(func
);
1873 LLVMFuncOp funcOp
= builder
.create
<LLVMFuncOp
>(
1874 loc
, func
->getName(), functionType
,
1875 convertLinkageFromLLVM(func
->getLinkage()), dsoLocal
, cconv
);
1877 convertParameterAttributes(func
, funcOp
, builder
);
1879 if (FlatSymbolRefAttr personality
= getPersonalityAsAttr(func
))
1880 funcOp
.setPersonalityAttr(personality
);
1881 else if (func
->hasPersonalityFn())
1882 emitWarning(funcOp
.getLoc(), "could not deduce personality, skipping it");
1885 funcOp
.setGarbageCollector(StringRef(func
->getGC()));
1887 if (func
->hasAtLeastLocalUnnamedAddr())
1888 funcOp
.setUnnamedAddr(convertUnnamedAddrFromLLVM(func
->getUnnamedAddr()));
1890 if (func
->hasSection())
1891 funcOp
.setSection(StringRef(func
->getSection()));
1893 funcOp
.setVisibility_(convertVisibilityFromLLVM(func
->getVisibility()));
1895 if (func
->hasComdat())
1896 funcOp
.setComdatAttr(comdatMapping
.lookup(func
->getComdat()));
1898 if (llvm::MaybeAlign maybeAlign
= func
->getAlign())
1899 funcOp
.setAlignment(maybeAlign
->value());
1901 // Handle Function attributes.
1902 processFunctionAttributes(func
, funcOp
);
1904 // Convert non-debug metadata by using the dialect interface.
1905 SmallVector
<std::pair
<unsigned, llvm::MDNode
*>> allMetadata
;
1906 func
->getAllMetadata(allMetadata
);
1907 for (auto &[kind
, node
] : allMetadata
) {
1908 if (!iface
.isConvertibleMetadata(kind
))
1910 if (failed(iface
.setMetadataAttrs(builder
, kind
, node
, funcOp
, *this))) {
1911 emitWarning(funcOp
.getLoc())
1912 << "unhandled function metadata: " << diagMD(node
, llvmModule
.get())
1913 << " on " << diag(*func
);
1917 if (func
->isDeclaration())
1920 // Collect the set of basic blocks reachable from the function's entry block.
1921 // This step is crucial as LLVM IR can contain unreachable blocks that
1922 // self-dominate. As a result, an operation might utilize a variable it
1923 // defines, which the import does not support. Given that MLIR lacks block
1924 // label support, we can safely remove unreachable blocks, as there are no
1925 // indirect branch instructions that could potentially target these blocks.
1926 llvm::df_iterator_default_set
<llvm::BasicBlock
*> reachable
;
1927 for (llvm::BasicBlock
*basicBlock
: llvm::depth_first_ext(func
, reachable
))
1930 // Eagerly create all reachable blocks.
1931 SmallVector
<llvm::BasicBlock
*> reachableBasicBlocks
;
1932 for (llvm::BasicBlock
&basicBlock
: *func
) {
1933 // Skip unreachable blocks.
1934 if (!reachable
.contains(&basicBlock
))
1936 Region
&body
= funcOp
.getBody();
1937 Block
*block
= builder
.createBlock(&body
, body
.end());
1938 mapBlock(&basicBlock
, block
);
1939 reachableBasicBlocks
.push_back(&basicBlock
);
1942 // Add function arguments to the entry block.
1943 for (const auto &it
: llvm::enumerate(func
->args())) {
1944 BlockArgument blockArg
= funcOp
.getFunctionBody().addArgument(
1945 functionType
.getParamType(it
.index()), funcOp
.getLoc());
1946 mapValue(&it
.value(), blockArg
);
1949 // Process the blocks in topological order. The ordered traversal ensures
1950 // operands defined in a dominating block have a valid mapping to an MLIR
1951 // value once a block is translated.
1952 SetVector
<llvm::BasicBlock
*> blocks
=
1953 getTopologicallySortedBlocks(reachableBasicBlocks
);
1954 setConstantInsertionPointToStart(lookupBlock(blocks
.front()));
1955 for (llvm::BasicBlock
*basicBlock
: blocks
)
1956 if (failed(processBasicBlock(basicBlock
, lookupBlock(basicBlock
))))
1959 // Process the debug intrinsics that require a delayed conversion after
1960 // everything else was converted.
1961 if (failed(processDebugIntrinsics()))
1967 /// Checks if `dbgIntr` is a kill location that holds metadata instead of an SSA
1969 static bool isMetadataKillLocation(llvm::DbgVariableIntrinsic
*dbgIntr
) {
1970 if (!dbgIntr
->isKillLocation())
1972 llvm::Value
*value
= dbgIntr
->getArgOperand(0);
1973 auto *nodeAsVal
= dyn_cast
<llvm::MetadataAsValue
>(value
);
1976 return !isa
<llvm::ValueAsMetadata
>(nodeAsVal
->getMetadata());
1980 ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic
*dbgIntr
,
1981 DominanceInfo
&domInfo
) {
1982 Location loc
= translateLoc(dbgIntr
->getDebugLoc());
1983 auto emitUnsupportedWarning
= [&]() {
1984 if (emitExpensiveWarnings
)
1985 emitWarning(loc
) << "dropped intrinsic: " << diag(*dbgIntr
);
1988 // Drop debug intrinsics with arg lists.
1989 // TODO: Support debug intrinsics that have arg lists.
1990 if (dbgIntr
->hasArgList())
1991 return emitUnsupportedWarning();
1992 // Kill locations can have metadata nodes as location operand. This
1993 // cannot be converted to poison as the type cannot be reconstructed.
1994 // TODO: find a way to support this case.
1995 if (isMetadataKillLocation(dbgIntr
))
1996 return emitUnsupportedWarning();
1997 // Drop debug intrinsics if the associated variable information cannot be
1998 // translated due to cyclic debug metadata.
1999 // TODO: Support cyclic debug metadata.
2000 DILocalVariableAttr localVariableAttr
=
2001 matchLocalVariableAttr(dbgIntr
->getArgOperand(1));
2002 if (!localVariableAttr
)
2003 return emitUnsupportedWarning();
2004 FailureOr
<Value
> argOperand
= convertMetadataValue(dbgIntr
->getArgOperand(0));
2005 if (failed(argOperand
))
2006 return emitError(loc
) << "failed to convert a debug intrinsic operand: "
2009 // Ensure that the debug instrinsic is inserted right after its operand is
2010 // defined. Otherwise, the operand might not necessarily dominate the
2011 // intrinsic. If the defining operation is a terminator, insert the intrinsic
2012 // into a dominated block.
2013 OpBuilder::InsertionGuard
guard(builder
);
2014 if (Operation
*op
= argOperand
->getDefiningOp();
2015 op
&& op
->hasTrait
<OpTrait::IsTerminator
>()) {
2016 // Find a dominated block that can hold the debug intrinsic.
2017 auto dominatedBlocks
= domInfo
.getNode(op
->getBlock())->children();
2018 // If no block is dominated by the terminator, this intrinisc cannot be
2020 if (dominatedBlocks
.empty())
2021 return emitUnsupportedWarning();
2022 // Set insertion point before the terminator, to avoid inserting something
2023 // before landingpads.
2024 Block
*dominatedBlock
= (*dominatedBlocks
.begin())->getBlock();
2025 builder
.setInsertionPoint(dominatedBlock
->getTerminator());
2027 builder
.setInsertionPointAfterValue(*argOperand
);
2029 auto locationExprAttr
=
2030 debugImporter
->translateExpression(dbgIntr
->getExpression());
2032 llvm::TypeSwitch
<llvm::DbgVariableIntrinsic
*, Operation
*>(dbgIntr
)
2033 .Case([&](llvm::DbgDeclareInst
*) {
2034 return builder
.create
<LLVM::DbgDeclareOp
>(
2035 loc
, *argOperand
, localVariableAttr
, locationExprAttr
);
2037 .Case([&](llvm::DbgValueInst
*) {
2038 return builder
.create
<LLVM::DbgValueOp
>(
2039 loc
, *argOperand
, localVariableAttr
, locationExprAttr
);
2041 mapNoResultOp(dbgIntr
, op
);
2042 setNonDebugMetadataAttrs(dbgIntr
, op
);
2046 LogicalResult
ModuleImport::processDebugIntrinsics() {
2047 DominanceInfo domInfo
;
2048 for (llvm::Instruction
*inst
: debugIntrinsics
) {
2049 auto *intrCall
= cast
<llvm::DbgVariableIntrinsic
>(inst
);
2050 if (failed(processDebugIntrinsic(intrCall
, domInfo
)))
2056 LogicalResult
ModuleImport::processBasicBlock(llvm::BasicBlock
*bb
,
2058 builder
.setInsertionPointToStart(block
);
2059 for (llvm::Instruction
&inst
: *bb
) {
2060 if (failed(processInstruction(&inst
)))
2063 // Skip additional processing when the instructions is a debug intrinsics
2064 // that was not yet converted.
2065 if (debugIntrinsics
.contains(&inst
))
2068 // Set the non-debug metadata attributes on the imported operation and emit
2069 // a warning if an instruction other than a phi instruction is dropped
2070 // during the import.
2071 if (Operation
*op
= lookupOperation(&inst
)) {
2072 setNonDebugMetadataAttrs(&inst
, op
);
2073 } else if (inst
.getOpcode() != llvm::Instruction::PHI
) {
2074 if (emitExpensiveWarnings
) {
2075 Location loc
= debugImporter
->translateLoc(inst
.getDebugLoc());
2076 emitWarning(loc
) << "dropped instruction: " << diag(inst
);
2083 FailureOr
<SmallVector
<AccessGroupAttr
>>
2084 ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode
*node
) const {
2085 return loopAnnotationImporter
->lookupAccessGroupAttrs(node
);
2089 ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode
*node
,
2090 Location loc
) const {
2091 return loopAnnotationImporter
->translateLoopAnnotation(node
, loc
);
2094 OwningOpRef
<ModuleOp
>
2095 mlir::translateLLVMIRToModule(std::unique_ptr
<llvm::Module
> llvmModule
,
2096 MLIRContext
*context
, bool emitExpensiveWarnings
,
2097 bool dropDICompositeTypeElements
) {
2098 // Preload all registered dialects to allow the import to iterate the
2099 // registered LLVMImportDialectInterface implementations and query the
2100 // supported LLVM IR constructs before starting the translation. Assumes the
2101 // LLVM and DLTI dialects that convert the core LLVM IR constructs have been
2102 // registered before.
2103 assert(llvm::is_contained(context
->getAvailableDialects(),
2104 LLVMDialect::getDialectNamespace()));
2105 assert(llvm::is_contained(context
->getAvailableDialects(),
2106 DLTIDialect::getDialectNamespace()));
2107 context
->loadAllAvailableDialects();
2108 OwningOpRef
<ModuleOp
> module(ModuleOp::create(FileLineColLoc::get(
2109 StringAttr::get(context
, llvmModule
->getSourceFileName()), /*line=*/0,
2112 ModuleImport
moduleImport(module
.get(), std::move(llvmModule
),
2113 emitExpensiveWarnings
, dropDICompositeTypeElements
);
2114 if (failed(moduleImport
.initializeImportInterface()))
2116 if (failed(moduleImport
.convertDataLayout()))
2118 if (failed(moduleImport
.convertComdats()))
2120 if (failed(moduleImport
.convertMetadata()))
2122 if (failed(moduleImport
.convertGlobals()))
2124 if (failed(moduleImport
.convertFunctions()))