[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Bytecode / Writer / IRNumbering.cpp
blob1bc02e17215732bcf69f5aace2ac5fd6a4161495
1 //===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "IRNumbering.h"
10 #include "mlir/Bytecode/BytecodeImplementation.h"
11 #include "mlir/Bytecode/BytecodeOpInterface.h"
12 #include "mlir/Bytecode/BytecodeWriter.h"
13 #include "mlir/Bytecode/Encoding.h"
14 #include "mlir/IR/AsmState.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/OpDefinition.h"
18 using namespace mlir;
19 using namespace mlir::bytecode::detail;
21 //===----------------------------------------------------------------------===//
22 // NumberingDialectWriter
23 //===----------------------------------------------------------------------===//
25 struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter {
26 NumberingDialectWriter(
27 IRNumberingState &state,
28 llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap)
29 : state(state), dialectVersionMap(dialectVersionMap) {}
31 void writeAttribute(Attribute attr) override { state.number(attr); }
32 void writeOptionalAttribute(Attribute attr) override {
33 if (attr)
34 state.number(attr);
36 void writeType(Type type) override { state.number(type); }
37 void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
38 state.number(resource.getDialect(), resource);
41 /// Stubbed out methods that are not used for numbering.
42 void writeVarInt(uint64_t) override {}
43 void writeSignedVarInt(int64_t value) override {}
44 void writeAPIntWithKnownWidth(const APInt &value) override {}
45 void writeAPFloatWithKnownSemantics(const APFloat &value) override {}
46 void writeOwnedString(StringRef) override {
47 // TODO: It might be nice to prenumber strings and sort by the number of
48 // references. This could potentially be useful for optimizing things like
49 // file locations.
51 void writeOwnedBlob(ArrayRef<char> blob) override {}
52 void writeOwnedBool(bool value) override {}
54 int64_t getBytecodeVersion() const override {
55 return state.getDesiredBytecodeVersion();
58 FailureOr<const DialectVersion *>
59 getDialectVersion(StringRef dialectName) const override {
60 auto dialectEntry = dialectVersionMap.find(dialectName);
61 if (dialectEntry == dialectVersionMap.end())
62 return failure();
63 return dialectEntry->getValue().get();
66 /// The parent numbering state that is populated by this writer.
67 IRNumberingState &state;
69 /// A map containing dialect version information for each dialect to emit.
70 llvm::StringMap<std::unique_ptr<DialectVersion>> &dialectVersionMap;
73 //===----------------------------------------------------------------------===//
74 // IR Numbering
75 //===----------------------------------------------------------------------===//
77 /// Group and sort the elements of the given range by their parent dialect. This
78 /// grouping is applied to sub-sections of the ranged defined by how many bytes
79 /// it takes to encode a varint index to that sub-section.
80 template <typename T>
81 static void groupByDialectPerByte(T range) {
82 if (range.empty())
83 return;
85 // A functor used to sort by a given dialect, with a desired dialect to be
86 // ordered first (to better enable sharing of dialects across byte groups).
87 auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs,
88 const auto &rhs) {
89 if (lhs->dialect->number == dialectToOrderFirst)
90 return rhs->dialect->number != dialectToOrderFirst;
91 if (rhs->dialect->number == dialectToOrderFirst)
92 return false;
93 return lhs->dialect->number < rhs->dialect->number;
96 unsigned dialectToOrderFirst = 0;
97 size_t elementsInByteGroup = 0;
98 auto iterRange = range;
99 for (unsigned i = 1; i < 9; ++i) {
100 // Update the number of elements in the current byte grouping. Reminder
101 // that varint encodes 7-bits per byte, so that's how we compute the
102 // number of elements in each byte grouping.
103 elementsInByteGroup = (1ULL << (7ULL * i)) - elementsInByteGroup;
105 // Slice out the sub-set of elements that are in the current byte grouping
106 // to be sorted.
107 auto byteSubRange = iterRange.take_front(elementsInByteGroup);
108 iterRange = iterRange.drop_front(byteSubRange.size());
110 // Sort the sub range for this byte.
111 llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) {
112 return sortByDialect(dialectToOrderFirst, lhs, rhs);
115 // Update the dialect to order first to be the dialect at the end of the
116 // current grouping. This seeks to allow larger dialect groupings across
117 // byte boundaries.
118 dialectToOrderFirst = byteSubRange.back()->dialect->number;
120 // If the data range is now empty, we are done.
121 if (iterRange.empty())
122 break;
125 // Assign the entry numbers based on the sort order.
126 for (auto [idx, value] : llvm::enumerate(range))
127 value->number = idx;
130 IRNumberingState::IRNumberingState(Operation *op,
131 const BytecodeWriterConfig &config)
132 : config(config) {
133 computeGlobalNumberingState(op);
135 // Number the root operation.
136 number(*op);
138 // A worklist of region contexts to number and the next value id before that
139 // region.
140 SmallVector<std::pair<Region *, unsigned>, 8> numberContext;
142 // Functor to push the regions of the given operation onto the numbering
143 // context.
144 auto addOpRegionsToNumber = [&](Operation *op) {
145 MutableArrayRef<Region> regions = op->getRegions();
146 if (regions.empty())
147 return;
149 // Isolated regions don't share value numbers with their parent, so we can
150 // start numbering these regions at zero.
151 unsigned opFirstValueID = isIsolatedFromAbove(op) ? 0 : nextValueID;
152 for (Region &region : regions)
153 numberContext.emplace_back(&region, opFirstValueID);
155 addOpRegionsToNumber(op);
157 // Iteratively process each of the nested regions.
158 while (!numberContext.empty()) {
159 Region *region;
160 std::tie(region, nextValueID) = numberContext.pop_back_val();
161 number(*region);
163 // Traverse into nested regions.
164 for (Operation &op : region->getOps())
165 addOpRegionsToNumber(&op);
168 // Number each of the dialects. For now this is just in the order they were
169 // found, given that the number of dialects on average is small enough to fit
170 // within a singly byte (128). If we ever have real world use cases that have
171 // a huge number of dialects, this could be made more intelligent.
172 for (auto [idx, dialect] : llvm::enumerate(dialects))
173 dialect.second->number = idx;
175 // Number each of the recorded components within each dialect.
177 // First sort by ref count so that the most referenced elements are first. We
178 // try to bias more heavily used elements to the front. This allows for more
179 // frequently referenced things to be encoded using smaller varints.
180 auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) {
181 return lhs->refCount > rhs->refCount;
183 llvm::stable_sort(orderedAttrs, sortByRefCountFn);
184 llvm::stable_sort(orderedOpNames, sortByRefCountFn);
185 llvm::stable_sort(orderedTypes, sortByRefCountFn);
187 // After that, we apply a secondary ordering based on the parent dialect. This
188 // ordering is applied to sub-sections of the element list defined by how many
189 // bytes it takes to encode a varint index to that sub-section. This allows
190 // for more efficiently encoding components of the same dialect (e.g. we only
191 // have to encode the dialect reference once).
192 groupByDialectPerByte(llvm::MutableArrayRef(orderedAttrs));
193 groupByDialectPerByte(llvm::MutableArrayRef(orderedOpNames));
194 groupByDialectPerByte(llvm::MutableArrayRef(orderedTypes));
196 // Finalize the numbering of the dialect resources.
197 finalizeDialectResourceNumberings(op);
200 void IRNumberingState::computeGlobalNumberingState(Operation *rootOp) {
201 // A simple state struct tracking data used when walking operations.
202 struct StackState {
203 /// The operation currently being walked.
204 Operation *op;
206 /// The numbering of the operation.
207 OperationNumbering *numbering;
209 /// A flag indicating if the current state or one of its parents has
210 /// unresolved isolation status. This is tracked separately from the
211 /// isIsolatedFromAbove bit on `numbering` because we need to be able to
212 /// handle the given case:
213 /// top.op {
214 /// %value = ...
215 /// middle.op {
216 /// %value2 = ...
217 /// inner.op {
218 /// // Here we mark `inner.op` as not isolated. Note `middle.op`
219 /// // isn't known not isolated yet.
220 /// use.op %value2
222 /// // Here inner.op is already known to be non-isolated, but
223 /// // `middle.op` is now also discovered to be non-isolated.
224 /// use.op %value
225 /// }
226 /// }
227 /// }
228 bool hasUnresolvedIsolation;
231 // Compute a global operation ID numbering according to the pre-order walk of
232 // the IR. This is used as reference to construct use-list orders.
233 unsigned operationID = 0;
235 // Walk each of the operations within the IR, tracking a stack of operations
236 // as we recurse into nested regions. This walk method hooks in at two stages
237 // during the walk:
239 // BeforeAllRegions:
240 // Here we generate a numbering for the operation and push it onto the
241 // stack if it has regions. We also compute the isolation status of parent
242 // regions at this stage. This is done by checking the parent regions of
243 // operands used by the operation, and marking each region between the
244 // the operand region and the current as not isolated. See
245 // StackState::hasUnresolvedIsolation above for an example.
247 // AfterAllRegions:
248 // Here we pop the operation from the stack, and if it hasn't been marked
249 // as non-isolated, we mark it as so. A non-isolated use would have been
250 // found while walking the regions, so it is safe to mark the operation at
251 // this point.
253 SmallVector<StackState> opStack;
254 rootOp->walk([&](Operation *op, const WalkStage &stage) {
255 // After visiting all nested regions, we pop the operation from the stack.
256 if (op->getNumRegions() && stage.isAfterAllRegions()) {
257 // If no non-isolated uses were found, we can safely mark this operation
258 // as isolated from above.
259 OperationNumbering *numbering = opStack.pop_back_val().numbering;
260 if (!numbering->isIsolatedFromAbove.has_value())
261 numbering->isIsolatedFromAbove = true;
262 return;
265 // When visiting before nested regions, we process "IsolatedFromAbove"
266 // checks and compute the number for this operation.
267 if (!stage.isBeforeAllRegions())
268 return;
269 // Update the isolation status of parent regions if any have yet to be
270 // resolved.
271 if (!opStack.empty() && opStack.back().hasUnresolvedIsolation) {
272 Region *parentRegion = op->getParentRegion();
273 for (Value operand : op->getOperands()) {
274 Region *operandRegion = operand.getParentRegion();
275 if (operandRegion == parentRegion)
276 continue;
277 // We've found a use of an operand outside of the current region,
278 // walk the operation stack searching for the parent operation,
279 // marking every region on the way as not isolated.
280 Operation *operandContainerOp = operandRegion->getParentOp();
281 auto it = std::find_if(
282 opStack.rbegin(), opStack.rend(), [=](const StackState &it) {
283 // We only need to mark up to the container region, or the first
284 // that has an unresolved status.
285 return !it.hasUnresolvedIsolation || it.op == operandContainerOp;
287 assert(it != opStack.rend() && "expected to find the container");
288 for (auto &state : llvm::make_range(opStack.rbegin(), it)) {
289 // If we stopped at a region that knows its isolation status, we can
290 // stop updating the isolation status for the parent regions.
291 state.hasUnresolvedIsolation = it->hasUnresolvedIsolation;
292 state.numbering->isIsolatedFromAbove = false;
297 // Compute the number for this op and push it onto the stack.
298 auto *numbering =
299 new (opAllocator.Allocate()) OperationNumbering(operationID++);
300 if (op->hasTrait<OpTrait::IsIsolatedFromAbove>())
301 numbering->isIsolatedFromAbove = true;
302 operations.try_emplace(op, numbering);
303 if (op->getNumRegions()) {
304 opStack.emplace_back(StackState{
305 op, numbering, !numbering->isIsolatedFromAbove.has_value()});
310 void IRNumberingState::number(Attribute attr) {
311 auto it = attrs.insert({attr, nullptr});
312 if (!it.second) {
313 ++it.first->second->refCount;
314 return;
316 auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
317 it.first->second = numbering;
318 orderedAttrs.push_back(numbering);
320 // Check for OpaqueAttr, which is a dialect-specific attribute that didn't
321 // have a registered dialect when it got created. We don't want to encode this
322 // as the builtin OpaqueAttr, we want to encode it as if the dialect was
323 // actually loaded.
324 if (OpaqueAttr opaqueAttr = dyn_cast<OpaqueAttr>(attr)) {
325 numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
326 return;
328 numbering->dialect = &numberDialect(&attr.getDialect());
330 // If this attribute will be emitted using the bytecode format, perform a
331 // dummy writing to number any nested components.
332 // TODO: We don't allow custom encodings for mutable attributes right now.
333 if (!attr.hasTrait<AttributeTrait::IsMutable>()) {
334 // Try overriding emission with callbacks.
335 for (const auto &callback : config.getAttributeWriterCallbacks()) {
336 NumberingDialectWriter writer(*this, config.getDialectVersionMap());
337 // The client has the ability to override the group name through the
338 // callback.
339 std::optional<StringRef> groupNameOverride;
340 if (succeeded(callback->write(attr, groupNameOverride, writer))) {
341 if (groupNameOverride.has_value())
342 numbering->dialect = &numberDialect(*groupNameOverride);
343 return;
347 if (const auto *interface = numbering->dialect->interface) {
348 NumberingDialectWriter writer(*this, config.getDialectVersionMap());
349 if (succeeded(interface->writeAttribute(attr, writer)))
350 return;
353 // If this attribute will be emitted using the fallback, number the nested
354 // dialect resources. We don't number everything (e.g. no nested
355 // attributes/types), because we don't want to encode things we won't decode
356 // (the textual format can't really share much).
357 AsmState tempState(attr.getContext());
358 llvm::raw_null_ostream dummyOS;
359 attr.print(dummyOS, tempState);
361 // Number the used dialect resources.
362 for (const auto &it : tempState.getDialectResources())
363 number(it.getFirst(), it.getSecond().getArrayRef());
366 void IRNumberingState::number(Block &block) {
367 // Number the arguments of the block.
368 for (BlockArgument arg : block.getArguments()) {
369 valueIDs.try_emplace(arg, nextValueID++);
370 number(arg.getLoc());
371 number(arg.getType());
374 // Number the operations in this block.
375 unsigned &numOps = blockOperationCounts[&block];
376 for (Operation &op : block) {
377 number(op);
378 ++numOps;
382 auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
383 DialectNumbering *&numbering = registeredDialects[dialect];
384 if (!numbering) {
385 numbering = &numberDialect(dialect->getNamespace());
386 numbering->interface = dyn_cast<BytecodeDialectInterface>(dialect);
387 numbering->asmInterface = dyn_cast<OpAsmDialectInterface>(dialect);
389 return *numbering;
392 auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
393 DialectNumbering *&numbering = dialects[dialect];
394 if (!numbering) {
395 numbering = new (dialectAllocator.Allocate())
396 DialectNumbering(dialect, dialects.size() - 1);
398 return *numbering;
401 void IRNumberingState::number(Region &region) {
402 if (region.empty())
403 return;
404 size_t firstValueID = nextValueID;
406 // Number the blocks within this region.
407 size_t blockCount = 0;
408 for (auto it : llvm::enumerate(region)) {
409 blockIDs.try_emplace(&it.value(), it.index());
410 number(it.value());
411 ++blockCount;
414 // Remember the number of blocks and values in this region.
415 regionBlockValueCounts.try_emplace(&region, blockCount,
416 nextValueID - firstValueID);
419 void IRNumberingState::number(Operation &op) {
420 // Number the components of an operation that won't be numbered elsewhere
421 // (e.g. we don't number operands, regions, or successors here).
422 number(op.getName());
423 for (OpResult result : op.getResults()) {
424 valueIDs.try_emplace(result, nextValueID++);
425 number(result.getType());
428 // Prior to a version with native property encoding, or when properties are
429 // not used, we need to number also the merged dictionary containing both the
430 // inherent and discardable attribute.
431 DictionaryAttr dictAttr;
432 if (config.getDesiredBytecodeVersion() >= bytecode::kNativePropertiesEncoding)
433 dictAttr = op.getRawDictionaryAttrs();
434 else
435 dictAttr = op.getAttrDictionary();
436 // Only number the operation's dictionary if it isn't empty.
437 if (!dictAttr.empty())
438 number(dictAttr);
440 // Visit the operation properties (if any) to make sure referenced attributes
441 // are numbered.
442 if (config.getDesiredBytecodeVersion() >=
443 bytecode::kNativePropertiesEncoding &&
444 op.getPropertiesStorageSize()) {
445 if (op.isRegistered()) {
446 // Operation that have properties *must* implement this interface.
447 auto iface = cast<BytecodeOpInterface>(op);
448 NumberingDialectWriter writer(*this, config.getDialectVersionMap());
449 iface.writeProperties(writer);
450 } else {
451 // Unregistered op are storing properties as an optional attribute.
452 if (Attribute prop = *op.getPropertiesStorage().as<Attribute *>())
453 number(prop);
457 number(op.getLoc());
460 void IRNumberingState::number(OperationName opName) {
461 OpNameNumbering *&numbering = opNames[opName];
462 if (numbering) {
463 ++numbering->refCount;
464 return;
466 DialectNumbering *dialectNumber = nullptr;
467 if (Dialect *dialect = opName.getDialect())
468 dialectNumber = &numberDialect(dialect);
469 else
470 dialectNumber = &numberDialect(opName.getDialectNamespace());
472 numbering =
473 new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
474 orderedOpNames.push_back(numbering);
477 void IRNumberingState::number(Type type) {
478 auto it = types.insert({type, nullptr});
479 if (!it.second) {
480 ++it.first->second->refCount;
481 return;
483 auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
484 it.first->second = numbering;
485 orderedTypes.push_back(numbering);
487 // Check for OpaqueType, which is a dialect-specific type that didn't have a
488 // registered dialect when it got created. We don't want to encode this as the
489 // builtin OpaqueType, we want to encode it as if the dialect was actually
490 // loaded.
491 if (OpaqueType opaqueType = dyn_cast<OpaqueType>(type)) {
492 numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
493 return;
495 numbering->dialect = &numberDialect(&type.getDialect());
497 // If this type will be emitted using the bytecode format, perform a dummy
498 // writing to number any nested components.
499 // TODO: We don't allow custom encodings for mutable types right now.
500 if (!type.hasTrait<TypeTrait::IsMutable>()) {
501 // Try overriding emission with callbacks.
502 for (const auto &callback : config.getTypeWriterCallbacks()) {
503 NumberingDialectWriter writer(*this, config.getDialectVersionMap());
504 // The client has the ability to override the group name through the
505 // callback.
506 std::optional<StringRef> groupNameOverride;
507 if (succeeded(callback->write(type, groupNameOverride, writer))) {
508 if (groupNameOverride.has_value())
509 numbering->dialect = &numberDialect(*groupNameOverride);
510 return;
514 // If this attribute will be emitted using the bytecode format, perform a
515 // dummy writing to number any nested components.
516 if (const auto *interface = numbering->dialect->interface) {
517 NumberingDialectWriter writer(*this, config.getDialectVersionMap());
518 if (succeeded(interface->writeType(type, writer)))
519 return;
522 // If this type will be emitted using the fallback, number the nested dialect
523 // resources. We don't number everything (e.g. no nested attributes/types),
524 // because we don't want to encode things we won't decode (the textual format
525 // can't really share much).
526 AsmState tempState(type.getContext());
527 llvm::raw_null_ostream dummyOS;
528 type.print(dummyOS, tempState);
530 // Number the used dialect resources.
531 for (const auto &it : tempState.getDialectResources())
532 number(it.getFirst(), it.getSecond().getArrayRef());
535 void IRNumberingState::number(Dialect *dialect,
536 ArrayRef<AsmDialectResourceHandle> resources) {
537 DialectNumbering &dialectNumber = numberDialect(dialect);
538 assert(
539 dialectNumber.asmInterface &&
540 "expected dialect owning a resource to implement OpAsmDialectInterface");
542 for (const auto &resource : resources) {
543 // Check if this is a newly seen resource.
544 if (!dialectNumber.resources.insert(resource))
545 return;
547 auto *numbering =
548 new (resourceAllocator.Allocate()) DialectResourceNumbering(
549 dialectNumber.asmInterface->getResourceKey(resource));
550 dialectNumber.resourceMap.insert({numbering->key, numbering});
551 dialectResources.try_emplace(resource, numbering);
555 int64_t IRNumberingState::getDesiredBytecodeVersion() const {
556 return config.getDesiredBytecodeVersion();
559 namespace {
560 /// A dummy resource builder used to number dialect resources.
561 struct NumberingResourceBuilder : public AsmResourceBuilder {
562 NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID)
563 : dialect(dialect), nextResourceID(nextResourceID) {}
564 ~NumberingResourceBuilder() override = default;
566 void buildBlob(StringRef key, ArrayRef<char>, uint32_t) final {
567 numberEntry(key);
569 void buildBool(StringRef key, bool) final { numberEntry(key); }
570 void buildString(StringRef key, StringRef) final {
571 // TODO: We could pre-number the value string here as well.
572 numberEntry(key);
575 /// Number the dialect entry for the given key.
576 void numberEntry(StringRef key) {
577 // TODO: We could pre-number resource key strings here as well.
579 auto *it = dialect->resourceMap.find(key);
580 if (it != dialect->resourceMap.end()) {
581 it->second->number = nextResourceID++;
582 it->second->isDeclaration = false;
586 DialectNumbering *dialect;
587 unsigned &nextResourceID;
589 } // namespace
591 void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) {
592 unsigned nextResourceID = 0;
593 for (DialectNumbering &dialect : getDialects()) {
594 if (!dialect.asmInterface)
595 continue;
596 NumberingResourceBuilder entryBuilder(&dialect, nextResourceID);
597 dialect.asmInterface->buildResources(rootOp, dialect.resources,
598 entryBuilder);
600 // Number any resources that weren't added by the dialect. This can happen
601 // if there was no backing data to the resource, but we still want these
602 // resource references to roundtrip, so we number them and indicate that the
603 // data is missing.
604 for (const auto &it : dialect.resourceMap)
605 if (it.second->isDeclaration)
606 it.second->number = nextResourceID++;