[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / Target / LLVMIR / DebugImporter.cpp
blob1ab55b079b529456a99f0497d36d6bef41b2ef20
1 //===- DebugImporter.cpp - LLVM to MLIR Debug conversion ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "DebugImporter.h"
10 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/BuiltinAttributes.h"
13 #include "mlir/IR/Location.h"
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/ADT/ScopeExit.h"
16 #include "llvm/ADT/SetOperations.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/BinaryFormat/Dwarf.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DebugInfoMetadata.h"
21 #include "llvm/IR/Metadata.h"
22 #include "llvm/Support/Casting.h"
23 #include "llvm/Support/ErrorHandling.h"
25 using namespace mlir;
26 using namespace mlir::LLVM;
27 using namespace mlir::LLVM::detail;
29 DebugImporter::DebugImporter(ModuleOp mlirModule,
30 bool dropDICompositeTypeElements)
31 : recursionPruner(mlirModule.getContext()),
32 context(mlirModule.getContext()), mlirModule(mlirModule),
33 dropDICompositeTypeElements(dropDICompositeTypeElements) {}
35 Location DebugImporter::translateFuncLocation(llvm::Function *func) {
36 llvm::DISubprogram *subprogram = func->getSubprogram();
37 if (!subprogram)
38 return UnknownLoc::get(context);
40 // Add a fused location to link the subprogram information.
41 StringAttr funcName = StringAttr::get(context, subprogram->getName());
42 StringAttr fileName = StringAttr::get(context, subprogram->getFilename());
43 return FusedLocWith<DISubprogramAttr>::get(
44 {NameLoc::get(funcName),
45 FileLineColLoc::get(fileName, subprogram->getLine(), /*column=*/0)},
46 translate(subprogram), context);
49 //===----------------------------------------------------------------------===//
50 // Attributes
51 //===----------------------------------------------------------------------===//
53 DIBasicTypeAttr DebugImporter::translateImpl(llvm::DIBasicType *node) {
54 return DIBasicTypeAttr::get(context, node->getTag(), node->getName(),
55 node->getSizeInBits(), node->getEncoding());
58 DICompileUnitAttr DebugImporter::translateImpl(llvm::DICompileUnit *node) {
59 std::optional<DIEmissionKind> emissionKind =
60 symbolizeDIEmissionKind(node->getEmissionKind());
61 std::optional<DINameTableKind> nameTableKind = symbolizeDINameTableKind(
62 static_cast<
63 std::underlying_type_t<llvm::DICompileUnit::DebugNameTableKind>>(
64 node->getNameTableKind()));
65 return DICompileUnitAttr::get(
66 context, getOrCreateDistinctID(node), node->getSourceLanguage(),
67 translate(node->getFile()), getStringAttrOrNull(node->getRawProducer()),
68 node->isOptimized(), emissionKind.value(), nameTableKind.value());
71 DICompositeTypeAttr DebugImporter::translateImpl(llvm::DICompositeType *node) {
72 std::optional<DIFlags> flags = symbolizeDIFlags(node->getFlags());
73 SmallVector<DINodeAttr> elements;
75 // A vector always requires an element.
76 bool isVectorType = flags && bitEnumContainsAll(*flags, DIFlags::Vector);
77 if (isVectorType || !dropDICompositeTypeElements) {
78 for (llvm::DINode *element : node->getElements()) {
79 assert(element && "expected a non-null element type");
80 elements.push_back(translate(element));
83 // Drop the elements parameter if any of the elements are invalid.
84 if (llvm::is_contained(elements, nullptr))
85 elements.clear();
86 DITypeAttr baseType = translate(node->getBaseType());
87 // Arrays require a base type, otherwise the debug metadata is considered to
88 // be malformed.
89 if (node->getTag() == llvm::dwarf::DW_TAG_array_type && !baseType)
90 return nullptr;
91 return DICompositeTypeAttr::get(
92 context, node->getTag(), /*recId=*/{},
93 getStringAttrOrNull(node->getRawName()), translate(node->getFile()),
94 node->getLine(), translate(node->getScope()), baseType,
95 flags.value_or(DIFlags::Zero), node->getSizeInBits(),
96 node->getAlignInBits(), elements);
99 DIDerivedTypeAttr DebugImporter::translateImpl(llvm::DIDerivedType *node) {
100 // Return nullptr if the base type is invalid.
101 DITypeAttr baseType = translate(node->getBaseType());
102 if (node->getBaseType() && !baseType)
103 return nullptr;
104 DINodeAttr extraData =
105 translate(dyn_cast_or_null<llvm::DINode>(node->getExtraData()));
106 return DIDerivedTypeAttr::get(
107 context, node->getTag(), getStringAttrOrNull(node->getRawName()),
108 baseType, node->getSizeInBits(), node->getAlignInBits(),
109 node->getOffsetInBits(), extraData);
112 DIFileAttr DebugImporter::translateImpl(llvm::DIFile *node) {
113 return DIFileAttr::get(context, node->getFilename(), node->getDirectory());
116 DILabelAttr DebugImporter::translateImpl(llvm::DILabel *node) {
117 // Return nullptr if the scope or type is a cyclic dependency.
118 DIScopeAttr scope = translate(node->getScope());
119 if (node->getScope() && !scope)
120 return nullptr;
121 return DILabelAttr::get(context, scope,
122 getStringAttrOrNull(node->getRawName()),
123 translate(node->getFile()), node->getLine());
126 DILexicalBlockAttr DebugImporter::translateImpl(llvm::DILexicalBlock *node) {
127 // Return nullptr if the scope or type is a cyclic dependency.
128 DIScopeAttr scope = translate(node->getScope());
129 if (node->getScope() && !scope)
130 return nullptr;
131 return DILexicalBlockAttr::get(context, scope, translate(node->getFile()),
132 node->getLine(), node->getColumn());
135 DILexicalBlockFileAttr
136 DebugImporter::translateImpl(llvm::DILexicalBlockFile *node) {
137 // Return nullptr if the scope or type is a cyclic dependency.
138 DIScopeAttr scope = translate(node->getScope());
139 if (node->getScope() && !scope)
140 return nullptr;
141 return DILexicalBlockFileAttr::get(context, scope, translate(node->getFile()),
142 node->getDiscriminator());
145 DIGlobalVariableAttr
146 DebugImporter::translateImpl(llvm::DIGlobalVariable *node) {
147 // Names of DIGlobalVariables can be empty. MLIR models them as null, instead
148 // of empty strings, so this special handling is necessary.
149 auto convertToStringAttr = [&](StringRef name) -> StringAttr {
150 if (name.empty())
151 return {};
152 return StringAttr::get(context, node->getName());
154 return DIGlobalVariableAttr::get(
155 context, translate(node->getScope()),
156 convertToStringAttr(node->getName()),
157 convertToStringAttr(node->getLinkageName()), translate(node->getFile()),
158 node->getLine(), translate(node->getType()), node->isLocalToUnit(),
159 node->isDefinition(), node->getAlignInBits());
162 DILocalVariableAttr DebugImporter::translateImpl(llvm::DILocalVariable *node) {
163 // Return nullptr if the scope or type is a cyclic dependency.
164 DIScopeAttr scope = translate(node->getScope());
165 if (node->getScope() && !scope)
166 return nullptr;
167 return DILocalVariableAttr::get(
168 context, scope, getStringAttrOrNull(node->getRawName()),
169 translate(node->getFile()), node->getLine(), node->getArg(),
170 node->getAlignInBits(), translate(node->getType()));
173 DIScopeAttr DebugImporter::translateImpl(llvm::DIScope *node) {
174 return cast<DIScopeAttr>(translate(static_cast<llvm::DINode *>(node)));
177 DIModuleAttr DebugImporter::translateImpl(llvm::DIModule *node) {
178 return DIModuleAttr::get(
179 context, translate(node->getFile()), translate(node->getScope()),
180 getStringAttrOrNull(node->getRawName()),
181 getStringAttrOrNull(node->getRawConfigurationMacros()),
182 getStringAttrOrNull(node->getRawIncludePath()),
183 getStringAttrOrNull(node->getRawAPINotesFile()), node->getLineNo(),
184 node->getIsDecl());
187 DINamespaceAttr DebugImporter::translateImpl(llvm::DINamespace *node) {
188 return DINamespaceAttr::get(context, getStringAttrOrNull(node->getRawName()),
189 translate(node->getScope()),
190 node->getExportSymbols());
193 DISubprogramAttr DebugImporter::translateImpl(llvm::DISubprogram *node) {
194 // Only definitions require a distinct identifier.
195 mlir::DistinctAttr id;
196 if (node->isDistinct())
197 id = getOrCreateDistinctID(node);
198 // Return nullptr if the scope or type is invalid.
199 DIScopeAttr scope = translate(node->getScope());
200 if (node->getScope() && !scope)
201 return nullptr;
202 std::optional<DISubprogramFlags> subprogramFlags =
203 symbolizeDISubprogramFlags(node->getSubprogram()->getSPFlags());
204 assert(subprogramFlags && "expected valid subprogram flags");
205 DISubroutineTypeAttr type = translate(node->getType());
206 if (node->getType() && !type)
207 return nullptr;
208 return DISubprogramAttr::get(context, id, translate(node->getUnit()), scope,
209 getStringAttrOrNull(node->getRawName()),
210 getStringAttrOrNull(node->getRawLinkageName()),
211 translate(node->getFile()), node->getLine(),
212 node->getScopeLine(), *subprogramFlags, type);
215 DISubrangeAttr DebugImporter::translateImpl(llvm::DISubrange *node) {
216 auto getIntegerAttrOrNull = [&](llvm::DISubrange::BoundType data) {
217 if (auto *constInt = llvm::dyn_cast_or_null<llvm::ConstantInt *>(data))
218 return IntegerAttr::get(IntegerType::get(context, 64),
219 constInt->getSExtValue());
220 return IntegerAttr();
222 IntegerAttr count = getIntegerAttrOrNull(node->getCount());
223 IntegerAttr upperBound = getIntegerAttrOrNull(node->getUpperBound());
224 // Either count or the upper bound needs to be present. Otherwise, the
225 // metadata is invalid. The conversion might fail due to unsupported DI nodes.
226 if (!count && !upperBound)
227 return {};
228 return DISubrangeAttr::get(
229 context, count, getIntegerAttrOrNull(node->getLowerBound()), upperBound,
230 getIntegerAttrOrNull(node->getStride()));
233 DISubroutineTypeAttr
234 DebugImporter::translateImpl(llvm::DISubroutineType *node) {
235 SmallVector<DITypeAttr> types;
236 for (llvm::DIType *type : node->getTypeArray()) {
237 if (!type) {
238 // A nullptr entry may appear at the beginning or the end of the
239 // subroutine types list modeling either a void result type or the type of
240 // a variadic argument. Translate the nullptr to an explicit
241 // DINullTypeAttr since the attribute list cannot contain a nullptr entry.
242 types.push_back(DINullTypeAttr::get(context));
243 continue;
245 types.push_back(translate(type));
247 // Return nullptr if any of the types is invalid.
248 if (llvm::is_contained(types, nullptr))
249 return nullptr;
250 return DISubroutineTypeAttr::get(context, node->getCC(), types);
253 DITypeAttr DebugImporter::translateImpl(llvm::DIType *node) {
254 return cast<DITypeAttr>(translate(static_cast<llvm::DINode *>(node)));
257 DINodeAttr DebugImporter::translate(llvm::DINode *node) {
258 if (!node)
259 return nullptr;
261 // Check for a cached instance.
262 if (DINodeAttr attr = nodeToAttr.lookup(node))
263 return attr;
265 // Register with the recursive translator. If it can be handled without
266 // recursing into it, return the result immediately.
267 if (DINodeAttr attr = recursionPruner.pruneOrPushTranslationStack(node))
268 return attr;
270 auto guard = llvm::make_scope_exit(
271 [&]() { recursionPruner.popTranslationStack(node); });
273 // Convert the debug metadata if possible.
274 auto translateNode = [this](llvm::DINode *node) -> DINodeAttr {
275 if (auto *casted = dyn_cast<llvm::DIBasicType>(node))
276 return translateImpl(casted);
277 if (auto *casted = dyn_cast<llvm::DICompileUnit>(node))
278 return translateImpl(casted);
279 if (auto *casted = dyn_cast<llvm::DICompositeType>(node))
280 return translateImpl(casted);
281 if (auto *casted = dyn_cast<llvm::DIDerivedType>(node))
282 return translateImpl(casted);
283 if (auto *casted = dyn_cast<llvm::DIFile>(node))
284 return translateImpl(casted);
285 if (auto *casted = dyn_cast<llvm::DIGlobalVariable>(node))
286 return translateImpl(casted);
287 if (auto *casted = dyn_cast<llvm::DILabel>(node))
288 return translateImpl(casted);
289 if (auto *casted = dyn_cast<llvm::DILexicalBlock>(node))
290 return translateImpl(casted);
291 if (auto *casted = dyn_cast<llvm::DILexicalBlockFile>(node))
292 return translateImpl(casted);
293 if (auto *casted = dyn_cast<llvm::DILocalVariable>(node))
294 return translateImpl(casted);
295 if (auto *casted = dyn_cast<llvm::DIModule>(node))
296 return translateImpl(casted);
297 if (auto *casted = dyn_cast<llvm::DINamespace>(node))
298 return translateImpl(casted);
299 if (auto *casted = dyn_cast<llvm::DISubprogram>(node))
300 return translateImpl(casted);
301 if (auto *casted = dyn_cast<llvm::DISubrange>(node))
302 return translateImpl(casted);
303 if (auto *casted = dyn_cast<llvm::DISubroutineType>(node))
304 return translateImpl(casted);
305 return nullptr;
307 if (DINodeAttr attr = translateNode(node)) {
308 auto [result, isSelfContained] =
309 recursionPruner.finalizeTranslation(node, attr);
310 // Only cache fully self-contained nodes.
311 if (isSelfContained)
312 nodeToAttr.try_emplace(node, result);
313 return result;
315 return nullptr;
318 //===----------------------------------------------------------------------===//
319 // RecursionPruner
320 //===----------------------------------------------------------------------===//
322 /// Get the `getRecSelf` constructor for the translated type of `node` if its
323 /// translated DITypeAttr supports recursion. Otherwise, returns nullptr.
324 static function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>
325 getRecSelfConstructor(llvm::DINode *node) {
326 using CtorType = function_ref<DIRecursiveTypeAttrInterface(DistinctAttr)>;
327 return TypeSwitch<llvm::DINode *, CtorType>(node)
328 .Case([&](llvm::DICompositeType *) {
329 return CtorType(DICompositeTypeAttr::getRecSelf);
331 .Default(CtorType());
334 DINodeAttr DebugImporter::RecursionPruner::pruneOrPushTranslationStack(
335 llvm::DINode *node) {
336 // If the node type is capable of being recursive, check if it's seen
337 // before.
338 auto recSelfCtor = getRecSelfConstructor(node);
339 if (recSelfCtor) {
340 // If a cyclic dependency is detected since the same node is being
341 // traversed twice, emit a recursive self type, and mark the duplicate
342 // node on the translationStack so it can emit a recursive decl type.
343 auto [iter, inserted] = translationStack.try_emplace(node);
344 if (!inserted) {
345 // The original node may have already been assigned a recursive ID from
346 // a different self-reference. Use that if possible.
347 DIRecursiveTypeAttrInterface recSelf = iter->second.recSelf;
348 if (!recSelf) {
349 DistinctAttr recId = nodeToRecId.lookup(node);
350 if (!recId) {
351 recId = DistinctAttr::create(UnitAttr::get(context));
352 nodeToRecId[node] = recId;
354 recSelf = recSelfCtor(recId);
355 iter->second.recSelf = recSelf;
357 // Inject the self-ref into the previous layer.
358 translationStack.back().second.unboundSelfRefs.insert(recSelf);
359 return cast<DINodeAttr>(recSelf);
363 return lookup(node);
366 std::pair<DINodeAttr, bool>
367 DebugImporter::RecursionPruner::finalizeTranslation(llvm::DINode *node,
368 DINodeAttr result) {
369 // If `node` is not a potentially recursive type, it will not be on the
370 // translation stack. Nothing to set in this case.
371 if (translationStack.empty())
372 return {result, true};
373 if (translationStack.back().first != node)
374 return {result, translationStack.back().second.unboundSelfRefs.empty()};
376 TranslationState &state = translationStack.back().second;
378 // If this node is actually recursive, set the recId onto `result`.
379 if (DIRecursiveTypeAttrInterface recSelf = state.recSelf) {
380 auto recType = cast<DIRecursiveTypeAttrInterface>(result);
381 result = cast<DINodeAttr>(recType.withRecId(recSelf.getRecId()));
382 // Remove this recSelf from the set of unbound selfRefs.
383 state.unboundSelfRefs.erase(recSelf);
386 // Insert the result into our internal cache if it's not self-contained.
387 if (!state.unboundSelfRefs.empty()) {
388 [[maybe_unused]] auto [_, inserted] = dependentCache.try_emplace(
389 node, DependentTranslation{result, state.unboundSelfRefs});
390 assert(inserted && "invalid state: caching the same DINode twice");
391 return {result, false};
393 return {result, true};
396 void DebugImporter::RecursionPruner::popTranslationStack(llvm::DINode *node) {
397 // If `node` is not a potentially recursive type, it will not be on the
398 // translation stack. Nothing to handle in this case.
399 if (translationStack.empty() || translationStack.back().first != node)
400 return;
402 // At the end of the stack, all unbound self-refs must be resolved already,
403 // and the entire cache should be accounted for.
404 TranslationState &currLayerState = translationStack.back().second;
405 if (translationStack.size() == 1) {
406 assert(currLayerState.unboundSelfRefs.empty() &&
407 "internal error: unbound recursive self reference at top level.");
408 translationStack.pop_back();
409 return;
412 // Copy unboundSelfRefs down to the previous level.
413 TranslationState &nextLayerState = (++translationStack.rbegin())->second;
414 nextLayerState.unboundSelfRefs.insert(currLayerState.unboundSelfRefs.begin(),
415 currLayerState.unboundSelfRefs.end());
416 translationStack.pop_back();
419 DINodeAttr DebugImporter::RecursionPruner::lookup(llvm::DINode *node) {
420 auto cacheIter = dependentCache.find(node);
421 if (cacheIter == dependentCache.end())
422 return {};
424 DependentTranslation &entry = cacheIter->second;
425 if (llvm::set_is_subset(entry.unboundSelfRefs,
426 translationStack.back().second.unboundSelfRefs))
427 return entry.attr;
429 // Stale cache entry.
430 dependentCache.erase(cacheIter);
431 return {};
434 //===----------------------------------------------------------------------===//
435 // Locations
436 //===----------------------------------------------------------------------===//
438 Location DebugImporter::translateLoc(llvm::DILocation *loc) {
439 if (!loc)
440 return UnknownLoc::get(context);
442 // Get the file location of the instruction.
443 Location result = FileLineColLoc::get(context, loc->getFilename(),
444 loc->getLine(), loc->getColumn());
446 // Add scope information.
447 assert(loc->getScope() && "expected non-null scope");
448 result = FusedLocWith<DIScopeAttr>::get({result}, translate(loc->getScope()),
449 context);
451 // Add call site information, if available.
452 if (llvm::DILocation *inlinedAt = loc->getInlinedAt())
453 result = CallSiteLoc::get(result, translateLoc(inlinedAt));
455 return result;
458 DIExpressionAttr DebugImporter::translateExpression(llvm::DIExpression *node) {
459 SmallVector<DIExpressionElemAttr> ops;
461 // Begin processing the operations.
462 for (const llvm::DIExpression::ExprOperand &op : node->expr_ops()) {
463 SmallVector<uint64_t> operands;
464 operands.reserve(op.getNumArgs());
465 for (const auto &i : llvm::seq(op.getNumArgs()))
466 operands.push_back(op.getArg(i));
467 const auto attr = DIExpressionElemAttr::get(context, op.getOp(), operands);
468 ops.push_back(attr);
470 return DIExpressionAttr::get(context, ops);
473 DIGlobalVariableExpressionAttr DebugImporter::translateGlobalVariableExpression(
474 llvm::DIGlobalVariableExpression *node) {
475 return DIGlobalVariableExpressionAttr::get(
476 context, translate(node->getVariable()),
477 translateExpression(node->getExpression()));
480 StringAttr DebugImporter::getStringAttrOrNull(llvm::MDString *stringNode) {
481 if (!stringNode)
482 return StringAttr();
483 return StringAttr::get(context, stringNode->getString());
486 DistinctAttr DebugImporter::getOrCreateDistinctID(llvm::DINode *node) {
487 DistinctAttr &id = nodeToDistinctAttr[node];
488 if (!id)
489 id = DistinctAttr::create(UnitAttr::get(context));
490 return id;