[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Target / LLVMIR / LoopAnnotationImporter.cpp
blobe4905423347a21764754450390f3eccd3bb448e4
1 //===- LoopAnnotationImporter.cpp - Loop annotation import ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "LoopAnnotationImporter.h"
10 #include "llvm/IR/Constants.h"
12 using namespace mlir;
13 using namespace mlir::LLVM;
14 using namespace mlir::LLVM::detail;
16 namespace {
17 /// Helper class that keeps the state of one metadata to attribute conversion.
18 struct LoopMetadataConversion {
19 LoopMetadataConversion(const llvm::MDNode *node, Location loc,
20 LoopAnnotationImporter &loopAnnotationImporter)
21 : node(node), loc(loc), loopAnnotationImporter(loopAnnotationImporter),
22 ctx(loc->getContext()){};
23 /// Converts this structs loop metadata node into a LoopAnnotationAttr.
24 LoopAnnotationAttr convert();
26 /// Initializes the shared state for the conversion member functions.
27 LogicalResult initConversionState();
29 /// Helper function to get and erase a property.
30 const llvm::MDNode *lookupAndEraseProperty(StringRef name);
32 /// Helper functions to lookup and convert MDNodes into a specifc attribute
33 /// kind. These functions return null-attributes if there is no node with the
34 /// specified name, or failure, if the node is ill-formatted.
35 FailureOr<BoolAttr> lookupUnitNode(StringRef name);
36 FailureOr<BoolAttr> lookupBoolNode(StringRef name, bool negated = false);
37 FailureOr<BoolAttr> lookupIntNodeAsBoolAttr(StringRef name);
38 FailureOr<IntegerAttr> lookupIntNode(StringRef name);
39 FailureOr<llvm::MDNode *> lookupMDNode(StringRef name);
40 FailureOr<SmallVector<llvm::MDNode *>> lookupMDNodes(StringRef name);
41 FailureOr<LoopAnnotationAttr> lookupFollowupNode(StringRef name);
42 FailureOr<BoolAttr> lookupBooleanUnitNode(StringRef enableName,
43 StringRef disableName,
44 bool negated = false);
46 /// Conversion functions for sub-attributes.
47 FailureOr<LoopVectorizeAttr> convertVectorizeAttr();
48 FailureOr<LoopInterleaveAttr> convertInterleaveAttr();
49 FailureOr<LoopUnrollAttr> convertUnrollAttr();
50 FailureOr<LoopUnrollAndJamAttr> convertUnrollAndJamAttr();
51 FailureOr<LoopLICMAttr> convertLICMAttr();
52 FailureOr<LoopDistributeAttr> convertDistributeAttr();
53 FailureOr<LoopPipelineAttr> convertPipelineAttr();
54 FailureOr<LoopPeeledAttr> convertPeeledAttr();
55 FailureOr<LoopUnswitchAttr> convertUnswitchAttr();
56 FailureOr<SmallVector<AccessGroupAttr>> convertParallelAccesses();
57 FusedLoc convertStartLoc();
58 FailureOr<FusedLoc> convertEndLoc();
60 llvm::SmallVector<llvm::DILocation *, 2> locations;
61 llvm::StringMap<const llvm::MDNode *> propertyMap;
62 const llvm::MDNode *node;
63 Location loc;
64 LoopAnnotationImporter &loopAnnotationImporter;
65 MLIRContext *ctx;
67 } // namespace
69 LogicalResult LoopMetadataConversion::initConversionState() {
70 // Check if it's a valid node.
71 if (node->getNumOperands() == 0 ||
72 dyn_cast<llvm::MDNode>(node->getOperand(0)) != node)
73 return emitWarning(loc) << "invalid loop node";
75 for (const llvm::MDOperand &operand : llvm::drop_begin(node->operands())) {
76 if (auto *diLoc = dyn_cast<llvm::DILocation>(operand)) {
77 locations.push_back(diLoc);
78 continue;
81 auto *property = dyn_cast<llvm::MDNode>(operand);
82 if (!property)
83 return emitWarning(loc) << "expected all loop properties to be either "
84 "debug locations or metadata nodes";
86 if (property->getNumOperands() == 0)
87 return emitWarning(loc) << "cannot import empty loop property";
89 auto *nameNode = dyn_cast<llvm::MDString>(property->getOperand(0));
90 if (!nameNode)
91 return emitWarning(loc) << "cannot import loop property without a name";
92 StringRef name = nameNode->getString();
94 bool succ = propertyMap.try_emplace(name, property).second;
95 if (!succ)
96 return emitWarning(loc)
97 << "cannot import loop properties with duplicated names " << name;
100 return success();
103 const llvm::MDNode *
104 LoopMetadataConversion::lookupAndEraseProperty(StringRef name) {
105 auto it = propertyMap.find(name);
106 if (it == propertyMap.end())
107 return nullptr;
108 const llvm::MDNode *property = it->getValue();
109 propertyMap.erase(it);
110 return property;
113 FailureOr<BoolAttr> LoopMetadataConversion::lookupUnitNode(StringRef name) {
114 const llvm::MDNode *property = lookupAndEraseProperty(name);
115 if (!property)
116 return BoolAttr(nullptr);
118 if (property->getNumOperands() != 1)
119 return emitWarning(loc)
120 << "expected metadata node " << name << " to hold no value";
122 return BoolAttr::get(ctx, true);
125 FailureOr<BoolAttr> LoopMetadataConversion::lookupBooleanUnitNode(
126 StringRef enableName, StringRef disableName, bool negated) {
127 auto enable = lookupUnitNode(enableName);
128 auto disable = lookupUnitNode(disableName);
129 if (failed(enable) || failed(disable))
130 return failure();
132 if (*enable && *disable)
133 return emitWarning(loc)
134 << "expected metadata nodes " << enableName << " and " << disableName
135 << " to be mutually exclusive.";
137 if (*enable)
138 return BoolAttr::get(ctx, !negated);
140 if (*disable)
141 return BoolAttr::get(ctx, negated);
142 return BoolAttr(nullptr);
145 FailureOr<BoolAttr> LoopMetadataConversion::lookupBoolNode(StringRef name,
146 bool negated) {
147 const llvm::MDNode *property = lookupAndEraseProperty(name);
148 if (!property)
149 return BoolAttr(nullptr);
151 auto emitNodeWarning = [&]() {
152 return emitWarning(loc)
153 << "expected metadata node " << name << " to hold a boolean value";
156 if (property->getNumOperands() != 2)
157 return emitNodeWarning();
158 llvm::ConstantInt *val =
159 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
160 if (!val || val->getBitWidth() != 1)
161 return emitNodeWarning();
163 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1) ^ negated);
166 FailureOr<BoolAttr>
167 LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name) {
168 const llvm::MDNode *property = lookupAndEraseProperty(name);
169 if (!property)
170 return BoolAttr(nullptr);
172 auto emitNodeWarning = [&]() {
173 return emitWarning(loc)
174 << "expected metadata node " << name << " to hold an integer value";
177 if (property->getNumOperands() != 2)
178 return emitNodeWarning();
179 llvm::ConstantInt *val =
180 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
181 if (!val || val->getBitWidth() != 32)
182 return emitNodeWarning();
184 return BoolAttr::get(ctx, val->getValue().getLimitedValue(1));
187 FailureOr<IntegerAttr> LoopMetadataConversion::lookupIntNode(StringRef name) {
188 const llvm::MDNode *property = lookupAndEraseProperty(name);
189 if (!property)
190 return IntegerAttr(nullptr);
192 auto emitNodeWarning = [&]() {
193 return emitWarning(loc)
194 << "expected metadata node " << name << " to hold an i32 value";
197 if (property->getNumOperands() != 2)
198 return emitNodeWarning();
200 llvm::ConstantInt *val =
201 llvm::mdconst::dyn_extract<llvm::ConstantInt>(property->getOperand(1));
202 if (!val || val->getBitWidth() != 32)
203 return emitNodeWarning();
205 return IntegerAttr::get(IntegerType::get(ctx, 32),
206 val->getValue().getLimitedValue());
209 FailureOr<llvm::MDNode *> LoopMetadataConversion::lookupMDNode(StringRef name) {
210 const llvm::MDNode *property = lookupAndEraseProperty(name);
211 if (!property)
212 return nullptr;
214 auto emitNodeWarning = [&]() {
215 return emitWarning(loc)
216 << "expected metadata node " << name << " to hold an MDNode";
219 if (property->getNumOperands() != 2)
220 return emitNodeWarning();
222 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(1));
223 if (!node)
224 return emitNodeWarning();
226 return node;
229 FailureOr<SmallVector<llvm::MDNode *>>
230 LoopMetadataConversion::lookupMDNodes(StringRef name) {
231 const llvm::MDNode *property = lookupAndEraseProperty(name);
232 SmallVector<llvm::MDNode *> res;
233 if (!property)
234 return res;
236 auto emitNodeWarning = [&]() {
237 return emitWarning(loc) << "expected metadata node " << name
238 << " to hold one or multiple MDNodes";
241 if (property->getNumOperands() < 2)
242 return emitNodeWarning();
244 for (unsigned i = 1, e = property->getNumOperands(); i < e; ++i) {
245 auto *node = dyn_cast<llvm::MDNode>(property->getOperand(i));
246 if (!node)
247 return emitNodeWarning();
248 res.push_back(node);
251 return res;
254 FailureOr<LoopAnnotationAttr>
255 LoopMetadataConversion::lookupFollowupNode(StringRef name) {
256 auto node = lookupMDNode(name);
257 if (failed(node))
258 return failure();
259 if (*node == nullptr)
260 return LoopAnnotationAttr(nullptr);
262 return loopAnnotationImporter.translateLoopAnnotation(*node, loc);
265 static bool isEmptyOrNull(const Attribute attr) { return !attr; }
267 template <typename T>
268 static bool isEmptyOrNull(const SmallVectorImpl<T> &vec) {
269 return vec.empty();
272 /// Helper function that only creates and attribute of type T if all argument
273 /// conversion were successfull and at least one of them holds a non-null value.
274 template <typename T, typename... P>
275 static T createIfNonNull(MLIRContext *ctx, const P &...args) {
276 bool anyFailed = (failed(args) || ...);
277 if (anyFailed)
278 return {};
280 bool allEmpty = (isEmptyOrNull(*args) && ...);
281 if (allEmpty)
282 return {};
284 return T::get(ctx, *args...);
287 FailureOr<LoopVectorizeAttr> LoopMetadataConversion::convertVectorizeAttr() {
288 FailureOr<BoolAttr> enable =
289 lookupBoolNode("llvm.loop.vectorize.enable", true);
290 FailureOr<BoolAttr> predicateEnable =
291 lookupBoolNode("llvm.loop.vectorize.predicate.enable");
292 FailureOr<BoolAttr> scalableEnable =
293 lookupBoolNode("llvm.loop.vectorize.scalable.enable");
294 FailureOr<IntegerAttr> width = lookupIntNode("llvm.loop.vectorize.width");
295 FailureOr<LoopAnnotationAttr> followupVec =
296 lookupFollowupNode("llvm.loop.vectorize.followup_vectorized");
297 FailureOr<LoopAnnotationAttr> followupEpi =
298 lookupFollowupNode("llvm.loop.vectorize.followup_epilogue");
299 FailureOr<LoopAnnotationAttr> followupAll =
300 lookupFollowupNode("llvm.loop.vectorize.followup_all");
302 return createIfNonNull<LoopVectorizeAttr>(ctx, enable, predicateEnable,
303 scalableEnable, width, followupVec,
304 followupEpi, followupAll);
307 FailureOr<LoopInterleaveAttr> LoopMetadataConversion::convertInterleaveAttr() {
308 FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.interleave.count");
309 return createIfNonNull<LoopInterleaveAttr>(ctx, count);
312 FailureOr<LoopUnrollAttr> LoopMetadataConversion::convertUnrollAttr() {
313 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
314 "llvm.loop.unroll.enable", "llvm.loop.unroll.disable", /*negated=*/true);
315 FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.unroll.count");
316 FailureOr<BoolAttr> runtimeDisable =
317 lookupUnitNode("llvm.loop.unroll.runtime.disable");
318 FailureOr<BoolAttr> full = lookupUnitNode("llvm.loop.unroll.full");
319 FailureOr<LoopAnnotationAttr> followupUnrolled =
320 lookupFollowupNode("llvm.loop.unroll.followup_unrolled");
321 FailureOr<LoopAnnotationAttr> followupRemainder =
322 lookupFollowupNode("llvm.loop.unroll.followup_remainder");
323 FailureOr<LoopAnnotationAttr> followupAll =
324 lookupFollowupNode("llvm.loop.unroll.followup_all");
326 return createIfNonNull<LoopUnrollAttr>(ctx, disable, count, runtimeDisable,
327 full, followupUnrolled,
328 followupRemainder, followupAll);
331 FailureOr<LoopUnrollAndJamAttr>
332 LoopMetadataConversion::convertUnrollAndJamAttr() {
333 FailureOr<BoolAttr> disable = lookupBooleanUnitNode(
334 "llvm.loop.unroll_and_jam.enable", "llvm.loop.unroll_and_jam.disable",
335 /*negated=*/true);
336 FailureOr<IntegerAttr> count =
337 lookupIntNode("llvm.loop.unroll_and_jam.count");
338 FailureOr<LoopAnnotationAttr> followupOuter =
339 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_outer");
340 FailureOr<LoopAnnotationAttr> followupInner =
341 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_inner");
342 FailureOr<LoopAnnotationAttr> followupRemainderOuter =
343 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer");
344 FailureOr<LoopAnnotationAttr> followupRemainderInner =
345 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner");
346 FailureOr<LoopAnnotationAttr> followupAll =
347 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_all");
348 return createIfNonNull<LoopUnrollAndJamAttr>(
349 ctx, disable, count, followupOuter, followupInner, followupRemainderOuter,
350 followupRemainderInner, followupAll);
353 FailureOr<LoopLICMAttr> LoopMetadataConversion::convertLICMAttr() {
354 FailureOr<BoolAttr> disable = lookupUnitNode("llvm.licm.disable");
355 FailureOr<BoolAttr> versioningDisable =
356 lookupUnitNode("llvm.loop.licm_versioning.disable");
357 return createIfNonNull<LoopLICMAttr>(ctx, disable, versioningDisable);
360 FailureOr<LoopDistributeAttr> LoopMetadataConversion::convertDistributeAttr() {
361 FailureOr<BoolAttr> disable =
362 lookupBoolNode("llvm.loop.distribute.enable", true);
363 FailureOr<LoopAnnotationAttr> followupCoincident =
364 lookupFollowupNode("llvm.loop.distribute.followup_coincident");
365 FailureOr<LoopAnnotationAttr> followupSequential =
366 lookupFollowupNode("llvm.loop.distribute.followup_sequential");
367 FailureOr<LoopAnnotationAttr> followupFallback =
368 lookupFollowupNode("llvm.loop.distribute.followup_fallback");
369 FailureOr<LoopAnnotationAttr> followupAll =
370 lookupFollowupNode("llvm.loop.distribute.followup_all");
371 return createIfNonNull<LoopDistributeAttr>(ctx, disable, followupCoincident,
372 followupSequential,
373 followupFallback, followupAll);
376 FailureOr<LoopPipelineAttr> LoopMetadataConversion::convertPipelineAttr() {
377 FailureOr<BoolAttr> disable = lookupBoolNode("llvm.loop.pipeline.disable");
378 FailureOr<IntegerAttr> initiationinterval =
379 lookupIntNode("llvm.loop.pipeline.initiationinterval");
380 return createIfNonNull<LoopPipelineAttr>(ctx, disable, initiationinterval);
383 FailureOr<LoopPeeledAttr> LoopMetadataConversion::convertPeeledAttr() {
384 FailureOr<IntegerAttr> count = lookupIntNode("llvm.loop.peeled.count");
385 return createIfNonNull<LoopPeeledAttr>(ctx, count);
388 FailureOr<LoopUnswitchAttr> LoopMetadataConversion::convertUnswitchAttr() {
389 FailureOr<BoolAttr> partialDisable =
390 lookupUnitNode("llvm.loop.unswitch.partial.disable");
391 return createIfNonNull<LoopUnswitchAttr>(ctx, partialDisable);
394 FailureOr<SmallVector<AccessGroupAttr>>
395 LoopMetadataConversion::convertParallelAccesses() {
396 FailureOr<SmallVector<llvm::MDNode *>> nodes =
397 lookupMDNodes("llvm.loop.parallel_accesses");
398 if (failed(nodes))
399 return failure();
400 SmallVector<AccessGroupAttr> refs;
401 for (llvm::MDNode *node : *nodes) {
402 FailureOr<SmallVector<AccessGroupAttr>> accessGroups =
403 loopAnnotationImporter.lookupAccessGroupAttrs(node);
404 if (failed(accessGroups)) {
405 emitWarning(loc) << "could not lookup access group";
406 continue;
408 llvm::append_range(refs, *accessGroups);
410 return refs;
413 FusedLoc LoopMetadataConversion::convertStartLoc() {
414 if (locations.empty())
415 return {};
416 return dyn_cast<FusedLoc>(
417 loopAnnotationImporter.moduleImport.translateLoc(locations[0]));
420 FailureOr<FusedLoc> LoopMetadataConversion::convertEndLoc() {
421 if (locations.size() < 2)
422 return FusedLoc();
423 if (locations.size() > 2)
424 return emitError(loc)
425 << "expected loop metadata to have at most two DILocations";
426 return dyn_cast<FusedLoc>(
427 loopAnnotationImporter.moduleImport.translateLoc(locations[1]));
430 LoopAnnotationAttr LoopMetadataConversion::convert() {
431 if (failed(initConversionState()))
432 return {};
434 FailureOr<BoolAttr> disableNonForced =
435 lookupUnitNode("llvm.loop.disable_nonforced");
436 FailureOr<LoopVectorizeAttr> vecAttr = convertVectorizeAttr();
437 FailureOr<LoopInterleaveAttr> interleaveAttr = convertInterleaveAttr();
438 FailureOr<LoopUnrollAttr> unrollAttr = convertUnrollAttr();
439 FailureOr<LoopUnrollAndJamAttr> unrollAndJamAttr = convertUnrollAndJamAttr();
440 FailureOr<LoopLICMAttr> licmAttr = convertLICMAttr();
441 FailureOr<LoopDistributeAttr> distributeAttr = convertDistributeAttr();
442 FailureOr<LoopPipelineAttr> pipelineAttr = convertPipelineAttr();
443 FailureOr<LoopPeeledAttr> peeledAttr = convertPeeledAttr();
444 FailureOr<LoopUnswitchAttr> unswitchAttr = convertUnswitchAttr();
445 FailureOr<BoolAttr> mustProgress = lookupUnitNode("llvm.loop.mustprogress");
446 FailureOr<BoolAttr> isVectorized =
447 lookupIntNodeAsBoolAttr("llvm.loop.isvectorized");
448 FailureOr<SmallVector<AccessGroupAttr>> parallelAccesses =
449 convertParallelAccesses();
451 // Drop the metadata if there are parts that cannot be imported.
452 if (!propertyMap.empty()) {
453 for (auto name : propertyMap.keys())
454 emitWarning(loc) << "unknown loop annotation " << name;
455 return {};
458 FailureOr<FusedLoc> startLoc = convertStartLoc();
459 FailureOr<FusedLoc> endLoc = convertEndLoc();
461 return createIfNonNull<LoopAnnotationAttr>(
462 ctx, disableNonForced, vecAttr, interleaveAttr, unrollAttr,
463 unrollAndJamAttr, licmAttr, distributeAttr, pipelineAttr, peeledAttr,
464 unswitchAttr, mustProgress, isVectorized, startLoc, endLoc,
465 parallelAccesses);
468 LoopAnnotationAttr
469 LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode *node,
470 Location loc) {
471 if (!node)
472 return {};
474 // Note: This check is necessary to distinguish between failed translations
475 // and not yet attempted translations.
476 auto it = loopMetadataMapping.find(node);
477 if (it != loopMetadataMapping.end())
478 return it->getSecond();
480 LoopAnnotationAttr attr = LoopMetadataConversion(node, loc, *this).convert();
482 mapLoopMetadata(node, attr);
483 return attr;
486 LogicalResult
487 LoopAnnotationImporter::translateAccessGroup(const llvm::MDNode *node,
488 Location loc) {
489 SmallVector<const llvm::MDNode *> accessGroups;
490 if (!node->getNumOperands())
491 accessGroups.push_back(node);
492 for (const llvm::MDOperand &operand : node->operands()) {
493 auto *childNode = dyn_cast<llvm::MDNode>(operand);
494 if (!childNode)
495 return failure();
496 accessGroups.push_back(cast<llvm::MDNode>(operand.get()));
499 // Convert all entries of the access group list to access group operations.
500 for (const llvm::MDNode *accessGroup : accessGroups) {
501 if (accessGroupMapping.count(accessGroup))
502 continue;
503 // Verify the access group node is distinct and empty.
504 if (accessGroup->getNumOperands() != 0 || !accessGroup->isDistinct())
505 return emitWarning(loc)
506 << "expected an access group node to be empty and distinct";
508 // Add a mapping from the access group node to the newly created attribute.
509 accessGroupMapping[accessGroup] = builder.getAttr<AccessGroupAttr>();
511 return success();
514 FailureOr<SmallVector<AccessGroupAttr>>
515 LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode *node) const {
516 // An access group node is either a single access group or an access group
517 // list.
518 SmallVector<AccessGroupAttr> accessGroups;
519 if (!node->getNumOperands())
520 accessGroups.push_back(accessGroupMapping.lookup(node));
521 for (const llvm::MDOperand &operand : node->operands()) {
522 auto *node = cast<llvm::MDNode>(operand.get());
523 accessGroups.push_back(accessGroupMapping.lookup(node));
525 // Exit if one of the access group node lookups failed.
526 if (llvm::is_contained(accessGroups, nullptr))
527 return failure();
528 return accessGroups;