1 //===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "IRNumbering.h"
10 #include "mlir/Bytecode/BytecodeImplementation.h"
11 #include "mlir/Bytecode/BytecodeOpInterface.h"
12 #include "mlir/Bytecode/BytecodeWriter.h"
13 #include "mlir/Bytecode/Encoding.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/OpDefinition.h"
19 using namespace mlir::bytecode::detail
;
21 //===----------------------------------------------------------------------===//
22 // NumberingDialectWriter
23 //===----------------------------------------------------------------------===//
25 struct IRNumberingState::NumberingDialectWriter
: public DialectBytecodeWriter
{
26 NumberingDialectWriter(
27 IRNumberingState
&state
,
28 llvm::StringMap
<std::unique_ptr
<DialectVersion
>> &dialectVersionMap
)
29 : state(state
), dialectVersionMap(dialectVersionMap
) {}
31 void writeAttribute(Attribute attr
) override
{ state
.number(attr
); }
32 void writeOptionalAttribute(Attribute attr
) override
{
36 void writeType(Type type
) override
{ state
.number(type
); }
37 void writeResourceHandle(const AsmDialectResourceHandle
&resource
) override
{
38 state
.number(resource
.getDialect(), resource
);
41 /// Stubbed out methods that are not used for numbering.
42 void writeVarInt(uint64_t) override
{}
43 void writeSignedVarInt(int64_t value
) override
{}
44 void writeAPIntWithKnownWidth(const APInt
&value
) override
{}
45 void writeAPFloatWithKnownSemantics(const APFloat
&value
) override
{}
46 void writeOwnedString(StringRef
) override
{
47 // TODO: It might be nice to prenumber strings and sort by the number of
48 // references. This could potentially be useful for optimizing things like
51 void writeOwnedBlob(ArrayRef
<char> blob
) override
{}
52 void writeOwnedBool(bool value
) override
{}
54 int64_t getBytecodeVersion() const override
{
55 return state
.getDesiredBytecodeVersion();
58 FailureOr
<const DialectVersion
*>
59 getDialectVersion(StringRef dialectName
) const override
{
60 auto dialectEntry
= dialectVersionMap
.find(dialectName
);
61 if (dialectEntry
== dialectVersionMap
.end())
63 return dialectEntry
->getValue().get();
66 /// The parent numbering state that is populated by this writer.
67 IRNumberingState
&state
;
69 /// A map containing dialect version information for each dialect to emit.
70 llvm::StringMap
<std::unique_ptr
<DialectVersion
>> &dialectVersionMap
;
73 //===----------------------------------------------------------------------===//
75 //===----------------------------------------------------------------------===//
77 /// Group and sort the elements of the given range by their parent dialect. This
78 /// grouping is applied to sub-sections of the ranged defined by how many bytes
79 /// it takes to encode a varint index to that sub-section.
81 static void groupByDialectPerByte(T range
) {
85 // A functor used to sort by a given dialect, with a desired dialect to be
86 // ordered first (to better enable sharing of dialects across byte groups).
87 auto sortByDialect
= [](unsigned dialectToOrderFirst
, const auto &lhs
,
89 if (lhs
->dialect
->number
== dialectToOrderFirst
)
90 return rhs
->dialect
->number
!= dialectToOrderFirst
;
91 if (rhs
->dialect
->number
== dialectToOrderFirst
)
93 return lhs
->dialect
->number
< rhs
->dialect
->number
;
96 unsigned dialectToOrderFirst
= 0;
97 size_t elementsInByteGroup
= 0;
98 auto iterRange
= range
;
99 for (unsigned i
= 1; i
< 9; ++i
) {
100 // Update the number of elements in the current byte grouping. Reminder
101 // that varint encodes 7-bits per byte, so that's how we compute the
102 // number of elements in each byte grouping.
103 elementsInByteGroup
= (1ULL << (7ULL * i
)) - elementsInByteGroup
;
105 // Slice out the sub-set of elements that are in the current byte grouping
107 auto byteSubRange
= iterRange
.take_front(elementsInByteGroup
);
108 iterRange
= iterRange
.drop_front(byteSubRange
.size());
110 // Sort the sub range for this byte.
111 llvm::stable_sort(byteSubRange
, [&](const auto &lhs
, const auto &rhs
) {
112 return sortByDialect(dialectToOrderFirst
, lhs
, rhs
);
115 // Update the dialect to order first to be the dialect at the end of the
116 // current grouping. This seeks to allow larger dialect groupings across
118 dialectToOrderFirst
= byteSubRange
.back()->dialect
->number
;
120 // If the data range is now empty, we are done.
121 if (iterRange
.empty())
125 // Assign the entry numbers based on the sort order.
126 for (auto [idx
, value
] : llvm::enumerate(range
))
130 IRNumberingState::IRNumberingState(Operation
*op
,
131 const BytecodeWriterConfig
&config
)
133 computeGlobalNumberingState(op
);
135 // Number the root operation.
138 // A worklist of region contexts to number and the next value id before that
140 SmallVector
<std::pair
<Region
*, unsigned>, 8> numberContext
;
142 // Functor to push the regions of the given operation onto the numbering
144 auto addOpRegionsToNumber
= [&](Operation
*op
) {
145 MutableArrayRef
<Region
> regions
= op
->getRegions();
149 // Isolated regions don't share value numbers with their parent, so we can
150 // start numbering these regions at zero.
151 unsigned opFirstValueID
= isIsolatedFromAbove(op
) ? 0 : nextValueID
;
152 for (Region
®ion
: regions
)
153 numberContext
.emplace_back(®ion
, opFirstValueID
);
155 addOpRegionsToNumber(op
);
157 // Iteratively process each of the nested regions.
158 while (!numberContext
.empty()) {
160 std::tie(region
, nextValueID
) = numberContext
.pop_back_val();
163 // Traverse into nested regions.
164 for (Operation
&op
: region
->getOps())
165 addOpRegionsToNumber(&op
);
168 // Number each of the dialects. For now this is just in the order they were
169 // found, given that the number of dialects on average is small enough to fit
170 // within a singly byte (128). If we ever have real world use cases that have
171 // a huge number of dialects, this could be made more intelligent.
172 for (auto [idx
, dialect
] : llvm::enumerate(dialects
))
173 dialect
.second
->number
= idx
;
175 // Number each of the recorded components within each dialect.
177 // First sort by ref count so that the most referenced elements are first. We
178 // try to bias more heavily used elements to the front. This allows for more
179 // frequently referenced things to be encoded using smaller varints.
180 auto sortByRefCountFn
= [](const auto &lhs
, const auto &rhs
) {
181 return lhs
->refCount
> rhs
->refCount
;
183 llvm::stable_sort(orderedAttrs
, sortByRefCountFn
);
184 llvm::stable_sort(orderedOpNames
, sortByRefCountFn
);
185 llvm::stable_sort(orderedTypes
, sortByRefCountFn
);
187 // After that, we apply a secondary ordering based on the parent dialect. This
188 // ordering is applied to sub-sections of the element list defined by how many
189 // bytes it takes to encode a varint index to that sub-section. This allows
190 // for more efficiently encoding components of the same dialect (e.g. we only
191 // have to encode the dialect reference once).
192 groupByDialectPerByte(llvm::MutableArrayRef(orderedAttrs
));
193 groupByDialectPerByte(llvm::MutableArrayRef(orderedOpNames
));
194 groupByDialectPerByte(llvm::MutableArrayRef(orderedTypes
));
196 // Finalize the numbering of the dialect resources.
197 finalizeDialectResourceNumberings(op
);
200 void IRNumberingState::computeGlobalNumberingState(Operation
*rootOp
) {
201 // A simple state struct tracking data used when walking operations.
203 /// The operation currently being walked.
206 /// The numbering of the operation.
207 OperationNumbering
*numbering
;
209 /// A flag indicating if the current state or one of its parents has
210 /// unresolved isolation status. This is tracked separately from the
211 /// isIsolatedFromAbove bit on `numbering` because we need to be able to
212 /// handle the given case:
218 /// // Here we mark `inner.op` as not isolated. Note `middle.op`
219 /// // isn't known not isolated yet.
222 /// // Here inner.op is already known to be non-isolated, but
223 /// // `middle.op` is now also discovered to be non-isolated.
228 bool hasUnresolvedIsolation
;
231 // Compute a global operation ID numbering according to the pre-order walk of
232 // the IR. This is used as reference to construct use-list orders.
233 unsigned operationID
= 0;
235 // Walk each of the operations within the IR, tracking a stack of operations
236 // as we recurse into nested regions. This walk method hooks in at two stages
240 // Here we generate a numbering for the operation and push it onto the
241 // stack if it has regions. We also compute the isolation status of parent
242 // regions at this stage. This is done by checking the parent regions of
243 // operands used by the operation, and marking each region between the
244 // the operand region and the current as not isolated. See
245 // StackState::hasUnresolvedIsolation above for an example.
248 // Here we pop the operation from the stack, and if it hasn't been marked
249 // as non-isolated, we mark it as so. A non-isolated use would have been
250 // found while walking the regions, so it is safe to mark the operation at
253 SmallVector
<StackState
> opStack
;
254 rootOp
->walk([&](Operation
*op
, const WalkStage
&stage
) {
255 // After visiting all nested regions, we pop the operation from the stack.
256 if (op
->getNumRegions() && stage
.isAfterAllRegions()) {
257 // If no non-isolated uses were found, we can safely mark this operation
258 // as isolated from above.
259 OperationNumbering
*numbering
= opStack
.pop_back_val().numbering
;
260 if (!numbering
->isIsolatedFromAbove
.has_value())
261 numbering
->isIsolatedFromAbove
= true;
265 // When visiting before nested regions, we process "IsolatedFromAbove"
266 // checks and compute the number for this operation.
267 if (!stage
.isBeforeAllRegions())
269 // Update the isolation status of parent regions if any have yet to be
271 if (!opStack
.empty() && opStack
.back().hasUnresolvedIsolation
) {
272 Region
*parentRegion
= op
->getParentRegion();
273 for (Value operand
: op
->getOperands()) {
274 Region
*operandRegion
= operand
.getParentRegion();
275 if (operandRegion
== parentRegion
)
277 // We've found a use of an operand outside of the current region,
278 // walk the operation stack searching for the parent operation,
279 // marking every region on the way as not isolated.
280 Operation
*operandContainerOp
= operandRegion
->getParentOp();
281 auto it
= std::find_if(
282 opStack
.rbegin(), opStack
.rend(), [=](const StackState
&it
) {
283 // We only need to mark up to the container region, or the first
284 // that has an unresolved status.
285 return !it
.hasUnresolvedIsolation
|| it
.op
== operandContainerOp
;
287 assert(it
!= opStack
.rend() && "expected to find the container");
288 for (auto &state
: llvm::make_range(opStack
.rbegin(), it
)) {
289 // If we stopped at a region that knows its isolation status, we can
290 // stop updating the isolation status for the parent regions.
291 state
.hasUnresolvedIsolation
= it
->hasUnresolvedIsolation
;
292 state
.numbering
->isIsolatedFromAbove
= false;
297 // Compute the number for this op and push it onto the stack.
299 new (opAllocator
.Allocate()) OperationNumbering(operationID
++);
300 if (op
->hasTrait
<OpTrait::IsIsolatedFromAbove
>())
301 numbering
->isIsolatedFromAbove
= true;
302 operations
.try_emplace(op
, numbering
);
303 if (op
->getNumRegions()) {
304 opStack
.emplace_back(StackState
{
305 op
, numbering
, !numbering
->isIsolatedFromAbove
.has_value()});
310 void IRNumberingState::number(Attribute attr
) {
311 auto it
= attrs
.insert({attr
, nullptr});
313 ++it
.first
->second
->refCount
;
316 auto *numbering
= new (attrAllocator
.Allocate()) AttributeNumbering(attr
);
317 it
.first
->second
= numbering
;
318 orderedAttrs
.push_back(numbering
);
320 // Check for OpaqueAttr, which is a dialect-specific attribute that didn't
321 // have a registered dialect when it got created. We don't want to encode this
322 // as the builtin OpaqueAttr, we want to encode it as if the dialect was
324 if (OpaqueAttr opaqueAttr
= dyn_cast
<OpaqueAttr
>(attr
)) {
325 numbering
->dialect
= &numberDialect(opaqueAttr
.getDialectNamespace());
328 numbering
->dialect
= &numberDialect(&attr
.getDialect());
330 // If this attribute will be emitted using the bytecode format, perform a
331 // dummy writing to number any nested components.
332 // TODO: We don't allow custom encodings for mutable attributes right now.
333 if (!attr
.hasTrait
<AttributeTrait::IsMutable
>()) {
334 // Try overriding emission with callbacks.
335 for (const auto &callback
: config
.getAttributeWriterCallbacks()) {
336 NumberingDialectWriter
writer(*this, config
.getDialectVersionMap());
337 // The client has the ability to override the group name through the
339 std::optional
<StringRef
> groupNameOverride
;
340 if (succeeded(callback
->write(attr
, groupNameOverride
, writer
))) {
341 if (groupNameOverride
.has_value())
342 numbering
->dialect
= &numberDialect(*groupNameOverride
);
347 if (const auto *interface
= numbering
->dialect
->interface
) {
348 NumberingDialectWriter
writer(*this, config
.getDialectVersionMap());
349 if (succeeded(interface
->writeAttribute(attr
, writer
)))
353 // If this attribute will be emitted using the fallback, number the nested
354 // dialect resources. We don't number everything (e.g. no nested
355 // attributes/types), because we don't want to encode things we won't decode
356 // (the textual format can't really share much).
357 AsmState
tempState(attr
.getContext());
358 llvm::raw_null_ostream dummyOS
;
359 attr
.print(dummyOS
, tempState
);
361 // Number the used dialect resources.
362 for (const auto &it
: tempState
.getDialectResources())
363 number(it
.getFirst(), it
.getSecond().getArrayRef());
366 void IRNumberingState::number(Block
&block
) {
367 // Number the arguments of the block.
368 for (BlockArgument arg
: block
.getArguments()) {
369 valueIDs
.try_emplace(arg
, nextValueID
++);
370 number(arg
.getLoc());
371 number(arg
.getType());
374 // Number the operations in this block.
375 unsigned &numOps
= blockOperationCounts
[&block
];
376 for (Operation
&op
: block
) {
382 auto IRNumberingState::numberDialect(Dialect
*dialect
) -> DialectNumbering
& {
383 DialectNumbering
*&numbering
= registeredDialects
[dialect
];
385 numbering
= &numberDialect(dialect
->getNamespace());
386 numbering
->interface
= dyn_cast
<BytecodeDialectInterface
>(dialect
);
387 numbering
->asmInterface
= dyn_cast
<OpAsmDialectInterface
>(dialect
);
392 auto IRNumberingState::numberDialect(StringRef dialect
) -> DialectNumbering
& {
393 DialectNumbering
*&numbering
= dialects
[dialect
];
395 numbering
= new (dialectAllocator
.Allocate())
396 DialectNumbering(dialect
, dialects
.size() - 1);
401 void IRNumberingState::number(Region
®ion
) {
404 size_t firstValueID
= nextValueID
;
406 // Number the blocks within this region.
407 size_t blockCount
= 0;
408 for (auto it
: llvm::enumerate(region
)) {
409 blockIDs
.try_emplace(&it
.value(), it
.index());
414 // Remember the number of blocks and values in this region.
415 regionBlockValueCounts
.try_emplace(®ion
, blockCount
,
416 nextValueID
- firstValueID
);
419 void IRNumberingState::number(Operation
&op
) {
420 // Number the components of an operation that won't be numbered elsewhere
421 // (e.g. we don't number operands, regions, or successors here).
422 number(op
.getName());
423 for (OpResult result
: op
.getResults()) {
424 valueIDs
.try_emplace(result
, nextValueID
++);
425 number(result
.getType());
428 // Prior to a version with native property encoding, or when properties are
429 // not used, we need to number also the merged dictionary containing both the
430 // inherent and discardable attribute.
431 DictionaryAttr dictAttr
;
432 if (config
.getDesiredBytecodeVersion() >= bytecode::kNativePropertiesEncoding
)
433 dictAttr
= op
.getRawDictionaryAttrs();
435 dictAttr
= op
.getAttrDictionary();
436 // Only number the operation's dictionary if it isn't empty.
437 if (!dictAttr
.empty())
440 // Visit the operation properties (if any) to make sure referenced attributes
442 if (config
.getDesiredBytecodeVersion() >=
443 bytecode::kNativePropertiesEncoding
&&
444 op
.getPropertiesStorageSize()) {
445 if (op
.isRegistered()) {
446 // Operation that have properties *must* implement this interface.
447 auto iface
= cast
<BytecodeOpInterface
>(op
);
448 NumberingDialectWriter
writer(*this, config
.getDialectVersionMap());
449 iface
.writeProperties(writer
);
451 // Unregistered op are storing properties as an optional attribute.
452 if (Attribute prop
= *op
.getPropertiesStorage().as
<Attribute
*>())
460 void IRNumberingState::number(OperationName opName
) {
461 OpNameNumbering
*&numbering
= opNames
[opName
];
463 ++numbering
->refCount
;
466 DialectNumbering
*dialectNumber
= nullptr;
467 if (Dialect
*dialect
= opName
.getDialect())
468 dialectNumber
= &numberDialect(dialect
);
470 dialectNumber
= &numberDialect(opName
.getDialectNamespace());
473 new (opNameAllocator
.Allocate()) OpNameNumbering(dialectNumber
, opName
);
474 orderedOpNames
.push_back(numbering
);
477 void IRNumberingState::number(Type type
) {
478 auto it
= types
.insert({type
, nullptr});
480 ++it
.first
->second
->refCount
;
483 auto *numbering
= new (typeAllocator
.Allocate()) TypeNumbering(type
);
484 it
.first
->second
= numbering
;
485 orderedTypes
.push_back(numbering
);
487 // Check for OpaqueType, which is a dialect-specific type that didn't have a
488 // registered dialect when it got created. We don't want to encode this as the
489 // builtin OpaqueType, we want to encode it as if the dialect was actually
491 if (OpaqueType opaqueType
= dyn_cast
<OpaqueType
>(type
)) {
492 numbering
->dialect
= &numberDialect(opaqueType
.getDialectNamespace());
495 numbering
->dialect
= &numberDialect(&type
.getDialect());
497 // If this type will be emitted using the bytecode format, perform a dummy
498 // writing to number any nested components.
499 // TODO: We don't allow custom encodings for mutable types right now.
500 if (!type
.hasTrait
<TypeTrait::IsMutable
>()) {
501 // Try overriding emission with callbacks.
502 for (const auto &callback
: config
.getTypeWriterCallbacks()) {
503 NumberingDialectWriter
writer(*this, config
.getDialectVersionMap());
504 // The client has the ability to override the group name through the
506 std::optional
<StringRef
> groupNameOverride
;
507 if (succeeded(callback
->write(type
, groupNameOverride
, writer
))) {
508 if (groupNameOverride
.has_value())
509 numbering
->dialect
= &numberDialect(*groupNameOverride
);
514 // If this attribute will be emitted using the bytecode format, perform a
515 // dummy writing to number any nested components.
516 if (const auto *interface
= numbering
->dialect
->interface
) {
517 NumberingDialectWriter
writer(*this, config
.getDialectVersionMap());
518 if (succeeded(interface
->writeType(type
, writer
)))
522 // If this type will be emitted using the fallback, number the nested dialect
523 // resources. We don't number everything (e.g. no nested attributes/types),
524 // because we don't want to encode things we won't decode (the textual format
525 // can't really share much).
526 AsmState
tempState(type
.getContext());
527 llvm::raw_null_ostream dummyOS
;
528 type
.print(dummyOS
, tempState
);
530 // Number the used dialect resources.
531 for (const auto &it
: tempState
.getDialectResources())
532 number(it
.getFirst(), it
.getSecond().getArrayRef());
535 void IRNumberingState::number(Dialect
*dialect
,
536 ArrayRef
<AsmDialectResourceHandle
> resources
) {
537 DialectNumbering
&dialectNumber
= numberDialect(dialect
);
539 dialectNumber
.asmInterface
&&
540 "expected dialect owning a resource to implement OpAsmDialectInterface");
542 for (const auto &resource
: resources
) {
543 // Check if this is a newly seen resource.
544 if (!dialectNumber
.resources
.insert(resource
))
548 new (resourceAllocator
.Allocate()) DialectResourceNumbering(
549 dialectNumber
.asmInterface
->getResourceKey(resource
));
550 dialectNumber
.resourceMap
.insert({numbering
->key
, numbering
});
551 dialectResources
.try_emplace(resource
, numbering
);
555 int64_t IRNumberingState::getDesiredBytecodeVersion() const {
556 return config
.getDesiredBytecodeVersion();
560 /// A dummy resource builder used to number dialect resources.
561 struct NumberingResourceBuilder
: public AsmResourceBuilder
{
562 NumberingResourceBuilder(DialectNumbering
*dialect
, unsigned &nextResourceID
)
563 : dialect(dialect
), nextResourceID(nextResourceID
) {}
564 ~NumberingResourceBuilder() override
= default;
566 void buildBlob(StringRef key
, ArrayRef
<char>, uint32_t) final
{
569 void buildBool(StringRef key
, bool) final
{ numberEntry(key
); }
570 void buildString(StringRef key
, StringRef
) final
{
571 // TODO: We could pre-number the value string here as well.
575 /// Number the dialect entry for the given key.
576 void numberEntry(StringRef key
) {
577 // TODO: We could pre-number resource key strings here as well.
579 auto *it
= dialect
->resourceMap
.find(key
);
580 if (it
!= dialect
->resourceMap
.end()) {
581 it
->second
->number
= nextResourceID
++;
582 it
->second
->isDeclaration
= false;
586 DialectNumbering
*dialect
;
587 unsigned &nextResourceID
;
591 void IRNumberingState::finalizeDialectResourceNumberings(Operation
*rootOp
) {
592 unsigned nextResourceID
= 0;
593 for (DialectNumbering
&dialect
: getDialects()) {
594 if (!dialect
.asmInterface
)
596 NumberingResourceBuilder
entryBuilder(&dialect
, nextResourceID
);
597 dialect
.asmInterface
->buildResources(rootOp
, dialect
.resources
,
600 // Number any resources that weren't added by the dialect. This can happen
601 // if there was no backing data to the resource, but we still want these
602 // resource references to roundtrip, so we number them and indicate that the
604 for (const auto &it
: dialect
.resourceMap
)
605 if (it
.second
->isDeclaration
)
606 it
.second
->number
= nextResourceID
++;