[Workflow] Roll back some settings since they caused more issues
[llvm-project.git] / mlir / lib / Target / LLVMIR / ModuleImport.cpp
blobc6c30880d4f2c154582a3450ead43bd2bfc0ee0f
1 //===- ModuleImport.cpp - LLVM to MLIR conversion ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the import of an LLVM IR module into an LLVM dialect
10 // module.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/ModuleImport.h"
15 #include "mlir/Target/LLVMIR/Import.h"
17 #include "AttrKindDetail.h"
18 #include "DataLayoutImporter.h"
19 #include "DebugImporter.h"
20 #include "LoopAnnotationImporter.h"
22 #include "mlir/Dialect/DLTI/DLTI.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/Interfaces/DataLayoutInterfaces.h"
27 #include "mlir/Tools/mlir-translate/Translation.h"
29 #include "llvm/ADT/PostOrderIterator.h"
30 #include "llvm/ADT/ScopeExit.h"
31 #include "llvm/ADT/StringSet.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/IR/Comdat.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/InlineAsm.h"
36 #include "llvm/IR/InstIterator.h"
37 #include "llvm/IR/Instructions.h"
38 #include "llvm/IR/IntrinsicInst.h"
39 #include "llvm/IR/Metadata.h"
40 #include "llvm/IR/Operator.h"
41 #include "llvm/Support/ModRef.h"
43 using namespace mlir;
44 using namespace mlir::LLVM;
45 using namespace mlir::LLVM::detail;
47 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc"
49 // Utility to print an LLVM value as a string for passing to emitError().
50 // FIXME: Diagnostic should be able to natively handle types that have
51 // operator << (raw_ostream&) defined.
52 static std::string diag(const llvm::Value &value) {
53 std::string str;
54 llvm::raw_string_ostream os(str);
55 os << value;
56 return os.str();
59 // Utility to print an LLVM metadata node as a string for passing
60 // to emitError(). The module argument is needed to print the nodes
61 // canonically numbered.
62 static std::string diagMD(const llvm::Metadata *node,
63 const llvm::Module *module) {
64 std::string str;
65 llvm::raw_string_ostream os(str);
66 node->print(os, module, /*IsForDebug=*/true);
67 return os.str();
70 /// Returns the name of the global_ctors global variables.
71 static constexpr StringRef getGlobalCtorsVarName() {
72 return "llvm.global_ctors";
75 /// Returns the name of the global_dtors global variables.
76 static constexpr StringRef getGlobalDtorsVarName() {
77 return "llvm.global_dtors";
80 /// Returns the symbol name for the module-level comdat operation. It must not
81 /// conflict with the user namespace.
82 static constexpr StringRef getGlobalComdatOpName() {
83 return "__llvm_global_comdat";
86 /// Converts the sync scope identifier of `inst` to the string representation
87 /// necessary to build an atomic LLVM dialect operation. Returns the empty
88 /// string if the operation has either no sync scope or the default system-level
89 /// sync scope attached. The atomic operations only set their sync scope
90 /// attribute if they have a non-default sync scope attached.
91 static StringRef getLLVMSyncScope(llvm::Instruction *inst) {
92 std::optional<llvm::SyncScope::ID> syncScopeID =
93 llvm::getAtomicSyncScopeID(inst);
94 if (!syncScopeID)
95 return "";
97 // Search the sync scope name for the given identifier. The default
98 // system-level sync scope thereby maps to the empty string.
99 SmallVector<StringRef> syncScopeName;
100 llvm::LLVMContext &llvmContext = inst->getContext();
101 llvmContext.getSyncScopeNames(syncScopeName);
102 auto *it = llvm::find_if(syncScopeName, [&](StringRef name) {
103 return *syncScopeID == llvmContext.getOrInsertSyncScopeID(name);
105 if (it != syncScopeName.end())
106 return *it;
107 llvm_unreachable("incorrect sync scope identifier");
110 /// Converts an array of unsigned indices to a signed integer position array.
111 static SmallVector<int64_t> getPositionFromIndices(ArrayRef<unsigned> indices) {
112 SmallVector<int64_t> position;
113 llvm::append_range(position, indices);
114 return position;
117 /// Converts the LLVM instructions that have a generated MLIR builder. Using a
118 /// static implementation method called from the module import ensures the
119 /// builders have to use the `moduleImport` argument and cannot directly call
120 /// import methods. As a result, both the intrinsic and the instruction MLIR
121 /// builders have to use the `moduleImport` argument and none of them has direct
122 /// access to the private module import methods.
123 static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
124 llvm::Instruction *inst,
125 ModuleImport &moduleImport) {
126 // Copy the operands to an LLVM operands array reference for conversion.
127 SmallVector<llvm::Value *> operands(inst->operands());
128 ArrayRef<llvm::Value *> llvmOperands(operands);
130 // Convert all instructions that provide an MLIR builder.
131 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
132 return failure();
135 /// Get a topologically sorted list of blocks for the given function.
136 static SetVector<llvm::BasicBlock *>
137 getTopologicallySortedBlocks(llvm::Function *func) {
138 SetVector<llvm::BasicBlock *> blocks;
139 for (llvm::BasicBlock &bb : *func) {
140 if (!blocks.contains(&bb)) {
141 llvm::ReversePostOrderTraversal<llvm::BasicBlock *> traversal(&bb);
142 blocks.insert(traversal.begin(), traversal.end());
145 assert(blocks.size() == func->size() && "some blocks are not sorted");
147 return blocks;
150 ModuleImport::ModuleImport(ModuleOp mlirModule,
151 std::unique_ptr<llvm::Module> llvmModule,
152 bool emitExpensiveWarnings)
153 : builder(mlirModule->getContext()), context(mlirModule->getContext()),
154 mlirModule(mlirModule), llvmModule(std::move(llvmModule)),
155 iface(mlirModule->getContext()),
156 typeTranslator(*mlirModule->getContext()),
157 debugImporter(std::make_unique<DebugImporter>(mlirModule)),
158 loopAnnotationImporter(
159 std::make_unique<LoopAnnotationImporter>(*this, builder)),
160 emitExpensiveWarnings(emitExpensiveWarnings) {
161 builder.setInsertionPointToStart(mlirModule.getBody());
164 ComdatOp ModuleImport::getGlobalComdatOp() {
165 if (globalComdatOp)
166 return globalComdatOp;
168 OpBuilder::InsertionGuard guard(builder);
169 builder.setInsertionPointToEnd(mlirModule.getBody());
170 globalComdatOp =
171 builder.create<ComdatOp>(mlirModule.getLoc(), getGlobalComdatOpName());
172 globalInsertionOp = globalComdatOp;
173 return globalComdatOp;
176 LogicalResult ModuleImport::processTBAAMetadata(const llvm::MDNode *node) {
177 Location loc = mlirModule.getLoc();
179 // If `node` is a valid TBAA root node, then return its optional identity
180 // string, otherwise return failure.
181 auto getIdentityIfRootNode =
182 [&](const llvm::MDNode *node) -> FailureOr<std::optional<StringRef>> {
183 // Root node, e.g.:
184 // !0 = !{!"Simple C/C++ TBAA"}
185 // !1 = !{}
186 if (node->getNumOperands() > 1)
187 return failure();
188 // If the operand is MDString, then assume that this is a root node.
189 if (node->getNumOperands() == 1)
190 if (const auto *op0 = dyn_cast<const llvm::MDString>(node->getOperand(0)))
191 return std::optional<StringRef>{op0->getString()};
192 return std::optional<StringRef>{};
195 // If `node` looks like a TBAA type descriptor metadata,
196 // then return true, if it is a valid node, and false otherwise.
197 // If it does not look like a TBAA type descriptor metadata, then
198 // return std::nullopt.
199 // If `identity` and `memberTypes/Offsets` are non-null, then they will
200 // contain the converted metadata operands for a valid TBAA node (i.e. when
201 // true is returned).
202 auto isTypeDescriptorNode = [&](const llvm::MDNode *node,
203 StringRef *identity = nullptr,
204 SmallVectorImpl<TBAAMemberAttr> *members =
205 nullptr) -> std::optional<bool> {
206 unsigned numOperands = node->getNumOperands();
207 // Type descriptor, e.g.:
208 // !1 = !{!"int", !0, /*optional*/i64 0} /* scalar int type */
209 // !2 = !{!"agg_t", !1, i64 0} /* struct agg_t { int x; } */
210 if (numOperands < 2)
211 return std::nullopt;
213 // TODO: support "new" format (D41501) for type descriptors,
214 // where the first operand is an MDNode.
215 const auto *identityNode =
216 dyn_cast<const llvm::MDString>(node->getOperand(0));
217 if (!identityNode)
218 return std::nullopt;
220 // This should be a type descriptor node.
221 if (identity)
222 *identity = identityNode->getString();
224 for (unsigned pairNum = 0, e = numOperands / 2; pairNum < e; ++pairNum) {
225 const auto *memberNode =
226 dyn_cast<const llvm::MDNode>(node->getOperand(2 * pairNum + 1));
227 if (!memberNode) {
228 emitError(loc) << "operand '" << 2 * pairNum + 1 << "' must be MDNode: "
229 << diagMD(node, llvmModule.get());
230 return false;
232 int64_t offset = 0;
233 if (2 * pairNum + 2 >= numOperands) {
234 // Allow for optional 0 offset in 2-operand nodes.
235 if (numOperands != 2) {
236 emitError(loc) << "missing member offset: "
237 << diagMD(node, llvmModule.get());
238 return false;
240 } else {
241 auto *offsetCI = llvm::mdconst::dyn_extract<llvm::ConstantInt>(
242 node->getOperand(2 * pairNum + 2));
243 if (!offsetCI) {
244 emitError(loc) << "operand '" << 2 * pairNum + 2
245 << "' must be ConstantInt: "
246 << diagMD(node, llvmModule.get());
247 return false;
249 offset = offsetCI->getZExtValue();
252 if (members)
253 members->push_back(TBAAMemberAttr::get(
254 cast<TBAANodeAttr>(tbaaMapping.lookup(memberNode)), offset));
257 return true;
260 // If `node` looks like a TBAA access tag metadata,
261 // then return true, if it is a valid node, and false otherwise.
262 // If it does not look like a TBAA access tag metadata, then
263 // return std::nullopt.
264 // If the other arguments are non-null, then they will contain
265 // the converted metadata operands for a valid TBAA node (i.e. when true is
266 // returned).
267 auto isTagNode = [&](const llvm::MDNode *node,
268 TBAATypeDescriptorAttr *baseAttr = nullptr,
269 TBAATypeDescriptorAttr *accessAttr = nullptr,
270 int64_t *offset = nullptr,
271 bool *isConstant = nullptr) -> std::optional<bool> {
272 // Access tag, e.g.:
273 // !3 = !{!1, !1, i64 0} /* scalar int access */
274 // !4 = !{!2, !1, i64 0} /* agg_t::x access */
276 // Optional 4th argument is ConstantInt 0/1 identifying whether
277 // the location being accessed is "constant" (see for details:
278 // https://llvm.org/docs/LangRef.html#representation).
279 unsigned numOperands = node->getNumOperands();
280 if (numOperands != 3 && numOperands != 4)
281 return std::nullopt;
282 const auto *baseMD = dyn_cast<const llvm::MDNode>(node->getOperand(0));
283 const auto *accessMD = dyn_cast<const llvm::MDNode>(node->getOperand(1));
284 auto *offsetCI =
285 llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(2));
286 if (!baseMD || !accessMD || !offsetCI)
287 return std::nullopt;
288 // TODO: support "new" TBAA format, if needed (see D41501).
289 // In the "old" format the first operand of the access type
290 // metadata is MDString. We have to distinguish the formats,
291 // because access tags have the same structure, but different
292 // meaning for the operands.
293 if (accessMD->getNumOperands() < 1 ||
294 !isa<llvm::MDString>(accessMD->getOperand(0)))
295 return std::nullopt;
296 bool isConst = false;
297 if (numOperands == 4) {
298 auto *isConstantCI =
299 llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(3));
300 if (!isConstantCI) {
301 emitError(loc) << "operand '3' must be ConstantInt: "
302 << diagMD(node, llvmModule.get());
303 return false;
305 isConst = isConstantCI->getValue()[0];
307 if (baseAttr)
308 *baseAttr = cast<TBAATypeDescriptorAttr>(tbaaMapping.lookup(baseMD));
309 if (accessAttr)
310 *accessAttr = cast<TBAATypeDescriptorAttr>(tbaaMapping.lookup(accessMD));
311 if (offset)
312 *offset = offsetCI->getZExtValue();
313 if (isConstant)
314 *isConstant = isConst;
315 return true;
318 // Do a post-order walk over the TBAA Graph. Since a correct TBAA Graph is a
319 // DAG, a post-order walk guarantees that we convert any metadata node we
320 // depend on, prior to converting the current node.
321 DenseSet<const llvm::MDNode *> seen;
322 SmallVector<const llvm::MDNode *> workList;
323 workList.push_back(node);
324 while (!workList.empty()) {
325 const llvm::MDNode *current = workList.back();
326 if (tbaaMapping.contains(current)) {
327 // Already converted. Just pop from the worklist.
328 workList.pop_back();
329 continue;
332 // If any child of this node is not yet converted, don't pop the current
333 // node from the worklist but push the not-yet-converted children in the
334 // front of the worklist.
335 bool anyChildNotConverted = false;
336 for (const llvm::MDOperand &operand : current->operands())
337 if (auto *childNode = dyn_cast_or_null<const llvm::MDNode>(operand.get()))
338 if (!tbaaMapping.contains(childNode)) {
339 workList.push_back(childNode);
340 anyChildNotConverted = true;
343 if (anyChildNotConverted) {
344 // If this is the second time we failed to convert an element in the
345 // worklist it must be because a child is dependent on it being converted
346 // and we have a cycle in the graph. Cycles are not allowed in TBAA
347 // graphs.
348 if (!seen.insert(current).second)
349 return emitError(loc) << "has cycle in TBAA graph: "
350 << diagMD(current, llvmModule.get());
352 continue;
355 // Otherwise simply import the current node.
356 workList.pop_back();
358 FailureOr<std::optional<StringRef>> rootNodeIdentity =
359 getIdentityIfRootNode(current);
360 if (succeeded(rootNodeIdentity)) {
361 StringAttr stringAttr = *rootNodeIdentity
362 ? builder.getStringAttr(**rootNodeIdentity)
363 : nullptr;
364 // The root nodes do not have operands, so we can create
365 // the TBAARootAttr on the first walk.
366 tbaaMapping.insert({current, builder.getAttr<TBAARootAttr>(stringAttr)});
367 continue;
370 StringRef identity;
371 SmallVector<TBAAMemberAttr> members;
372 if (std::optional<bool> isValid =
373 isTypeDescriptorNode(current, &identity, &members)) {
374 assert(isValid.value() && "type descriptor node must be valid");
376 tbaaMapping.insert({current, builder.getAttr<TBAATypeDescriptorAttr>(
377 identity, members)});
378 continue;
381 TBAATypeDescriptorAttr baseAttr, accessAttr;
382 int64_t offset;
383 bool isConstant;
384 if (std::optional<bool> isValid =
385 isTagNode(current, &baseAttr, &accessAttr, &offset, &isConstant)) {
386 assert(isValid.value() && "access tag node must be valid");
387 tbaaMapping.insert(
388 {current, builder.getAttr<TBAATagAttr>(baseAttr, accessAttr, offset,
389 isConstant)});
390 continue;
393 return emitError(loc) << "unsupported TBAA node format: "
394 << diagMD(current, llvmModule.get());
396 return success();
399 LogicalResult
400 ModuleImport::processAccessGroupMetadata(const llvm::MDNode *node) {
401 Location loc = mlirModule.getLoc();
402 if (failed(loopAnnotationImporter->translateAccessGroup(node, loc)))
403 return emitError(loc) << "unsupported access group node: "
404 << diagMD(node, llvmModule.get());
405 return success();
408 LogicalResult
409 ModuleImport::processAliasScopeMetadata(const llvm::MDNode *node) {
410 Location loc = mlirModule.getLoc();
411 // Helper that verifies the node has a self reference operand.
412 auto verifySelfRef = [](const llvm::MDNode *node) {
413 return node->getNumOperands() != 0 &&
414 node == dyn_cast<llvm::MDNode>(node->getOperand(0));
416 // Helper that verifies the given operand is a string or does not exist.
417 auto verifyDescription = [](const llvm::MDNode *node, unsigned idx) {
418 return idx >= node->getNumOperands() ||
419 isa<llvm::MDString>(node->getOperand(idx));
421 // Helper that creates an alias scope domain attribute.
422 auto createAliasScopeDomainOp = [&](const llvm::MDNode *aliasDomain) {
423 StringAttr description = nullptr;
424 if (aliasDomain->getNumOperands() >= 2)
425 if (auto *operand = dyn_cast<llvm::MDString>(aliasDomain->getOperand(1)))
426 description = builder.getStringAttr(operand->getString());
427 return builder.getAttr<AliasScopeDomainAttr>(
428 DistinctAttr::create(builder.getUnitAttr()), description);
431 // Collect the alias scopes and domains to translate them.
432 for (const llvm::MDOperand &operand : node->operands()) {
433 if (const auto *scope = dyn_cast<llvm::MDNode>(operand)) {
434 llvm::AliasScopeNode aliasScope(scope);
435 const llvm::MDNode *domain = aliasScope.getDomain();
437 // Verify the scope node points to valid scope metadata which includes
438 // verifying its domain. Perform the verification before looking it up in
439 // the alias scope mapping since it could have been inserted as a domain
440 // node before.
441 if (!verifySelfRef(scope) || !domain || !verifyDescription(scope, 2))
442 return emitError(loc) << "unsupported alias scope node: "
443 << diagMD(scope, llvmModule.get());
444 if (!verifySelfRef(domain) || !verifyDescription(domain, 1))
445 return emitError(loc) << "unsupported alias domain node: "
446 << diagMD(domain, llvmModule.get());
448 if (aliasScopeMapping.contains(scope))
449 continue;
451 // Convert the domain metadata node if it has not been translated before.
452 auto it = aliasScopeMapping.find(aliasScope.getDomain());
453 if (it == aliasScopeMapping.end()) {
454 auto aliasScopeDomainOp = createAliasScopeDomainOp(domain);
455 it = aliasScopeMapping.try_emplace(domain, aliasScopeDomainOp).first;
458 // Convert the scope metadata node if it has not been converted before.
459 StringAttr description = nullptr;
460 if (!aliasScope.getName().empty())
461 description = builder.getStringAttr(aliasScope.getName());
462 auto aliasScopeOp = builder.getAttr<AliasScopeAttr>(
463 DistinctAttr::create(builder.getUnitAttr()),
464 cast<AliasScopeDomainAttr>(it->second), description);
465 aliasScopeMapping.try_emplace(aliasScope.getNode(), aliasScopeOp);
468 return success();
471 FailureOr<SmallVector<AliasScopeAttr>>
472 ModuleImport::lookupAliasScopeAttrs(const llvm::MDNode *node) const {
473 SmallVector<AliasScopeAttr> aliasScopes;
474 aliasScopes.reserve(node->getNumOperands());
475 for (const llvm::MDOperand &operand : node->operands()) {
476 auto *node = cast<llvm::MDNode>(operand.get());
477 aliasScopes.push_back(
478 dyn_cast_or_null<AliasScopeAttr>(aliasScopeMapping.lookup(node)));
480 // Return failure if one of the alias scope lookups failed.
481 if (llvm::is_contained(aliasScopes, nullptr))
482 return failure();
483 return aliasScopes;
486 void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) {
487 debugIntrinsics.insert(intrinsic);
490 LogicalResult ModuleImport::convertMetadata() {
491 OpBuilder::InsertionGuard guard(builder);
492 builder.setInsertionPointToEnd(mlirModule.getBody());
493 for (const llvm::Function &func : llvmModule->functions()) {
494 for (const llvm::Instruction &inst : llvm::instructions(func)) {
495 // Convert access group metadata nodes.
496 if (llvm::MDNode *node =
497 inst.getMetadata(llvm::LLVMContext::MD_access_group))
498 if (failed(processAccessGroupMetadata(node)))
499 return failure();
501 // Convert alias analysis metadata nodes.
502 llvm::AAMDNodes aliasAnalysisNodes = inst.getAAMetadata();
503 if (!aliasAnalysisNodes)
504 continue;
505 if (aliasAnalysisNodes.TBAA)
506 if (failed(processTBAAMetadata(aliasAnalysisNodes.TBAA)))
507 return failure();
508 if (aliasAnalysisNodes.Scope)
509 if (failed(processAliasScopeMetadata(aliasAnalysisNodes.Scope)))
510 return failure();
511 if (aliasAnalysisNodes.NoAlias)
512 if (failed(processAliasScopeMetadata(aliasAnalysisNodes.NoAlias)))
513 return failure();
516 return success();
519 void ModuleImport::processComdat(const llvm::Comdat *comdat) {
520 if (comdatMapping.contains(comdat))
521 return;
523 ComdatOp comdatOp = getGlobalComdatOp();
524 OpBuilder::InsertionGuard guard(builder);
525 builder.setInsertionPointToEnd(&comdatOp.getBody().back());
526 auto selectorOp = builder.create<ComdatSelectorOp>(
527 mlirModule.getLoc(), comdat->getName(),
528 convertComdatFromLLVM(comdat->getSelectionKind()));
529 auto symbolRef =
530 SymbolRefAttr::get(builder.getContext(), getGlobalComdatOpName(),
531 FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()));
532 comdatMapping.try_emplace(comdat, symbolRef);
535 LogicalResult ModuleImport::convertComdats() {
536 for (llvm::GlobalVariable &globalVar : llvmModule->globals())
537 if (globalVar.hasComdat())
538 processComdat(globalVar.getComdat());
539 for (llvm::Function &func : llvmModule->functions())
540 if (func.hasComdat())
541 processComdat(func.getComdat());
542 return success();
545 LogicalResult ModuleImport::convertGlobals() {
546 for (llvm::GlobalVariable &globalVar : llvmModule->globals()) {
547 if (globalVar.getName() == getGlobalCtorsVarName() ||
548 globalVar.getName() == getGlobalDtorsVarName()) {
549 if (failed(convertGlobalCtorsAndDtors(&globalVar))) {
550 return emitError(UnknownLoc::get(context))
551 << "unhandled global variable: " << diag(globalVar);
553 continue;
555 if (failed(convertGlobal(&globalVar))) {
556 return emitError(UnknownLoc::get(context))
557 << "unhandled global variable: " << diag(globalVar);
560 return success();
563 LogicalResult ModuleImport::convertDataLayout() {
564 Location loc = mlirModule.getLoc();
565 DataLayoutImporter dataLayoutImporter(context, llvmModule->getDataLayout());
566 if (!dataLayoutImporter.getDataLayout())
567 return emitError(loc, "cannot translate data layout: ")
568 << dataLayoutImporter.getLastToken();
570 for (StringRef token : dataLayoutImporter.getUnhandledTokens())
571 emitWarning(loc, "unhandled data layout token: ") << token;
573 mlirModule->setAttr(DLTIDialect::kDataLayoutAttrName,
574 dataLayoutImporter.getDataLayout());
575 return success();
578 LogicalResult ModuleImport::convertFunctions() {
579 for (llvm::Function &func : llvmModule->functions())
580 if (failed(processFunction(&func)))
581 return failure();
582 return success();
585 void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
586 Operation *op) {
587 SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
588 inst->getAllMetadataOtherThanDebugLoc(allMetadata);
589 for (auto &[kind, node] : allMetadata) {
590 if (!iface.isConvertibleMetadata(kind))
591 continue;
592 if (failed(iface.setMetadataAttrs(builder, kind, node, op, *this))) {
593 if (emitExpensiveWarnings) {
594 Location loc = debugImporter->translateLoc(inst->getDebugLoc());
595 emitWarning(loc) << "unhandled metadata: "
596 << diagMD(node, llvmModule.get()) << " on "
597 << diag(*inst);
603 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
604 Operation *op) const {
605 auto iface = cast<FastmathFlagsInterface>(op);
607 // Even if the imported operation implements the fastmath interface, the
608 // original instruction may not have fastmath flags set. Exit if an
609 // instruction, such as a non floating-point function call, does not have
610 // fastmath flags.
611 if (!isa<llvm::FPMathOperator>(inst))
612 return;
613 llvm::FastMathFlags flags = inst->getFastMathFlags();
615 // Set the fastmath bits flag-by-flag.
616 FastmathFlags value = {};
617 value = bitEnumSet(value, FastmathFlags::nnan, flags.noNaNs());
618 value = bitEnumSet(value, FastmathFlags::ninf, flags.noInfs());
619 value = bitEnumSet(value, FastmathFlags::nsz, flags.noSignedZeros());
620 value = bitEnumSet(value, FastmathFlags::arcp, flags.allowReciprocal());
621 value = bitEnumSet(value, FastmathFlags::contract, flags.allowContract());
622 value = bitEnumSet(value, FastmathFlags::afn, flags.approxFunc());
623 value = bitEnumSet(value, FastmathFlags::reassoc, flags.allowReassoc());
624 FastmathFlagsAttr attr = FastmathFlagsAttr::get(builder.getContext(), value);
625 iface->setAttr(iface.getFastmathAttrName(), attr);
628 /// Returns if `type` is a scalar integer or floating-point type.
629 static bool isScalarType(Type type) {
630 return isa<IntegerType, FloatType>(type);
633 /// Returns `type` if it is a builtin integer or floating-point vector type that
634 /// can be used to create an attribute or nullptr otherwise. If provided,
635 /// `arrayShape` is added to the shape of the vector to create an attribute that
636 /// matches an array of vectors.
637 static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
638 if (!LLVM::isCompatibleVectorType(type))
639 return {};
641 llvm::ElementCount numElements = LLVM::getVectorNumElements(type);
642 if (numElements.isScalable()) {
643 emitError(UnknownLoc::get(type.getContext()))
644 << "scalable vectors not supported";
645 return {};
648 // An LLVM dialect vector can only contain scalars.
649 Type elementType = LLVM::getVectorElementType(type);
650 if (!isScalarType(elementType))
651 return {};
653 SmallVector<int64_t> shape(arrayShape.begin(), arrayShape.end());
654 shape.push_back(numElements.getKnownMinValue());
655 return VectorType::get(shape, elementType);
658 Type ModuleImport::getBuiltinTypeForAttr(Type type) {
659 if (!type)
660 return {};
662 // Return builtin integer and floating-point types as is.
663 if (isScalarType(type))
664 return type;
666 // Return builtin vectors of integer and floating-point types as is.
667 if (Type vectorType = getVectorTypeForAttr(type))
668 return vectorType;
670 // Multi-dimensional array types are converted to tensors or vectors,
671 // depending on the innermost type being a scalar or a vector.
672 SmallVector<int64_t> arrayShape;
673 while (auto arrayType = dyn_cast<LLVMArrayType>(type)) {
674 arrayShape.push_back(arrayType.getNumElements());
675 type = arrayType.getElementType();
677 if (isScalarType(type))
678 return RankedTensorType::get(arrayShape, type);
679 return getVectorTypeForAttr(type, arrayShape);
682 /// Returns an integer or float attribute for the provided scalar constant
683 /// `constScalar` or nullptr if the conversion fails.
684 static Attribute getScalarConstantAsAttr(OpBuilder &builder,
685 llvm::Constant *constScalar) {
686 MLIRContext *context = builder.getContext();
688 // Convert scalar intergers.
689 if (auto *constInt = dyn_cast<llvm::ConstantInt>(constScalar)) {
690 return builder.getIntegerAttr(
691 IntegerType::get(context, constInt->getType()->getBitWidth()),
692 constInt->getValue());
695 // Convert scalar floats.
696 if (auto *constFloat = dyn_cast<llvm::ConstantFP>(constScalar)) {
697 llvm::Type *type = constFloat->getType();
698 FloatType floatType =
699 type->isBFloatTy()
700 ? FloatType::getBF16(context)
701 : LLVM::detail::getFloatType(context, type->getScalarSizeInBits());
702 if (!floatType) {
703 emitError(UnknownLoc::get(builder.getContext()))
704 << "unexpected floating-point type";
705 return {};
707 return builder.getFloatAttr(floatType, constFloat->getValueAPF());
709 return {};
712 /// Returns an integer or float attribute array for the provided constant
713 /// sequence `constSequence` or nullptr if the conversion fails.
714 static SmallVector<Attribute>
715 getSequenceConstantAsAttrs(OpBuilder &builder,
716 llvm::ConstantDataSequential *constSequence) {
717 SmallVector<Attribute> elementAttrs;
718 elementAttrs.reserve(constSequence->getNumElements());
719 for (auto idx : llvm::seq<int64_t>(0, constSequence->getNumElements())) {
720 llvm::Constant *constElement = constSequence->getElementAsConstant(idx);
721 elementAttrs.push_back(getScalarConstantAsAttr(builder, constElement));
723 return elementAttrs;
726 Attribute ModuleImport::getConstantAsAttr(llvm::Constant *constant) {
727 // Convert scalar constants.
728 if (Attribute scalarAttr = getScalarConstantAsAttr(builder, constant))
729 return scalarAttr;
731 // Convert function references.
732 if (auto *func = dyn_cast<llvm::Function>(constant))
733 return SymbolRefAttr::get(builder.getContext(), func->getName());
735 // Returns the static shape of the provided type if possible.
736 auto getConstantShape = [&](llvm::Type *type) {
737 return llvm::dyn_cast_if_present<ShapedType>(
738 getBuiltinTypeForAttr(convertType(type)));
741 // Convert one-dimensional constant arrays or vectors that store 1/2/4/8-byte
742 // integer or half/bfloat/float/double values.
743 if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(constant)) {
744 if (constArray->isString())
745 return builder.getStringAttr(constArray->getAsString());
746 auto shape = getConstantShape(constArray->getType());
747 if (!shape)
748 return {};
749 // Convert splat constants to splat elements attributes.
750 auto *constVector = dyn_cast<llvm::ConstantDataVector>(constant);
751 if (constVector && constVector->isSplat()) {
752 // A vector is guaranteed to have at least size one.
753 Attribute splatAttr = getScalarConstantAsAttr(
754 builder, constVector->getElementAsConstant(0));
755 return SplatElementsAttr::get(shape, splatAttr);
757 // Convert non-splat constants to dense elements attributes.
758 SmallVector<Attribute> elementAttrs =
759 getSequenceConstantAsAttrs(builder, constArray);
760 return DenseElementsAttr::get(shape, elementAttrs);
763 // Convert multi-dimensional constant aggregates that store all kinds of
764 // integer and floating-point types.
765 if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(constant)) {
766 auto shape = getConstantShape(constAggregate->getType());
767 if (!shape)
768 return {};
769 // Collect the aggregate elements in depths first order.
770 SmallVector<Attribute> elementAttrs;
771 SmallVector<llvm::Constant *> workList = {constAggregate};
772 while (!workList.empty()) {
773 llvm::Constant *current = workList.pop_back_val();
774 // Append any nested aggregates in reverse order to ensure the head
775 // element of the nested aggregates is at the back of the work list.
776 if (auto *constAggregate = dyn_cast<llvm::ConstantAggregate>(current)) {
777 for (auto idx :
778 reverse(llvm::seq<int64_t>(0, constAggregate->getNumOperands())))
779 workList.push_back(constAggregate->getAggregateElement(idx));
780 continue;
782 // Append the elements of nested constant arrays or vectors that store
783 // 1/2/4/8-byte integer or half/bfloat/float/double values.
784 if (auto *constArray = dyn_cast<llvm::ConstantDataSequential>(current)) {
785 SmallVector<Attribute> attrs =
786 getSequenceConstantAsAttrs(builder, constArray);
787 elementAttrs.append(attrs.begin(), attrs.end());
788 continue;
790 // Append nested scalar constants that store all kinds of integer and
791 // floating-point types.
792 if (Attribute scalarAttr = getScalarConstantAsAttr(builder, current)) {
793 elementAttrs.push_back(scalarAttr);
794 continue;
796 // Bail if the aggregate contains a unsupported constant type such as a
797 // constant expression.
798 return {};
800 return DenseElementsAttr::get(shape, elementAttrs);
803 // Convert zero aggregates.
804 if (auto *constZero = dyn_cast<llvm::ConstantAggregateZero>(constant)) {
805 auto shape = llvm::dyn_cast_if_present<ShapedType>(
806 getBuiltinTypeForAttr(convertType(constZero->getType())));
807 if (!shape)
808 return {};
809 // Convert zero aggregates with a static shape to splat elements attributes.
810 Attribute splatAttr = builder.getZeroAttr(shape.getElementType());
811 assert(splatAttr && "expected non-null zero attribute for scalar types");
812 return SplatElementsAttr::get(shape, splatAttr);
814 return {};
817 LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) {
818 // Insert the global after the last one or at the start of the module.
819 OpBuilder::InsertionGuard guard(builder);
820 if (!globalInsertionOp)
821 builder.setInsertionPointToStart(mlirModule.getBody());
822 else
823 builder.setInsertionPointAfter(globalInsertionOp);
825 Attribute valueAttr;
826 if (globalVar->hasInitializer())
827 valueAttr = getConstantAsAttr(globalVar->getInitializer());
828 Type type = convertType(globalVar->getValueType());
830 uint64_t alignment = 0;
831 llvm::MaybeAlign maybeAlign = globalVar->getAlign();
832 if (maybeAlign.has_value()) {
833 llvm::Align align = *maybeAlign;
834 alignment = align.value();
837 GlobalOp globalOp = builder.create<GlobalOp>(
838 mlirModule.getLoc(), type, globalVar->isConstant(),
839 convertLinkageFromLLVM(globalVar->getLinkage()), globalVar->getName(),
840 valueAttr, alignment, /*addr_space=*/globalVar->getAddressSpace(),
841 /*dso_local=*/globalVar->isDSOLocal(),
842 /*thread_local=*/globalVar->isThreadLocal());
843 globalInsertionOp = globalOp;
845 if (globalVar->hasInitializer() && !valueAttr) {
846 clearRegionState();
847 Block *block = builder.createBlock(&globalOp.getInitializerRegion());
848 setConstantInsertionPointToStart(block);
849 FailureOr<Value> initializer =
850 convertConstantExpr(globalVar->getInitializer());
851 if (failed(initializer))
852 return failure();
853 builder.create<ReturnOp>(globalOp.getLoc(), *initializer);
855 if (globalVar->hasAtLeastLocalUnnamedAddr()) {
856 globalOp.setUnnamedAddr(
857 convertUnnamedAddrFromLLVM(globalVar->getUnnamedAddr()));
859 if (globalVar->hasSection())
860 globalOp.setSection(globalVar->getSection());
861 globalOp.setVisibility_(
862 convertVisibilityFromLLVM(globalVar->getVisibility()));
864 if (globalVar->hasComdat())
865 globalOp.setComdatAttr(comdatMapping.lookup(globalVar->getComdat()));
867 return success();
870 LogicalResult
871 ModuleImport::convertGlobalCtorsAndDtors(llvm::GlobalVariable *globalVar) {
872 if (!globalVar->hasInitializer() || !globalVar->hasAppendingLinkage())
873 return failure();
874 auto *initializer =
875 dyn_cast<llvm::ConstantArray>(globalVar->getInitializer());
876 if (!initializer)
877 return failure();
879 SmallVector<Attribute> funcs;
880 SmallVector<int32_t> priorities;
881 for (llvm::Value *operand : initializer->operands()) {
882 auto *aggregate = dyn_cast<llvm::ConstantAggregate>(operand);
883 if (!aggregate || aggregate->getNumOperands() != 3)
884 return failure();
886 auto *priority = dyn_cast<llvm::ConstantInt>(aggregate->getOperand(0));
887 auto *func = dyn_cast<llvm::Function>(aggregate->getOperand(1));
888 auto *data = dyn_cast<llvm::Constant>(aggregate->getOperand(2));
889 if (!priority || !func || !data)
890 return failure();
892 // GlobalCtorsOps and GlobalDtorsOps do not support non-null data fields.
893 if (!data->isNullValue())
894 return failure();
896 funcs.push_back(FlatSymbolRefAttr::get(context, func->getName()));
897 priorities.push_back(priority->getValue().getZExtValue());
900 OpBuilder::InsertionGuard guard(builder);
901 if (!globalInsertionOp)
902 builder.setInsertionPointToStart(mlirModule.getBody());
903 else
904 builder.setInsertionPointAfter(globalInsertionOp);
906 if (globalVar->getName() == getGlobalCtorsVarName()) {
907 globalInsertionOp = builder.create<LLVM::GlobalCtorsOp>(
908 mlirModule.getLoc(), builder.getArrayAttr(funcs),
909 builder.getI32ArrayAttr(priorities));
910 return success();
912 globalInsertionOp = builder.create<LLVM::GlobalDtorsOp>(
913 mlirModule.getLoc(), builder.getArrayAttr(funcs),
914 builder.getI32ArrayAttr(priorities));
915 return success();
918 SetVector<llvm::Constant *>
919 ModuleImport::getConstantsToConvert(llvm::Constant *constant) {
920 // Return the empty set if the constant has been translated before.
921 if (valueMapping.contains(constant))
922 return {};
924 // Traverse the constants in post-order and stop the traversal if a constant
925 // already has a `valueMapping` from an earlier constant translation or if the
926 // constant is traversed a second time.
927 SetVector<llvm::Constant *> orderedSet;
928 SetVector<llvm::Constant *> workList;
929 DenseMap<llvm::Constant *, SmallVector<llvm::Constant *>> adjacencyLists;
930 workList.insert(constant);
931 while (!workList.empty()) {
932 llvm::Constant *current = workList.back();
933 // Collect all dependencies of the current constant and add them to the
934 // adjacency list if none has been computed before.
935 auto adjacencyIt = adjacencyLists.find(current);
936 if (adjacencyIt == adjacencyLists.end()) {
937 adjacencyIt = adjacencyLists.try_emplace(current).first;
938 // Add all constant operands to the adjacency list and skip any other
939 // values such as basic block addresses.
940 for (llvm::Value *operand : current->operands())
941 if (auto *constDependency = dyn_cast<llvm::Constant>(operand))
942 adjacencyIt->getSecond().push_back(constDependency);
943 // Use the getElementValue method to add the dependencies of zero
944 // initialized aggregate constants since they do not take any operands.
945 if (auto *constAgg = dyn_cast<llvm::ConstantAggregateZero>(current)) {
946 unsigned numElements = constAgg->getElementCount().getFixedValue();
947 for (unsigned i = 0, e = numElements; i != e; ++i)
948 adjacencyIt->getSecond().push_back(constAgg->getElementValue(i));
951 // Add the current constant to the `orderedSet` of the traversed nodes if
952 // all its dependencies have been traversed before. Additionally, remove the
953 // constant from the `workList` and continue the traversal.
954 if (adjacencyIt->getSecond().empty()) {
955 orderedSet.insert(current);
956 workList.pop_back();
957 continue;
959 // Add the next dependency from the adjacency list to the `workList` and
960 // continue the traversal. Remove the dependency from the adjacency list to
961 // mark that it has been processed. Only enqueue the dependency if it has no
962 // `valueMapping` from an earlier translation and if it has not been
963 // enqueued before.
964 llvm::Constant *dependency = adjacencyIt->getSecond().pop_back_val();
965 if (valueMapping.contains(dependency) || workList.contains(dependency) ||
966 orderedSet.contains(dependency))
967 continue;
968 workList.insert(dependency);
971 return orderedSet;
974 FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) {
975 Location loc = UnknownLoc::get(context);
977 // Convert constants that can be represented as attributes.
978 if (Attribute attr = getConstantAsAttr(constant)) {
979 Type type = convertType(constant->getType());
980 if (auto symbolRef = dyn_cast<FlatSymbolRefAttr>(attr)) {
981 return builder.create<AddressOfOp>(loc, type, symbolRef.getValue())
982 .getResult();
984 return builder.create<ConstantOp>(loc, type, attr).getResult();
987 // Convert null pointer constants.
988 if (auto *nullPtr = dyn_cast<llvm::ConstantPointerNull>(constant)) {
989 Type type = convertType(nullPtr->getType());
990 return builder.create<NullOp>(loc, type).getResult();
993 // Convert none token constants.
994 if (auto *noneToken = dyn_cast<llvm::ConstantTokenNone>(constant)) {
995 return builder.create<NoneTokenOp>(loc).getResult();
998 // Convert poison.
999 if (auto *poisonVal = dyn_cast<llvm::PoisonValue>(constant)) {
1000 Type type = convertType(poisonVal->getType());
1001 return builder.create<PoisonOp>(loc, type).getResult();
1004 // Convert undef.
1005 if (auto *undefVal = dyn_cast<llvm::UndefValue>(constant)) {
1006 Type type = convertType(undefVal->getType());
1007 return builder.create<UndefOp>(loc, type).getResult();
1010 // Convert global variable accesses.
1011 if (auto *globalVar = dyn_cast<llvm::GlobalVariable>(constant)) {
1012 Type type = convertType(globalVar->getType());
1013 auto symbolRef = FlatSymbolRefAttr::get(context, globalVar->getName());
1014 return builder.create<AddressOfOp>(loc, type, symbolRef).getResult();
1017 // Convert constant expressions.
1018 if (auto *constExpr = dyn_cast<llvm::ConstantExpr>(constant)) {
1019 // Convert the constant expression to a temporary LLVM instruction and
1020 // translate it using the `processInstruction` method. Delete the
1021 // instruction after the translation and remove it from `valueMapping`,
1022 // since later calls to `getAsInstruction` may return the same address
1023 // resulting in a conflicting `valueMapping` entry.
1024 llvm::Instruction *inst = constExpr->getAsInstruction();
1025 auto guard = llvm::make_scope_exit([&]() {
1026 assert(!noResultOpMapping.contains(inst) &&
1027 "expected constant expression to return a result");
1028 valueMapping.erase(inst);
1029 inst->deleteValue();
1031 // Note: `processInstruction` does not call `convertConstant` recursively
1032 // since all constant dependencies have been converted before.
1033 assert(llvm::all_of(inst->operands(), [&](llvm::Value *value) {
1034 return valueMapping.contains(value);
1035 }));
1036 if (failed(processInstruction(inst)))
1037 return failure();
1038 return lookupValue(inst);
1041 // Convert aggregate constants.
1042 if (isa<llvm::ConstantAggregate>(constant) ||
1043 isa<llvm::ConstantAggregateZero>(constant)) {
1044 // Lookup the aggregate elements that have been converted before.
1045 SmallVector<Value> elementValues;
1046 if (auto *constAgg = dyn_cast<llvm::ConstantAggregate>(constant)) {
1047 elementValues.reserve(constAgg->getNumOperands());
1048 for (llvm::Value *operand : constAgg->operands())
1049 elementValues.push_back(lookupValue(operand));
1051 if (auto *constAgg = dyn_cast<llvm::ConstantAggregateZero>(constant)) {
1052 unsigned numElements = constAgg->getElementCount().getFixedValue();
1053 elementValues.reserve(numElements);
1054 for (unsigned i = 0, e = numElements; i != e; ++i)
1055 elementValues.push_back(lookupValue(constAgg->getElementValue(i)));
1057 assert(llvm::count(elementValues, nullptr) == 0 &&
1058 "expected all elements have been converted before");
1060 // Generate an UndefOp as root value and insert the aggregate elements.
1061 Type rootType = convertType(constant->getType());
1062 bool isArrayOrStruct = isa<LLVMArrayType, LLVMStructType>(rootType);
1063 assert((isArrayOrStruct || LLVM::isCompatibleVectorType(rootType)) &&
1064 "unrecognized aggregate type");
1065 Value root = builder.create<UndefOp>(loc, rootType);
1066 for (const auto &it : llvm::enumerate(elementValues)) {
1067 if (isArrayOrStruct) {
1068 root = builder.create<InsertValueOp>(loc, root, it.value(), it.index());
1069 } else {
1070 Attribute indexAttr = builder.getI32IntegerAttr(it.index());
1071 Value indexValue =
1072 builder.create<ConstantOp>(loc, builder.getI32Type(), indexAttr);
1073 root = builder.create<InsertElementOp>(loc, rootType, root, it.value(),
1074 indexValue);
1077 return root;
1080 if (auto *constTargetNone = dyn_cast<llvm::ConstantTargetNone>(constant)) {
1081 LLVMTargetExtType targetExtType =
1082 cast<LLVMTargetExtType>(convertType(constTargetNone->getType()));
1083 assert(targetExtType.hasProperty(LLVMTargetExtType::HasZeroInit) &&
1084 "target extension type does not support zero-initialization");
1085 // Create llvm.mlir.zero operation to represent zero-initialization of
1086 // target extension type.
1087 return builder.create<LLVM::ZeroOp>(loc, targetExtType).getRes();
1090 StringRef error = "";
1091 if (isa<llvm::BlockAddress>(constant))
1092 error = " since blockaddress(...) is unsupported";
1094 return emitError(loc) << "unhandled constant: " << diag(*constant) << error;
1097 FailureOr<Value> ModuleImport::convertConstantExpr(llvm::Constant *constant) {
1098 // Only call the function for constants that have not been translated before
1099 // since it updates the constant insertion point assuming the converted
1100 // constant has been introduced at the end of the constant section.
1101 assert(!valueMapping.contains(constant) &&
1102 "expected constant has not been converted before");
1103 assert(constantInsertionBlock &&
1104 "expected the constant insertion block to be non-null");
1106 // Insert the constant after the last one or at the start of the entry block.
1107 OpBuilder::InsertionGuard guard(builder);
1108 if (!constantInsertionOp)
1109 builder.setInsertionPointToStart(constantInsertionBlock);
1110 else
1111 builder.setInsertionPointAfter(constantInsertionOp);
1113 // Convert all constants of the expression and add them to `valueMapping`.
1114 SetVector<llvm::Constant *> constantsToConvert =
1115 getConstantsToConvert(constant);
1116 for (llvm::Constant *constantToConvert : constantsToConvert) {
1117 FailureOr<Value> converted = convertConstant(constantToConvert);
1118 if (failed(converted))
1119 return failure();
1120 mapValue(constantToConvert, *converted);
1123 // Update the constant insertion point and return the converted constant.
1124 Value result = lookupValue(constant);
1125 constantInsertionOp = result.getDefiningOp();
1126 return result;
1129 FailureOr<Value> ModuleImport::convertValue(llvm::Value *value) {
1130 assert(!isa<llvm::MetadataAsValue>(value) &&
1131 "expected value to not be metadata");
1133 // Return the mapped value if it has been converted before.
1134 auto it = valueMapping.find(value);
1135 if (it != valueMapping.end())
1136 return it->getSecond();
1138 // Convert constants such as immediate values that have no mapping yet.
1139 if (auto *constant = dyn_cast<llvm::Constant>(value))
1140 return convertConstantExpr(constant);
1142 Location loc = UnknownLoc::get(context);
1143 if (auto *inst = dyn_cast<llvm::Instruction>(value))
1144 loc = translateLoc(inst->getDebugLoc());
1145 return emitError(loc) << "unhandled value: " << diag(*value);
1148 FailureOr<Value> ModuleImport::convertMetadataValue(llvm::Value *value) {
1149 // A value may be wrapped as metadata, for example, when passed to a debug
1150 // intrinsic. Unwrap these values before the conversion.
1151 auto *nodeAsVal = dyn_cast<llvm::MetadataAsValue>(value);
1152 if (!nodeAsVal)
1153 return failure();
1154 auto *node = dyn_cast<llvm::ValueAsMetadata>(nodeAsVal->getMetadata());
1155 if (!node)
1156 return failure();
1157 value = node->getValue();
1159 // Return the mapped value if it has been converted before.
1160 auto it = valueMapping.find(value);
1161 if (it != valueMapping.end())
1162 return it->getSecond();
1164 // Convert constants such as immediate values that have no mapping yet.
1165 if (auto *constant = dyn_cast<llvm::Constant>(value))
1166 return convertConstantExpr(constant);
1167 return failure();
1170 FailureOr<SmallVector<Value>>
1171 ModuleImport::convertValues(ArrayRef<llvm::Value *> values) {
1172 SmallVector<Value> remapped;
1173 remapped.reserve(values.size());
1174 for (llvm::Value *value : values) {
1175 FailureOr<Value> converted = convertValue(value);
1176 if (failed(converted))
1177 return failure();
1178 remapped.push_back(*converted);
1180 return remapped;
1183 IntegerAttr ModuleImport::matchIntegerAttr(llvm::Value *value) {
1184 IntegerAttr integerAttr;
1185 FailureOr<Value> converted = convertValue(value);
1186 bool success = succeeded(converted) &&
1187 matchPattern(*converted, m_Constant(&integerAttr));
1188 assert(success && "expected a constant integer value");
1189 (void)success;
1190 return integerAttr;
1193 FloatAttr ModuleImport::matchFloatAttr(llvm::Value *value) {
1194 FloatAttr floatAttr;
1195 FailureOr<Value> converted = convertValue(value);
1196 bool success =
1197 succeeded(converted) && matchPattern(*converted, m_Constant(&floatAttr));
1198 assert(success && "expected a constant float value");
1199 (void)success;
1200 return floatAttr;
1203 DILocalVariableAttr ModuleImport::matchLocalVariableAttr(llvm::Value *value) {
1204 auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
1205 auto *node = cast<llvm::DILocalVariable>(nodeAsVal->getMetadata());
1206 return debugImporter->translate(node);
1209 DILabelAttr ModuleImport::matchLabelAttr(llvm::Value *value) {
1210 auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
1211 auto *node = cast<llvm::DILabel>(nodeAsVal->getMetadata());
1212 return debugImporter->translate(node);
1215 FailureOr<SmallVector<AliasScopeAttr>>
1216 ModuleImport::matchAliasScopeAttrs(llvm::Value *value) {
1217 auto *nodeAsVal = cast<llvm::MetadataAsValue>(value);
1218 auto *node = cast<llvm::MDNode>(nodeAsVal->getMetadata());
1219 return lookupAliasScopeAttrs(node);
1222 Location ModuleImport::translateLoc(llvm::DILocation *loc) {
1223 return debugImporter->translateLoc(loc);
1226 LogicalResult
1227 ModuleImport::convertBranchArgs(llvm::Instruction *branch,
1228 llvm::BasicBlock *target,
1229 SmallVectorImpl<Value> &blockArguments) {
1230 for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
1231 auto *phiInst = cast<llvm::PHINode>(&*inst);
1232 llvm::Value *value = phiInst->getIncomingValueForBlock(branch->getParent());
1233 FailureOr<Value> converted = convertValue(value);
1234 if (failed(converted))
1235 return failure();
1236 blockArguments.push_back(*converted);
1238 return success();
1241 LogicalResult
1242 ModuleImport::convertCallTypeAndOperands(llvm::CallBase *callInst,
1243 SmallVectorImpl<Type> &types,
1244 SmallVectorImpl<Value> &operands) {
1245 if (!callInst->getType()->isVoidTy())
1246 types.push_back(convertType(callInst->getType()));
1248 if (!callInst->getCalledFunction()) {
1249 FailureOr<Value> called = convertValue(callInst->getCalledOperand());
1250 if (failed(called))
1251 return failure();
1252 operands.push_back(*called);
1254 SmallVector<llvm::Value *> args(callInst->args());
1255 FailureOr<SmallVector<Value>> arguments = convertValues(args);
1256 if (failed(arguments))
1257 return failure();
1258 llvm::append_range(operands, *arguments);
1259 return success();
1262 LogicalResult ModuleImport::convertIntrinsic(llvm::CallInst *inst) {
1263 if (succeeded(iface.convertIntrinsic(builder, inst, *this)))
1264 return success();
1266 Location loc = translateLoc(inst->getDebugLoc());
1267 return emitError(loc) << "unhandled intrinsic: " << diag(*inst);
1270 LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
1271 // Convert all instructions that do not provide an MLIR builder.
1272 Location loc = translateLoc(inst->getDebugLoc());
1273 if (inst->getOpcode() == llvm::Instruction::Br) {
1274 auto *brInst = cast<llvm::BranchInst>(inst);
1276 SmallVector<Block *> succBlocks;
1277 SmallVector<SmallVector<Value>> succBlockArgs;
1278 for (auto i : llvm::seq<unsigned>(0, brInst->getNumSuccessors())) {
1279 llvm::BasicBlock *succ = brInst->getSuccessor(i);
1280 SmallVector<Value> blockArgs;
1281 if (failed(convertBranchArgs(brInst, succ, blockArgs)))
1282 return failure();
1283 succBlocks.push_back(lookupBlock(succ));
1284 succBlockArgs.push_back(blockArgs);
1287 if (!brInst->isConditional()) {
1288 auto brOp = builder.create<LLVM::BrOp>(loc, succBlockArgs.front(),
1289 succBlocks.front());
1290 mapNoResultOp(inst, brOp);
1291 return success();
1293 FailureOr<Value> condition = convertValue(brInst->getCondition());
1294 if (failed(condition))
1295 return failure();
1296 auto condBrOp = builder.create<LLVM::CondBrOp>(
1297 loc, *condition, succBlocks.front(), succBlockArgs.front(),
1298 succBlocks.back(), succBlockArgs.back());
1299 mapNoResultOp(inst, condBrOp);
1300 return success();
1302 if (inst->getOpcode() == llvm::Instruction::Switch) {
1303 auto *swInst = cast<llvm::SwitchInst>(inst);
1304 // Process the condition value.
1305 FailureOr<Value> condition = convertValue(swInst->getCondition());
1306 if (failed(condition))
1307 return failure();
1308 SmallVector<Value> defaultBlockArgs;
1309 // Process the default case.
1310 llvm::BasicBlock *defaultBB = swInst->getDefaultDest();
1311 if (failed(convertBranchArgs(swInst, defaultBB, defaultBlockArgs)))
1312 return failure();
1314 // Process the cases.
1315 unsigned numCases = swInst->getNumCases();
1316 SmallVector<SmallVector<Value>> caseOperands(numCases);
1317 SmallVector<ValueRange> caseOperandRefs(numCases);
1318 SmallVector<APInt> caseValues(numCases);
1319 SmallVector<Block *> caseBlocks(numCases);
1320 for (const auto &it : llvm::enumerate(swInst->cases())) {
1321 const llvm::SwitchInst::CaseHandle &caseHandle = it.value();
1322 llvm::BasicBlock *succBB = caseHandle.getCaseSuccessor();
1323 if (failed(convertBranchArgs(swInst, succBB, caseOperands[it.index()])))
1324 return failure();
1325 caseOperandRefs[it.index()] = caseOperands[it.index()];
1326 caseValues[it.index()] = caseHandle.getCaseValue()->getValue();
1327 caseBlocks[it.index()] = lookupBlock(succBB);
1330 auto switchOp = builder.create<SwitchOp>(
1331 loc, *condition, lookupBlock(defaultBB), defaultBlockArgs, caseValues,
1332 caseBlocks, caseOperandRefs);
1333 mapNoResultOp(inst, switchOp);
1334 return success();
1336 if (inst->getOpcode() == llvm::Instruction::PHI) {
1337 Type type = convertType(inst->getType());
1338 mapValue(inst, builder.getInsertionBlock()->addArgument(
1339 type, translateLoc(inst->getDebugLoc())));
1340 return success();
1342 if (inst->getOpcode() == llvm::Instruction::Call) {
1343 auto *callInst = cast<llvm::CallInst>(inst);
1345 SmallVector<Type> types;
1346 SmallVector<Value> operands;
1347 if (failed(convertCallTypeAndOperands(callInst, types, operands)))
1348 return failure();
1350 CallOp callOp;
1351 if (llvm::Function *callee = callInst->getCalledFunction()) {
1352 callOp = builder.create<CallOp>(
1353 loc, types, SymbolRefAttr::get(context, callee->getName()), operands);
1354 } else {
1355 callOp = builder.create<CallOp>(loc, types, operands);
1357 setFastmathFlagsAttr(inst, callOp);
1358 if (!callInst->getType()->isVoidTy())
1359 mapValue(inst, callOp.getResult());
1360 else
1361 mapNoResultOp(inst, callOp);
1362 return success();
1364 if (inst->getOpcode() == llvm::Instruction::LandingPad) {
1365 auto *lpInst = cast<llvm::LandingPadInst>(inst);
1367 SmallVector<Value> operands;
1368 operands.reserve(lpInst->getNumClauses());
1369 for (auto i : llvm::seq<unsigned>(0, lpInst->getNumClauses())) {
1370 FailureOr<Value> operand = convertValue(lpInst->getClause(i));
1371 if (failed(operand))
1372 return failure();
1373 operands.push_back(*operand);
1376 Type type = convertType(lpInst->getType());
1377 auto lpOp =
1378 builder.create<LandingpadOp>(loc, type, lpInst->isCleanup(), operands);
1379 mapValue(inst, lpOp);
1380 return success();
1382 if (inst->getOpcode() == llvm::Instruction::Invoke) {
1383 auto *invokeInst = cast<llvm::InvokeInst>(inst);
1385 SmallVector<Type> types;
1386 SmallVector<Value> operands;
1387 if (failed(convertCallTypeAndOperands(invokeInst, types, operands)))
1388 return failure();
1390 // Check whether the invoke result is an argument to the normal destination
1391 // block.
1392 bool invokeResultUsedInPhi = llvm::any_of(
1393 invokeInst->getNormalDest()->phis(), [&](const llvm::PHINode &phi) {
1394 return phi.getIncomingValueForBlock(invokeInst->getParent()) ==
1395 invokeInst;
1398 Block *normalDest = lookupBlock(invokeInst->getNormalDest());
1399 Block *directNormalDest = normalDest;
1400 if (invokeResultUsedInPhi) {
1401 // The invoke result cannot be an argument to the normal destination
1402 // block, as that would imply using the invoke operation result in its
1403 // definition, so we need to create a dummy block to serve as an
1404 // intermediate destination.
1405 OpBuilder::InsertionGuard g(builder);
1406 directNormalDest = builder.createBlock(normalDest);
1409 SmallVector<Value> unwindArgs;
1410 if (failed(convertBranchArgs(invokeInst, invokeInst->getUnwindDest(),
1411 unwindArgs)))
1412 return failure();
1414 // Create the invoke operation. Normal destination block arguments will be
1415 // added later on to handle the case in which the operation result is
1416 // included in this list.
1417 InvokeOp invokeOp;
1418 if (llvm::Function *callee = invokeInst->getCalledFunction()) {
1419 invokeOp = builder.create<InvokeOp>(
1420 loc, types,
1421 SymbolRefAttr::get(builder.getContext(), callee->getName()), operands,
1422 directNormalDest, ValueRange(),
1423 lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1424 } else {
1425 invokeOp = builder.create<InvokeOp>(
1426 loc, types, operands, directNormalDest, ValueRange(),
1427 lookupBlock(invokeInst->getUnwindDest()), unwindArgs);
1429 if (!invokeInst->getType()->isVoidTy())
1430 mapValue(inst, invokeOp.getResults().front());
1431 else
1432 mapNoResultOp(inst, invokeOp);
1434 SmallVector<Value> normalArgs;
1435 if (failed(convertBranchArgs(invokeInst, invokeInst->getNormalDest(),
1436 normalArgs)))
1437 return failure();
1439 if (invokeResultUsedInPhi) {
1440 // The dummy normal dest block will just host an unconditional branch
1441 // instruction to the normal destination block passing the required block
1442 // arguments (including the invoke operation's result).
1443 OpBuilder::InsertionGuard g(builder);
1444 builder.setInsertionPointToStart(directNormalDest);
1445 builder.create<LLVM::BrOp>(loc, normalArgs, normalDest);
1446 } else {
1447 // If the invoke operation's result is not a block argument to the normal
1448 // destination block, just add the block arguments as usual.
1449 assert(llvm::none_of(
1450 normalArgs,
1451 [&](Value val) { return val.getDefiningOp() == invokeOp; }) &&
1452 "An llvm.invoke operation cannot pass its result as a block "
1453 "argument.");
1454 invokeOp.getNormalDestOperandsMutable().append(normalArgs);
1457 return success();
1459 if (inst->getOpcode() == llvm::Instruction::GetElementPtr) {
1460 auto *gepInst = cast<llvm::GetElementPtrInst>(inst);
1461 Type sourceElementType = convertType(gepInst->getSourceElementType());
1462 FailureOr<Value> basePtr = convertValue(gepInst->getOperand(0));
1463 if (failed(basePtr))
1464 return failure();
1466 // Treat every indices as dynamic since GEPOp::build will refine those
1467 // indices into static attributes later. One small downside of this
1468 // approach is that many unused `llvm.mlir.constant` would be emitted
1469 // at first place.
1470 SmallVector<GEPArg> indices;
1471 for (llvm::Value *operand : llvm::drop_begin(gepInst->operand_values())) {
1472 FailureOr<Value> index = convertValue(operand);
1473 if (failed(index))
1474 return failure();
1475 indices.push_back(*index);
1478 Type type = convertType(inst->getType());
1479 auto gepOp = builder.create<GEPOp>(loc, type, sourceElementType, *basePtr,
1480 indices, gepInst->isInBounds());
1481 mapValue(inst, gepOp);
1482 return success();
1485 // Convert all instructions that have an mlirBuilder.
1486 if (succeeded(convertInstructionImpl(builder, inst, *this)))
1487 return success();
1489 return emitError(loc) << "unhandled instruction: " << diag(*inst);
1492 LogicalResult ModuleImport::processInstruction(llvm::Instruction *inst) {
1493 // FIXME: Support uses of SubtargetData.
1494 // FIXME: Add support for call / operand attributes.
1495 // FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch,
1496 // callbr, vaarg, catchpad, cleanuppad instructions.
1498 // Convert LLVM intrinsics calls to MLIR intrinsics.
1499 if (auto *intrinsic = dyn_cast<llvm::IntrinsicInst>(inst))
1500 return convertIntrinsic(intrinsic);
1502 // Convert all remaining LLVM instructions to MLIR operations.
1503 return convertInstruction(inst);
1506 FlatSymbolRefAttr ModuleImport::getPersonalityAsAttr(llvm::Function *f) {
1507 if (!f->hasPersonalityFn())
1508 return nullptr;
1510 llvm::Constant *pf = f->getPersonalityFn();
1512 // If it directly has a name, we can use it.
1513 if (pf->hasName())
1514 return SymbolRefAttr::get(builder.getContext(), pf->getName());
1516 // If it doesn't have a name, currently, only function pointers that are
1517 // bitcast to i8* are parsed.
1518 if (auto *ce = dyn_cast<llvm::ConstantExpr>(pf)) {
1519 if (ce->getOpcode() == llvm::Instruction::BitCast &&
1520 ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) {
1521 if (auto *func = dyn_cast<llvm::Function>(ce->getOperand(0)))
1522 return SymbolRefAttr::get(builder.getContext(), func->getName());
1525 return FlatSymbolRefAttr();
1528 static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) {
1529 llvm::MemoryEffects memEffects = func->getMemoryEffects();
1531 auto othermem = convertModRefInfoFromLLVM(
1532 memEffects.getModRef(llvm::MemoryEffects::Location::Other));
1533 auto argMem = convertModRefInfoFromLLVM(
1534 memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
1535 auto inaccessibleMem = convertModRefInfoFromLLVM(
1536 memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
1537 auto memAttr = MemoryEffectsAttr::get(funcOp.getContext(), othermem, argMem,
1538 inaccessibleMem);
1539 // Only set the attr when it does not match the default value.
1540 if (memAttr.isReadWrite())
1541 return;
1542 funcOp.setMemoryAttr(memAttr);
1545 static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) {
1546 MLIRContext *context = funcOp.getContext();
1547 SmallVector<Attribute> passthroughs;
1548 llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes(
1549 llvm::AttributeList::AttrIndex::FunctionIndex);
1550 for (llvm::Attribute attr : funcAttrs) {
1551 // Skip the memory attribute since the LLVMFuncOp has an explicit memory
1552 // attribute.
1553 if (attr.hasAttribute(llvm::Attribute::Memory))
1554 continue;
1556 // Skip invalid type attributes.
1557 if (attr.isTypeAttribute()) {
1558 emitWarning(funcOp.getLoc(),
1559 "type attributes on a function are invalid, skipping it");
1560 continue;
1563 StringRef attrName;
1564 if (attr.isStringAttribute())
1565 attrName = attr.getKindAsString();
1566 else
1567 attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum());
1568 auto keyAttr = StringAttr::get(context, attrName);
1570 // Skip the aarch64_pstate_sm_<body|enabled> since the LLVMFuncOp has an
1571 // explicit attribute.
1572 if (attrName == "aarch64_pstate_sm_enabled" ||
1573 attrName == "aarch64_pstate_sm_body")
1574 continue;
1576 if (attr.isStringAttribute()) {
1577 StringRef val = attr.getValueAsString();
1578 if (val.empty()) {
1579 passthroughs.push_back(keyAttr);
1580 continue;
1582 passthroughs.push_back(
1583 ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
1584 continue;
1586 if (attr.isIntAttribute()) {
1587 auto val = std::to_string(attr.getValueAsInt());
1588 passthroughs.push_back(
1589 ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)}));
1590 continue;
1592 if (attr.isEnumAttribute()) {
1593 passthroughs.push_back(keyAttr);
1594 continue;
1597 llvm_unreachable("unexpected attribute kind");
1600 if (!passthroughs.empty())
1601 funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs));
1604 void ModuleImport::processFunctionAttributes(llvm::Function *func,
1605 LLVMFuncOp funcOp) {
1606 processMemoryEffects(func, funcOp);
1607 processPassthroughAttrs(func, funcOp);
1609 if (func->hasFnAttribute("aarch64_pstate_sm_enabled"))
1610 funcOp.setArmStreaming(true);
1611 else if (func->hasFnAttribute("aarch64_pstate_sm_body"))
1612 funcOp.setArmLocallyStreaming(true);
1615 DictionaryAttr
1616 ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
1617 OpBuilder &builder) {
1618 SmallVector<NamedAttribute> paramAttrs;
1619 for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
1620 auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind);
1621 // Skip attributes that are not attached.
1622 if (!llvmAttr.isValid())
1623 continue;
1624 Attribute mlirAttr;
1625 if (llvmAttr.isTypeAttribute())
1626 mlirAttr = TypeAttr::get(convertType(llvmAttr.getValueAsType()));
1627 else if (llvmAttr.isIntAttribute())
1628 mlirAttr = builder.getI64IntegerAttr(llvmAttr.getValueAsInt());
1629 else if (llvmAttr.isEnumAttribute())
1630 mlirAttr = builder.getUnitAttr();
1631 else
1632 llvm_unreachable("unexpected parameter attribute kind");
1633 paramAttrs.push_back(builder.getNamedAttr(mlirName, mlirAttr));
1636 return builder.getDictionaryAttr(paramAttrs);
1639 void ModuleImport::convertParameterAttributes(llvm::Function *func,
1640 LLVMFuncOp funcOp,
1641 OpBuilder &builder) {
1642 auto llvmAttrs = func->getAttributes();
1643 for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
1644 llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i);
1645 funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder));
1647 // Convert the result attributes and attach them wrapped in an ArrayAttribute
1648 // to the funcOp.
1649 llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
1650 if (!llvmResAttr.hasAttributes())
1651 return;
1652 funcOp.setResAttrsAttr(
1653 builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
1656 LogicalResult ModuleImport::processFunction(llvm::Function *func) {
1657 clearRegionState();
1659 auto functionType =
1660 dyn_cast<LLVMFunctionType>(convertType(func->getFunctionType()));
1661 if (func->isIntrinsic() &&
1662 iface.isConvertibleIntrinsic(func->getIntrinsicID()))
1663 return success();
1665 bool dsoLocal = func->hasLocalLinkage();
1666 CConv cconv = convertCConvFromLLVM(func->getCallingConv());
1668 // Insert the function at the end of the module.
1669 OpBuilder::InsertionGuard guard(builder);
1670 builder.setInsertionPoint(mlirModule.getBody(), mlirModule.getBody()->end());
1672 Location loc = debugImporter->translateFuncLocation(func);
1673 LLVMFuncOp funcOp = builder.create<LLVMFuncOp>(
1674 loc, func->getName(), functionType,
1675 convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv);
1677 convertParameterAttributes(func, funcOp, builder);
1679 if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func))
1680 funcOp.setPersonalityAttr(personality);
1681 else if (func->hasPersonalityFn())
1682 emitWarning(funcOp.getLoc(), "could not deduce personality, skipping it");
1684 if (func->hasGC())
1685 funcOp.setGarbageCollector(StringRef(func->getGC()));
1687 if (func->hasAtLeastLocalUnnamedAddr())
1688 funcOp.setUnnamedAddr(convertUnnamedAddrFromLLVM(func->getUnnamedAddr()));
1690 if (func->hasSection())
1691 funcOp.setSection(StringRef(func->getSection()));
1693 funcOp.setVisibility_(convertVisibilityFromLLVM(func->getVisibility()));
1695 if (func->hasComdat())
1696 funcOp.setComdatAttr(comdatMapping.lookup(func->getComdat()));
1698 if (llvm::MaybeAlign maybeAlign = func->getAlign())
1699 funcOp.setAlignment(maybeAlign->value());
1701 // Handle Function attributes.
1702 processFunctionAttributes(func, funcOp);
1704 // Convert non-debug metadata by using the dialect interface.
1705 SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
1706 func->getAllMetadata(allMetadata);
1707 for (auto &[kind, node] : allMetadata) {
1708 if (!iface.isConvertibleMetadata(kind))
1709 continue;
1710 if (failed(iface.setMetadataAttrs(builder, kind, node, funcOp, *this))) {
1711 emitWarning(funcOp.getLoc())
1712 << "unhandled function metadata: " << diagMD(node, llvmModule.get())
1713 << " on " << diag(*func);
1717 if (func->isDeclaration())
1718 return success();
1720 // Eagerly create all blocks.
1721 for (llvm::BasicBlock &bb : *func) {
1722 Block *block =
1723 builder.createBlock(&funcOp.getBody(), funcOp.getBody().end());
1724 mapBlock(&bb, block);
1727 // Add function arguments to the entry block.
1728 for (const auto &it : llvm::enumerate(func->args())) {
1729 BlockArgument blockArg = funcOp.getFunctionBody().addArgument(
1730 functionType.getParamType(it.index()), funcOp.getLoc());
1731 mapValue(&it.value(), blockArg);
1734 // Process the blocks in topological order. The ordered traversal ensures
1735 // operands defined in a dominating block have a valid mapping to an MLIR
1736 // value once a block is translated.
1737 SetVector<llvm::BasicBlock *> blocks = getTopologicallySortedBlocks(func);
1738 setConstantInsertionPointToStart(lookupBlock(blocks.front()));
1739 for (llvm::BasicBlock *bb : blocks)
1740 if (failed(processBasicBlock(bb, lookupBlock(bb))))
1741 return failure();
1743 // Process the debug intrinsics that require a delayed conversion after
1744 // everything else was converted.
1745 if (failed(processDebugIntrinsics()))
1746 return failure();
1748 return success();
1751 /// Checks if `dbgIntr` is a kill location that holds metadata instead of an SSA
1752 /// value.
1753 static bool isMetadataKillLocation(llvm::DbgVariableIntrinsic *dbgIntr) {
1754 if (!dbgIntr->isKillLocation())
1755 return false;
1756 llvm::Value *value = dbgIntr->getArgOperand(0);
1757 auto *nodeAsVal = dyn_cast<llvm::MetadataAsValue>(value);
1758 if (!nodeAsVal)
1759 return false;
1760 return !isa<llvm::ValueAsMetadata>(nodeAsVal->getMetadata());
1763 LogicalResult
1764 ModuleImport::processDebugIntrinsic(llvm::DbgVariableIntrinsic *dbgIntr,
1765 DominanceInfo &domInfo) {
1766 Location loc = translateLoc(dbgIntr->getDebugLoc());
1767 auto emitUnsupportedWarning = [&]() {
1768 if (emitExpensiveWarnings)
1769 emitWarning(loc) << "dropped intrinsic: " << diag(*dbgIntr);
1770 return success();
1772 // Drop debug intrinsics with a non-empty debug expression.
1773 // TODO: Support debug intrinsics that evaluate a debug expression.
1774 if (dbgIntr->hasArgList() || dbgIntr->getExpression()->getNumElements() != 0)
1775 return emitUnsupportedWarning();
1776 // Kill locations can have metadata nodes as location operand. This
1777 // cannot be converted to poison as the type cannot be reconstructed.
1778 // TODO: find a way to support this case.
1779 if (isMetadataKillLocation(dbgIntr))
1780 return emitUnsupportedWarning();
1781 FailureOr<Value> argOperand = convertMetadataValue(dbgIntr->getArgOperand(0));
1782 if (failed(argOperand))
1783 return emitError(loc) << "failed to convert a debug intrinsic operand: "
1784 << diag(*dbgIntr);
1786 // Ensure that the debug instrinsic is inserted right after its operand is
1787 // defined. Otherwise, the operand might not necessarily dominate the
1788 // intrinsic. If the defining operation is a terminator, insert the intrinsic
1789 // into a dominated block.
1790 OpBuilder::InsertionGuard guard(builder);
1791 if (Operation *op = argOperand->getDefiningOp();
1792 op && op->hasTrait<OpTrait::IsTerminator>()) {
1793 // Find a dominated block that can hold the debug intrinsic.
1794 auto dominatedBlocks = domInfo.getNode(op->getBlock())->children();
1795 // If no block is dominated by the terminator, this intrinisc cannot be
1796 // converted.
1797 if (dominatedBlocks.empty())
1798 return emitUnsupportedWarning();
1799 // Set insertion point before the terminator, to avoid inserting something
1800 // before landingpads.
1801 Block *dominatedBlock = (*dominatedBlocks.begin())->getBlock();
1802 builder.setInsertionPoint(dominatedBlock->getTerminator());
1803 } else {
1804 builder.setInsertionPointAfterValue(*argOperand);
1806 DILocalVariableAttr localVariableAttr =
1807 matchLocalVariableAttr(dbgIntr->getArgOperand(1));
1808 Operation *op =
1809 llvm::TypeSwitch<llvm::DbgVariableIntrinsic *, Operation *>(dbgIntr)
1810 .Case([&](llvm::DbgDeclareInst *) {
1811 return builder.create<LLVM::DbgDeclareOp>(loc, *argOperand,
1812 localVariableAttr);
1814 .Case([&](llvm::DbgValueInst *) {
1815 return builder.create<LLVM::DbgValueOp>(loc, *argOperand,
1816 localVariableAttr);
1818 mapNoResultOp(dbgIntr, op);
1819 setNonDebugMetadataAttrs(dbgIntr, op);
1820 return success();
1823 LogicalResult ModuleImport::processDebugIntrinsics() {
1824 DominanceInfo domInfo;
1825 for (llvm::Instruction *inst : debugIntrinsics) {
1826 auto *intrCall = cast<llvm::DbgVariableIntrinsic>(inst);
1827 if (failed(processDebugIntrinsic(intrCall, domInfo)))
1828 return failure();
1830 return success();
1833 LogicalResult ModuleImport::processBasicBlock(llvm::BasicBlock *bb,
1834 Block *block) {
1835 builder.setInsertionPointToStart(block);
1836 for (llvm::Instruction &inst : *bb) {
1837 if (failed(processInstruction(&inst)))
1838 return failure();
1840 // Skip additional processing when the instructions is a debug intrinsics
1841 // that was not yet converted.
1842 if (debugIntrinsics.contains(&inst))
1843 continue;
1845 // Set the non-debug metadata attributes on the imported operation and emit
1846 // a warning if an instruction other than a phi instruction is dropped
1847 // during the import.
1848 if (Operation *op = lookupOperation(&inst)) {
1849 setNonDebugMetadataAttrs(&inst, op);
1850 } else if (inst.getOpcode() != llvm::Instruction::PHI) {
1851 if (emitExpensiveWarnings) {
1852 Location loc = debugImporter->translateLoc(inst.getDebugLoc());
1853 emitWarning(loc) << "dropped instruction: " << diag(inst);
1857 return success();
1860 FailureOr<SmallVector<AccessGroupAttr>>
1861 ModuleImport::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
1862 return loopAnnotationImporter->lookupAccessGroupAttrs(node);
1865 LoopAnnotationAttr
1866 ModuleImport::translateLoopAnnotationAttr(const llvm::MDNode *node,
1867 Location loc) const {
1868 return loopAnnotationImporter->translateLoopAnnotation(node, loc);
1871 OwningOpRef<ModuleOp>
1872 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
1873 MLIRContext *context,
1874 bool emitExpensiveWarnings) {
1875 // Preload all registered dialects to allow the import to iterate the
1876 // registered LLVMImportDialectInterface implementations and query the
1877 // supported LLVM IR constructs before starting the translation. Assumes the
1878 // LLVM and DLTI dialects that convert the core LLVM IR constructs have been
1879 // registered before.
1880 assert(llvm::is_contained(context->getAvailableDialects(),
1881 LLVMDialect::getDialectNamespace()));
1882 assert(llvm::is_contained(context->getAvailableDialects(),
1883 DLTIDialect::getDialectNamespace()));
1884 context->loadAllAvailableDialects();
1885 OwningOpRef<ModuleOp> module(ModuleOp::create(FileLineColLoc::get(
1886 StringAttr::get(context, llvmModule->getSourceFileName()), /*line=*/0,
1887 /*column=*/0)));
1889 ModuleImport moduleImport(module.get(), std::move(llvmModule),
1890 emitExpensiveWarnings);
1891 if (failed(moduleImport.initializeImportInterface()))
1892 return {};
1893 if (failed(moduleImport.convertDataLayout()))
1894 return {};
1895 if (failed(moduleImport.convertComdats()))
1896 return {};
1897 if (failed(moduleImport.convertMetadata()))
1898 return {};
1899 if (failed(moduleImport.convertGlobals()))
1900 return {};
1901 if (failed(moduleImport.convertFunctions()))
1902 return {};
1904 return module;