1 //===- LoopAnnotationImporter.cpp - Loop annotation import ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "LoopAnnotationImporter.h"
10 #include "llvm/IR/Constants.h"
13 using namespace mlir::LLVM
;
14 using namespace mlir::LLVM::detail
;
17 /// Helper class that keeps the state of one metadata to attribute conversion.
18 struct LoopMetadataConversion
{
19 LoopMetadataConversion(const llvm::MDNode
*node
, Location loc
,
20 LoopAnnotationImporter
&loopAnnotationImporter
)
21 : node(node
), loc(loc
), loopAnnotationImporter(loopAnnotationImporter
),
22 ctx(loc
->getContext()){};
23 /// Converts this structs loop metadata node into a LoopAnnotationAttr.
24 LoopAnnotationAttr
convert();
26 /// Initializes the shared state for the conversion member functions.
27 LogicalResult
initConversionState();
29 /// Helper function to get and erase a property.
30 const llvm::MDNode
*lookupAndEraseProperty(StringRef name
);
32 /// Helper functions to lookup and convert MDNodes into a specifc attribute
33 /// kind. These functions return null-attributes if there is no node with the
34 /// specified name, or failure, if the node is ill-formatted.
35 FailureOr
<BoolAttr
> lookupUnitNode(StringRef name
);
36 FailureOr
<BoolAttr
> lookupBoolNode(StringRef name
, bool negated
= false);
37 FailureOr
<BoolAttr
> lookupIntNodeAsBoolAttr(StringRef name
);
38 FailureOr
<IntegerAttr
> lookupIntNode(StringRef name
);
39 FailureOr
<llvm::MDNode
*> lookupMDNode(StringRef name
);
40 FailureOr
<SmallVector
<llvm::MDNode
*>> lookupMDNodes(StringRef name
);
41 FailureOr
<LoopAnnotationAttr
> lookupFollowupNode(StringRef name
);
42 FailureOr
<BoolAttr
> lookupBooleanUnitNode(StringRef enableName
,
43 StringRef disableName
,
44 bool negated
= false);
46 /// Conversion functions for sub-attributes.
47 FailureOr
<LoopVectorizeAttr
> convertVectorizeAttr();
48 FailureOr
<LoopInterleaveAttr
> convertInterleaveAttr();
49 FailureOr
<LoopUnrollAttr
> convertUnrollAttr();
50 FailureOr
<LoopUnrollAndJamAttr
> convertUnrollAndJamAttr();
51 FailureOr
<LoopLICMAttr
> convertLICMAttr();
52 FailureOr
<LoopDistributeAttr
> convertDistributeAttr();
53 FailureOr
<LoopPipelineAttr
> convertPipelineAttr();
54 FailureOr
<LoopPeeledAttr
> convertPeeledAttr();
55 FailureOr
<LoopUnswitchAttr
> convertUnswitchAttr();
56 FailureOr
<SmallVector
<AccessGroupAttr
>> convertParallelAccesses();
57 FusedLoc
convertStartLoc();
58 FailureOr
<FusedLoc
> convertEndLoc();
60 llvm::SmallVector
<llvm::DILocation
*, 2> locations
;
61 llvm::StringMap
<const llvm::MDNode
*> propertyMap
;
62 const llvm::MDNode
*node
;
64 LoopAnnotationImporter
&loopAnnotationImporter
;
69 LogicalResult
LoopMetadataConversion::initConversionState() {
70 // Check if it's a valid node.
71 if (node
->getNumOperands() == 0 ||
72 dyn_cast
<llvm::MDNode
>(node
->getOperand(0)) != node
)
73 return emitWarning(loc
) << "invalid loop node";
75 for (const llvm::MDOperand
&operand
: llvm::drop_begin(node
->operands())) {
76 if (auto *diLoc
= dyn_cast
<llvm::DILocation
>(operand
)) {
77 locations
.push_back(diLoc
);
81 auto *property
= dyn_cast
<llvm::MDNode
>(operand
);
83 return emitWarning(loc
) << "expected all loop properties to be either "
84 "debug locations or metadata nodes";
86 if (property
->getNumOperands() == 0)
87 return emitWarning(loc
) << "cannot import empty loop property";
89 auto *nameNode
= dyn_cast
<llvm::MDString
>(property
->getOperand(0));
91 return emitWarning(loc
) << "cannot import loop property without a name";
92 StringRef name
= nameNode
->getString();
94 bool succ
= propertyMap
.try_emplace(name
, property
).second
;
96 return emitWarning(loc
)
97 << "cannot import loop properties with duplicated names " << name
;
104 LoopMetadataConversion::lookupAndEraseProperty(StringRef name
) {
105 auto it
= propertyMap
.find(name
);
106 if (it
== propertyMap
.end())
108 const llvm::MDNode
*property
= it
->getValue();
109 propertyMap
.erase(it
);
113 FailureOr
<BoolAttr
> LoopMetadataConversion::lookupUnitNode(StringRef name
) {
114 const llvm::MDNode
*property
= lookupAndEraseProperty(name
);
116 return BoolAttr(nullptr);
118 if (property
->getNumOperands() != 1)
119 return emitWarning(loc
)
120 << "expected metadata node " << name
<< " to hold no value";
122 return BoolAttr::get(ctx
, true);
125 FailureOr
<BoolAttr
> LoopMetadataConversion::lookupBooleanUnitNode(
126 StringRef enableName
, StringRef disableName
, bool negated
) {
127 auto enable
= lookupUnitNode(enableName
);
128 auto disable
= lookupUnitNode(disableName
);
129 if (failed(enable
) || failed(disable
))
132 if (*enable
&& *disable
)
133 return emitWarning(loc
)
134 << "expected metadata nodes " << enableName
<< " and " << disableName
135 << " to be mutually exclusive.";
138 return BoolAttr::get(ctx
, !negated
);
141 return BoolAttr::get(ctx
, negated
);
142 return BoolAttr(nullptr);
145 FailureOr
<BoolAttr
> LoopMetadataConversion::lookupBoolNode(StringRef name
,
147 const llvm::MDNode
*property
= lookupAndEraseProperty(name
);
149 return BoolAttr(nullptr);
151 auto emitNodeWarning
= [&]() {
152 return emitWarning(loc
)
153 << "expected metadata node " << name
<< " to hold a boolean value";
156 if (property
->getNumOperands() != 2)
157 return emitNodeWarning();
158 llvm::ConstantInt
*val
=
159 llvm::mdconst::dyn_extract
<llvm::ConstantInt
>(property
->getOperand(1));
160 if (!val
|| val
->getBitWidth() != 1)
161 return emitNodeWarning();
163 return BoolAttr::get(ctx
, val
->getValue().getLimitedValue(1) ^ negated
);
167 LoopMetadataConversion::lookupIntNodeAsBoolAttr(StringRef name
) {
168 const llvm::MDNode
*property
= lookupAndEraseProperty(name
);
170 return BoolAttr(nullptr);
172 auto emitNodeWarning
= [&]() {
173 return emitWarning(loc
)
174 << "expected metadata node " << name
<< " to hold an integer value";
177 if (property
->getNumOperands() != 2)
178 return emitNodeWarning();
179 llvm::ConstantInt
*val
=
180 llvm::mdconst::dyn_extract
<llvm::ConstantInt
>(property
->getOperand(1));
181 if (!val
|| val
->getBitWidth() != 32)
182 return emitNodeWarning();
184 return BoolAttr::get(ctx
, val
->getValue().getLimitedValue(1));
187 FailureOr
<IntegerAttr
> LoopMetadataConversion::lookupIntNode(StringRef name
) {
188 const llvm::MDNode
*property
= lookupAndEraseProperty(name
);
190 return IntegerAttr(nullptr);
192 auto emitNodeWarning
= [&]() {
193 return emitWarning(loc
)
194 << "expected metadata node " << name
<< " to hold an i32 value";
197 if (property
->getNumOperands() != 2)
198 return emitNodeWarning();
200 llvm::ConstantInt
*val
=
201 llvm::mdconst::dyn_extract
<llvm::ConstantInt
>(property
->getOperand(1));
202 if (!val
|| val
->getBitWidth() != 32)
203 return emitNodeWarning();
205 return IntegerAttr::get(IntegerType::get(ctx
, 32),
206 val
->getValue().getLimitedValue());
209 FailureOr
<llvm::MDNode
*> LoopMetadataConversion::lookupMDNode(StringRef name
) {
210 const llvm::MDNode
*property
= lookupAndEraseProperty(name
);
214 auto emitNodeWarning
= [&]() {
215 return emitWarning(loc
)
216 << "expected metadata node " << name
<< " to hold an MDNode";
219 if (property
->getNumOperands() != 2)
220 return emitNodeWarning();
222 auto *node
= dyn_cast
<llvm::MDNode
>(property
->getOperand(1));
224 return emitNodeWarning();
229 FailureOr
<SmallVector
<llvm::MDNode
*>>
230 LoopMetadataConversion::lookupMDNodes(StringRef name
) {
231 const llvm::MDNode
*property
= lookupAndEraseProperty(name
);
232 SmallVector
<llvm::MDNode
*> res
;
236 auto emitNodeWarning
= [&]() {
237 return emitWarning(loc
) << "expected metadata node " << name
238 << " to hold one or multiple MDNodes";
241 if (property
->getNumOperands() < 2)
242 return emitNodeWarning();
244 for (unsigned i
= 1, e
= property
->getNumOperands(); i
< e
; ++i
) {
245 auto *node
= dyn_cast
<llvm::MDNode
>(property
->getOperand(i
));
247 return emitNodeWarning();
254 FailureOr
<LoopAnnotationAttr
>
255 LoopMetadataConversion::lookupFollowupNode(StringRef name
) {
256 auto node
= lookupMDNode(name
);
259 if (*node
== nullptr)
260 return LoopAnnotationAttr(nullptr);
262 return loopAnnotationImporter
.translateLoopAnnotation(*node
, loc
);
265 static bool isEmptyOrNull(const Attribute attr
) { return !attr
; }
267 template <typename T
>
268 static bool isEmptyOrNull(const SmallVectorImpl
<T
> &vec
) {
272 /// Helper function that only creates and attribute of type T if all argument
273 /// conversion were successfull and at least one of them holds a non-null value.
274 template <typename T
, typename
... P
>
275 static T
createIfNonNull(MLIRContext
*ctx
, const P
&...args
) {
276 bool anyFailed
= (failed(args
) || ...);
280 bool allEmpty
= (isEmptyOrNull(*args
) && ...);
284 return T::get(ctx
, *args
...);
287 FailureOr
<LoopVectorizeAttr
> LoopMetadataConversion::convertVectorizeAttr() {
288 FailureOr
<BoolAttr
> enable
=
289 lookupBoolNode("llvm.loop.vectorize.enable", true);
290 FailureOr
<BoolAttr
> predicateEnable
=
291 lookupBoolNode("llvm.loop.vectorize.predicate.enable");
292 FailureOr
<BoolAttr
> scalableEnable
=
293 lookupBoolNode("llvm.loop.vectorize.scalable.enable");
294 FailureOr
<IntegerAttr
> width
= lookupIntNode("llvm.loop.vectorize.width");
295 FailureOr
<LoopAnnotationAttr
> followupVec
=
296 lookupFollowupNode("llvm.loop.vectorize.followup_vectorized");
297 FailureOr
<LoopAnnotationAttr
> followupEpi
=
298 lookupFollowupNode("llvm.loop.vectorize.followup_epilogue");
299 FailureOr
<LoopAnnotationAttr
> followupAll
=
300 lookupFollowupNode("llvm.loop.vectorize.followup_all");
302 return createIfNonNull
<LoopVectorizeAttr
>(ctx
, enable
, predicateEnable
,
303 scalableEnable
, width
, followupVec
,
304 followupEpi
, followupAll
);
307 FailureOr
<LoopInterleaveAttr
> LoopMetadataConversion::convertInterleaveAttr() {
308 FailureOr
<IntegerAttr
> count
= lookupIntNode("llvm.loop.interleave.count");
309 return createIfNonNull
<LoopInterleaveAttr
>(ctx
, count
);
312 FailureOr
<LoopUnrollAttr
> LoopMetadataConversion::convertUnrollAttr() {
313 FailureOr
<BoolAttr
> disable
= lookupBooleanUnitNode(
314 "llvm.loop.unroll.enable", "llvm.loop.unroll.disable", /*negated=*/true);
315 FailureOr
<IntegerAttr
> count
= lookupIntNode("llvm.loop.unroll.count");
316 FailureOr
<BoolAttr
> runtimeDisable
=
317 lookupUnitNode("llvm.loop.unroll.runtime.disable");
318 FailureOr
<BoolAttr
> full
= lookupUnitNode("llvm.loop.unroll.full");
319 FailureOr
<LoopAnnotationAttr
> followupUnrolled
=
320 lookupFollowupNode("llvm.loop.unroll.followup_unrolled");
321 FailureOr
<LoopAnnotationAttr
> followupRemainder
=
322 lookupFollowupNode("llvm.loop.unroll.followup_remainder");
323 FailureOr
<LoopAnnotationAttr
> followupAll
=
324 lookupFollowupNode("llvm.loop.unroll.followup_all");
326 return createIfNonNull
<LoopUnrollAttr
>(ctx
, disable
, count
, runtimeDisable
,
327 full
, followupUnrolled
,
328 followupRemainder
, followupAll
);
331 FailureOr
<LoopUnrollAndJamAttr
>
332 LoopMetadataConversion::convertUnrollAndJamAttr() {
333 FailureOr
<BoolAttr
> disable
= lookupBooleanUnitNode(
334 "llvm.loop.unroll_and_jam.enable", "llvm.loop.unroll_and_jam.disable",
336 FailureOr
<IntegerAttr
> count
=
337 lookupIntNode("llvm.loop.unroll_and_jam.count");
338 FailureOr
<LoopAnnotationAttr
> followupOuter
=
339 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_outer");
340 FailureOr
<LoopAnnotationAttr
> followupInner
=
341 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_inner");
342 FailureOr
<LoopAnnotationAttr
> followupRemainderOuter
=
343 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer");
344 FailureOr
<LoopAnnotationAttr
> followupRemainderInner
=
345 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner");
346 FailureOr
<LoopAnnotationAttr
> followupAll
=
347 lookupFollowupNode("llvm.loop.unroll_and_jam.followup_all");
348 return createIfNonNull
<LoopUnrollAndJamAttr
>(
349 ctx
, disable
, count
, followupOuter
, followupInner
, followupRemainderOuter
,
350 followupRemainderInner
, followupAll
);
353 FailureOr
<LoopLICMAttr
> LoopMetadataConversion::convertLICMAttr() {
354 FailureOr
<BoolAttr
> disable
= lookupUnitNode("llvm.licm.disable");
355 FailureOr
<BoolAttr
> versioningDisable
=
356 lookupUnitNode("llvm.loop.licm_versioning.disable");
357 return createIfNonNull
<LoopLICMAttr
>(ctx
, disable
, versioningDisable
);
360 FailureOr
<LoopDistributeAttr
> LoopMetadataConversion::convertDistributeAttr() {
361 FailureOr
<BoolAttr
> disable
=
362 lookupBoolNode("llvm.loop.distribute.enable", true);
363 FailureOr
<LoopAnnotationAttr
> followupCoincident
=
364 lookupFollowupNode("llvm.loop.distribute.followup_coincident");
365 FailureOr
<LoopAnnotationAttr
> followupSequential
=
366 lookupFollowupNode("llvm.loop.distribute.followup_sequential");
367 FailureOr
<LoopAnnotationAttr
> followupFallback
=
368 lookupFollowupNode("llvm.loop.distribute.followup_fallback");
369 FailureOr
<LoopAnnotationAttr
> followupAll
=
370 lookupFollowupNode("llvm.loop.distribute.followup_all");
371 return createIfNonNull
<LoopDistributeAttr
>(ctx
, disable
, followupCoincident
,
373 followupFallback
, followupAll
);
376 FailureOr
<LoopPipelineAttr
> LoopMetadataConversion::convertPipelineAttr() {
377 FailureOr
<BoolAttr
> disable
= lookupBoolNode("llvm.loop.pipeline.disable");
378 FailureOr
<IntegerAttr
> initiationinterval
=
379 lookupIntNode("llvm.loop.pipeline.initiationinterval");
380 return createIfNonNull
<LoopPipelineAttr
>(ctx
, disable
, initiationinterval
);
383 FailureOr
<LoopPeeledAttr
> LoopMetadataConversion::convertPeeledAttr() {
384 FailureOr
<IntegerAttr
> count
= lookupIntNode("llvm.loop.peeled.count");
385 return createIfNonNull
<LoopPeeledAttr
>(ctx
, count
);
388 FailureOr
<LoopUnswitchAttr
> LoopMetadataConversion::convertUnswitchAttr() {
389 FailureOr
<BoolAttr
> partialDisable
=
390 lookupUnitNode("llvm.loop.unswitch.partial.disable");
391 return createIfNonNull
<LoopUnswitchAttr
>(ctx
, partialDisable
);
394 FailureOr
<SmallVector
<AccessGroupAttr
>>
395 LoopMetadataConversion::convertParallelAccesses() {
396 FailureOr
<SmallVector
<llvm::MDNode
*>> nodes
=
397 lookupMDNodes("llvm.loop.parallel_accesses");
400 SmallVector
<AccessGroupAttr
> refs
;
401 for (llvm::MDNode
*node
: *nodes
) {
402 FailureOr
<SmallVector
<AccessGroupAttr
>> accessGroups
=
403 loopAnnotationImporter
.lookupAccessGroupAttrs(node
);
404 if (failed(accessGroups
)) {
405 emitWarning(loc
) << "could not lookup access group";
408 llvm::append_range(refs
, *accessGroups
);
413 FusedLoc
LoopMetadataConversion::convertStartLoc() {
414 if (locations
.empty())
416 return dyn_cast
<FusedLoc
>(
417 loopAnnotationImporter
.moduleImport
.translateLoc(locations
[0]));
420 FailureOr
<FusedLoc
> LoopMetadataConversion::convertEndLoc() {
421 if (locations
.size() < 2)
423 if (locations
.size() > 2)
424 return emitError(loc
)
425 << "expected loop metadata to have at most two DILocations";
426 return dyn_cast
<FusedLoc
>(
427 loopAnnotationImporter
.moduleImport
.translateLoc(locations
[1]));
430 LoopAnnotationAttr
LoopMetadataConversion::convert() {
431 if (failed(initConversionState()))
434 FailureOr
<BoolAttr
> disableNonForced
=
435 lookupUnitNode("llvm.loop.disable_nonforced");
436 FailureOr
<LoopVectorizeAttr
> vecAttr
= convertVectorizeAttr();
437 FailureOr
<LoopInterleaveAttr
> interleaveAttr
= convertInterleaveAttr();
438 FailureOr
<LoopUnrollAttr
> unrollAttr
= convertUnrollAttr();
439 FailureOr
<LoopUnrollAndJamAttr
> unrollAndJamAttr
= convertUnrollAndJamAttr();
440 FailureOr
<LoopLICMAttr
> licmAttr
= convertLICMAttr();
441 FailureOr
<LoopDistributeAttr
> distributeAttr
= convertDistributeAttr();
442 FailureOr
<LoopPipelineAttr
> pipelineAttr
= convertPipelineAttr();
443 FailureOr
<LoopPeeledAttr
> peeledAttr
= convertPeeledAttr();
444 FailureOr
<LoopUnswitchAttr
> unswitchAttr
= convertUnswitchAttr();
445 FailureOr
<BoolAttr
> mustProgress
= lookupUnitNode("llvm.loop.mustprogress");
446 FailureOr
<BoolAttr
> isVectorized
=
447 lookupIntNodeAsBoolAttr("llvm.loop.isvectorized");
448 FailureOr
<SmallVector
<AccessGroupAttr
>> parallelAccesses
=
449 convertParallelAccesses();
451 // Drop the metadata if there are parts that cannot be imported.
452 if (!propertyMap
.empty()) {
453 for (auto name
: propertyMap
.keys())
454 emitWarning(loc
) << "unknown loop annotation " << name
;
458 FailureOr
<FusedLoc
> startLoc
= convertStartLoc();
459 FailureOr
<FusedLoc
> endLoc
= convertEndLoc();
461 return createIfNonNull
<LoopAnnotationAttr
>(
462 ctx
, disableNonForced
, vecAttr
, interleaveAttr
, unrollAttr
,
463 unrollAndJamAttr
, licmAttr
, distributeAttr
, pipelineAttr
, peeledAttr
,
464 unswitchAttr
, mustProgress
, isVectorized
, startLoc
, endLoc
,
469 LoopAnnotationImporter::translateLoopAnnotation(const llvm::MDNode
*node
,
474 // Note: This check is necessary to distinguish between failed translations
475 // and not yet attempted translations.
476 auto it
= loopMetadataMapping
.find(node
);
477 if (it
!= loopMetadataMapping
.end())
478 return it
->getSecond();
480 LoopAnnotationAttr attr
= LoopMetadataConversion(node
, loc
, *this).convert();
482 mapLoopMetadata(node
, attr
);
487 LoopAnnotationImporter::translateAccessGroup(const llvm::MDNode
*node
,
489 SmallVector
<const llvm::MDNode
*> accessGroups
;
490 if (!node
->getNumOperands())
491 accessGroups
.push_back(node
);
492 for (const llvm::MDOperand
&operand
: node
->operands()) {
493 auto *childNode
= dyn_cast
<llvm::MDNode
>(operand
);
496 accessGroups
.push_back(cast
<llvm::MDNode
>(operand
.get()));
499 // Convert all entries of the access group list to access group operations.
500 for (const llvm::MDNode
*accessGroup
: accessGroups
) {
501 if (accessGroupMapping
.count(accessGroup
))
503 // Verify the access group node is distinct and empty.
504 if (accessGroup
->getNumOperands() != 0 || !accessGroup
->isDistinct())
505 return emitWarning(loc
)
506 << "expected an access group node to be empty and distinct";
508 // Add a mapping from the access group node to the newly created attribute.
509 accessGroupMapping
[accessGroup
] = builder
.getAttr
<AccessGroupAttr
>();
514 FailureOr
<SmallVector
<AccessGroupAttr
>>
515 LoopAnnotationImporter::lookupAccessGroupAttrs(const llvm::MDNode
*node
) const {
516 // An access group node is either a single access group or an access group
518 SmallVector
<AccessGroupAttr
> accessGroups
;
519 if (!node
->getNumOperands())
520 accessGroups
.push_back(accessGroupMapping
.lookup(node
));
521 for (const llvm::MDOperand
&operand
: node
->operands()) {
522 auto *node
= cast
<llvm::MDNode
>(operand
.get());
523 accessGroups
.push_back(accessGroupMapping
.lookup(node
));
525 // Exit if one of the access group node lookups failed.
526 if (llvm::is_contained(accessGroups
, nullptr))