[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / Target / LLVMIR / ModuleImport.cpp
blob191b84acd56fae7c986ba829e4ed267f4697b8b8
1 //===- ModuleImport.cpp - LLVM to MLIR conversion ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the import of an LLVM IR module into an LLVM dialect
10 // module.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/ModuleImport.h"
15 #include "mlir/Target/LLVMIR/Import.h"
17 #include "AttrKindDetail.h"
18 #include "DataLayoutImporter.h"
19 #include "DebugImporter.h"
20 #include "LoopAnnotationImporter.h"
22 #include "mlir/Dialect/DLTI/DLTI.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/Interfaces/DataLayoutInterfaces.h"
27 #include "mlir/Tools/mlir-translate/Translation.h"
29 #include "llvm/ADT/DepthFirstIterator.h"
30 #include "llvm/ADT/PostOrderIterator.h"
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/StringSet.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/IR/Comdat.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/InlineAsm.h"
37 #include "llvm/IR/InstIterator.h"
38 #include "llvm/IR/Instructions.h"
39 #include "llvm/IR/IntrinsicInst.h"
40 #include "llvm/IR/Metadata.h"
41 #include "llvm/IR/Operator.h"
42 #include "llvm/Support/ModRef.h"
44 using namespace mlir;
45 using namespace mlir::LLVM;
46 using namespace mlir::LLVM::detail;
48 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
50 // Utility to print an LLVM value as a string for passing to emitError().
51 // FIXME: Diagnostic should be able to natively handle types that have
52 // operator << (raw_ostream&) defined.
53 static std::string diag(const llvm::Value &value) {
54 std::string str;
55 llvm::raw_string_ostream os(str);
56 os << value;
57 return os.str();
60 // Utility to print an LLVM metadata node as a string for passing
61 // to emitError(). The module argument is needed to print the nodes
62 // canonically numbered.
63 static std::string diagMD(const llvm::Metadata *node,
64 const llvm::Module *module) {
65 std::string str;
66 llvm::raw_string_ostream os(str);
67 node->print(os, module, /*IsForDebug=*/true);
68 return os.str();
71 /// Returns the name of the global_ctors global variables.
72 static constexpr StringRef getGlobalCtorsVarName() {
73 return "llvm.global_ctors";
76 /// Returns the name of the global_dtors global variables.
77 static constexpr StringRef getGlobalDtorsVarName() {
78 return "llvm.global_dtors";
81 /// Returns the symbol name for the module-level comdat operation. It must not
82 /// conflict with the user namespace.
83 static constexpr StringRef getGlobalComdatOpName() {
84 return "__llvm_global_comdat";
87 /// Converts the sync scope identifier of `inst` to the string representation
88 /// necessary to build an atomic LLVM dialect operation. Returns the empty
89 /// string if the operation has either no sync scope or the default system-level
90 /// sync scope attached. The atomic operations only set their sync scope
91 /// attribute if they have a non-default sync scope attached.
92 static StringRef getLLVMSyncScope(llvm::Instruction *inst) {
93 std::optional<llvm::SyncScope::ID> syncScopeID =
94 llvm::getAtomicSyncScopeID(inst);
95 if (!syncScopeID)
96 return "";
98 // Search the sync scope name for the given identifier. The default
99 // system-level sync scope thereby maps to the empty string.
100 SmallVector<StringRef> syncScopeName;
101 llvm::LLVMContext &llvmContext = inst->getContext();
102 llvmContext.getSyncScopeNames(syncScopeName);
103 auto *it = llvm::find_if(syncScopeName, [&](StringRef name) {
104 return *syncScopeID == llvmContext.getOrInsertSyncScopeID(name);
106 if (it != syncScopeName.end())
107 return *it;
108 llvm_unreachable("incorrect sync scope identifier");
111 /// Converts an array of unsigned indices to a signed integer position array.
112 static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) {
113 SmallVector<int64_t> position;
114 llvm::append_range(position, indices);
115 return position;
118 /// Converts the LLVM instructions that have a generated MLIR builder. Using a
119 /// static implementation method called from the module import ensures the
120 /// builders have to use the `moduleImport` argument and cannot directly call
121 /// import methods. As a result, both the intrinsic and the instruction MLIR
122 /// builders have to use the `moduleImport` argument and none of them has direct
123 /// access to the private module import methods.
124 static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
125 llvm::Instruction *inst,
126 ModuleImport &moduleImport,
127 LLVMImportInterface &iface) {
128 // Copy the operands to an LLVM operands array reference for conversion.
129 SmallVector<llvm::Value *> operands(inst->operands());
130 ArrayRef<llvm::Value *> llvmOperands(operands);
132 // Convert all instructions that provide an MLIR builder.
133 if (iface.isConvertibleInstruction(inst->getOpcode()))
134 return iface.convertInstruction(odsBuilder, inst, llvmOperands,
135 moduleImport);
136 // TODO: Implement the `convertInstruction` hooks in the
137 // `LLVMDialectLLVMIRImportInterface` and move the following include there.
138 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
139 return failure();
142 /// Get a topologically sorted list of blocks for the given basic blocks.
143 static SetVector<llvm::BasicBlock *>
144 getTopologicallySortedBlocks(ArrayRef<llvm::BasicBlock *> basicBlocks) {
145 SetVector<llvm::BasicBlock *> blocks;
146 for (llvm::BasicBlock *basicBlock : basicBlocks) {
147 if (!blocks.contains(basicBlock)) {
148 llvm::ReversePostOrderTraversal<llvm::BasicBlock *> traversal(basicBlock);
149 blocks.insert(traversal.begin(), traversal.end());
152 assert(blocks.size() == basicBlocks.size() && "some blocks are not sorted");
153 return blocks;
156 ModuleImport::ModuleImport(ModuleOp mlirModule,
157 std::unique_ptr<llvm::Module> llvmModule,
158 bool emitExpensiveWarnings,
159 bool importEmptyDICompositeTypes)
160 : builder(mlirModule->getContext()), context(mlirModule->getContext()),
161 mlirModule(mlirModule), llvmModule(std::move(llvmModule)),
162 iface(mlirModule->getContext()),
163 typeTranslator(*mlirModule->getContext()),
164 debugImporter(std::make_unique<DebugImporter>(
165 mlirModule, importEmptyDICompositeTypes)),
166 loopAnnotationImporter(
167 std::make_unique<LoopAnnotationImporter>(*this, builder)),
168 emitExpensiveWarnings(emitExpensiveWarnings) {
169 builder.setInsertionPointToStart(mlirModule.getBody());
172 ComdatOp ModuleImport::getGlobalComdatOp() {
173 if (globalComdatOp)
174 return globalComdatOp;
176 OpBuilder::InsertionGuard guard(builder);
177 builder.setInsertionPointToEnd(mlirModule.getBody());
178 globalComdatOp =
179 builder.create<ComdatOp>(mlirModule.getLoc(), getGlobalComdatOpName());
180 globalInsertionOp = globalComdatOp;
181 return globalComdatOp;
184 LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) {
185 Location loc = mlirModule.getLoc();
187 // If `node` is a valid TBAA root node, then return its optional identity
188 // string, otherwise return failure.
189 auto getIdentityIfRootNode =
190 [&](const llvm::MDNode *node) -> FailureOr<std::optional<StringRef>> {
191 // Root node, e.g.:
192 // !0 = !{!"Simple C/C++ TBAA"}
193 // !1 = !{}
194 if (node->getNumOperands() > 1)
195 return failure();
196 // If the operand is MDString, then assume that this is a root node.
197 if (node->getNumOperands() == 1)
198 if (const auto *op0 = dyn_cast<const llvm::MDString>(node->getOperand(0)))
199 return std::optional<StringRef>{op0->getString()};
200 return std::optional<StringRef>{};
203 // If `node` looks like a TBAA type descriptor metadata,
204 // then return true, if it is a valid node, and false otherwise.
205 // If it does not look like a TBAA type descriptor metadata, then
206 // return std::nullopt.
207 // If `identity` and `memberTypes/Offsets` are non-null, then they will
208 // contain the converted metadata operands for a valid TBAA node (i.e. when
209 // true is returned).
210 auto isTypeDescriptorNode = [&](const llvm::MDNode *node,
211 StringRef *identity = nullptr,
212 SmallVectorImpl<TBAAMemberAttr> *members =
213 nullptr) -> std::optional<bool> {
214 unsigned numOperands = node->getNumOperands();
215 // Type descriptor, e.g.:
216 // !1 = !{!"int", !0, /*optional*/i64 0} /* scalar int type */
217 // !2 = !{!"agg_t", !1, i64 0} /* struct agg_t { int x; } */
218 if (numOperands < 2)
219 return std::nullopt;
221 // TODO: support "new" format (D41501) for type descriptors,
222 // where the first operand is an MDNode.
223 const auto *identityNode =
224 dyn_cast<const llvm::MDString>(node->getOperand(0));
225 if (!identityNode)
226 return std::nullopt;
228 // This should be a type descriptor node.
229 if (identity)
230 *identity = identityNode->getString();
232 for (unsigned pairNum = 0, e = numOperands / 2; pairNum < e; ++pairNum) {
233 const auto *memberNode =
234 dyn_cast<const llvm::MDNode>(node->getOperand(2 * pairNum + 1));
235 if (!memberNode) {
236 emitError(loc) << "operand '" << 2 * pairNum + 1 << "' must be MDNode: "
237 << diagMD(node, llvmModule.get());
238 return false;
240 int64_t offset = 0;
241 if (2 * pairNum + 2 >= numOperands) {
242 // Allow for optional 0 offset in 2-operand nodes.
243 if (numOperands != 2) {
244 emitError(loc) << "missing member offset: "
245 << diagMD(node, llvmModule.get());
246 return false;
248 } else {
249 auto *offsetCI = llvm::mdconst::dyn_extract<llvm::ConstantInt>(
250 node->getOperand(2 * pairNum + 2));
251 if (!offsetCI) {
252 emitError(loc) << "operand '" << 2 * pairNum + 2
253 << "' must be ConstantInt: "
254 << diagMD(node, llvmModule.get());
255 return false;
257 offset = offsetCI->getZExtValue();
260 if (members)
261 members->push_back(TBAAMemberAttr::get(
262 cast<TBAANodeAttr>(tbaaMapping.lookup(memberNode)), offset));
265 return true;
268 // If `node` looks like a TBAA access tag metadata,
269 // then return true, if it is a valid node, and false otherwise.
270 // If it does not look like a TBAA access tag metadata, then
271 // return std::nullopt.
272 // If the other arguments are non-null, then they will contain
273 // the converted metadata operands for a valid TBAA node (i.e. when true is
274 // returned).
275 auto isTagNode = [&](const llvm::MDNode *node,
276 TBAATypeDescriptorAttr *baseAttr = nullptr,
277 TBAATypeDescriptorAttr *accessAttr = nullptr,
278 int64_t *offset = nullptr,
279 bool *isConstant = nullptr) -> std::optional<bool> {
280 // Access tag, e.g.:
281 // !3 = !{!1, !1, i64 0} /* scalar int access */
282 // !4 = !{!2, !1, i64 0} /* agg_t::x access */
284 // Optional 4th argument is ConstantInt 0/1 identifying whether
285 // the location being accessed is "constant" (see for details:
286 // https://llvm.org/docs/LangRef.html#representation).
287 unsigned numOperands = node->getNumOperands();
288 if (numOperands != 3 && numOperands != 4)
289 return std::nullopt;
290 const auto *baseMD = dyn_cast<const llvm::MDNode>(node->getOperand(0));
291 const auto *accessMD = dyn_cast<const llvm::MDNode>(node->getOperand(1));
292 auto *offsetCI =
293 llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(2));
294 if (!baseMD || !accessMD || !offsetCI)
295 return std::nullopt;
296 // TODO: support "new" TBAA format, if needed (see D41501).
297 // In the "old" format the first operand of the access type
298 // metadata is MDString. We have to distinguish the formats,
299 // because access tags have the same structure, but different
300 // meaning for the operands.
301 if (accessMD->getNumOperands() < 1 ||
302 !isa<llvm::MDString>(accessMD->getOperand(0)))
303 return std::nullopt;
304 bool isConst = false;
305 if (numOperands == 4) {
306 auto *isConstantCI =
307 llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(3));
308 if (!isConstantCI) {
309 emitError(loc) << "operand '3' must be ConstantInt: "
310 << diagMD(node, llvmModule.get());
311 return false;
313 isConst = isConstantCI->getValue()[0];
315 if (baseAttr)
316 *baseAttr = cast<TBAATypeDescriptorAttr>(tbaaMapping.lookup(baseMD));
317 if (accessAttr)
318 *accessAttr = cast<TBAATypeDescriptorAttr>(tbaaMapping.lookup(accessMD));
319 if (offset)
320 *offset = offsetCI->getZExtValue();
321 if (isConstant)
322 *isConstant = isConst;
323 return true;
326 // Do a post-order walk over the TBAA Graph. Since a correct TBAA Graph is a
327 // DAG, a post-order walk guarantees that we convert any metadata node we
328 // depend on, prior to converting the current node.
329 DenseSet<const llvm::MDNode *> seen;
330 SmallVector<const llvm::MDNode *> workList;
331 workList.push_back(node);
332 while (!workList.empty()) {
333 const llvm::MDNode *current = workList.back();
334 if (tbaaMapping.contains(current)) {
335 // Already converted. Just pop from the worklist.
336 workList.pop_back();
337 continue;
340 // If any child of this node is not yet converted, don't pop the current
341 // node from the worklist but push the not-yet-converted children in the
342 // front of the worklist.
343 bool anyChildNotConverted = false;
344 for (const llvm::MDOperand &operand : current->operands())
345 if (auto *childNode = dyn_cast_or_null<const llvm::MDNode>(operand.get()))
346 if (!tbaaMapping.contains(childNode)) {
347 workList.push_back(childNode);
348 anyChildNotConverted = true;
351 if (anyChildNotConverted) {
352 // If this is the second time we failed to convert an element in the
353 // worklist it must be because a child is dependent on it being converted
354 // and we have a cycle in the graph. Cycles are not allowed in TBAA
355 // graphs.
356 if (!seen.insert(current).second)
357 return emitError(loc) << "has cycle in TBAA graph: "
358 << diagMD(current, llvmModule.get());
360 continue;
363 // Otherwise simply import the current node.
364 workList.pop_back();
366 FailureOr<std::optional<StringRef>> rootNodeIdentity =
367 getIdentityIfRootNode(current);
368 if (succeeded(rootNodeIdentity)) {
369 StringAttr stringAttr = *rootNodeIdentity
370 ? builder.getStringAttr(**rootNodeIdentity)
371 : nullptr;
372 // The root nodes do not have operands, so we can create
373 // the TBAARootAttr on the first walk.
374 tbaaMapping.insert({current, builder.getAttr<TBAARootAttr>(stringAttr)});
375 continue;
378 StringRef identity;
379 SmallVector<TBAAMemberAttr> members;
380 if (std::optional<bool> isValid =
381 isTypeDescriptorNode(current, &identity, &members)) {
382 assert(isValid.value() && "type descriptor node must be valid");
384 tbaaMapping.insert({current, builder.getAttr<TBAATypeDescriptorAttr>(
385 identity, members)});
386 continue;
389 TBAATypeDescriptorAttr baseAttr, accessAttr;
390 int64_t offset;
391 bool isConstant;
392 if (std::optional<bool> isValid =
393 isTagNode(current, &baseAttr, &accessAttr, &offset, &isConstant)) {
394 assert(isValid.value() && "access tag node must be valid");
395 tbaaMapping.insert(
396 {current, builder.getAttr<TBAATagAttr>(baseAttr, accessAttr, offset,
397 isConstant)});
398 continue;
401 return emitError(loc) << "unsupported TBAA node format: "
402 << diagMD(current, llvmModule.get());
404 return success();
407 LogicalResult
408 ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) {
409 Location loc = mlirModule.getLoc();
410 if (failed(loopAnnotationImporter->translateAccessGroup(node, loc)))
411 return emitError(loc) << "unsupported access group node: "
412 << diagMD(node, llvmModule.get());
413 return success();
416 LogicalResult
417 ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) {
418 Location loc = mlirModule.getLoc();
419 // Helper that verifies the node has a self reference operand.
420 auto verifySelfRef = [](const llvm::MDNode *node) {
421 return node->getNumOperands() != 0 &&
422 node == dyn_cast<llvm::MDNode>(node->getOperand(0));
424 // Helper that verifies the given operand is a string or does not exist.
425 auto verifyDescription = [](const llvm::MDNode *node, unsigned idx) {
426 return idx >= node->getNumOperands() ||
427 isa<llvm::MDString>(node->getOperand(idx));
429 // Helper that creates an alias scope domain attribute.
430 auto createAliasScopeDomainOp = [&](const llvm::MDNode *aliasDomain) {
431 StringAttr description = nullptr;
432 if (aliasDomain->getNumOperands() >= 2)
433 if (auto *operand = dyn_cast<llvm::MDString>(aliasDomain->getOperand(1)))
434 description = builder.getStringAttr(operand->getString());
435 return builder.getAttr<AliasScopeDomainAttr>(
436 DistinctAttr::create(builder.getUnitAttr()), description);
439 // Collect the alias scopes and domains to translate them.
440 for (const llvm::MDOperand &operand : node->operands()) {
441 if (const auto *scope = dyn_cast<llvm::MDNode>(operand)) {
442 llvm::AliasScopeNode aliasScope(scope);
443 const llvm::MDNode *domain = aliasScope.getDomain();
445 // Verify the scope node points to valid scope metadata which includes
446 // verifying its domain. Perform the verification before looking it up in
447 // the alias scope mapping since it could have been inserted as a domain
448 // node before.
449 if (!verifySelfRef(scope) || !domain || !verifyDescription(scope, 2))
450 return emitError(loc) << "unsupported alias scope node: "
451 << diagMD(scope, llvmModule.get());
452 if (!verifySelfRef(domain) || !verifyDescription(domain, 1))
453 return emitError(loc) << "unsupported alias domain node: "
454 << diagMD(domain, llvmModule.get());
456 if (aliasScopeMapping.contains(scope))
457 continue;
459 // Convert the domain metadata node if it has not been translated before.
460 auto it = aliasScopeMapping.find(aliasScope.getDomain());
461 if (it == aliasScopeMapping.end()) {
462 auto aliasScopeDomainOp = createAliasScopeDomainOp(domain);
463 it = aliasScopeMapping.try_emplace(domain, aliasScopeDomainOp).first;
466 // Convert the scope metadata node if it has not been converted before.
467 StringAttr description = nullptr;
468 if (!aliasScope.getName().empty())
469 description = builder.getStringAttr(aliasScope.getName());
470 auto aliasScopeOp = builder.getAttr<AliasScopeAttr>(
471 DistinctAttr::create(builder.getUnitAttr()),
472 cast<AliasScopeDomainAttr>(it->second), description);
473 aliasScopeMapping.try_emplace(aliasScope.getNode(), aliasScopeOp);
476 return success();
479 FailureOr<SmallVector<AliasScopeAttr>>
480 ModuleImport::lookupAliasScopeAttrs(const llvm::MDNode *node) const {
481 SmallVector<AliasScopeAttr> aliasScopes;
482 aliasScopes.reserve(node->getNumOperands());
483 for (const llvm::MDOperand &operand : node->operands()) {
484 auto *node = cast<llvm::MDNode>(operand.get());
485 aliasScopes.push_back(
486 dyn_cast_or_null<AliasScopeAttr>(aliasScopeMapping.lookup(node)));
488 // Return failure if one of the alias scope lookups failed.
489 if (llvm::is_contained(aliasScopes, nullptr))
490 return failure();
491 return aliasScopes;
494 void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) {
495 debugIntrinsics.insert(intrinsic);
498 LogicalResult ModuleImport::convertLinkerOptionsMetadata() {
499 for (const llvm::NamedMDNode &named : llvmModule->named_metadata()) {
500 if (named.getName() != "llvm.linker.options")
501 continue;
502 // llvm.linker.options operands are lists of strings.
503 for (const llvm::MDNode *md : named.operands()) {
504 SmallVector<StringRef> options;
505 options.reserve(md->getNumOperands());
506 for (const llvm::MDOperand &option : md->operands())
507 options.push_back(cast<llvm::MDString>(option)->getString());
508 builder.create<LLVM::LinkerOptionsOp>(mlirModule.getLoc(),
509 builder.getStrArrayAttr(options));
512 return success();
515 LogicalResult ModuleImport::convertMetadata() {
516 OpBuilder::InsertionGuard guard(builder);
517 builder.setInsertionPointToEnd(mlirModule.getBody());
518 for (const llvm::Function &func : llvmModule->functions()) {
519 for (const llvm::Instruction &inst : llvm::instructions(func)) {
520 // Convert access group metadata nodes.
521 if (llvm::MDNode *node =
522 inst.getMetadata(llvm::LLVMContext::MD_access_group))
523 if (failed(processAccessGroupMetadata(node)))
524 return failure();
526 // Convert alias analysis metadata nodes.
527 llvm::AAMDNodes aliasAnalysisNodes = inst.getAAMetadata();
528 if (!aliasAnalysisNodes)
529 continue;
530 if (aliasAnalysisNodes.TBAA)
531 if (failed(processTBAAMetadata(aliasAnalysisNodes.TBAA)))
532 return failure();
533 if (aliasAnalysisNodes.Scope)
534 if (failed(processAliasScopeMetadata(aliasAnalysisNodes.Scope)))
535 return failure();
536 if (aliasAnalysisNodes.NoAlias)
537 if (failed(processAliasScopeMetadata(aliasAnalysisNodes.NoAlias)))
538 return failure();
541 if (failed(convertLinkerOptionsMetadata()))
542 return failure();
543 return success();
546 void ModuleImport::processComdat(const llvm::Comdat *comdat) {
547 if (comdatMapping.contains(comdat))
548 return;
550 ComdatOp comdatOp = getGlobalComdatOp();
551 OpBuilder::InsertionGuard guard(builder);
552 builder.setInsertionPointToEnd(&comdatOp.getBody().back());
553 auto selectorOp = builder.create<ComdatSelectorOp>(
554 mlirModule.getLoc(), comdat->getName(),
555 convertComdatFromLLVM(comdat->getSelectionKind()));
556 auto symbolRef =
557 SymbolRefAttr::get(builder.getContext(), getGlobalComdatOpName(),
558 FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()));
559 comdatMapping.try_emplace(comdat, symbolRef);
562 LogicalResult ModuleImport::convertComdats() {
563 for (llvm::GlobalVariable &globalVar : llvmModule->globals())
564 if (globalVar.hasComdat())
565 processComdat(globalVar.getComdat());
566 for (llvm::Function &func : llvmModule->functions())
567 if (func.hasComdat())
568 processComdat(func.getComdat());
569 return success();
572 LogicalResult ModuleImport::convertGlobals() {
573 for (llvm::GlobalVariable &globalVar : llvmModule->globals()) {
574 if (globalVar.getName() == getGlobalCtorsVarName() ||
575 globalVar.getName() == getGlobalDtorsVarName()) {
576 if (failed(convertGlobalCtorsAndDtors(&globalVar))) {
577 return emitError(UnknownLoc::get(context))
578 << "unhandled global variable: " << diag(globalVar);
580 continue;
582 if (failed(convertGlobal(&globalVar))) {
583 return emitError(UnknownLoc::get(context))
584 << "unhandled global variable: " << diag(globalVar);
587 return success();
590 LogicalResult ModuleImport::convertDataLayout() {
591 Location loc = mlirModule.getLoc();
592 DataLayoutImporter dataLayoutImporter(context, llvmModule->getDataLayout());
593 if (!dataLayoutImporter.getDataLayout())
594 return emitError(loc, "cannot translate data layout: ")
595 << dataLayoutImporter.getLastToken();
597 for (StringRef token : dataLayoutImporter.getUnhandledTokens())
598 emitWarning(loc, "unhandled data layout token: ") << token;
600 mlirModule->setAttr(DLTIDialect::kDataLayoutAttrName,
601 dataLayoutImporter.getDataLayout());
602 return success();
605 LogicalResult ModuleImport::convertFunctions() {
606 for (llvm::Function &func : llvmModule->functions())
607 if (failed(processFunction(&func)))
608 return failure();
609 return success();
612 void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
613 Operation *op) {
614 SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
615 inst->getAllMetadataOtherThanDebugLoc(allMetadata);
616 for (auto &[kind, node] : allMetadata) {
617 if (!iface.isConvertibleMetadata(kind))
618 continue;
619 if (failed(iface.setMetadataAttrs(builder, kind, node, op, *this))) {
620 if (emitExpensiveWarnings) {
621 Location loc = debugImporter->translateLoc(inst->getDebugLoc());
622 emitWarning(loc) << "unhandled metadata: "
623 << diagMD(node, llvmModule.get()) << " on "
624 << diag(*inst);
630 void ModuleImport::setIntegerOverflowFlags(llvm::Instruction *inst,
631 Operation *op) const {
632 auto iface = cast<IntegerOverflowFlagsInterface>(op);
634 IntegerOverflowFlags value = {};
635 value = bitEnumSet(value, IntegerOverflowFlags::nsw, inst->hasNoSignedWrap());
636 value =
637 bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
639 iface.setOverflowFlags(value);
642 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
643 Operation *op) const {
644 auto iface = cast<FastmathFlagsInterface>(op);
646 // Even if the imported operation implements the fastmath interface, the
647 // original instruction may not have fastmath flags set. Exit if an
648 // instruction, such as a non floating-point function call, does not have
649 // fastmath flags.
650 if (!isa<llvm::FPMathOperator>(inst))
651 return;
652 llvm::FastMathFlags flags = inst->getFastMathFlags();
654 // Set the fastmath bits flag-by-flag.
655 FastmathFlags value = {};
656 value = bitEnumSet(value, FastmathFlags::nnan, flags.noNaNs());
657 value = bitEnumSet(value, FastmathFlags::ninf, flags.noInfs());
658 value = bitEnumSet(value, FastmathFlags::nsz, flags.noSignedZeros());
659 value = bitEnumSet(value, FastmathFlags::arcp, flags.allowReciprocal());
660 value = bitEnumSet(value, FastmathFlags::contract, flags.allowContract());
661 value = bitEnumSet(value, FastmathFlags::afn, flags.approxFunc());
662 value = bitEnumSet(value, FastmathFlags::reassoc, flags.allowReassoc());
663 FastmathFlagsAttr attr = FastmathFlagsAttr::get(builder.getContext(), value);
664 iface->setAttr(iface.getFastmathAttrName(), attr);
667 /// Returns if `type` is a scalar integer or floating-point type.
668 static bool isScalarType(Type type) {
669 return isa<IntegerType, FloatType>(type);
672 /// Returns `type` if it is a builtin integer or floating-point vector type that
673 /// can be used to create an attribute or nullptr otherwise. If provided,
674 /// `arrayShape` is added to the shape of the vector to create an attribute that
675 /// matches an array of vectors.
676 static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
677 if (!LLVM::isCompatibleVectorType(type))
678 return {};
680 llvm::ElementCount numElements = LLVM::getVectorNumElements(type);
681 if (numElements.isScalable()) {
682 emitError(UnknownLoc::get(type.getContext()))
683 << "scalable vectors not supported";
684 return {};
687 // An LLVM dialect vector can only contain scalars.
688 Type elementType = LLVM::getVectorElementType(type);
689 if (!isScalarType(elementType))
690 return {};
692 SmallVector<int64_t> shape(arrayShape.begin(), arrayShape.end());
693 shape.push_back(numElements.getKnownMinValue());
694 return VectorType::get(shape, elementType);
697 Type ModuleImport::getBuiltinTypeForAttr(Type type) {
698 if (!type)
699 return {};
701 // Return builtin integer and floating-point types as is.
702 if (isScalarType(type))
703 return type;
705 // Return builtin vectors of integer and floating-point types as is.
706 if (Type vectorType = getVectorTypeForAttr(type))
707 return vectorType;
709 // Multi-dimensional array types are converted to tensors or vectors,
710 // depending on the innermost type being a scalar or a vector.
711 SmallVector<int64_t> arrayShape;
712 while (auto arrayType = dyn_cast<LLVMArrayType>(type)) {
713 arrayShape.push_back(arrayType.getNumElements());
714 type = arrayType.getElementType();
716 if (isScalarType(type))
717 return RankedTensorType::get(arrayShape, type);
718 return getVectorTypeForAttr(type, arrayShape);
721 /// Returns an integer or float attribute for the provided scalar constant
722 /// `constScalar` or nullptr if the conversion fails.
723 static TypedAttr getScalarConstantAsAttr(OpBuilder &builder,
724 llvm::Constant *constScalar) {
725 MLIRContext *context = builder.getContext();
727 // Convert scalar intergers.
728 if (auto *constInt = dyn_cast<llvm::ConstantInt>(constScalar)) {
729 return builder.getIntegerAttr(
730 IntegerType::get(context, constInt->getBitWidth()),
731 constInt->getValue());
734 // Convert scalar floats.
735 if (auto *constFloat = dyn_cast<llvm::ConstantFP>(constScalar)) {
736 llvm::Type *type = constFloat->getType();
737 FloatType floatType =
738 type->isBFloatTy()
739 ? FloatType::getBF16(context)
740 : LLVM::detail::getFloatType(context, type->getScalarSizeInBits());
741 if (!floatType) {
742 emitError(UnknownLoc::get(builder.getContext()))
743 << "unexpected floating-point type";
744 return {};
746 return builder.getFloatAttr(floatType, constFloat->getValueAPF());
748 return {};
751 /// Returns an integer or float attribute array for the provided constant
752 /// sequence `constSequence` or nullptr if the conversion fails.
753 static SmallVector<Attribute>
754 getSequenceConstantAsAttrs(OpBuilder &builder,
755 llvm::ConstantDataSequential *constSequence) {
756 SmallVector<Attribute> elementAttrs;
757 elementAttrs.reserve(constSequence->getNumElements());
758 for (auto idx : llvm::seq<int64_t>(0, constSequence->getNumElements())) {
759 llvm::Constant *constElement = constSequence->getElementAsConstant(idx);
760 elementAttrs.push_back(getScalarConstantAsAttr(builder, constElement));
762 return elementAttrs;
765 Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) {
766 // Convert scalar constants.
767 if (Attribute scalarAttr = getScalarConstantAsAttr(builder, constant))
768 return scalarAttr;
770 // Convert function references.
771 if (auto *func = dyn_cast<llvm::Function>(constant))
772 return SymbolRefAttr::get(builder.getContext(), func->getName());
774 // Returns the static shape of the provided type if possible.
775 auto getConstantShape = [&](llvm::Type *type) {
776 return llvm::dyn_cast_if_present<ShapedType>(
777 getBuiltinTypeForAttr(convertType(type)));
780 // Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte
781 // integer or half/bfloat/float/double values.
782 if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(constant)) {
783 if (constArray->isString())
784 return builder.getStringAttr(constArray->getAsString());
785 auto shape = getConstantShape(constArray->getType());
786 if (!shape)
787 return {};
788 // Convert splat constants to splat elements attributes.
789 auto *constVector = dyn_cast<llvm::ConstantDataVector>(constant);
790 if (constVector && constVector->isSplat()) {
791 // A vector is guaranteed to have at least size one.
792 Attribute splatAttr = getScalarConstantAsAttr(
793 builder, constVector->getElementAsConstant(0));
794 return SplatElementsAttr::get(shape, splatAttr);
796 // Convert non-splat constants to dense elements attributes.
797 SmallVector<Attribute> elementAttrs =
798 getSequenceConstantAsAttrs(builder, constArray);
799 return DenseElementsAttr::get(shape, elementAttrs);
802 // Convert multi-dimensional constant aggregates that store all kinds of
803 // integer and floating-point types.
804 if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(constant)) {
805 auto shape = getConstantShape(constAggregate->getType());
806 if (!shape)
807 return {};
808 // Collect the aggregate elements in depths first order.
809 SmallVector<Attribute> elementAttrs;
810 SmallVector<llvm::Constant *> workList = {constAggregate};
811 while (!workList.empty()) {
812 llvm::Constant *current = workList.pop_back_val();
813 // Append any nested aggregates in reverse order to ensure the head
814 // element of the nested aggregates is at the back of the work list.
815 if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(current)) {
816 for (auto idx :
817 reverse(llvm::seq<int64_t>(0, constAggregate->getNumOperands())))
818 workList.push_back(constAggregate->getAggregateElement(idx));
819 continue;
821 // Append the elements of nested constant arrays or vectors that store
822 // 1/2/4/8-byte integer or half/bfloat/float/double values.
823 if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(current)) {
824 SmallVector<Attribute> attrs =
825 getSequenceConstantAsAttrs(builder, constArray);
826 elementAttrs.append(attrs.begin(), attrs.end());
827 continue;
829 // Append nested scalar constants that store all kinds of integer and
830 // floating-point types.
831 if (Attribute scalarAttr = getScalarConstantAsAttr(builder, current)) {
832 elementAttrs.push_back(scalarAttr);
833 continue;
835 // Bail if the aggregate contains a unsupported constant type such as a
836 // constant expression.
837 return {};
839 return DenseElementsAttr::get(shape, elementAttrs);
842 // Convert zero aggregates.
843 if (auto *constZero = dyn_cast<llvm::ConstantAggregateZero>(constant)) {
844 auto shape = llvm::dyn_cast_if_present<ShapedType>(
845 getBuiltinTypeForAttr(convertType(constZero->getType())));
846 if (!shape)
847 return {};
848 // Convert zero aggregates with a static shape to splat elements attributes.
849 Attribute splatAttr = builder.getZeroAttr(shape.getElementType());
850 assert(splatAttr && "expected non-null zero attribute for scalar types");
851 return SplatElementsAttr::get(shape, splatAttr);
853 return {};
856 LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {
857 // Insert the global after the last one or at the start of the module.
858 OpBuilder::InsertionGuard guard(builder);
859 if (!globalInsertionOp)
860 builder.setInsertionPointToStart(mlirModule.getBody());
861 else
862 builder.setInsertionPointAfter(globalInsertionOp);
864 Attribute valueAttr;
865 if (globalVar->hasInitializer())
866 valueAttr = getConstantAsAttr(globalVar->getInitializer());
867 Type type = convertType(globalVar->getValueType());
869 uint64_t alignment = 0;
870 llvm::MaybeAlign maybeAlign = globalVar->getAlign();
871 if (maybeAlign.has_value()) {
872 llvm::Align align = *maybeAlign;
873 alignment = align.value();
876 // Get the global expression associated with this global variable and convert
877 // it.
878 DIGlobalVariableExpressionAttr globalExpressionAttr;
879 SmallVector<llvm::DIGlobalVariableExpression *> globalExpressions;
880 globalVar->getDebugInfo(globalExpressions);
882 // There should only be a single global expression.
883 if (!globalExpressions.empty())
884 globalExpressionAttr =
885 debugImporter->translateGlobalVariableExpression(globalExpressions[0]);
887 GlobalOp globalOp = builder.create<GlobalOp>(
888 mlirModule.getLoc(), type, globalVar->isConstant(),
889 convertLinkageFromLLVM(globalVar->getLinkage()), globalVar->getName(),
890 valueAttr, alignment, /*addr_space=*/globalVar->getAddressSpace(),
891 /*dso_local=*/globalVar->isDSOLocal(),
892 /*thread_local=*/globalVar->isThreadLocal(), /*comdat=*/SymbolRefAttr(),
893 /*attrs=*/ArrayRef<NamedAttribute>(), /*dbgExpr=*/globalExpressionAttr);
894 globalInsertionOp = globalOp;
896 if (globalVar->hasInitializer() && !valueAttr) {
897 clearRegionState();
898 Block *block = builder.createBlock(&globalOp.getInitializerRegion());
899 setConstantInsertionPointToStart(block);
900 FailureOr<Value> initializer =
901 convertConstantExpr(globalVar->getInitializer());
902 if (failed(initializer))
903 return failure();
904 builder.create<ReturnOp>(globalOp.getLoc(), *initializer);
906 if (globalVar->hasAtLeastLocalUnnamedAddr()) {
907 globalOp.setUnnamedAddr(
908 convertUnnamedAddrFromLLVM(globalVar->getUnnamedAddr()));
910 if (globalVar->hasSection())
911 globalOp.setSection(globalVar->getSection());
912 globalOp.setVisibility_(
913 convertVisibilityFromLLVM(globalVar->getVisibility()));
915 if (globalVar->hasComdat())
916 globalOp.setComdatAttr(comdatMapping.lookup(globalVar->getComdat()));
918 return success();
921 LogicalResult
922 ModuleImport::convertGlobalCtorsAndDtors(llvm::GlobalVariable *globalVar) {
923 if (!globalVar->hasInitializer() || !globalVar->hasAppendingLinkage())
924 return failure();
925 auto *initializer =
926 dyn_cast<llvm::ConstantArray>(globalVar->getInitializer());
927 if (!initializer)
928 return failure();
930 SmallVector<Attribute> funcs;
931 SmallVector<int32_t> priorities;
932 for (llvm::Value *operand : initializer->operands()) {
933 auto *aggregate = dyn_cast<llvm::ConstantAggregate>(operand);
934 if (!aggregate || aggregate->getNumOperands() != 3)
935 return failure();
937 auto *priority = dyn_cast<llvm::ConstantInt>(aggregate->getOperand(0));
938 auto *func = dyn_cast<llvm::Function>(aggregate->getOperand(1));
939 auto *data = dyn_cast<llvm::Constant>(aggregate->getOperand(2));
940 if (!priority || !func || !data)
941 return failure();
943 // GlobalCtorsOps and GlobalDtorsOps do not support non-null data fields.
944 if (!data->isNullValue())
945 return failure();
947 funcs.push_back(FlatSymbolRefAttr::get(context, func->getName()));
948 priorities.push_back(priority->getValue().getZExtValue());
951 OpBuilder::InsertionGuard guard(builder);
952 if (!globalInsertionOp)
953 builder.setInsertionPointToStart(mlirModule.getBody());
954 else
955 builder.setInsertionPointAfter(globalInsertionOp);
957 if (globalVar->getName() == getGlobalCtorsVarName()) {
958 globalInsertionOp = builder.create<LLVM::GlobalCtorsOp>(
959 mlirModule.getLoc(), builder.getArrayAttr(funcs),
960 builder.getI32ArrayAttr(priorities));
961 return success();
963 globalInsertionOp = builder.create<LLVM::GlobalDtorsOp>(
964 mlirModule.getLoc(), builder.getArrayAttr(funcs),
965 builder.getI32ArrayAttr(priorities));
966 return success();
969 SetVector<llvm::Constant *>
970 ModuleImport::getConstantsToConvert(llvm::Constant *constant) {
971 // Return the empty set if the constant has been translated before.
972 if (valueMapping.contains(constant))
973 return {};
975 // Traverse the constants in post-order and stop the traversal if a constant
976 // already has a `valueMapping` from an earlier constant translation or if the
977 // constant is traversed a second time.
978 SetVector<llvm::Constant *> orderedSet;
979 SetVector<llvm::Constant *> workList;
980 DenseMap<llvm::Constant *, SmallVector<llvm::Constant *>> adjacencyLists;
981 workList.insert(constant);
982 while (!workList.empty()) {
983 llvm::Constant *current = workList.back();
984 // Collect all dependencies of the current constant and add them to the
985 // adjacency list if none has been computed before.
986 auto adjacencyIt = adjacencyLists.find(current);
987 if (adjacencyIt == adjacencyLists.end()) {
988 adjacencyIt = adjacencyLists.try_emplace(current).first;
989 // Add all constant operands to the adjacency list and skip any other
990 // values such as basic block addresses.
991 for (llvm::Value *operand : current->operands())
992 if (auto *constDependency = dyn_cast<llvm::Constant>(operand))
993 adjacencyIt->getSecond().push_back(constDependency);
994 // Use the getElementValue method to add the dependencies of zero
995 // initialized aggregate constants since they do not take any operands.
996 if (auto *constAgg = dyn_cast<llvm::ConstantAggregateZero>(current)) {
997 unsigned numElements = constAgg->getElementCount().getFixedValue();
998 for (unsigned i = 0, e = numElements; i != e; ++i)
999 adjacencyIt->getSecond().push_back(constAgg->getElementValue(i));
1002 // Add the current constant to the `orderedSet` of the traversed nodes if
1003 // all its dependencies have been traversed before. Additionally, remove the
1004 // constant from the `workList` and continue the traversal.
1005 if (adjacencyIt->getSecond().empty()) {
1006 orderedSet.insert(current);
1007 workList.pop_back();
1008 continue;
1010 // Add the next dependency from the adjacency list to the `workList` and
1011 // continue the traversal. Remove the dependency from the adjacency list to
1012 // mark that it has been processed. Only enqueue the dependency if it has no
1013 // `valueMapping` from an earlier translation and if it has not been
1014 // enqueued before.
1015 llvm::Constant *dependency = adjacencyIt->getSecond().pop_back_val();
1016 if (valueMapping.contains(dependency) || workList.contains(dependency) ||
1017 orderedSet.contains(dependency))
1018 continue;
1019 workList.insert(dependency);
1022 return orderedSet;
1025 FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
1026 Location loc = UnknownLoc::get(context);
1028 // Convert constants that can be represented as attributes.
1029 if (Attribute attr = getConstantAsAttr(constant)) {
1030 Type type = convertType(constant->getType());
1031 if (auto symbolRef = dyn_cast<FlatSymbolRefAttr>(attr)) {
1032 return builder.create<AddressOfOp>(loc, type, symbolRef.getValue())
1033 .getResult();
1035 return builder.create<ConstantOp>(loc, type, attr).getResult();
1038 // Convert null pointer constants.
1039 if (auto *nullPtr = dyn_cast<llvm::ConstantPointerNull>(constant)) {
1040 Type type = convertType(nullPtr->getType());
1041 return builder.create<ZeroOp>(loc, type).getResult();
1044 // Convert none token constants.
1045 if (isa<llvm::ConstantTokenNone>(constant)) {
1046 return builder.create<NoneTokenOp>(loc).getResult();
1049 // Convert poison.
1050 if (auto *poisonVal = dyn_cast<llvm::PoisonValue>(constant)) {
1051 Type type = convertType(poisonVal->getType());
1052 return builder.create<PoisonOp>(loc, type).getResult();
1055 // Convert undef.
1056 if (auto *undefVal = dyn_cast<llvm::UndefValue>(constant)) {
1057 Type type = convertType(undefVal->getType());
1058 return builder.create<UndefOp>(loc, type).getResult();
1061 // Convert global variable accesses.
1062 if (auto *globalVar = dyn_cast<llvm::GlobalVariable>(constant)) {
1063 Type type = convertType(globalVar->getType());
1064 auto symbolRef = FlatSymbolRefAttr::get(context, globalVar->getName());
1065 return builder.create<AddressOfOp>(loc, type, symbolRef).getResult();
1068 // Convert constant expressions.
1069 if (auto *constExpr = dyn_cast<llvm::ConstantExpr>(constant)) {
1070 // Convert the constant expression to a temporary LLVM instruction and
1071 // translate it using the `processInstruction` method. Delete the
1072 // instruction after the translation and remove it from `valueMapping`,
1073 // since later calls to `getAsInstruction` may return the same address
1074 // resulting in a conflicting `valueMapping` entry.
1075 llvm::Instruction *inst = constExpr->getAsInstruction();
1076 auto guard = llvm::make_scope_exit([&]() {
1077 assert(!noResultOpMapping.contains(inst) &&
1078 "expected constant expression to return a result");
1079 valueMapping.erase(inst);
1080 inst->deleteValue();
1082 // Note: `processInstruction` does not call `convertConstant` recursively
1083 // since all constant dependencies have been converted before.
1084 assert(llvm::all_of(inst->operands(), [&](llvm::Value *value) {
1085 return valueMapping.contains(value);
1086 }));
1087 if (failed(processInstruction(inst)))
1088 return failure();
1089 return lookupValue(inst);
1092 // Convert aggregate constants.
1093 if (isa<llvm::ConstantAggregate>(constant) ||
1094 isa<llvm::ConstantAggregateZero>(constant)) {
1095 // Lookup the aggregate elements that have been converted before.
1096 SmallVector<Value> elementValues;
1097 if (auto *constAgg = dyn_cast<llvm::ConstantAggregate>(constant)) {
1098 elementValues.reserve(constAgg->getNumOperands());
1099 for (llvm::Value *operand : constAgg->operands())
1100 elementValues.push_back(lookupValue(operand));
1102 if (auto *constAgg = dyn_cast<llvm::ConstantAggregateZero>(constant)) {
1103 unsigned numElements = constAgg->getElementCount().getFixedValue();
1104 elementValues.reserve(numElements);
1105 for (unsigned i = 0, e = numElements; i != e; ++i)
1106 elementValues.push_back(lookupValue(constAgg->getElementValue(i)));
1108 assert(llvm::count(elementValues, nullptr) == 0 &&
1109 "expected all elements have been converted before");
1111 // Generate an UndefOp as root value and insert the aggregate elements.
1112 Type rootType = convertType(constant->getType());
1113 bool isArrayOrStruct = isa<LLVMArrayType, LLVMStructType>(rootType);
1114 assert((isArrayOrStruct || LLVM::isCompatibleVectorType(rootType)) &&
1115 "unrecognized aggregate type");
1116 Value root = builder.create<UndefOp>(loc, rootType);
1117 for (const auto &it : llvm::enumerate(elementValues)) {
1118 if (isArrayOrStruct) {
1119 root = builder.create<InsertValueOp>(loc, root, it.value(), it.index());
1120 } else {
1121 Attribute indexAttr = builder.getI32IntegerAttr(it.index());
1122 Value indexValue =
1123 builder.create<ConstantOp>(loc, builder.getI32Type(), indexAttr);
1124 root = builder.create<InsertElementOp>(loc, rootType, root, it.value(),
1125 indexValue);
1128 return root;
1131 if (auto *constTargetNone = dyn_cast<llvm::ConstantTargetNone>(constant)) {
1132 LLVMTargetExtType targetExtType =
1133 cast<LLVMTargetExtType>(convertType(constTargetNone->getType()));
1134 assert(targetExtType.hasProperty(LLVMTargetExtType::HasZeroInit) &&
1135 "target extension type does not support zero-initialization");
1136 // Create llvm.mlir.zero operation to represent zero-initialization of
1137 // target extension type.
1138 return builder.create<LLVM::ZeroOp>(loc, targetExtType).getRes();
1141 StringRef error = "";
1142 if (isa<llvm::BlockAddress>(constant))
1143 error = " since blockaddress(...) is unsupported";
1145 return emitError(loc) << "unhandled constant: " << diag(*constant) << error;
1148 FailureOr<Value> ModuleImport::convertConstantExpr(llvm::Constant *constant) {
1149 // Only call the function for constants that have not been translated before
1150 // since it updates the constant insertion point assuming the converted
1151 // constant has been introduced at the end of the constant section.
1152 assert(!valueMapping.contains(constant) &&
1153 "expected constant has not been converted before");
1154 assert(constantInsertionBlock &&
1155 "expected the constant insertion block to be non-null");
1157 // Insert the constant after the last one or at the start of the entry block.
1158 OpBuilder::InsertionGuard guard(builder);
1159 if (!constantInsertionOp)
1160 builder.setInsertionPointToStart(constantInsertionBlock);
1161 else
1162 builder.setInsertionPointAfter(constantInsertionOp);
1164 // Convert all constants of the expression and add them to `valueMapping`.
1165 SetVector<llvm::Constant *> constantsToConvert =
1166 getConstantsToConvert(constant);
1167 for (llvm::Constant *constantToConvert : constantsToConvert) {
1168 FailureOr<Value> converted = convertConstant(constantToConvert);
1169 if (failed(converted))
1170 return failure();
1171 mapValue(constantToConvert, *converted);
1174 // Update the constant insertion point and return the converted constant.
1175 Value result = lookupValue(constant);
1176 constantInsertionOp = result.getDefiningOp();
1177 return result;
1180 FailureOr<Value> ModuleImport::convertValue(llvm::Value *value) {
1181 assert(!isa<llvm::MetadataAsValue>(value) &&
1182 "expected value to not be metadata");
1184 // Return the mapped value if it has been converted before.
1185 auto it = valueMapping.find(value);
1186 if (it != valueMapping.end())
1187 return it->getSecond();
1189 // Convert constants such as immediate values that have no mapping yet.
1190 if (auto *constant = dyn_cast<llvm::Constant>(value))
1191 return convertConstantExpr(constant);
1193 Location loc = UnknownLoc::get(context);
1194 if (auto *inst = dyn_cast<llvm::Instruction>(value))
1195 loc = translateLoc(inst->getDebugLoc());
1196 return emitError(loc) << "unhandled value: " << diag(*value);
1199 FailureOr<Value> ModuleImport::convertMetadataValue(llvm::Value *value) {
1200 // A value may be wrapped as metadata, for example, when passed to a debug
1201 // intrinsic. Unwrap these values before the conversion.
1202 auto *nodeAsVal = dyn_cast<llvm::MetadataAsValue>(value);
1203 if (!nodeAsVal)
1204 return failure();
1205 auto *node = dyn_cast<llvm::ValueAsMetadata>(nodeAsVal->getMetadata());
1206 if (!node)
1207 return failure();
1208 value = node->getValue();
1210 // Return the mapped value if it has been converted before.
1211 auto it = valueMapping.find(value);
1212 if (it != valueMapping.end())
1213 return it->getSecond();
1215 // Convert constants such as immediate values that have no mapping yet.
1216 if (auto *constant = dyn_cast<llvm::Constant>(value))
1217 return convertConstantExpr(constant);
1218 return failure();
1221 FailureOr<SmallVector<Value>>
1222 ModuleImport::convertValues(ArrayRef<llvm::Value *> values) {
1223 SmallVector<Value> remapped;
1224 remapped.reserve(values.size());
1225 for (llvm::Value *value : values) {
1226 FailureOr<Value> converted = convertValue(value);
1227 if (failed(converted))
1228 return failure();
1229 remapped.push_back(*converted);
1231 return remapped;
1234 LogicalResult ModuleImport::convertIntrinsicArguments(
1235 ArrayRef<llvm::Value *> values, ArrayRef<unsigned> immArgPositions,
1236 ArrayRef<StringLiteral> immArgAttrNames, SmallVectorImpl<Value> &valuesOut,
1237 SmallVectorImpl<NamedAttribute> &attrsOut) {
1238 assert(immArgPositions.size() == immArgAttrNames.size() &&
1239 "LLVM `immArgPositions` and MLIR `immArgAttrNames` should have equal "
1240 "length");
1242 SmallVector<llvm::Value *> operands(values);
1243 for (auto [immArgPos, immArgName] :
1244 llvm::zip(immArgPositions, immArgAttrNames)) {
1245 auto &value = operands[immArgPos];
1246 auto *constant = llvm::cast<llvm::Constant>(value);
1247 auto attr = getScalarConstantAsAttr(builder, constant);
1248 assert(attr && attr.getType().isIntOrFloat() &&
1249 "expected immarg to be float or integer constant");
1250 auto nameAttr = StringAttr::get(attr.getContext(), immArgName);
1251 attrsOut.push_back({nameAttr, attr});
1252 // Mark matched attribute values as null (so they can be removed below).
1253 value = nullptr;
1256 for (llvm::Value *value : operands) {
1257 if (!value)
1258 continue;
1259 auto mlirValue = convertValue(value);
1260 if (failed(mlirValue))
1261 return failure();
1262 valuesOut.push_back(*mlirValue);
1265 return success();
1268 IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
1269 IntegerAttr integerAttr;
1270 FailureOr<Value> converted = convertValue(value);
1271 bool success = succeeded(converted) &&
1272 matchPattern(*converted, m_Constant(&integerAttr));
1273 assert(success && "expected a constant integer value");
1274 (void)success;
1275 return integerAttr;
1278 FloatAttr ModuleImport::matchFloatAttr(llvm::Value *value) {
1279 FloatAttr floatAttr;
1280 FailureOr<Value> converted = convertValue(value);
1281 bool success =
1282 succeeded(converted) && matchPattern(*converted, m_Constant(&floatAttr));
1283 assert(success && "expected a constant float value");
1284 (void)success;
1285 return floatAttr;
1288 DILocalVariableAttr ModuleImport::matchLocalVariableAttr(llvm::Value *value) {
1289 auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
1290 auto *node = cast<llvm::DILocalVariable>(nodeAsVal->getMetadata());
1291 return debugImporter->translate(node);
1294 DILabelAttr ModuleImport::matchLabelAttr(llvm::Value *value) {
1295 auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
1296 auto *node = cast<llvm::DILabel>(nodeAsVal->getMetadata());
1297 return debugImporter->translate(node);
1300 FPExceptionBehaviorAttr
1301 ModuleImport::matchFPExceptionBehaviorAttr(llvm::Value *value) {
1302 auto *metadata = cast<llvm::MetadataAsValue>(value);
1303 auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
1304 std::optional<llvm::fp::ExceptionBehavior> optLLVM =
1305 llvm::convertStrToExceptionBehavior(mdstr->getString());
1306 assert(optLLVM && "Expecting FP exception behavior");
1307 return builder.getAttr<FPExceptionBehaviorAttr>(
1308 convertFPExceptionBehaviorFromLLVM(*optLLVM));
1311 RoundingModeAttr ModuleImport::matchRoundingModeAttr(llvm::Value *value) {
1312 auto *metadata = cast<llvm::MetadataAsValue>(value);
1313 auto *mdstr = cast<llvm::MDString>(metadata->getMetadata());
1314 std::optional<llvm::RoundingMode> optLLVM =
1315 llvm::convertStrToRoundingMode(mdstr->getString());
1316 assert(optLLVM && "Expecting rounding mode");
1317 return builder.getAttr<RoundingModeAttr>(
1318 convertRoundingModeFromLLVM(*optLLVM));
1321 FailureOr<SmallVector<AliasScopeAttr>>
1322 ModuleImport::matchAliasScopeAttrs(llvm::Value *value) {
1323 auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
1324 auto *node = cast<llvm::MDNode>(nodeAsVal->getMetadata());
1325 return lookupAliasScopeAttrs(node);
1328 Location ModuleImport::translateLoc(llvm::DILocation *loc) {
1329 return debugImporter->translateLoc(loc);
1332 LogicalResult
1333 ModuleImport::convertBranchArgs(llvm::Instruction *branch,
1334 llvm::BasicBlock *target,
1335 SmallVectorImpl<Value> &blockArguments) {
1336 for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
1337 auto *phiInst = cast<llvm::PHINode>(&*inst);
1338 llvm::Value *value = phiInst->getIncomingValueForBlock(branch->getParent());
1339 FailureOr<Value> converted = convertValue(value);
1340 if (failed(converted))
1341 return failure();
1342 blockArguments.push_back(*converted);
1344 return success();
1347 LogicalResult
1348 ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
1349 SmallVectorImpl<Type> &types,
1350 SmallVectorImpl<Value> &operands) {
1351 if (!callInst->getType()->isVoidTy())
1352 types.push_back(convertType(callInst->getType()));
1354 if (!callInst->getCalledFunction()) {
1355 FailureOr<Value> called = convertValue(callInst->getCalledOperand());
1356 if (failed(called))
1357 return failure();
1358 operands.push_back(*called);
1360 SmallVector<llvm::Value *> args(callInst->args());
1361 FailureOr<SmallVector<Value>> arguments = convertValues(args);
1362 if (failed(arguments))
1363 return failure();
1364 llvm::append_range(operands, *arguments);
1365 return success();
1368 LogicalResult ModuleImport::convertIntrinsic(llvm::CallInst *inst) {
1369 if (succeeded(iface.convertIntrinsic(builder, inst, *this)))
1370 return success();
1372 Location loc = translateLoc(inst->getDebugLoc());
1373 return emitError(loc) << "unhandled intrinsic: " << diag(*inst);
1376 LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
1377 // Convert all instructions that do not provide an MLIR builder.
1378 Location loc = translateLoc(inst->getDebugLoc());
1379 if (inst->getOpcode() == llvm::Instruction::Br) {
1380 auto *brInst = cast<llvm::BranchInst>(inst);
1382 SmallVector<Block *> succBlocks;
1383 SmallVector<SmallVector<Value>> succBlockArgs;
1384 for (auto i : llvm::seq<unsigned>(0, brInst->getNumSuccessors())) {
1385 llvm::BasicBlock *succ = brInst->getSuccessor(i);
1386 SmallVector<Value> blockArgs;
1387 if (failed(convertBranchArgs(brInst, succ, blockArgs)))
1388 return failure();
1389 succBlocks.push_back(lookupBlock(succ));
1390 succBlockArgs.push_back(blockArgs);
1393 if (!brInst->isConditional()) {
1394 auto brOp = builder.create<LLVM::BrOp>(loc, succBlockArgs.front(),
1395 succBlocks.front());
1396 mapNoResultOp(inst, brOp);
1397 return success();
1399 FailureOr<Value> condition = convertValue(brInst->getCondition());
1400 if (failed(condition))
1401 return failure();
1402 auto condBrOp = builder.create<LLVM::CondBrOp>(
1403 loc, *condition, succBlocks.front(), succBlockArgs.front(),
1404 succBlocks.back(), succBlockArgs.back());
1405 mapNoResultOp(inst, condBrOp);
1406 return success();
1408 if (inst->getOpcode() == llvm::Instruction::Switch) {
1409 auto *swInst = cast<llvm::SwitchInst>(inst);
1410 // Process the condition value.
1411 FailureOr<Value> condition = convertValue(swInst->getCondition());
1412 if (failed(condition))
1413 return failure();
1414 SmallVector<Value> defaultBlockArgs;
1415 // Process the default case.
1416 llvm::BasicBlock *defaultBB = swInst->getDefaultDest();
1417 if (failed(convertBranchArgs(swInst, defaultBB, defaultBlockArgs)))
1418 return failure();
1420 // Process the cases.
1421 unsigned numCases = swInst->getNumCases();
1422 SmallVector<SmallVector<Value>> caseOperands(numCases);
1423 SmallVector<ValueRange> caseOperandRefs(numCases);
1424 SmallVector<APInt> caseValues(numCases);
1425 SmallVector<Block *> caseBlocks(numCases);
1426 for (const auto &it : llvm::enumerate(swInst->cases())) {
1427 const llvm::SwitchInst::CaseHandle &caseHandle = it.value();
1428 llvm::BasicBlock *succBB = caseHandle.getCaseSuccessor();
1429 if (failed(convertBranchArgs(swInst, succBB, caseOperands[it.index()])))
1430 return failure();
1431 caseOperandRefs[it.index()] = caseOperands[it.index()];
1432 caseValues[it.index()] = caseHandle.getCaseValue()->getValue();
1433 caseBlocks[it.index()] = lookupBlock(succBB);
1436 auto switchOp = builder.create<SwitchOp>(
1437 loc, *condition, lookupBlock(defaultBB), defaultBlockArgs, caseValues,
1438 caseBlocks, caseOperandRefs);
1439 mapNoResultOp(inst, switchOp);
1440 return success();
1442 if (inst->getOpcode() == llvm::Instruction::PHI) {
1443 Type type = convertType(inst->getType());
1444 mapValue(inst, builder.getInsertionBlock()->addArgument(
1445 type, translateLoc(inst->getDebugLoc())));
1446 return success();
1448 if (inst->getOpcode() == llvm::Instruction::Call) {
1449 auto *callInst = cast<llvm::CallInst>(inst);
1451 SmallVector<Type> types;
1452 SmallVector<Value> operands;
1453 if (failed(convertCallTypeAndOperands(callInst, types, operands)))
1454 return failure();
1456 auto funcTy =
1457 dyn_cast<LLVMFunctionType>(convertType(callInst->getFunctionType()));
1458 if (!funcTy)
1459 return failure();
1461 CallOp callOp;
1463 if (llvm::Function *callee = callInst->getCalledFunction()) {
1464 callOp = builder.create<CallOp>(
1465 loc, funcTy, SymbolRefAttr::get(context, callee->getName()),
1466 operands);
1467 } else {
1468 callOp = builder.create<CallOp>(loc, funcTy, operands);
1470 callOp.setCConv(convertCConvFromLLVM(callInst->getCallingConv()));
1471 setFastmathFlagsAttr(inst, callOp);
1472 if (!callInst->getType()->isVoidTy())
1473 mapValue(inst, callOp.getResult());
1474 else
1475 mapNoResultOp(inst, callOp);
1476 return success();
1478 if (inst->getOpcode() == llvm::Instruction::LandingPad) {
1479 auto *lpInst = cast<llvm::LandingPadInst>(inst);
1481 SmallVector<Value> operands;
1482 operands.reserve(lpInst->getNumClauses());
1483 for (auto i : llvm::seq<unsigned>(0, lpInst->getNumClauses())) {
1484 FailureOr<Value> operand = convertValue(lpInst->getClause(i));
1485 if (failed(operand))
1486 return failure();
1487 operands.push_back(*operand);
1490 Type type = convertType(lpInst->getType());
1491 auto lpOp =
1492 builder.create<LandingpadOp>(loc, type, lpInst->isCleanup(), operands);
1493 mapValue(inst, lpOp);
1494 return success();
1496 if (inst->getOpcode() == llvm::Instruction::Invoke) {
1497 auto *invokeInst = cast<llvm::InvokeInst>(inst);
1499 SmallVector<Type> types;
1500 SmallVector<Value> operands;
1501 if (failed(convertCallTypeAndOperands(invokeInst, types, operands)))
1502 return failure();
1504 // Check whether the invoke result is an argument to the normal destination
1505 // block.
1506 bool invokeResultUsedInPhi = llvm::any_of(
1507 invokeInst->getNormalDest()->phis(), [&](const llvm::PHINode &phi) {
1508 return phi.getIncomingValueForBlock(invokeInst->getParent()) ==
1509 invokeInst;
1512 Block *normalDest = lookupBlock(invokeInst->getNormalDest());
1513 Block *directNormalDest = normalDest;
1514 if (invokeResultUsedInPhi) {
1515 // The invoke result cannot be an argument to the normal destination
1516 // block, as that would imply using the invoke operation result in its
1517 // definition, so we need to create a dummy block to serve as an
1518 // intermediate destination.
1519 OpBuilder::InsertionGuard g(builder);
1520 directNormalDest = builder.createBlock(normalDest);
1523 SmallVector<Value> unwindArgs;
1524 if (failed(convertBranchArgs(invokeInst, invokeInst->getUnwindDest(),
1525 unwindArgs)))
1526 return failure();
1528 auto funcTy =
1529 dyn_cast<LLVMFunctionType>(convertType(invokeInst->getFunctionType()));
1530 if (!funcTy)
1531 return failure();
1533 // Create the invoke operation. Normal destination block arguments will be
1534 // added later on to handle the case in which the operation result is
1535 // included in this list.
1536 InvokeOp invokeOp;
1537 if (llvm::Function *callee = invokeInst->getCalledFunction()) {
1538 invokeOp = builder.create<InvokeOp>(
1539 loc, funcTy,
1540 SymbolRefAttr::get(builder.getContext(), callee->getName()), operands,
1541 directNormalDest, ValueRange(),
1542 lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1543 } else {
1544 invokeOp = builder.create<InvokeOp>(
1545 loc, funcTy, /*callee=*/nullptr, operands, directNormalDest,
1546 ValueRange(), lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1548 invokeOp.setCConv(convertCConvFromLLVM(invokeInst->getCallingConv()));
1549 if (!invokeInst->getType()->isVoidTy())
1550 mapValue(inst, invokeOp.getResults().front());
1551 else
1552 mapNoResultOp(inst, invokeOp);
1554 SmallVector<Value> normalArgs;
1555 if (failed(convertBranchArgs(invokeInst, invokeInst->getNormalDest(),
1556 normalArgs)))
1557 return failure();
1559 if (invokeResultUsedInPhi) {
1560 // The dummy normal dest block will just host an unconditional branch
1561 // instruction to the normal destination block passing the required block
1562 // arguments (including the invoke operation's result).
1563 OpBuilder::InsertionGuard g(builder);
1564 builder.setInsertionPointToStart(directNormalDest);
1565 builder.create<LLVM::BrOp>(loc, normalArgs, normalDest);
1566 } else {
1567 // If the invoke operation's result is not a block argument to the normal
1568 // destination block, just add the block arguments as usual.
1569 assert(llvm::none_of(
1570 normalArgs,
1571 [&](Value val) { return val.getDefiningOp() == invokeOp; }) &&
1572 "An llvm.invoke operation cannot pass its result as a block "
1573 "argument.");
1574 invokeOp.getNormalDestOperandsMutable().append(normalArgs);
1577 return success();
1579 if (inst->getOpcode() == llvm::Instruction::GetElementPtr) {
1580 auto *gepInst = cast<llvm::GetElementPtrInst>(inst);
1581 Type sourceElementType = convertType(gepInst->getSourceElementType());
1582 FailureOr<Value> basePtr = convertValue(gepInst->getOperand(0));
1583 if (failed(basePtr))
1584 return failure();
1586 // Treat every indices as dynamic since GEPOp::build will refine those
1587 // indices into static attributes later. One small downside of this
1588 // approach is that many unused `llvm.mlir.constant` would be emitted
1589 // at first place.
1590 SmallVector<GEPArg> indices;
1591 for (llvm::Value *operand : llvm::drop_begin(gepInst->operand_values())) {
1592 FailureOr<Value> index = convertValue(operand);
1593 if (failed(index))
1594 return failure();
1595 indices.push_back(*index);
1598 Type type = convertType(inst->getType());
1599 auto gepOp = builder.create<GEPOp>(loc, type, sourceElementType, *basePtr,
1600 indices, gepInst->isInBounds());
1601 mapValue(inst, gepOp);
1602 return success();
1605 // Convert all instructions that have an mlirBuilder.
1606 if (succeeded(convertInstructionImpl(builder, inst, *this, iface)))
1607 return success();
1609 return emitError(loc) << "unhandled instruction: " << diag(*inst);
1612 LogicalResult ModuleImport::processInstruction(llvm::Instruction *inst) {
1613 // FIXME: Support uses of SubtargetData.
1614 // FIXME: Add support for call / operand attributes.
1615 // FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch,
1616 // callbr, vaarg, catchpad, cleanuppad instructions.
1618 // Convert LLVM intrinsics calls to MLIR intrinsics.
1619 if (auto *intrinsic = dyn_cast<llvm::IntrinsicInst>(inst))
1620 return convertIntrinsic(intrinsic);
1622 // Convert all remaining LLVM instructions to MLIR operations.
1623 return convertInstruction(inst);
1626 FlatSymbolRefAttr ModuleImport::getPersonalityAsAttr(llvm::Function *f) {
1627 if (!f->hasPersonalityFn())
1628 return nullptr;
1630 llvm::Constant *pf = f->getPersonalityFn();
1632 // If it directly has a name, we can use it.
1633 if (pf->hasName())
1634 return SymbolRefAttr::get(builder.getContext(), pf->getName());
1636 // If it doesn't have a name, currently, only function pointers that are
1637 // bitcast to i8* are parsed.
1638 if (auto *ce = dyn_cast<llvm::ConstantExpr>(pf)) {
1639 if (ce->getOpcode() == llvm::Instruction::BitCast &&
1640 ce->getType() == llvm::PointerType::getUnqual(f->getContext())) {
1641 if (auto *func = dyn_cast<llvm::Function>(ce->getOperand(0)))
1642 return SymbolRefAttr::get(builder.getContext(), func->getName());
1645 return FlatSymbolRefAttr();
1648 static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) {
1649 llvm::MemoryEffects memEffects = func->getMemoryEffects();
1651 auto othermem = convertModRefInfoFromLLVM(
1652 memEffects.getModRef(llvm::MemoryEffects::Location::Other));
1653 auto argMem = convertModRefInfoFromLLVM(
1654 memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
1655 auto inaccessibleMem = convertModRefInfoFromLLVM(
1656 memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
1657 auto memAttr = MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem,
1658 inaccessibleMem);
1659 // Only set the attr when it does not match the default value.
1660 if (memAttr.isReadWrite())
1661 return;
1662 funcOp.setMemoryAttr(memAttr);
1665 // List of LLVM IR attributes that map to an explicit attribute on the MLIR
1666 // LLVMFuncOp.
1667 static constexpr std::array ExplicitAttributes{
1668 StringLiteral("aarch64_pstate_sm_enabled"),
1669 StringLiteral("aarch64_pstate_sm_body"),
1670 StringLiteral("aarch64_pstate_sm_compatible"),
1671 StringLiteral("aarch64_new_za"),
1672 StringLiteral("aarch64_preserves_za"),
1673 StringLiteral("aarch64_in_za"),
1674 StringLiteral("aarch64_out_za"),
1675 StringLiteral("aarch64_inout_za"),
1676 StringLiteral("vscale_range"),
1677 StringLiteral("frame-pointer"),
1678 StringLiteral("target-features"),
1679 StringLiteral("unsafe-fp-math"),
1680 StringLiteral("no-infs-fp-math"),
1681 StringLiteral("no-nans-fp-math"),
1682 StringLiteral("approx-func-fp-math"),
1683 StringLiteral("no-signed-zeros-fp-math"),
1686 static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
1687 MLIRContext *context = funcOp.getContext();
1688 SmallVector<Attribute> passthroughs;
1689 llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes(
1690 llvm::AttributeList::AttrIndex::FunctionIndex);
1691 for (llvm::Attribute attr : funcAttrs) {
1692 // Skip the memory attribute since the LLVMFuncOp has an explicit memory
1693 // attribute.
1694 if (attr.hasAttribute(llvm::Attribute::Memory))
1695 continue;
1697 // Skip invalid type attributes.
1698 if (attr.isTypeAttribute()) {
1699 emitWarning(funcOp.getLoc(),
1700 "type attributes on a function are invalid, skipping it");
1701 continue;
1704 StringRef attrName;
1705 if (attr.isStringAttribute())
1706 attrName = attr.getKindAsString();
1707 else
1708 attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
1709 auto keyAttr = StringAttr::get(context, attrName);
1711 // Skip attributes that map to an explicit attribute on the LLVMFuncOp.
1712 if (llvm::is_contained(ExplicitAttributes, attrName))
1713 continue;
1715 if (attr.isStringAttribute()) {
1716 StringRef val = attr.getValueAsString();
1717 if (val.empty()) {
1718 passthroughs.push_back(keyAttr);
1719 continue;
1721 passthroughs.push_back(
1722 ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
1723 continue;
1725 if (attr.isIntAttribute()) {
1726 auto val = std::to_string(attr.getValueAsInt());
1727 passthroughs.push_back(
1728 ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
1729 continue;
1731 if (attr.isEnumAttribute()) {
1732 passthroughs.push_back(keyAttr);
1733 continue;
1736 llvm_unreachable("unexpected attribute kind");
1739 if (!passthroughs.empty())
1740 funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs));
1743 void ModuleImport::processFunctionAttributes(llvm::Function *func,
1744 LLVMFuncOp funcOp) {
1745 processMemoryEffects(func, funcOp);
1746 processPassthroughAttrs(func, funcOp);
1748 if (func->hasFnAttribute("aarch64_pstate_sm_enabled"))
1749 funcOp.setArmStreaming(true);
1750 else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
1751 funcOp.setArmLocallyStreaming(true);
1752 else if (func->hasFnAttribute("aarch64_pstate_sm_compatible"))
1753 funcOp.setArmStreamingCompatible(true);
1755 if (func->hasFnAttribute("aarch64_new_za"))
1756 funcOp.setArmNewZa(true);
1757 else if (func->hasFnAttribute("aarch64_in_za"))
1758 funcOp.setArmInZa(true);
1759 else if (func->hasFnAttribute("aarch64_out_za"))
1760 funcOp.setArmOutZa(true);
1761 else if (func->hasFnAttribute("aarch64_inout_za"))
1762 funcOp.setArmInoutZa(true);
1763 else if (func->hasFnAttribute("aarch64_preserves_za"))
1764 funcOp.setArmPreservesZa(true);
1766 llvm::Attribute attr = func->getFnAttribute(llvm::Attribute::VScaleRange);
1767 if (attr.isValid()) {
1768 MLIRContext *context = funcOp.getContext();
1769 auto intTy = IntegerType::get(context, 32);
1770 funcOp.setVscaleRangeAttr(LLVM::VScaleRangeAttr::get(
1771 context, IntegerAttr::get(intTy, attr.getVScaleRangeMin()),
1772 IntegerAttr::get(intTy, attr.getVScaleRangeMax().value_or(0))));
1775 // Process frame-pointer attribute.
1776 if (func->hasFnAttribute("frame-pointer")) {
1777 StringRef stringRefFramePointerKind =
1778 func->getFnAttribute("frame-pointer").getValueAsString();
1779 funcOp.setFramePointerAttr(LLVM::FramePointerKindAttr::get(
1780 funcOp.getContext(), LLVM::framePointerKind::symbolizeFramePointerKind(
1781 stringRefFramePointerKind)
1782 .value()));
1785 if (llvm::Attribute attr = func->getFnAttribute("target-cpu");
1786 attr.isStringAttribute())
1787 funcOp.setTargetCpuAttr(StringAttr::get(context, attr.getValueAsString()));
1789 if (llvm::Attribute attr = func->getFnAttribute("target-features");
1790 attr.isStringAttribute())
1791 funcOp.setTargetFeaturesAttr(
1792 LLVM::TargetFeaturesAttr::get(context, attr.getValueAsString()));
1794 if (llvm::Attribute attr = func->getFnAttribute("unsafe-fp-math");
1795 attr.isStringAttribute())
1796 funcOp.setUnsafeFpMath(attr.getValueAsBool());
1798 if (llvm::Attribute attr = func->getFnAttribute("no-infs-fp-math");
1799 attr.isStringAttribute())
1800 funcOp.setNoInfsFpMath(attr.getValueAsBool());
1802 if (llvm::Attribute attr = func->getFnAttribute("no-nans-fp-math");
1803 attr.isStringAttribute())
1804 funcOp.setNoNansFpMath(attr.getValueAsBool());
1806 if (llvm::Attribute attr = func->getFnAttribute("approx-func-fp-math");
1807 attr.isStringAttribute())
1808 funcOp.setApproxFuncFpMath(attr.getValueAsBool());
1810 if (llvm::Attribute attr = func->getFnAttribute("no-signed-zeros-fp-math");
1811 attr.isStringAttribute())
1812 funcOp.setNoSignedZerosFpMath(attr.getValueAsBool());
1815 DictionaryAttr
1816 ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
1817 OpBuilder &builder) {
1818 SmallVector<NamedAttribute> paramAttrs;
1819 for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
1820 auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind);
1821 // Skip attributes that are not attached.
1822 if (!llvmAttr.isValid())
1823 continue;
1824 Attribute mlirAttr;
1825 if (llvmAttr.isTypeAttribute())
1826 mlirAttr = TypeAttr::get(convertType(llvmAttr.getValueAsType()));
1827 else if (llvmAttr.isIntAttribute())
1828 mlirAttr = builder.getI64IntegerAttr(llvmAttr.getValueAsInt());
1829 else if (llvmAttr.isEnumAttribute())
1830 mlirAttr = builder.getUnitAttr();
1831 else
1832 llvm_unreachable("unexpected parameter attribute kind");
1833 paramAttrs.push_back(builder.getNamedAttr(mlirName, mlirAttr));
1836 return builder.getDictionaryAttr(paramAttrs);
1839 void ModuleImport::convertParameterAttributes(llvm::Function *func,
1840 LLVMFuncOp funcOp,
1841 OpBuilder &builder) {
1842 auto llvmAttrs = func->getAttributes();
1843 for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
1844 llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i);
1845 funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder));
1847 // Convert the result attributes and attach them wrapped in an ArrayAttribute
1848 // to the funcOp.
1849 llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
1850 if (!llvmResAttr.hasAttributes())
1851 return;
1852 funcOp.setResAttrsAttr(
1853 builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
1856 LogicalResult ModuleImport::processFunction(llvm::Function *func) {
1857 clearRegionState();
1859 auto functionType =
1860 dyn_cast<LLVMFunctionType>(convertType(func->getFunctionType()));
1861 if (func->isIntrinsic() &&
1862 iface.isConvertibleIntrinsic(func->getIntrinsicID()))
1863 return success();
1865 bool dsoLocal = func->hasLocalLinkage();
1866 CConv cconv = convertCConvFromLLVM(func->getCallingConv());
1868 // Insert the function at the end of the module.
1869 OpBuilder::InsertionGuard guard(builder);
1870 builder.setInsertionPoint(mlirModule.getBody(), mlirModule.getBody()->end());
1872 Location loc = debugImporter->translateFuncLocation(func);
1873 LLVMFuncOp funcOp = builder.create<LLVMFuncOp>(
1874 loc, func->getName(), functionType,
1875 convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv);
1877 convertParameterAttributes(func, funcOp, builder);
1879 if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func))
1880 funcOp.setPersonalityAttr(personality);
1881 else if (func->hasPersonalityFn())
1882 emitWarning(funcOp.getLoc(), "could not deduce personality, skipping it");
1884 if (func->hasGC())
1885 funcOp.setGarbageCollector(StringRef(func->getGC()));
1887 if (func->hasAtLeastLocalUnnamedAddr())
1888 funcOp.setUnnamedAddr(convertUnnamedAddrFromLLVM(func->getUnnamedAddr()));
1890 if (func->hasSection())
1891 funcOp.setSection(StringRef(func->getSection()));
1893 funcOp.setVisibility_(convertVisibilityFromLLVM(func->getVisibility()));
1895 if (func->hasComdat())
1896 funcOp.setComdatAttr(comdatMapping.lookup(func->getComdat()));
1898 if (llvm::MaybeAlign maybeAlign = func->getAlign())
1899 funcOp.setAlignment(maybeAlign->value());
1901 // Handle Function attributes.
1902 processFunctionAttributes(func, funcOp);
1904 // Convert non-debug metadata by using the dialect interface.
1905 SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
1906 func->getAllMetadata(allMetadata);
1907 for (auto &[kind, node] : allMetadata) {
1908 if (!iface.isConvertibleMetadata(kind))
1909 continue;
1910 if (failed(iface.setMetadataAttrs(builder, kind, node, funcOp, *this))) {
1911 emitWarning(funcOp.getLoc())
1912 << "unhandled function metadata: " << diagMD(node, llvmModule.get())
1913 << " on " << diag(*func);
1917 if (func->isDeclaration())
1918 return success();
1920 // Collect the set of basic blocks reachable from the function's entry block.
1921 // This step is crucial as LLVM IR can contain unreachable blocks that
1922 // self-dominate. As a result, an operation might utilize a variable it
1923 // defines, which the import does not support. Given that MLIR lacks block
1924 // label support, we can safely remove unreachable blocks, as there are no
1925 // indirect branch instructions that could potentially target these blocks.
1926 llvm::df_iterator_default_set<llvm::BasicBlock *> reachable;
1927 for (llvm::BasicBlock *basicBlock : llvm::depth_first_ext(func, reachable))
1928 (void)basicBlock;
1930 // Eagerly create all reachable blocks.
1931 SmallVector<llvm::BasicBlock *> reachableBasicBlocks;
1932 for (llvm::BasicBlock &basicBlock : *func) {
1933 // Skip unreachable blocks.
1934 if (!reachable.contains(&basicBlock))
1935 continue;
1936 Region &body = funcOp.getBody();
1937 Block *block = builder.createBlock(&body, body.end());
1938 mapBlock(&basicBlock, block);
1939 reachableBasicBlocks.push_back(&basicBlock);
1942 // Add function arguments to the entry block.
1943 for (const auto &it : llvm::enumerate(func->args())) {
1944 BlockArgument blockArg = funcOp.getFunctionBody().addArgument(
1945 functionType.getParamType(it.index()), funcOp.getLoc());
1946 mapValue(&it.value(), blockArg);
1949 // Process the blocks in topological order. The ordered traversal ensures
1950 // operands defined in a dominating block have a valid mapping to an MLIR
1951 // value once a block is translated.
1952 SetVector<llvm::BasicBlock *> blocks =
1953 getTopologicallySortedBlocks(reachableBasicBlocks);
1954 setConstantInsertionPointToStart(lookupBlock(blocks.front()));
1955 for (llvm::BasicBlock *basicBlock : blocks)
1956 if (failed(processBasicBlock(basicBlock, lookupBlock(basicBlock))))
1957 return failure();
1959 // Process the debug intrinsics that require a delayed conversion after
1960 // everything else was converted.
1961 if (failed(processDebugIntrinsics()))
1962 return failure();
1964 return success();
1967 /// Checks if `dbgIntr` is a kill location that holds metadata instead of an SSA
1968 /// value.
1969 static bool isMetadataKillLocation(llvm::DbgVariableIntrinsic *dbgIntr) {
1970 if (!dbgIntr->isKillLocation())
1971 return false;
1972 llvm::Value *value = dbgIntr->getArgOperand(0);
1973 auto *nodeAsVal = dyn_cast<llvm::MetadataAsValue>(value);
1974 if (!nodeAsVal)
1975 return false;
1976 return !isa<llvm::ValueAsMetadata>(nodeAsVal->getMetadata());
1979 LogicalResult
1980 ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr,
1981 DominanceInfo &domInfo) {
1982 Location loc = translateLoc(dbgIntr->getDebugLoc());
1983 auto emitUnsupportedWarning = [&]() {
1984 if (emitExpensiveWarnings)
1985 emitWarning(loc) << "dropped intrinsic: " << diag(*dbgIntr);
1986 return success();
1988 // Drop debug intrinsics with arg lists.
1989 // TODO: Support debug intrinsics that have arg lists.
1990 if (dbgIntr->hasArgList())
1991 return emitUnsupportedWarning();
1992 // Kill locations can have metadata nodes as location operand. This
1993 // cannot be converted to poison as the type cannot be reconstructed.
1994 // TODO: find a way to support this case.
1995 if (isMetadataKillLocation(dbgIntr))
1996 return emitUnsupportedWarning();
1997 // Drop debug intrinsics if the associated variable information cannot be
1998 // translated due to cyclic debug metadata.
1999 // TODO: Support cyclic debug metadata.
2000 DILocalVariableAttr localVariableAttr =
2001 matchLocalVariableAttr(dbgIntr->getArgOperand(1));
2002 if (!localVariableAttr)
2003 return emitUnsupportedWarning();
2004 FailureOr<Value> argOperand = convertMetadataValue(dbgIntr->getArgOperand(0));
2005 if (failed(argOperand))
2006 return emitError(loc) << "failed to convert a debug intrinsic operand: "
2007 << diag(*dbgIntr);
2009 // Ensure that the debug instrinsic is inserted right after its operand is
2010 // defined. Otherwise, the operand might not necessarily dominate the
2011 // intrinsic. If the defining operation is a terminator, insert the intrinsic
2012 // into a dominated block.
2013 OpBuilder::InsertionGuard guard(builder);
2014 if (Operation *op = argOperand->getDefiningOp();
2015 op && op->hasTrait<OpTrait::IsTerminator>()) {
2016 // Find a dominated block that can hold the debug intrinsic.
2017 auto dominatedBlocks = domInfo.getNode(op->getBlock())->children();
2018 // If no block is dominated by the terminator, this intrinisc cannot be
2019 // converted.
2020 if (dominatedBlocks.empty())
2021 return emitUnsupportedWarning();
2022 // Set insertion point before the terminator, to avoid inserting something
2023 // before landingpads.
2024 Block *dominatedBlock = (*dominatedBlocks.begin())->getBlock();
2025 builder.setInsertionPoint(dominatedBlock->getTerminator());
2026 } else {
2027 builder.setInsertionPointAfterValue(*argOperand);
2029 auto locationExprAttr =
2030 debugImporter->translateExpression(dbgIntr->getExpression());
2031 Operation *op =
2032 llvm::TypeSwitch<llvm::DbgVariableIntrinsic *, Operation *>(dbgIntr)
2033 .Case([&](llvm::DbgDeclareInst *) {
2034 return builder.create<LLVM::DbgDeclareOp>(
2035 loc, *argOperand, localVariableAttr, locationExprAttr);
2037 .Case([&](llvm::DbgValueInst *) {
2038 return builder.create<LLVM::DbgValueOp>(
2039 loc, *argOperand, localVariableAttr, locationExprAttr);
2041 mapNoResultOp(dbgIntr, op);
2042 setNonDebugMetadataAttrs(dbgIntr, op);
2043 return success();
2046 LogicalResult ModuleImport::processDebugIntrinsics() {
2047 DominanceInfo domInfo;
2048 for (llvm::Instruction *inst : debugIntrinsics) {
2049 auto *intrCall = cast<llvm::DbgVariableIntrinsic>(inst);
2050 if (failed(processDebugIntrinsic(intrCall, domInfo)))
2051 return failure();
2053 return success();
2056 LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
2057 Block *block) {
2058 builder.setInsertionPointToStart(block);
2059 for (llvm::Instruction &inst : *bb) {
2060 if (failed(processInstruction(&inst)))
2061 return failure();
2063 // Skip additional processing when the instructions is a debug intrinsics
2064 // that was not yet converted.
2065 if (debugIntrinsics.contains(&inst))
2066 continue;
2068 // Set the non-debug metadata attributes on the imported operation and emit
2069 // a warning if an instruction other than a phi instruction is dropped
2070 // during the import.
2071 if (Operation *op = lookupOperation(&inst)) {
2072 setNonDebugMetadataAttrs(&inst, op);
2073 } else if (inst.getOpcode() != llvm::Instruction::PHI) {
2074 if (emitExpensiveWarnings) {
2075 Location loc = debugImporter->translateLoc(inst.getDebugLoc());
2076 emitWarning(loc) << "dropped instruction: " << diag(inst);
2080 return success();
2083 FailureOr<SmallVector<AccessGroupAttr>>
2084 ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
2085 return loopAnnotationImporter->lookupAccessGroupAttrs(node);
2088 LoopAnnotationAttr
2089 ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
2090 Location loc) const {
2091 return loopAnnotationImporter->translateLoopAnnotation(node, loc);
2094 OwningOpRef<ModuleOp>
2095 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
2096 MLIRContext *context, bool emitExpensiveWarnings,
2097 bool dropDICompositeTypeElements) {
2098 // Preload all registered dialects to allow the import to iterate the
2099 // registered LLVMImportDialectInterface implementations and query the
2100 // supported LLVM IR constructs before starting the translation. Assumes the
2101 // LLVM and DLTI dialects that convert the core LLVM IR constructs have been
2102 // registered before.
2103 assert(llvm::is_contained(context->getAvailableDialects(),
2104 LLVMDialect::getDialectNamespace()));
2105 assert(llvm::is_contained(context->getAvailableDialects(),
2106 DLTIDialect::getDialectNamespace()));
2107 context->loadAllAvailableDialects();
2108 OwningOpRef<ModuleOp> module(ModuleOp::create(FileLineColLoc::get(
2109 StringAttr::get(context, llvmModule->getSourceFileName()), /*line=*/0,
2110 /*column=*/0)));
2112 ModuleImport moduleImport(module.get(), std::move(llvmModule),
2113 emitExpensiveWarnings, dropDICompositeTypeElements);
2114 if (failed(moduleImport.initializeImportInterface()))
2115 return {};
2116 if (failed(moduleImport.convertDataLayout()))
2117 return {};
2118 if (failed(moduleImport.convertComdats()))
2119 return {};
2120 if (failed(moduleImport.convertMetadata()))
2121 return {};
2122 if (failed(moduleImport.convertGlobals()))
2123 return {};
2124 if (failed(moduleImport.convertFunctions()))
2125 return {};
2127 return module;