1 //===- LoopAnnotationTranslation.cpp - Loop annotation export -------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "LoopAnnotationTranslation.h"
10 #include "llvm/IR/DebugInfoMetadata.h"
13 using namespace mlir::LLVM
;
14 using namespace mlir::LLVM::detail
;
17 /// Helper class that keeps the state of one attribute to metadata conversion.
18 struct LoopAnnotationConversion
{
19 LoopAnnotationConversion(LoopAnnotationAttr attr
, Operation
*op
,
20 LoopAnnotationTranslation
&loopAnnotationTranslation
,
21 llvm::LLVMContext
&ctx
)
23 loopAnnotationTranslation(loopAnnotationTranslation
), ctx(ctx
) {}
25 /// Converts this struct's loop annotation into a corresponding LLVMIR
26 /// metadata representation.
27 llvm::MDNode
*convert();
29 /// Conversion functions for different payload attribute kinds.
30 void addUnitNode(StringRef name
);
31 void addUnitNode(StringRef name
, BoolAttr attr
);
32 void addI32NodeWithVal(StringRef name
, uint32_t val
);
33 void convertBoolNode(StringRef name
, BoolAttr attr
, bool negated
= false);
34 void convertI32Node(StringRef name
, IntegerAttr attr
);
35 void convertFollowupNode(StringRef name
, LoopAnnotationAttr attr
);
36 void convertLocation(FusedLoc attr
);
38 /// Conversion functions for each for each loop annotation sub-attribute.
39 void convertLoopOptions(LoopVectorizeAttr options
);
40 void convertLoopOptions(LoopInterleaveAttr options
);
41 void convertLoopOptions(LoopUnrollAttr options
);
42 void convertLoopOptions(LoopUnrollAndJamAttr options
);
43 void convertLoopOptions(LoopLICMAttr options
);
44 void convertLoopOptions(LoopDistributeAttr options
);
45 void convertLoopOptions(LoopPipelineAttr options
);
46 void convertLoopOptions(LoopPeeledAttr options
);
47 void convertLoopOptions(LoopUnswitchAttr options
);
49 LoopAnnotationAttr attr
;
51 LoopAnnotationTranslation
&loopAnnotationTranslation
;
52 llvm::LLVMContext
&ctx
;
53 llvm::SmallVector
<llvm::Metadata
*> metadataNodes
;
57 void LoopAnnotationConversion::addUnitNode(StringRef name
) {
58 metadataNodes
.push_back(
59 llvm::MDNode::get(ctx
, {llvm::MDString::get(ctx
, name
)}));
62 void LoopAnnotationConversion::addUnitNode(StringRef name
, BoolAttr attr
) {
63 if (attr
&& attr
.getValue())
67 void LoopAnnotationConversion::addI32NodeWithVal(StringRef name
, uint32_t val
) {
68 llvm::Constant
*cstValue
= llvm::ConstantInt::get(
69 llvm::IntegerType::get(ctx
, /*NumBits=*/32), val
, /*isSigned=*/false);
70 metadataNodes
.push_back(
71 llvm::MDNode::get(ctx
, {llvm::MDString::get(ctx
, name
),
72 llvm::ConstantAsMetadata::get(cstValue
)}));
75 void LoopAnnotationConversion::convertBoolNode(StringRef name
, BoolAttr attr
,
79 bool val
= negated
^ attr
.getValue();
80 llvm::Constant
*cstValue
= llvm::ConstantInt::getBool(ctx
, val
);
81 metadataNodes
.push_back(
82 llvm::MDNode::get(ctx
, {llvm::MDString::get(ctx
, name
),
83 llvm::ConstantAsMetadata::get(cstValue
)}));
86 void LoopAnnotationConversion::convertI32Node(StringRef name
,
90 addI32NodeWithVal(name
, attr
.getInt());
93 void LoopAnnotationConversion::convertFollowupNode(StringRef name
,
94 LoopAnnotationAttr attr
) {
99 loopAnnotationTranslation
.translateLoopAnnotation(attr
, op
);
101 metadataNodes
.push_back(
102 llvm::MDNode::get(ctx
, {llvm::MDString::get(ctx
, name
), node
}));
105 void LoopAnnotationConversion::convertLoopOptions(LoopVectorizeAttr options
) {
106 convertBoolNode("llvm.loop.vectorize.enable", options
.getDisable(), true);
107 convertBoolNode("llvm.loop.vectorize.predicate.enable",
108 options
.getPredicateEnable());
109 convertBoolNode("llvm.loop.vectorize.scalable.enable",
110 options
.getScalableEnable());
111 convertI32Node("llvm.loop.vectorize.width", options
.getWidth());
112 convertFollowupNode("llvm.loop.vectorize.followup_vectorized",
113 options
.getFollowupVectorized());
114 convertFollowupNode("llvm.loop.vectorize.followup_epilogue",
115 options
.getFollowupEpilogue());
116 convertFollowupNode("llvm.loop.vectorize.followup_all",
117 options
.getFollowupAll());
120 void LoopAnnotationConversion::convertLoopOptions(LoopInterleaveAttr options
) {
121 convertI32Node("llvm.loop.interleave.count", options
.getCount());
124 void LoopAnnotationConversion::convertLoopOptions(LoopUnrollAttr options
) {
125 if (auto disable
= options
.getDisable())
126 addUnitNode(disable
.getValue() ? "llvm.loop.unroll.disable"
127 : "llvm.loop.unroll.enable");
128 convertI32Node("llvm.loop.unroll.count", options
.getCount());
129 convertBoolNode("llvm.loop.unroll.runtime.disable",
130 options
.getRuntimeDisable());
131 addUnitNode("llvm.loop.unroll.full", options
.getFull());
132 convertFollowupNode("llvm.loop.unroll.followup_unrolled",
133 options
.getFollowupUnrolled());
134 convertFollowupNode("llvm.loop.unroll.followup_remainder",
135 options
.getFollowupRemainder());
136 convertFollowupNode("llvm.loop.unroll.followup_all",
137 options
.getFollowupAll());
140 void LoopAnnotationConversion::convertLoopOptions(
141 LoopUnrollAndJamAttr options
) {
142 if (auto disable
= options
.getDisable())
143 addUnitNode(disable
.getValue() ? "llvm.loop.unroll_and_jam.disable"
144 : "llvm.loop.unroll_and_jam.enable");
145 convertI32Node("llvm.loop.unroll_and_jam.count", options
.getCount());
146 convertFollowupNode("llvm.loop.unroll_and_jam.followup_outer",
147 options
.getFollowupOuter());
148 convertFollowupNode("llvm.loop.unroll_and_jam.followup_inner",
149 options
.getFollowupInner());
150 convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_outer",
151 options
.getFollowupRemainderOuter());
152 convertFollowupNode("llvm.loop.unroll_and_jam.followup_remainder_inner",
153 options
.getFollowupRemainderInner());
154 convertFollowupNode("llvm.loop.unroll_and_jam.followup_all",
155 options
.getFollowupAll());
158 void LoopAnnotationConversion::convertLoopOptions(LoopLICMAttr options
) {
159 addUnitNode("llvm.licm.disable", options
.getDisable());
160 addUnitNode("llvm.loop.licm_versioning.disable",
161 options
.getVersioningDisable());
164 void LoopAnnotationConversion::convertLoopOptions(LoopDistributeAttr options
) {
165 convertBoolNode("llvm.loop.distribute.enable", options
.getDisable(), true);
166 convertFollowupNode("llvm.loop.distribute.followup_coincident",
167 options
.getFollowupCoincident());
168 convertFollowupNode("llvm.loop.distribute.followup_sequential",
169 options
.getFollowupSequential());
170 convertFollowupNode("llvm.loop.distribute.followup_fallback",
171 options
.getFollowupFallback());
172 convertFollowupNode("llvm.loop.distribute.followup_all",
173 options
.getFollowupAll());
176 void LoopAnnotationConversion::convertLoopOptions(LoopPipelineAttr options
) {
177 convertBoolNode("llvm.loop.pipeline.disable", options
.getDisable());
178 convertI32Node("llvm.loop.pipeline.initiationinterval",
179 options
.getInitiationinterval());
182 void LoopAnnotationConversion::convertLoopOptions(LoopPeeledAttr options
) {
183 convertI32Node("llvm.loop.peeled.count", options
.getCount());
186 void LoopAnnotationConversion::convertLoopOptions(LoopUnswitchAttr options
) {
187 addUnitNode("llvm.loop.unswitch.partial.disable",
188 options
.getPartialDisable());
191 void LoopAnnotationConversion::convertLocation(FusedLoc location
) {
192 auto localScopeAttr
=
193 dyn_cast_or_null
<DILocalScopeAttr
>(location
.getMetadata());
196 auto *localScope
= dyn_cast
<llvm::DILocalScope
>(
197 loopAnnotationTranslation
.moduleTranslation
.translateDebugInfo(
201 llvm::Metadata
*loc
=
202 loopAnnotationTranslation
.moduleTranslation
.translateLoc(location
,
204 metadataNodes
.push_back(loc
);
207 llvm::MDNode
*LoopAnnotationConversion::convert() {
208 // Reserve operand 0 for loop id self reference.
209 auto dummy
= llvm::MDNode::getTemporary(ctx
, std::nullopt
);
210 metadataNodes
.push_back(dummy
.get());
212 if (FusedLoc startLoc
= attr
.getStartLoc())
213 convertLocation(startLoc
);
215 if (FusedLoc endLoc
= attr
.getEndLoc())
216 convertLocation(endLoc
);
218 addUnitNode("llvm.loop.disable_nonforced", attr
.getDisableNonforced());
219 addUnitNode("llvm.loop.mustprogress", attr
.getMustProgress());
220 // "isvectorized" is encoded as an i32 value.
221 if (BoolAttr isVectorized
= attr
.getIsVectorized())
222 addI32NodeWithVal("llvm.loop.isvectorized", isVectorized
.getValue());
224 if (auto options
= attr
.getVectorize())
225 convertLoopOptions(options
);
226 if (auto options
= attr
.getInterleave())
227 convertLoopOptions(options
);
228 if (auto options
= attr
.getUnroll())
229 convertLoopOptions(options
);
230 if (auto options
= attr
.getUnrollAndJam())
231 convertLoopOptions(options
);
232 if (auto options
= attr
.getLicm())
233 convertLoopOptions(options
);
234 if (auto options
= attr
.getDistribute())
235 convertLoopOptions(options
);
236 if (auto options
= attr
.getPipeline())
237 convertLoopOptions(options
);
238 if (auto options
= attr
.getPeeled())
239 convertLoopOptions(options
);
240 if (auto options
= attr
.getUnswitch())
241 convertLoopOptions(options
);
243 ArrayRef
<AccessGroupAttr
> parallelAccessGroups
= attr
.getParallelAccesses();
244 if (!parallelAccessGroups
.empty()) {
245 SmallVector
<llvm::Metadata
*> parallelAccess
;
246 parallelAccess
.push_back(
247 llvm::MDString::get(ctx
, "llvm.loop.parallel_accesses"));
248 for (AccessGroupAttr accessGroupAttr
: parallelAccessGroups
)
249 parallelAccess
.push_back(
250 loopAnnotationTranslation
.getAccessGroup(accessGroupAttr
));
251 metadataNodes
.push_back(llvm::MDNode::get(ctx
, parallelAccess
));
254 // Create loop options and set the first operand to itself.
255 llvm::MDNode
*loopMD
= llvm::MDNode::get(ctx
, metadataNodes
);
256 loopMD
->replaceOperandWith(0, loopMD
);
262 LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr
,
267 llvm::MDNode
*loopMD
= lookupLoopMetadata(attr
);
272 LoopAnnotationConversion(attr
, op
, *this, this->llvmModule
.getContext())
274 // Store a map from this Attribute to the LLVM metadata in case we
275 // encounter it again.
276 mapLoopMetadata(attr
, loopMD
);
281 LoopAnnotationTranslation::getAccessGroup(AccessGroupAttr accessGroupAttr
) {
282 auto [result
, inserted
] =
283 accessGroupMetadataMapping
.insert({accessGroupAttr
, nullptr});
285 result
->second
= llvm::MDNode::getDistinct(llvmModule
.getContext(), {});
286 return result
->second
;
290 LoopAnnotationTranslation::getAccessGroups(AccessGroupOpInterface op
) {
291 ArrayAttr accessGroups
= op
.getAccessGroupsOrNull();
292 if (!accessGroups
|| accessGroups
.empty())
295 SmallVector
<llvm::Metadata
*> groupMDs
;
296 for (AccessGroupAttr group
: accessGroups
.getAsRange
<AccessGroupAttr
>())
297 groupMDs
.push_back(getAccessGroup(group
));
298 if (groupMDs
.size() == 1)
299 return llvm::cast
<llvm::MDNode
>(groupMDs
.front());
300 return llvm::MDNode::get(llvmModule
.getContext(), groupMDs
);