1 //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Bytecode/BytecodeReader.h"
10 #include "mlir/AsmParser/AsmParser.h"
11 #include "mlir/Bytecode/BytecodeImplementation.h"
12 #include "mlir/Bytecode/BytecodeOpInterface.h"
13 #include "mlir/Bytecode/Encoding.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/Verifier.h"
18 #include "mlir/IR/Visitors.h"
19 #include "mlir/Support/LLVM.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Endian.h"
25 #include "llvm/Support/MemoryBufferRef.h"
26 #include "llvm/Support/SourceMgr.h"
34 #define DEBUG_TYPE "mlir-bytecode-reader"
38 /// Stringify the given section ID.
39 static std::string
toString(bytecode::Section::ID sectionID
) {
41 case bytecode::Section::kString
:
43 case bytecode::Section::kDialect
:
45 case bytecode::Section::kAttrType
:
46 return "AttrType (2)";
47 case bytecode::Section::kAttrTypeOffset
:
48 return "AttrTypeOffset (3)";
49 case bytecode::Section::kIR
:
51 case bytecode::Section::kResource
:
52 return "Resource (5)";
53 case bytecode::Section::kResourceOffset
:
54 return "ResourceOffset (6)";
55 case bytecode::Section::kDialectVersions
:
56 return "DialectVersions (7)";
57 case bytecode::Section::kProperties
:
58 return "Properties (8)";
60 return ("Unknown (" + Twine(static_cast<unsigned>(sectionID
)) + ")").str();
64 /// Returns true if the given top-level section ID is optional.
65 static bool isSectionOptional(bytecode::Section::ID sectionID
, int version
) {
67 case bytecode::Section::kString
:
68 case bytecode::Section::kDialect
:
69 case bytecode::Section::kAttrType
:
70 case bytecode::Section::kAttrTypeOffset
:
71 case bytecode::Section::kIR
:
73 case bytecode::Section::kResource
:
74 case bytecode::Section::kResourceOffset
:
75 case bytecode::Section::kDialectVersions
:
77 case bytecode::Section::kProperties
:
78 return version
< bytecode::kNativePropertiesEncoding
;
80 llvm_unreachable("unknown section ID");
84 //===----------------------------------------------------------------------===//
86 //===----------------------------------------------------------------------===//
89 class EncodingReader
{
91 explicit EncodingReader(ArrayRef
<uint8_t> contents
, Location fileLoc
)
92 : buffer(contents
), dataIt(buffer
.begin()), fileLoc(fileLoc
) {}
93 explicit EncodingReader(StringRef contents
, Location fileLoc
)
94 : EncodingReader({reinterpret_cast<const uint8_t *>(contents
.data()),
98 /// Returns true if the entire section has been read.
99 bool empty() const { return dataIt
== buffer
.end(); }
101 /// Returns the remaining size of the bytecode.
102 size_t size() const { return buffer
.end() - dataIt
; }
104 /// Align the current reader position to the specified alignment.
105 LogicalResult
alignTo(unsigned alignment
) {
106 if (!llvm::isPowerOf2_32(alignment
))
107 return emitError("expected alignment to be a power-of-two");
109 auto isUnaligned
= [&](const uint8_t *ptr
) {
110 return ((uintptr_t)ptr
& (alignment
- 1)) != 0;
113 // Shift the reader position to the next alignment boundary.
114 while (isUnaligned(dataIt
)) {
116 if (failed(parseByte(padding
)))
118 if (padding
!= bytecode::kAlignmentByte
) {
119 return emitError("expected alignment byte (0xCB), but got: '0x" +
120 llvm::utohexstr(padding
) + "'");
124 // Ensure the data iterator is now aligned. This case is unlikely because we
125 // *just* went through the effort to align the data iterator.
126 if (LLVM_UNLIKELY(isUnaligned(dataIt
))) {
127 return emitError("expected data iterator aligned to ", alignment
,
128 ", but got pointer: '0x" +
129 llvm::utohexstr((uintptr_t)dataIt
) + "'");
135 /// Emit an error using the given arguments.
136 template <typename
... Args
>
137 InFlightDiagnostic
emitError(Args
&&...args
) const {
138 return ::emitError(fileLoc
).append(std::forward
<Args
>(args
)...);
140 InFlightDiagnostic
emitError() const { return ::emitError(fileLoc
); }
142 /// Parse a single byte from the stream.
143 template <typename T
>
144 LogicalResult
parseByte(T
&value
) {
146 return emitError("attempting to parse a byte at the end of the bytecode");
147 value
= static_cast<T
>(*dataIt
++);
150 /// Parse a range of bytes of 'length' into the given result.
151 LogicalResult
parseBytes(size_t length
, ArrayRef
<uint8_t> &result
) {
152 if (length
> size()) {
153 return emitError("attempting to parse ", length
, " bytes when only ",
156 result
= {dataIt
, length
};
160 /// Parse a range of bytes of 'length' into the given result, which can be
161 /// assumed to be large enough to hold `length`.
162 LogicalResult
parseBytes(size_t length
, uint8_t *result
) {
163 if (length
> size()) {
164 return emitError("attempting to parse ", length
, " bytes when only ",
167 memcpy(result
, dataIt
, length
);
172 /// Parse an aligned blob of data, where the alignment was encoded alongside
174 LogicalResult
parseBlobAndAlignment(ArrayRef
<uint8_t> &data
,
175 uint64_t &alignment
) {
177 if (failed(parseVarInt(alignment
)) || failed(parseVarInt(dataSize
)) ||
178 failed(alignTo(alignment
)))
180 return parseBytes(dataSize
, data
);
183 /// Parse a variable length encoded integer from the byte stream. The first
184 /// encoded byte contains a prefix in the low bits indicating the encoded
185 /// length of the value. This length prefix is a bit sequence of '0's followed
186 /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
187 /// (not including the prefix byte). All remaining bits in the first byte,
188 /// along with all of the bits in additional bytes, provide the value of the
189 /// integer encoded in little-endian order.
190 LogicalResult
parseVarInt(uint64_t &result
) {
191 // Parse the first byte of the encoding, which contains the length prefix.
192 if (failed(parseByte(result
)))
195 // Handle the overwhelmingly common case where the value is stored in a
196 // single byte. In this case, the first bit is the `1` marker bit.
197 if (LLVM_LIKELY(result
& 1)) {
202 // Handle the overwhelming uncommon case where the value required all 8
203 // bytes (i.e. a really really big number). In this case, the marker byte is
204 // all zeros: `00000000`.
205 if (LLVM_UNLIKELY(result
== 0)) {
206 llvm::support::ulittle64_t resultLE
;
207 if (failed(parseBytes(sizeof(resultLE
),
208 reinterpret_cast<uint8_t *>(&resultLE
))))
213 return parseMultiByteVarInt(result
);
216 /// Parse a signed variable length encoded integer from the byte stream. A
217 /// signed varint is encoded as a normal varint with zigzag encoding applied,
218 /// i.e. the low bit of the value is used to indicate the sign.
219 LogicalResult
parseSignedVarInt(uint64_t &result
) {
220 if (failed(parseVarInt(result
)))
222 // Essentially (but using unsigned): (x >> 1) ^ -(x & 1)
223 result
= (result
>> 1) ^ (~(result
& 1) + 1);
227 /// Parse a variable length encoded integer whose low bit is used to encode an
228 /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
229 LogicalResult
parseVarIntWithFlag(uint64_t &result
, bool &flag
) {
230 if (failed(parseVarInt(result
)))
237 /// Skip the first `length` bytes within the reader.
238 LogicalResult
skipBytes(size_t length
) {
239 if (length
> size()) {
240 return emitError("attempting to skip ", length
, " bytes when only ",
247 /// Parse a null-terminated string into `result` (without including the NUL
249 LogicalResult
parseNullTerminatedString(StringRef
&result
) {
250 const char *startIt
= (const char *)dataIt
;
251 const char *nulIt
= (const char *)memchr(startIt
, 0, size());
254 "malformed null-terminated string, no null character found");
256 result
= StringRef(startIt
, nulIt
- startIt
);
257 dataIt
= (const uint8_t *)nulIt
+ 1;
261 /// Parse a section header, placing the kind of section in `sectionID` and the
262 /// contents of the section in `sectionData`.
263 LogicalResult
parseSection(bytecode::Section::ID
§ionID
,
264 ArrayRef
<uint8_t> §ionData
) {
265 uint8_t sectionIDAndHasAlignment
;
267 if (failed(parseByte(sectionIDAndHasAlignment
)) ||
268 failed(parseVarInt(length
)))
271 // Extract the section ID and whether the section is aligned. The high bit
272 // of the ID is the alignment flag.
273 sectionID
= static_cast<bytecode::Section::ID
>(sectionIDAndHasAlignment
&
275 bool hasAlignment
= sectionIDAndHasAlignment
& 0b10000000;
277 // Check that the section is actually valid before trying to process its
279 if (sectionID
>= bytecode::Section::kNumSections
)
280 return emitError("invalid section ID: ", unsigned(sectionID
));
282 // Process the section alignment if present.
285 if (failed(parseVarInt(alignment
)) || failed(alignTo(alignment
)))
289 // Parse the actual section data.
290 return parseBytes(static_cast<size_t>(length
), sectionData
);
293 Location
getLoc() const { return fileLoc
; }
296 /// Parse a variable length encoded integer from the byte stream. This method
297 /// is a fallback when the number of bytes used to encode the value is greater
298 /// than 1, but less than the max (9). The provided `result` value can be
299 /// assumed to already contain the first byte of the value.
300 /// NOTE: This method is marked noinline to avoid pessimizing the common case
301 /// of single byte encoding.
302 LLVM_ATTRIBUTE_NOINLINE LogicalResult
parseMultiByteVarInt(uint64_t &result
) {
303 // Count the number of trailing zeros in the marker byte, this indicates the
304 // number of trailing bytes that are part of the value. We use `uint32_t`
305 // here because we only care about the first byte, and so that be actually
306 // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
308 uint32_t numBytes
= llvm::countr_zero
<uint32_t>(result
);
309 assert(numBytes
> 0 && numBytes
<= 7 &&
310 "unexpected number of trailing zeros in varint encoding");
312 // Parse in the remaining bytes of the value.
313 llvm::support::ulittle64_t
resultLE(result
);
315 parseBytes(numBytes
, reinterpret_cast<uint8_t *>(&resultLE
) + 1)))
318 // Shift out the low-order bits that were used to mark how the value was
320 result
= resultLE
>> (numBytes
+ 1);
324 /// The bytecode buffer.
325 ArrayRef
<uint8_t> buffer
;
327 /// The current iterator within the 'buffer'.
328 const uint8_t *dataIt
;
330 /// A location for the bytecode used to report errors.
335 /// Resolve an index into the given entry list. `entry` may either be a
336 /// reference, in which case it is assigned to the corresponding value in
337 /// `entries`, or a pointer, in which case it is assigned to the address of the
338 /// element in `entries`.
339 template <typename RangeT
, typename T
>
340 static LogicalResult
resolveEntry(EncodingReader
&reader
, RangeT
&entries
,
341 uint64_t index
, T
&entry
,
342 StringRef entryStr
) {
343 if (index
>= entries
.size())
344 return reader
.emitError("invalid ", entryStr
, " index: ", index
);
346 // If the provided entry is a pointer, resolve to the address of the entry.
347 if constexpr (std::is_convertible_v
<llvm::detail::ValueOfRange
<RangeT
>, T
>)
348 entry
= entries
[index
];
350 entry
= &entries
[index
];
354 /// Parse and resolve an index into the given entry list.
355 template <typename RangeT
, typename T
>
356 static LogicalResult
parseEntry(EncodingReader
&reader
, RangeT
&entries
,
357 T
&entry
, StringRef entryStr
) {
359 if (failed(reader
.parseVarInt(entryIdx
)))
361 return resolveEntry(reader
, entries
, entryIdx
, entry
, entryStr
);
364 //===----------------------------------------------------------------------===//
365 // StringSectionReader
366 //===----------------------------------------------------------------------===//
369 /// This class is used to read references to the string section from the
371 class StringSectionReader
{
373 /// Initialize the string section reader with the given section data.
374 LogicalResult
initialize(Location fileLoc
, ArrayRef
<uint8_t> sectionData
);
376 /// Parse a shared string from the string section. The shared string is
377 /// encoded using an index to a corresponding string in the string section.
378 LogicalResult
parseString(EncodingReader
&reader
, StringRef
&result
) const {
379 return parseEntry(reader
, strings
, result
, "string");
382 /// Parse a shared string from the string section. The shared string is
383 /// encoded using an index to a corresponding string in the string section.
384 /// This variant parses a flag compressed with the index.
385 LogicalResult
parseStringWithFlag(EncodingReader
&reader
, StringRef
&result
,
388 if (failed(reader
.parseVarIntWithFlag(entryIdx
, flag
)))
390 return parseStringAtIndex(reader
, entryIdx
, result
);
393 /// Parse a shared string from the string section. The shared string is
394 /// encoded using an index to a corresponding string in the string section.
395 LogicalResult
parseStringAtIndex(EncodingReader
&reader
, uint64_t index
,
396 StringRef
&result
) const {
397 return resolveEntry(reader
, strings
, index
, result
, "string");
401 /// The table of strings referenced within the bytecode file.
402 SmallVector
<StringRef
> strings
;
406 LogicalResult
StringSectionReader::initialize(Location fileLoc
,
407 ArrayRef
<uint8_t> sectionData
) {
408 EncodingReader
stringReader(sectionData
, fileLoc
);
410 // Parse the number of strings in the section.
412 if (failed(stringReader
.parseVarInt(numStrings
)))
414 strings
.resize(numStrings
);
416 // Parse each of the strings. The sizes of the strings are encoded in reverse
417 // order, so that's the order we populate the table.
418 size_t stringDataEndOffset
= sectionData
.size();
419 for (StringRef
&string
: llvm::reverse(strings
)) {
421 if (failed(stringReader
.parseVarInt(stringSize
)))
423 if (stringDataEndOffset
< stringSize
) {
424 return stringReader
.emitError(
425 "string size exceeds the available data size");
428 // Extract the string from the data, dropping the null character.
429 size_t stringOffset
= stringDataEndOffset
- stringSize
;
431 reinterpret_cast<const char *>(sectionData
.data() + stringOffset
),
433 stringDataEndOffset
= stringOffset
;
436 // Check that the only remaining data was for the strings, i.e. the reader
437 // should be at the same offset as the first string.
438 if ((sectionData
.size() - stringReader
.size()) != stringDataEndOffset
) {
439 return stringReader
.emitError("unexpected trailing data between the "
440 "offsets for strings and their data");
445 //===----------------------------------------------------------------------===//
447 //===----------------------------------------------------------------------===//
452 /// This struct represents a dialect entry within the bytecode.
453 struct BytecodeDialect
{
454 /// Load the dialect into the provided context if it hasn't been loaded yet.
455 /// Returns failure if the dialect couldn't be loaded *and* the provided
456 /// context does not allow unregistered dialects. The provided reader is used
457 /// for error emission if necessary.
458 LogicalResult
load(const DialectReader
&reader
, MLIRContext
*ctx
);
460 /// Return the loaded dialect, or nullptr if the dialect is unknown. This can
461 /// only be called after `load`.
462 Dialect
*getLoadedDialect() const {
464 "expected `load` to be invoked before `getLoadedDialect`");
468 /// The loaded dialect entry. This field is std::nullopt if we haven't
469 /// attempted to load, nullptr if we failed to load, otherwise the loaded
471 std::optional
<Dialect
*> dialect
;
473 /// The bytecode interface of the dialect, or nullptr if the dialect does not
474 /// implement the bytecode interface. This field should only be checked if the
475 /// `dialect` field is not std::nullopt.
476 const BytecodeDialectInterface
*interface
= nullptr;
478 /// The name of the dialect.
481 /// A buffer containing the encoding of the dialect version parsed.
482 ArrayRef
<uint8_t> versionBuffer
;
484 /// Lazy loaded dialect version from the handle above.
485 std::unique_ptr
<DialectVersion
> loadedVersion
;
488 /// This struct represents an operation name entry within the bytecode.
489 struct BytecodeOperationName
{
490 BytecodeOperationName(BytecodeDialect
*dialect
, StringRef name
,
491 std::optional
<bool> wasRegistered
)
492 : dialect(dialect
), name(name
), wasRegistered(wasRegistered
) {}
494 /// The loaded operation name, or std::nullopt if it hasn't been processed
496 std::optional
<OperationName
> opName
;
498 /// The dialect that owns this operation name.
499 BytecodeDialect
*dialect
;
501 /// The name of the operation, without the dialect prefix.
504 /// Whether this operation was registered when the bytecode was produced.
505 /// This flag is populated when bytecode version >=kNativePropertiesEncoding.
506 std::optional
<bool> wasRegistered
;
510 /// Parse a single dialect group encoded in the byte stream.
511 static LogicalResult
parseDialectGrouping(
512 EncodingReader
&reader
,
513 MutableArrayRef
<std::unique_ptr
<BytecodeDialect
>> dialects
,
514 function_ref
<LogicalResult(BytecodeDialect
*)> entryCallback
) {
515 // Parse the dialect and the number of entries in the group.
516 std::unique_ptr
<BytecodeDialect
> *dialect
;
517 if (failed(parseEntry(reader
, dialects
, dialect
, "dialect")))
520 if (failed(reader
.parseVarInt(numEntries
)))
523 for (uint64_t i
= 0; i
< numEntries
; ++i
)
524 if (failed(entryCallback(dialect
->get())))
529 //===----------------------------------------------------------------------===//
530 // ResourceSectionReader
531 //===----------------------------------------------------------------------===//
534 /// This class is used to read the resource section from the bytecode.
535 class ResourceSectionReader
{
537 /// Initialize the resource section reader with the given section data.
539 initialize(Location fileLoc
, const ParserConfig
&config
,
540 MutableArrayRef
<std::unique_ptr
<BytecodeDialect
>> dialects
,
541 StringSectionReader
&stringReader
, ArrayRef
<uint8_t> sectionData
,
542 ArrayRef
<uint8_t> offsetSectionData
, DialectReader
&dialectReader
,
543 const std::shared_ptr
<llvm::SourceMgr
> &bufferOwnerRef
);
545 /// Parse a dialect resource handle from the resource section.
546 LogicalResult
parseResourceHandle(EncodingReader
&reader
,
547 AsmDialectResourceHandle
&result
) const {
548 return parseEntry(reader
, dialectResources
, result
, "resource handle");
552 /// The table of dialect resources within the bytecode file.
553 SmallVector
<AsmDialectResourceHandle
> dialectResources
;
554 llvm::StringMap
<std::string
> dialectResourceHandleRenamingMap
;
557 class ParsedResourceEntry
: public AsmParsedResourceEntry
{
559 ParsedResourceEntry(StringRef key
, AsmResourceEntryKind kind
,
560 EncodingReader
&reader
, StringSectionReader
&stringReader
,
561 const std::shared_ptr
<llvm::SourceMgr
> &bufferOwnerRef
)
562 : key(key
), kind(kind
), reader(reader
), stringReader(stringReader
),
563 bufferOwnerRef(bufferOwnerRef
) {}
564 ~ParsedResourceEntry() override
= default;
566 StringRef
getKey() const final
{ return key
; }
568 InFlightDiagnostic
emitError() const final
{ return reader
.emitError(); }
570 AsmResourceEntryKind
getKind() const final
{ return kind
; }
572 FailureOr
<bool> parseAsBool() const final
{
573 if (kind
!= AsmResourceEntryKind::Bool
)
574 return emitError() << "expected a bool resource entry, but found a "
575 << toString(kind
) << " entry instead";
578 if (failed(reader
.parseByte(value
)))
582 FailureOr
<std::string
> parseAsString() const final
{
583 if (kind
!= AsmResourceEntryKind::String
)
584 return emitError() << "expected a string resource entry, but found a "
585 << toString(kind
) << " entry instead";
588 if (failed(stringReader
.parseString(reader
, string
)))
593 FailureOr
<AsmResourceBlob
>
594 parseAsBlob(BlobAllocatorFn allocator
) const final
{
595 if (kind
!= AsmResourceEntryKind::Blob
)
596 return emitError() << "expected a blob resource entry, but found a "
597 << toString(kind
) << " entry instead";
599 ArrayRef
<uint8_t> data
;
601 if (failed(reader
.parseBlobAndAlignment(data
, alignment
)))
604 // If we have an extendable reference to the buffer owner, we don't need to
605 // allocate a new buffer for the data, and can use the data directly.
606 if (bufferOwnerRef
) {
607 ArrayRef
<char> charData(reinterpret_cast<const char *>(data
.data()),
610 // Allocate an unmanager buffer which captures a reference to the owner.
611 // For now we just mark this as immutable, but in the future we should
612 // explore marking this as mutable when desired.
613 return UnmanagedAsmResourceBlob::allocateWithAlign(
615 [bufferOwnerRef
= bufferOwnerRef
](void *, size_t, size_t) {});
618 // Allocate memory for the blob using the provided allocator and copy the
620 AsmResourceBlob blob
= allocator(data
.size(), alignment
);
621 assert(llvm::isAddrAligned(llvm::Align(alignment
), blob
.getData().data()) &&
623 "blob allocator did not return a properly aligned address");
624 memcpy(blob
.getMutableData().data(), data
.data(), data
.size());
630 AsmResourceEntryKind kind
;
631 EncodingReader
&reader
;
632 StringSectionReader
&stringReader
;
633 const std::shared_ptr
<llvm::SourceMgr
> &bufferOwnerRef
;
637 template <typename T
>
639 parseResourceGroup(Location fileLoc
, bool allowEmpty
,
640 EncodingReader
&offsetReader
, EncodingReader
&resourceReader
,
641 StringSectionReader
&stringReader
, T
*handler
,
642 const std::shared_ptr
<llvm::SourceMgr
> &bufferOwnerRef
,
643 function_ref
<StringRef(StringRef
)> remapKey
= {},
644 function_ref
<LogicalResult(StringRef
)> processKeyFn
= {}) {
645 uint64_t numResources
;
646 if (failed(offsetReader
.parseVarInt(numResources
)))
649 for (uint64_t i
= 0; i
< numResources
; ++i
) {
651 AsmResourceEntryKind kind
;
652 uint64_t resourceOffset
;
653 ArrayRef
<uint8_t> data
;
654 if (failed(stringReader
.parseString(offsetReader
, key
)) ||
655 failed(offsetReader
.parseVarInt(resourceOffset
)) ||
656 failed(offsetReader
.parseByte(kind
)) ||
657 failed(resourceReader
.parseBytes(resourceOffset
, data
)))
660 // Process the resource key.
661 if ((processKeyFn
&& failed(processKeyFn(key
))))
664 // If the resource data is empty and we allow it, don't error out when
665 // parsing below, just skip it.
666 if (allowEmpty
&& data
.empty())
669 // Ignore the entry if we don't have a valid handler.
673 // Otherwise, parse the resource value.
674 EncodingReader
entryReader(data
, fileLoc
);
676 ParsedResourceEntry
entry(key
, kind
, entryReader
, stringReader
,
678 if (failed(handler
->parseResource(entry
)))
680 if (!entryReader
.empty()) {
681 return entryReader
.emitError(
682 "unexpected trailing bytes in resource entry '", key
, "'");
688 LogicalResult
ResourceSectionReader::initialize(
689 Location fileLoc
, const ParserConfig
&config
,
690 MutableArrayRef
<std::unique_ptr
<BytecodeDialect
>> dialects
,
691 StringSectionReader
&stringReader
, ArrayRef
<uint8_t> sectionData
,
692 ArrayRef
<uint8_t> offsetSectionData
, DialectReader
&dialectReader
,
693 const std::shared_ptr
<llvm::SourceMgr
> &bufferOwnerRef
) {
694 EncodingReader
resourceReader(sectionData
, fileLoc
);
695 EncodingReader
offsetReader(offsetSectionData
, fileLoc
);
697 // Read the number of external resource providers.
698 uint64_t numExternalResourceGroups
;
699 if (failed(offsetReader
.parseVarInt(numExternalResourceGroups
)))
702 // Utility functor that dispatches to `parseResourceGroup`, but implicitly
703 // provides most of the arguments.
704 auto parseGroup
= [&](auto *handler
, bool allowEmpty
= false,
705 function_ref
<LogicalResult(StringRef
)> keyFn
= {}) {
706 auto resolveKey
= [&](StringRef key
) -> StringRef
{
707 auto it
= dialectResourceHandleRenamingMap
.find(key
);
708 if (it
== dialectResourceHandleRenamingMap
.end())
713 return parseResourceGroup(fileLoc
, allowEmpty
, offsetReader
, resourceReader
,
714 stringReader
, handler
, bufferOwnerRef
, resolveKey
,
718 // Read the external resources from the bytecode.
719 for (uint64_t i
= 0; i
< numExternalResourceGroups
; ++i
) {
721 if (failed(stringReader
.parseString(offsetReader
, key
)))
724 // Get the handler for these resources.
725 // TODO: Should we require handling external resources in some scenarios?
726 AsmResourceParser
*handler
= config
.getResourceParser(key
);
728 emitWarning(fileLoc
) << "ignoring unknown external resources for '" << key
732 if (failed(parseGroup(handler
)))
736 // Read the dialect resources from the bytecode.
737 MLIRContext
*ctx
= fileLoc
->getContext();
738 while (!offsetReader
.empty()) {
739 std::unique_ptr
<BytecodeDialect
> *dialect
;
740 if (failed(parseEntry(offsetReader
, dialects
, dialect
, "dialect")) ||
741 failed((*dialect
)->load(dialectReader
, ctx
)))
743 Dialect
*loadedDialect
= (*dialect
)->getLoadedDialect();
744 if (!loadedDialect
) {
745 return resourceReader
.emitError()
746 << "dialect '" << (*dialect
)->name
<< "' is unknown";
748 const auto *handler
= dyn_cast
<OpAsmDialectInterface
>(loadedDialect
);
750 return resourceReader
.emitError()
751 << "unexpected resources for dialect '" << (*dialect
)->name
<< "'";
754 // Ensure that each resource is declared before being processed.
755 auto processResourceKeyFn
= [&](StringRef key
) -> LogicalResult
{
756 FailureOr
<AsmDialectResourceHandle
> handle
=
757 handler
->declareResource(key
);
758 if (failed(handle
)) {
759 return resourceReader
.emitError()
760 << "unknown 'resource' key '" << key
<< "' for dialect '"
761 << (*dialect
)->name
<< "'";
763 dialectResourceHandleRenamingMap
[key
] = handler
->getResourceKey(*handle
);
764 dialectResources
.push_back(*handle
);
768 // Parse the resources for this dialect. We allow empty resources because we
769 // just treat these as declarations.
770 if (failed(parseGroup(handler
, /*allowEmpty=*/true, processResourceKeyFn
)))
777 //===----------------------------------------------------------------------===//
778 // Attribute/Type Reader
779 //===----------------------------------------------------------------------===//
782 /// This class provides support for reading attribute and type entries from the
783 /// bytecode. Attribute and Type entries are read lazily on demand, so we use
784 /// this reader to manage when to actually parse them from the bytecode.
785 class AttrTypeReader
{
786 /// This class represents a single attribute or type entry.
787 template <typename T
>
789 /// The entry, or null if it hasn't been resolved yet.
791 /// The parent dialect of this entry.
792 BytecodeDialect
*dialect
= nullptr;
793 /// A flag indicating if the entry was encoded using a custom encoding,
794 /// instead of using the textual assembly format.
795 bool hasCustomEncoding
= false;
796 /// The raw data of this entry in the bytecode.
797 ArrayRef
<uint8_t> data
;
799 using AttrEntry
= Entry
<Attribute
>;
800 using TypeEntry
= Entry
<Type
>;
803 AttrTypeReader(const StringSectionReader
&stringReader
,
804 const ResourceSectionReader
&resourceReader
,
805 const llvm::StringMap
<BytecodeDialect
*> &dialectsMap
,
806 uint64_t &bytecodeVersion
, Location fileLoc
,
807 const ParserConfig
&config
)
808 : stringReader(stringReader
), resourceReader(resourceReader
),
809 dialectsMap(dialectsMap
), fileLoc(fileLoc
),
810 bytecodeVersion(bytecodeVersion
), parserConfig(config
) {}
812 /// Initialize the attribute and type information within the reader.
814 initialize(MutableArrayRef
<std::unique_ptr
<BytecodeDialect
>> dialects
,
815 ArrayRef
<uint8_t> sectionData
,
816 ArrayRef
<uint8_t> offsetSectionData
);
818 /// Resolve the attribute or type at the given index. Returns nullptr on
820 Attribute
resolveAttribute(size_t index
) {
821 return resolveEntry(attributes
, index
, "Attribute");
823 Type
resolveType(size_t index
) { return resolveEntry(types
, index
, "Type"); }
825 /// Parse a reference to an attribute or type using the given reader.
826 LogicalResult
parseAttribute(EncodingReader
&reader
, Attribute
&result
) {
828 if (failed(reader
.parseVarInt(attrIdx
)))
830 result
= resolveAttribute(attrIdx
);
831 return success(!!result
);
833 LogicalResult
parseOptionalAttribute(EncodingReader
&reader
,
837 if (failed(reader
.parseVarIntWithFlag(attrIdx
, flag
)))
841 result
= resolveAttribute(attrIdx
);
842 return success(!!result
);
845 LogicalResult
parseType(EncodingReader
&reader
, Type
&result
) {
847 if (failed(reader
.parseVarInt(typeIdx
)))
849 result
= resolveType(typeIdx
);
850 return success(!!result
);
853 template <typename T
>
854 LogicalResult
parseAttribute(EncodingReader
&reader
, T
&result
) {
855 Attribute baseResult
;
856 if (failed(parseAttribute(reader
, baseResult
)))
858 if ((result
= dyn_cast
<T
>(baseResult
)))
860 return reader
.emitError("expected attribute of type: ",
861 llvm::getTypeName
<T
>(), ", but got: ", baseResult
);
865 /// Resolve the given entry at `index`.
866 template <typename T
>
867 T
resolveEntry(SmallVectorImpl
<Entry
<T
>> &entries
, size_t index
,
868 StringRef entryType
);
870 /// Parse an entry using the given reader that was encoded using the textual
872 template <typename T
>
873 LogicalResult
parseAsmEntry(T
&result
, EncodingReader
&reader
,
874 StringRef entryType
);
876 /// Parse an entry using the given reader that was encoded using a custom
878 template <typename T
>
879 LogicalResult
parseCustomEntry(Entry
<T
> &entry
, EncodingReader
&reader
,
880 StringRef entryType
);
882 /// The string section reader used to resolve string references when parsing
883 /// custom encoded attribute/type entries.
884 const StringSectionReader
&stringReader
;
886 /// The resource section reader used to resolve resource references when
887 /// parsing custom encoded attribute/type entries.
888 const ResourceSectionReader
&resourceReader
;
890 /// The map of the loaded dialects used to retrieve dialect information, such
891 /// as the dialect version.
892 const llvm::StringMap
<BytecodeDialect
*> &dialectsMap
;
894 /// The set of attribute and type entries.
895 SmallVector
<AttrEntry
> attributes
;
896 SmallVector
<TypeEntry
> types
;
898 /// A location used for error emission.
901 /// Current bytecode version being used.
902 uint64_t &bytecodeVersion
;
904 /// Reference to the parser configuration.
905 const ParserConfig
&parserConfig
;
908 class DialectReader
: public DialectBytecodeReader
{
910 DialectReader(AttrTypeReader
&attrTypeReader
,
911 const StringSectionReader
&stringReader
,
912 const ResourceSectionReader
&resourceReader
,
913 const llvm::StringMap
<BytecodeDialect
*> &dialectsMap
,
914 EncodingReader
&reader
, uint64_t &bytecodeVersion
)
915 : attrTypeReader(attrTypeReader
), stringReader(stringReader
),
916 resourceReader(resourceReader
), dialectsMap(dialectsMap
),
917 reader(reader
), bytecodeVersion(bytecodeVersion
) {}
919 InFlightDiagnostic
emitError(const Twine
&msg
) const override
{
920 return reader
.emitError(msg
);
923 FailureOr
<const DialectVersion
*>
924 getDialectVersion(StringRef dialectName
) const override
{
925 // First check if the dialect is available in the map.
926 auto dialectEntry
= dialectsMap
.find(dialectName
);
927 if (dialectEntry
== dialectsMap
.end())
929 // If the dialect was found, try to load it. This will trigger reading the
930 // bytecode version from the version buffer if it wasn't already processed.
931 // Return failure if either of those two actions could not be completed.
932 if (failed(dialectEntry
->getValue()->load(*this, getLoc().getContext())) ||
933 dialectEntry
->getValue()->loadedVersion
== nullptr)
935 return dialectEntry
->getValue()->loadedVersion
.get();
938 MLIRContext
*getContext() const override
{ return getLoc().getContext(); }
940 uint64_t getBytecodeVersion() const override
{ return bytecodeVersion
; }
942 DialectReader
withEncodingReader(EncodingReader
&encReader
) const {
943 return DialectReader(attrTypeReader
, stringReader
, resourceReader
,
944 dialectsMap
, encReader
, bytecodeVersion
);
947 Location
getLoc() const { return reader
.getLoc(); }
949 //===--------------------------------------------------------------------===//
951 //===--------------------------------------------------------------------===//
953 LogicalResult
readAttribute(Attribute
&result
) override
{
954 return attrTypeReader
.parseAttribute(reader
, result
);
956 LogicalResult
readOptionalAttribute(Attribute
&result
) override
{
957 return attrTypeReader
.parseOptionalAttribute(reader
, result
);
959 LogicalResult
readType(Type
&result
) override
{
960 return attrTypeReader
.parseType(reader
, result
);
963 FailureOr
<AsmDialectResourceHandle
> readResourceHandle() override
{
964 AsmDialectResourceHandle handle
;
965 if (failed(resourceReader
.parseResourceHandle(reader
, handle
)))
970 //===--------------------------------------------------------------------===//
972 //===--------------------------------------------------------------------===//
974 LogicalResult
readVarInt(uint64_t &result
) override
{
975 return reader
.parseVarInt(result
);
978 LogicalResult
readSignedVarInt(int64_t &result
) override
{
979 uint64_t unsignedResult
;
980 if (failed(reader
.parseSignedVarInt(unsignedResult
)))
982 result
= static_cast<int64_t>(unsignedResult
);
986 FailureOr
<APInt
> readAPIntWithKnownWidth(unsigned bitWidth
) override
{
987 // Small values are encoded using a single byte.
990 if (failed(reader
.parseByte(value
)))
992 return APInt(bitWidth
, value
);
995 // Large values up to 64 bits are encoded using a single varint.
996 if (bitWidth
<= 64) {
998 if (failed(reader
.parseSignedVarInt(value
)))
1000 return APInt(bitWidth
, value
);
1003 // Otherwise, for really big values we encode the array of active words in
1005 uint64_t numActiveWords
;
1006 if (failed(reader
.parseVarInt(numActiveWords
)))
1008 SmallVector
<uint64_t, 4> words(numActiveWords
);
1009 for (uint64_t i
= 0; i
< numActiveWords
; ++i
)
1010 if (failed(reader
.parseSignedVarInt(words
[i
])))
1012 return APInt(bitWidth
, words
);
1016 readAPFloatWithKnownSemantics(const llvm::fltSemantics
&semantics
) override
{
1017 FailureOr
<APInt
> intVal
=
1018 readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics
));
1021 return APFloat(semantics
, *intVal
);
1024 LogicalResult
readString(StringRef
&result
) override
{
1025 return stringReader
.parseString(reader
, result
);
1028 LogicalResult
readBlob(ArrayRef
<char> &result
) override
{
1030 ArrayRef
<uint8_t> data
;
1031 if (failed(reader
.parseVarInt(dataSize
)) ||
1032 failed(reader
.parseBytes(dataSize
, data
)))
1034 result
= llvm::ArrayRef(reinterpret_cast<const char *>(data
.data()),
1039 LogicalResult
readBool(bool &result
) override
{
1040 return reader
.parseByte(result
);
1044 AttrTypeReader
&attrTypeReader
;
1045 const StringSectionReader
&stringReader
;
1046 const ResourceSectionReader
&resourceReader
;
1047 const llvm::StringMap
<BytecodeDialect
*> &dialectsMap
;
1048 EncodingReader
&reader
;
1049 uint64_t &bytecodeVersion
;
1052 /// Wraps the properties section and handles reading properties out of it.
1053 class PropertiesSectionReader
{
1055 /// Initialize the properties section reader with the given section data.
1056 LogicalResult
initialize(Location fileLoc
, ArrayRef
<uint8_t> sectionData
) {
1057 if (sectionData
.empty())
1059 EncodingReader
propReader(sectionData
, fileLoc
);
1061 if (failed(propReader
.parseVarInt(count
)))
1063 // Parse the raw properties buffer.
1064 if (failed(propReader
.parseBytes(propReader
.size(), propertiesBuffers
)))
1067 EncodingReader
offsetsReader(propertiesBuffers
, fileLoc
);
1068 offsetTable
.reserve(count
);
1069 for (auto idx
: llvm::seq
<int64_t>(0, count
)) {
1071 offsetTable
.push_back(propertiesBuffers
.size() - offsetsReader
.size());
1072 ArrayRef
<uint8_t> rawProperties
;
1074 if (failed(offsetsReader
.parseVarInt(dataSize
)) ||
1075 failed(offsetsReader
.parseBytes(dataSize
, rawProperties
)))
1078 if (!offsetsReader
.empty())
1079 return offsetsReader
.emitError()
1080 << "Broken properties section: didn't exhaust the offsets table";
1084 LogicalResult
read(Location fileLoc
, DialectReader
&dialectReader
,
1085 OperationName
*opName
, OperationState
&opState
) const {
1086 uint64_t propertiesIdx
;
1087 if (failed(dialectReader
.readVarInt(propertiesIdx
)))
1089 if (propertiesIdx
>= offsetTable
.size())
1090 return dialectReader
.emitError("Properties idx out-of-bound for ")
1091 << opName
->getStringRef();
1092 size_t propertiesOffset
= offsetTable
[propertiesIdx
];
1093 if (propertiesIdx
>= propertiesBuffers
.size())
1094 return dialectReader
.emitError("Properties offset out-of-bound for ")
1095 << opName
->getStringRef();
1097 // Acquire the sub-buffer that represent the requested properties.
1098 ArrayRef
<char> rawProperties
;
1100 // "Seek" to the requested offset by getting a new reader with the right
1102 EncodingReader
reader(propertiesBuffers
.drop_front(propertiesOffset
),
1104 // Properties are stored as a sequence of {size + raw_data}.
1106 dialectReader
.withEncodingReader(reader
).readBlob(rawProperties
)))
1109 // Setup a new reader to read from the `rawProperties` sub-buffer.
1110 EncodingReader
reader(
1111 StringRef(rawProperties
.begin(), rawProperties
.size()), fileLoc
);
1112 DialectReader propReader
= dialectReader
.withEncodingReader(reader
);
1114 auto *iface
= opName
->getInterface
<BytecodeOpInterface
>();
1116 return iface
->readProperties(propReader
, opState
);
1117 if (opName
->isRegistered())
1118 return propReader
.emitError(
1119 "has properties but missing BytecodeOpInterface for ")
1120 << opName
->getStringRef();
1121 // Unregistered op are storing properties as an attribute.
1122 return propReader
.readAttribute(opState
.propertiesAttr
);
1126 /// The properties buffer referenced within the bytecode file.
1127 ArrayRef
<uint8_t> propertiesBuffers
;
1129 /// Table of offset in the buffer above.
1130 SmallVector
<int64_t> offsetTable
;
1134 LogicalResult
AttrTypeReader::initialize(
1135 MutableArrayRef
<std::unique_ptr
<BytecodeDialect
>> dialects
,
1136 ArrayRef
<uint8_t> sectionData
, ArrayRef
<uint8_t> offsetSectionData
) {
1137 EncodingReader
offsetReader(offsetSectionData
, fileLoc
);
1139 // Parse the number of attribute and type entries.
1140 uint64_t numAttributes
, numTypes
;
1141 if (failed(offsetReader
.parseVarInt(numAttributes
)) ||
1142 failed(offsetReader
.parseVarInt(numTypes
)))
1144 attributes
.resize(numAttributes
);
1145 types
.resize(numTypes
);
1147 // A functor used to accumulate the offsets for the entries in the given
1149 uint64_t currentOffset
= 0;
1150 auto parseEntries
= [&](auto &&range
) {
1151 size_t currentIndex
= 0, endIndex
= range
.size();
1153 // Parse an individual entry.
1154 auto parseEntryFn
= [&](BytecodeDialect
*dialect
) -> LogicalResult
{
1155 auto &entry
= range
[currentIndex
++];
1158 if (failed(offsetReader
.parseVarIntWithFlag(entrySize
,
1159 entry
.hasCustomEncoding
)))
1162 // Verify that the offset is actually valid.
1163 if (currentOffset
+ entrySize
> sectionData
.size()) {
1164 return offsetReader
.emitError(
1165 "Attribute or Type entry offset points past the end of section");
1168 entry
.data
= sectionData
.slice(currentOffset
, entrySize
);
1169 entry
.dialect
= dialect
;
1170 currentOffset
+= entrySize
;
1173 while (currentIndex
!= endIndex
)
1174 if (failed(parseDialectGrouping(offsetReader
, dialects
, parseEntryFn
)))
1179 // Process each of the attributes, and then the types.
1180 if (failed(parseEntries(attributes
)) || failed(parseEntries(types
)))
1183 // Ensure that we read everything from the section.
1184 if (!offsetReader
.empty()) {
1185 return offsetReader
.emitError(
1186 "unexpected trailing data in the Attribute/Type offset section");
1192 template <typename T
>
1193 T
AttrTypeReader::resolveEntry(SmallVectorImpl
<Entry
<T
>> &entries
, size_t index
,
1194 StringRef entryType
) {
1195 if (index
>= entries
.size()) {
1196 emitError(fileLoc
) << "invalid " << entryType
<< " index: " << index
;
1200 // If the entry has already been resolved, there is nothing left to do.
1201 Entry
<T
> &entry
= entries
[index
];
1206 EncodingReader
reader(entry
.data
, fileLoc
);
1208 // Parse based on how the entry was encoded.
1209 if (entry
.hasCustomEncoding
) {
1210 if (failed(parseCustomEntry(entry
, reader
, entryType
)))
1212 } else if (failed(parseAsmEntry(entry
.entry
, reader
, entryType
))) {
1216 if (!reader
.empty()) {
1217 reader
.emitError("unexpected trailing bytes after " + entryType
+ " entry");
1223 template <typename T
>
1224 LogicalResult
AttrTypeReader::parseAsmEntry(T
&result
, EncodingReader
&reader
,
1225 StringRef entryType
) {
1227 if (failed(reader
.parseNullTerminatedString(asmStr
)))
1230 // Invoke the MLIR assembly parser to parse the entry text.
1232 MLIRContext
*context
= fileLoc
->getContext();
1233 if constexpr (std::is_same_v
<T
, Type
>)
1235 ::parseType(asmStr
, context
, &numRead
, /*isKnownNullTerminated=*/true);
1237 result
= ::parseAttribute(asmStr
, context
, Type(), &numRead
,
1238 /*isKnownNullTerminated=*/true);
1242 // Ensure there weren't dangling characters after the entry.
1243 if (numRead
!= asmStr
.size()) {
1244 return reader
.emitError("trailing characters found after ", entryType
,
1245 " assembly format: ", asmStr
.drop_front(numRead
));
1250 template <typename T
>
1251 LogicalResult
AttrTypeReader::parseCustomEntry(Entry
<T
> &entry
,
1252 EncodingReader
&reader
,
1253 StringRef entryType
) {
1254 DialectReader
dialectReader(*this, stringReader
, resourceReader
, dialectsMap
,
1255 reader
, bytecodeVersion
);
1256 if (failed(entry
.dialect
->load(dialectReader
, fileLoc
.getContext())))
1259 if constexpr (std::is_same_v
<T
, Type
>) {
1260 // Try parsing with callbacks first if available.
1261 for (const auto &callback
:
1262 parserConfig
.getBytecodeReaderConfig().getTypeCallbacks()) {
1264 callback
->read(dialectReader
, entry
.dialect
->name
, entry
.entry
)))
1266 // Early return if parsing was successful.
1270 // Reset the reader if we failed to parse, so we can fall through the
1271 // other parsing functions.
1272 reader
= EncodingReader(entry
.data
, reader
.getLoc());
1275 // Try parsing with callbacks first if available.
1276 for (const auto &callback
:
1277 parserConfig
.getBytecodeReaderConfig().getAttributeCallbacks()) {
1279 callback
->read(dialectReader
, entry
.dialect
->name
, entry
.entry
)))
1281 // Early return if parsing was successful.
1285 // Reset the reader if we failed to parse, so we can fall through the
1286 // other parsing functions.
1287 reader
= EncodingReader(entry
.data
, reader
.getLoc());
1291 // Ensure that the dialect implements the bytecode interface.
1292 if (!entry
.dialect
->interface
) {
1293 return reader
.emitError("dialect '", entry
.dialect
->name
,
1294 "' does not implement the bytecode interface");
1297 if constexpr (std::is_same_v
<T
, Type
>)
1298 entry
.entry
= entry
.dialect
->interface
->readType(dialectReader
);
1300 entry
.entry
= entry
.dialect
->interface
->readAttribute(dialectReader
);
1302 return success(!!entry
.entry
);
1305 //===----------------------------------------------------------------------===//
1307 //===----------------------------------------------------------------------===//
1309 /// This class is used to read a bytecode buffer and translate it into MLIR.
1310 class mlir::BytecodeReader::Impl
{
1311 struct RegionReadState
;
1312 using LazyLoadableOpsInfo
=
1313 std::list
<std::pair
<Operation
*, RegionReadState
>>;
1314 using LazyLoadableOpsMap
=
1315 DenseMap
<Operation
*, LazyLoadableOpsInfo::iterator
>;
1318 Impl(Location fileLoc
, const ParserConfig
&config
, bool lazyLoading
,
1319 llvm::MemoryBufferRef buffer
,
1320 const std::shared_ptr
<llvm::SourceMgr
> &bufferOwnerRef
)
1321 : config(config
), fileLoc(fileLoc
), lazyLoading(lazyLoading
),
1322 attrTypeReader(stringReader
, resourceReader
, dialectsMap
, version
,
1324 // Use the builtin unrealized conversion cast operation to represent
1325 // forward references to values that aren't yet defined.
1326 forwardRefOpState(UnknownLoc::get(config
.getContext()),
1327 "builtin.unrealized_conversion_cast", ValueRange(),
1328 NoneType::get(config
.getContext())),
1329 buffer(buffer
), bufferOwnerRef(bufferOwnerRef
) {}
1331 /// Read the bytecode defined within `buffer` into the given block.
1332 LogicalResult
read(Block
*block
,
1333 llvm::function_ref
<bool(Operation
*)> lazyOps
);
1335 /// Return the number of ops that haven't been materialized yet.
1336 int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap
.size(); }
1338 bool isMaterializable(Operation
*op
) { return lazyLoadableOpsMap
.count(op
); }
1340 /// Materialize the provided operation, invoke the lazyOpsCallback on every
1341 /// newly found lazy operation.
1343 materialize(Operation
*op
,
1344 llvm::function_ref
<bool(Operation
*)> lazyOpsCallback
) {
1345 this->lazyOpsCallback
= lazyOpsCallback
;
1346 auto resetlazyOpsCallback
=
1347 llvm::make_scope_exit([&] { this->lazyOpsCallback
= nullptr; });
1348 auto it
= lazyLoadableOpsMap
.find(op
);
1349 assert(it
!= lazyLoadableOpsMap
.end() &&
1350 "materialize called on non-materializable op");
1351 return materialize(it
);
1354 /// Materialize all operations.
1355 LogicalResult
materializeAll() {
1356 while (!lazyLoadableOpsMap
.empty()) {
1357 if (failed(materialize(lazyLoadableOpsMap
.begin())))
1363 /// Finalize the lazy-loading by calling back with every op that hasn't been
1364 /// materialized to let the client decide if the op should be deleted or
1365 /// materialized. The op is materialized if the callback returns true, deleted
1367 LogicalResult
finalize(function_ref
<bool(Operation
*)> shouldMaterialize
) {
1368 while (!lazyLoadableOps
.empty()) {
1369 Operation
*op
= lazyLoadableOps
.begin()->first
;
1370 if (shouldMaterialize(op
)) {
1371 if (failed(materialize(lazyLoadableOpsMap
.find(op
))))
1375 op
->dropAllReferences();
1377 lazyLoadableOps
.pop_front();
1378 lazyLoadableOpsMap
.erase(op
);
1384 LogicalResult
materialize(LazyLoadableOpsMap::iterator it
) {
1385 assert(it
!= lazyLoadableOpsMap
.end() &&
1386 "materialize called on non-materializable op");
1387 valueScopes
.emplace_back();
1388 std::vector
<RegionReadState
> regionStack
;
1389 regionStack
.push_back(std::move(it
->getSecond()->second
));
1390 lazyLoadableOps
.erase(it
->getSecond());
1391 lazyLoadableOpsMap
.erase(it
);
1393 while (!regionStack
.empty())
1394 if (failed(parseRegions(regionStack
, regionStack
.back())))
1399 /// Return the context for this config.
1400 MLIRContext
*getContext() const { return config
.getContext(); }
1402 /// Parse the bytecode version.
1403 LogicalResult
parseVersion(EncodingReader
&reader
);
1405 //===--------------------------------------------------------------------===//
1408 LogicalResult
parseDialectSection(ArrayRef
<uint8_t> sectionData
);
1410 /// Parse an operation name reference using the given reader, and set the
1411 /// `wasRegistered` flag that indicates if the bytecode was produced by a
1412 /// context where opName was registered.
1413 FailureOr
<OperationName
> parseOpName(EncodingReader
&reader
,
1414 std::optional
<bool> &wasRegistered
);
1416 //===--------------------------------------------------------------------===//
1417 // Attribute/Type Section
1419 /// Parse an attribute or type using the given reader.
1420 template <typename T
>
1421 LogicalResult
parseAttribute(EncodingReader
&reader
, T
&result
) {
1422 return attrTypeReader
.parseAttribute(reader
, result
);
1424 LogicalResult
parseType(EncodingReader
&reader
, Type
&result
) {
1425 return attrTypeReader
.parseType(reader
, result
);
1428 //===--------------------------------------------------------------------===//
1432 parseResourceSection(EncodingReader
&reader
,
1433 std::optional
<ArrayRef
<uint8_t>> resourceData
,
1434 std::optional
<ArrayRef
<uint8_t>> resourceOffsetData
);
1436 //===--------------------------------------------------------------------===//
1439 /// This struct represents the current read state of a range of regions. This
1440 /// struct is used to enable iterative parsing of regions.
1441 struct RegionReadState
{
1442 RegionReadState(Operation
*op
, EncodingReader
*reader
,
1443 bool isIsolatedFromAbove
)
1444 : RegionReadState(op
->getRegions(), reader
, isIsolatedFromAbove
) {}
1445 RegionReadState(MutableArrayRef
<Region
> regions
, EncodingReader
*reader
,
1446 bool isIsolatedFromAbove
)
1447 : curRegion(regions
.begin()), endRegion(regions
.end()), reader(reader
),
1448 isIsolatedFromAbove(isIsolatedFromAbove
) {}
1450 /// The current regions being read.
1451 MutableArrayRef
<Region
>::iterator curRegion
, endRegion
;
1452 /// This is the reader to use for this region, this pointer is pointing to
1453 /// the parent region reader unless the current region is IsolatedFromAbove,
1454 /// in which case the pointer is pointing to the `owningReader` which is a
1455 /// section dedicated to the current region.
1456 EncodingReader
*reader
;
1457 std::unique_ptr
<EncodingReader
> owningReader
;
1459 /// The number of values defined immediately within this region.
1460 unsigned numValues
= 0;
1462 /// The current blocks of the region being read.
1463 SmallVector
<Block
*> curBlocks
;
1464 Region::iterator curBlock
= {};
1466 /// The number of operations remaining to be read from the current block
1468 uint64_t numOpsRemaining
= 0;
1470 /// A flag indicating if the regions being read are isolated from above.
1471 bool isIsolatedFromAbove
= false;
1474 LogicalResult
parseIRSection(ArrayRef
<uint8_t> sectionData
, Block
*block
);
1475 LogicalResult
parseRegions(std::vector
<RegionReadState
> ®ionStack
,
1476 RegionReadState
&readState
);
1477 FailureOr
<Operation
*> parseOpWithoutRegions(EncodingReader
&reader
,
1478 RegionReadState
&readState
,
1479 bool &isIsolatedFromAbove
);
1481 LogicalResult
parseRegion(RegionReadState
&readState
);
1482 LogicalResult
parseBlockHeader(EncodingReader
&reader
,
1483 RegionReadState
&readState
);
1484 LogicalResult
parseBlockArguments(EncodingReader
&reader
, Block
*block
);
1486 //===--------------------------------------------------------------------===//
1489 /// Parse an operand reference using the given reader. Returns nullptr in the
1490 /// case of failure.
1491 Value
parseOperand(EncodingReader
&reader
);
1493 /// Sequentially define the given value range.
1494 LogicalResult
defineValues(EncodingReader
&reader
, ValueRange values
);
1496 /// Create a value to use for a forward reference.
1497 Value
createForwardRef();
1499 //===--------------------------------------------------------------------===//
1500 // Use-list order helpers
1502 /// This struct is a simple storage that contains information required to
1503 /// reorder the use-list of a value with respect to the pre-order traversal
1505 struct UseListOrderStorage
{
1506 UseListOrderStorage(bool isIndexPairEncoding
,
1507 SmallVector
<unsigned, 4> &&indices
)
1508 : indices(std::move(indices
)),
1509 isIndexPairEncoding(isIndexPairEncoding
){};
1510 /// The vector containing the information required to reorder the
1511 /// use-list of a value.
1512 SmallVector
<unsigned, 4> indices
;
1514 /// Whether indices represent a pair of type `(src, dst)` or it is a direct
1515 /// indexing, such as `dst = order[src]`.
1516 bool isIndexPairEncoding
;
1519 /// Parse use-list order from bytecode for a range of values if available. The
1520 /// range is expected to be either a block argument or an op result range. On
1521 /// success, return a map of the position in the range and the use-list order
1522 /// encoding. The function assumes to know the size of the range it is
1524 using UseListMapT
= DenseMap
<unsigned, UseListOrderStorage
>;
1525 FailureOr
<UseListMapT
> parseUseListOrderForRange(EncodingReader
&reader
,
1526 uint64_t rangeSize
);
1528 /// Shuffle the use-chain according to the order parsed.
1529 LogicalResult
sortUseListOrder(Value value
);
1531 /// Recursively visit all the values defined within topLevelOp and sort the
1532 /// use-list orders according to the indices parsed.
1533 LogicalResult
processUseLists(Operation
*topLevelOp
);
1535 //===--------------------------------------------------------------------===//
1538 /// This class represents a single value scope, in which a value scope is
1539 /// delimited by isolated from above regions.
1541 /// Push a new region state onto this scope, reserving enough values for
1542 /// those defined within the current region of the provided state.
1543 void push(RegionReadState
&readState
) {
1544 nextValueIDs
.push_back(values
.size());
1545 values
.resize(values
.size() + readState
.numValues
);
1548 /// Pop the values defined for the current region within the provided region
1550 void pop(RegionReadState
&readState
) {
1551 values
.resize(values
.size() - readState
.numValues
);
1552 nextValueIDs
.pop_back();
1555 /// The set of values defined in this scope.
1556 std::vector
<Value
> values
;
1558 /// The ID for the next defined value for each region current being
1559 /// processed in this scope.
1560 SmallVector
<unsigned, 4> nextValueIDs
;
1563 /// The configuration of the parser.
1564 const ParserConfig
&config
;
1566 /// A location to use when emitting errors.
1569 /// Flag that indicates if lazyloading is enabled.
1572 /// Keep track of operations that have been lazy loaded (their regions haven't
1573 /// been materialized), along with the `RegionReadState` that allows to
1574 /// lazy-load the regions nested under the operation.
1575 LazyLoadableOpsInfo lazyLoadableOps
;
1576 LazyLoadableOpsMap lazyLoadableOpsMap
;
1577 llvm::function_ref
<bool(Operation
*)> lazyOpsCallback
;
1579 /// The reader used to process attribute and types within the bytecode.
1580 AttrTypeReader attrTypeReader
;
1582 /// The version of the bytecode being read.
1583 uint64_t version
= 0;
1585 /// The producer of the bytecode being read.
1588 /// The table of IR units referenced within the bytecode file.
1589 SmallVector
<std::unique_ptr
<BytecodeDialect
>> dialects
;
1590 llvm::StringMap
<BytecodeDialect
*> dialectsMap
;
1591 SmallVector
<BytecodeOperationName
> opNames
;
1593 /// The reader used to process resources within the bytecode.
1594 ResourceSectionReader resourceReader
;
1596 /// Worklist of values with custom use-list orders to process before the end
1598 DenseMap
<void *, UseListOrderStorage
> valueToUseListMap
;
1600 /// The table of strings referenced within the bytecode file.
1601 StringSectionReader stringReader
;
1603 /// The table of properties referenced by the operation in the bytecode file.
1604 PropertiesSectionReader propertiesReader
;
1606 /// The current set of available IR value scopes.
1607 std::vector
<ValueScope
> valueScopes
;
1609 /// The global pre-order operation ordering.
1610 DenseMap
<Operation
*, unsigned> operationIDs
;
1612 /// A block containing the set of operations defined to create forward
1614 Block forwardRefOps
;
1616 /// A block containing previously created, and no longer used, forward
1617 /// reference operations.
1618 Block openForwardRefOps
;
1620 /// An operation state used when instantiating forward references.
1621 OperationState forwardRefOpState
;
1623 /// Reference to the input buffer.
1624 llvm::MemoryBufferRef buffer
;
1626 /// The optional owning source manager, which when present may be used to
1627 /// extend the lifetime of the input buffer.
1628 const std::shared_ptr
<llvm::SourceMgr
> &bufferOwnerRef
;
1631 LogicalResult
BytecodeReader::Impl::read(
1632 Block
*block
, llvm::function_ref
<bool(Operation
*)> lazyOpsCallback
) {
1633 EncodingReader
reader(buffer
.getBuffer(), fileLoc
);
1634 this->lazyOpsCallback
= lazyOpsCallback
;
1635 auto resetlazyOpsCallback
=
1636 llvm::make_scope_exit([&] { this->lazyOpsCallback
= nullptr; });
1638 // Skip over the bytecode header, this should have already been checked.
1639 if (failed(reader
.skipBytes(StringRef("ML\xefR").size())))
1641 // Parse the bytecode version and producer.
1642 if (failed(parseVersion(reader
)) ||
1643 failed(reader
.parseNullTerminatedString(producer
)))
1646 // Add a diagnostic handler that attaches a note that includes the original
1647 // producer of the bytecode.
1648 ScopedDiagnosticHandler
diagHandler(getContext(), [&](Diagnostic
&diag
) {
1649 diag
.attachNote() << "in bytecode version " << version
1650 << " produced by: " << producer
;
1654 // Parse the raw data for each of the top-level sections of the bytecode.
1655 std::optional
<ArrayRef
<uint8_t>>
1656 sectionDatas
[bytecode::Section::kNumSections
];
1657 while (!reader
.empty()) {
1658 // Read the next section from the bytecode.
1659 bytecode::Section::ID sectionID
;
1660 ArrayRef
<uint8_t> sectionData
;
1661 if (failed(reader
.parseSection(sectionID
, sectionData
)))
1664 // Check for duplicate sections, we only expect one instance of each.
1665 if (sectionDatas
[sectionID
]) {
1666 return reader
.emitError("duplicate top-level section: ",
1667 ::toString(sectionID
));
1669 sectionDatas
[sectionID
] = sectionData
;
1671 // Check that all of the required sections were found.
1672 for (int i
= 0; i
< bytecode::Section::kNumSections
; ++i
) {
1673 bytecode::Section::ID sectionID
= static_cast<bytecode::Section::ID
>(i
);
1674 if (!sectionDatas
[i
] && !isSectionOptional(sectionID
, version
)) {
1675 return reader
.emitError("missing data for top-level section: ",
1676 ::toString(sectionID
));
1680 // Process the string section first.
1681 if (failed(stringReader
.initialize(
1682 fileLoc
, *sectionDatas
[bytecode::Section::kString
])))
1685 // Process the properties section.
1686 if (sectionDatas
[bytecode::Section::kProperties
] &&
1687 failed(propertiesReader
.initialize(
1688 fileLoc
, *sectionDatas
[bytecode::Section::kProperties
])))
1691 // Process the dialect section.
1692 if (failed(parseDialectSection(*sectionDatas
[bytecode::Section::kDialect
])))
1695 // Process the resource section if present.
1696 if (failed(parseResourceSection(
1697 reader
, sectionDatas
[bytecode::Section::kResource
],
1698 sectionDatas
[bytecode::Section::kResourceOffset
])))
1701 // Process the attribute and type section.
1702 if (failed(attrTypeReader
.initialize(
1703 dialects
, *sectionDatas
[bytecode::Section::kAttrType
],
1704 *sectionDatas
[bytecode::Section::kAttrTypeOffset
])))
1707 // Finally, process the IR section.
1708 return parseIRSection(*sectionDatas
[bytecode::Section::kIR
], block
);
1711 LogicalResult
BytecodeReader::Impl::parseVersion(EncodingReader
&reader
) {
1712 if (failed(reader
.parseVarInt(version
)))
1715 // Validate the bytecode version.
1716 uint64_t currentVersion
= bytecode::kVersion
;
1717 uint64_t minSupportedVersion
= bytecode::kMinSupportedVersion
;
1718 if (version
< minSupportedVersion
) {
1719 return reader
.emitError("bytecode version ", version
,
1720 " is older than the current version of ",
1721 currentVersion
, ", and upgrade is not supported");
1723 if (version
> currentVersion
) {
1724 return reader
.emitError("bytecode version ", version
,
1725 " is newer than the current version ",
1728 // Override any request to lazy-load if the bytecode version is too old.
1729 if (version
< bytecode::kLazyLoading
)
1730 lazyLoading
= false;
1734 //===----------------------------------------------------------------------===//
1737 LogicalResult
BytecodeDialect::load(const DialectReader
&reader
,
1741 Dialect
*loadedDialect
= ctx
->getOrLoadDialect(name
);
1742 if (!loadedDialect
&& !ctx
->allowsUnregisteredDialects()) {
1743 return reader
.emitError("dialect '")
1745 << "' is unknown. If this is intended, please call "
1746 "allowUnregisteredDialects() on the MLIRContext, or use "
1747 "-allow-unregistered-dialect with the MLIR tool used.";
1749 dialect
= loadedDialect
;
1751 // If the dialect was actually loaded, check to see if it has a bytecode
1754 interface
= dyn_cast
<BytecodeDialectInterface
>(loadedDialect
);
1755 if (!versionBuffer
.empty()) {
1757 return reader
.emitError("dialect '")
1759 << "' does not implement the bytecode interface, "
1760 "but found a version entry";
1761 EncodingReader
encReader(versionBuffer
, reader
.getLoc());
1762 DialectReader versionReader
= reader
.withEncodingReader(encReader
);
1763 loadedVersion
= interface
->readVersion(versionReader
);
1771 BytecodeReader::Impl::parseDialectSection(ArrayRef
<uint8_t> sectionData
) {
1772 EncodingReader
sectionReader(sectionData
, fileLoc
);
1774 // Parse the number of dialects in the section.
1775 uint64_t numDialects
;
1776 if (failed(sectionReader
.parseVarInt(numDialects
)))
1778 dialects
.resize(numDialects
);
1780 // Parse each of the dialects.
1781 for (uint64_t i
= 0; i
< numDialects
; ++i
) {
1782 dialects
[i
] = std::make_unique
<BytecodeDialect
>();
1783 /// Before version kDialectVersioning, there wasn't any versioning available
1784 /// for dialects, and the entryIdx represent the string itself.
1785 if (version
< bytecode::kDialectVersioning
) {
1786 if (failed(stringReader
.parseString(sectionReader
, dialects
[i
]->name
)))
1791 // Parse ID representing dialect and version.
1792 uint64_t dialectNameIdx
;
1793 bool versionAvailable
;
1794 if (failed(sectionReader
.parseVarIntWithFlag(dialectNameIdx
,
1797 if (failed(stringReader
.parseStringAtIndex(sectionReader
, dialectNameIdx
,
1798 dialects
[i
]->name
)))
1800 if (versionAvailable
) {
1801 bytecode::Section::ID sectionID
;
1802 if (failed(sectionReader
.parseSection(sectionID
,
1803 dialects
[i
]->versionBuffer
)))
1805 if (sectionID
!= bytecode::Section::kDialectVersions
) {
1806 emitError(fileLoc
, "expected dialect version section");
1810 dialectsMap
[dialects
[i
]->name
] = dialects
[i
].get();
1813 // Parse the operation names, which are grouped by dialect.
1814 auto parseOpName
= [&](BytecodeDialect
*dialect
) {
1816 std::optional
<bool> wasRegistered
;
1817 // Prior to version kNativePropertiesEncoding, the information about wheter
1818 // an op was registered or not wasn't encoded.
1819 if (version
< bytecode::kNativePropertiesEncoding
) {
1820 if (failed(stringReader
.parseString(sectionReader
, opName
)))
1823 bool wasRegisteredFlag
;
1824 if (failed(stringReader
.parseStringWithFlag(sectionReader
, opName
,
1825 wasRegisteredFlag
)))
1827 wasRegistered
= wasRegisteredFlag
;
1829 opNames
.emplace_back(dialect
, opName
, wasRegistered
);
1832 // Avoid re-allocation in bytecode version >=kElideUnknownBlockArgLocation
1833 // where the number of ops are known.
1834 if (version
>= bytecode::kElideUnknownBlockArgLocation
) {
1836 if (failed(sectionReader
.parseVarInt(numOps
)))
1838 opNames
.reserve(numOps
);
1840 while (!sectionReader
.empty())
1841 if (failed(parseDialectGrouping(sectionReader
, dialects
, parseOpName
)))
1846 FailureOr
<OperationName
>
1847 BytecodeReader::Impl::parseOpName(EncodingReader
&reader
,
1848 std::optional
<bool> &wasRegistered
) {
1849 BytecodeOperationName
*opName
= nullptr;
1850 if (failed(parseEntry(reader
, opNames
, opName
, "operation name")))
1852 wasRegistered
= opName
->wasRegistered
;
1853 // Check to see if this operation name has already been resolved. If we
1854 // haven't, load the dialect and build the operation name.
1855 if (!opName
->opName
) {
1856 // If the opName is empty, this is because we use to accept names such as
1857 // `foo` without any `.` separator. We shouldn't tolerate this in textual
1858 // format anymore but for now we'll be backward compatible. This can only
1859 // happen with unregistered dialects.
1860 if (opName
->name
.empty()) {
1861 opName
->opName
.emplace(opName
->dialect
->name
, getContext());
1863 // Load the dialect and its version.
1864 DialectReader
dialectReader(attrTypeReader
, stringReader
, resourceReader
,
1865 dialectsMap
, reader
, version
);
1866 if (failed(opName
->dialect
->load(dialectReader
, getContext())))
1868 opName
->opName
.emplace((opName
->dialect
->name
+ "." + opName
->name
).str(),
1872 return *opName
->opName
;
1875 //===----------------------------------------------------------------------===//
1878 LogicalResult
BytecodeReader::Impl::parseResourceSection(
1879 EncodingReader
&reader
, std::optional
<ArrayRef
<uint8_t>> resourceData
,
1880 std::optional
<ArrayRef
<uint8_t>> resourceOffsetData
) {
1881 // Ensure both sections are either present or not.
1882 if (resourceData
.has_value() != resourceOffsetData
.has_value()) {
1883 if (resourceOffsetData
)
1884 return emitError(fileLoc
, "unexpected resource offset section when "
1885 "resource section is not present");
1888 "expected resource offset section when resource section is present");
1891 // If the resource sections are absent, there is nothing to do.
1895 // Initialize the resource reader with the resource sections.
1896 DialectReader
dialectReader(attrTypeReader
, stringReader
, resourceReader
,
1897 dialectsMap
, reader
, version
);
1898 return resourceReader
.initialize(fileLoc
, config
, dialects
, stringReader
,
1899 *resourceData
, *resourceOffsetData
,
1900 dialectReader
, bufferOwnerRef
);
1903 //===----------------------------------------------------------------------===//
1904 // UseListOrder Helpers
1906 FailureOr
<BytecodeReader::Impl::UseListMapT
>
1907 BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader
&reader
,
1908 uint64_t numResults
) {
1909 BytecodeReader::Impl::UseListMapT map
;
1910 uint64_t numValuesToRead
= 1;
1911 if (numResults
> 1 && failed(reader
.parseVarInt(numValuesToRead
)))
1914 for (size_t valueIdx
= 0; valueIdx
< numValuesToRead
; valueIdx
++) {
1915 uint64_t resultIdx
= 0;
1916 if (numResults
> 1 && failed(reader
.parseVarInt(resultIdx
)))
1920 bool indexPairEncoding
;
1921 if (failed(reader
.parseVarIntWithFlag(numValues
, indexPairEncoding
)))
1924 SmallVector
<unsigned, 4> useListOrders
;
1925 for (size_t idx
= 0; idx
< numValues
; idx
++) {
1927 if (failed(reader
.parseVarInt(index
)))
1929 useListOrders
.push_back(index
);
1932 // Store in a map the result index
1933 map
.try_emplace(resultIdx
, UseListOrderStorage(indexPairEncoding
,
1934 std::move(useListOrders
)));
1940 /// Sorts each use according to the order specified in the use-list parsed. If
1941 /// the custom use-list is not found, this means that the order needs to be
1942 /// consistent with the reverse pre-order walk of the IR. If multiple uses lie
1943 /// on the same operation, the order will follow the reverse operand number
1945 LogicalResult
BytecodeReader::Impl::sortUseListOrder(Value value
) {
1946 // Early return for trivial use-lists.
1947 if (value
.use_empty() || value
.hasOneUse())
1950 bool hasIncomingOrder
=
1951 valueToUseListMap
.contains(value
.getAsOpaquePointer());
1953 // Compute the current order of the use-list with respect to the global
1954 // ordering. Detect if the order is already sorted while doing so.
1955 bool alreadySorted
= true;
1956 auto &firstUse
= *value
.use_begin();
1958 bytecode::getUseID(firstUse
, operationIDs
.at(firstUse
.getOwner()));
1959 llvm::SmallVector
<std::pair
<unsigned, uint64_t>> currentOrder
= {{0, prevID
}};
1960 for (auto item
: llvm::drop_begin(llvm::enumerate(value
.getUses()))) {
1961 uint64_t currentID
= bytecode::getUseID(
1962 item
.value(), operationIDs
.at(item
.value().getOwner()));
1963 alreadySorted
&= prevID
> currentID
;
1964 currentOrder
.push_back({item
.index(), currentID
});
1968 // If the order is already sorted, and there wasn't a custom order to apply
1969 // from the bytecode file, we are done.
1970 if (alreadySorted
&& !hasIncomingOrder
)
1973 // If not already sorted, sort the indices of the current order by descending
1977 currentOrder
.begin(), currentOrder
.end(),
1978 [](auto elem1
, auto elem2
) { return elem1
.second
> elem2
.second
; });
1980 if (!hasIncomingOrder
) {
1981 // If the bytecode file did not contain any custom use-list order, it means
1982 // that the order was descending useID. Hence, shuffle by the first index
1983 // of the `currentOrder` pair.
1984 SmallVector
<unsigned> shuffle
= SmallVector
<unsigned>(
1985 llvm::map_range(currentOrder
, [&](auto item
) { return item
.first
; }));
1986 value
.shuffleUseList(shuffle
);
1990 // Pull the custom order info from the map.
1991 UseListOrderStorage customOrder
=
1992 valueToUseListMap
.at(value
.getAsOpaquePointer());
1993 SmallVector
<unsigned, 4> shuffle
= std::move(customOrder
.indices
);
1995 std::distance(value
.getUses().begin(), value
.getUses().end());
1997 // If the encoding was a pair of indices `(src, dst)` for every permutation,
1998 // reconstruct the shuffle vector for every use. Initialize the shuffle vector
1999 // as identity, and then apply the mapping encoded in the indices.
2000 if (customOrder
.isIndexPairEncoding
) {
2001 // Return failure if the number of indices was not representing pairs.
2002 if (shuffle
.size() & 1)
2005 SmallVector
<unsigned, 4> newShuffle(numUses
);
2007 std::iota(newShuffle
.begin(), newShuffle
.end(), idx
);
2008 for (idx
= 0; idx
< shuffle
.size(); idx
+= 2)
2009 newShuffle
[shuffle
[idx
]] = shuffle
[idx
+ 1];
2011 shuffle
= std::move(newShuffle
);
2014 // Make sure that the indices represent a valid mapping. That is, the sum of
2015 // all the values needs to be equal to (numUses - 1) * numUses / 2, and no
2016 // duplicates are allowed in the list.
2017 DenseSet
<unsigned> set
;
2018 uint64_t accumulator
= 0;
2019 for (const auto &elem
: shuffle
) {
2020 if (!set
.insert(elem
).second
)
2022 accumulator
+= elem
;
2024 if (numUses
!= shuffle
.size() ||
2025 accumulator
!= (((numUses
- 1) * numUses
) >> 1))
2028 // Apply the current ordering map onto the shuffle vector to get the final
2029 // use-list sorting indices before shuffling.
2030 shuffle
= SmallVector
<unsigned, 4>(llvm::map_range(
2031 currentOrder
, [&](auto item
) { return shuffle
[item
.first
]; }));
2032 value
.shuffleUseList(shuffle
);
2036 LogicalResult
BytecodeReader::Impl::processUseLists(Operation
*topLevelOp
) {
2037 // Precompute operation IDs according to the pre-order walk of the IR. We
2038 // can't do this while parsing since parseRegions ordering is not strictly
2039 // equal to the pre-order walk.
2040 unsigned operationID
= 0;
2041 topLevelOp
->walk
<mlir::WalkOrder::PreOrder
>(
2042 [&](Operation
*op
) { operationIDs
.try_emplace(op
, operationID
++); });
2044 auto blockWalk
= topLevelOp
->walk([this](Block
*block
) {
2045 for (auto arg
: block
->getArguments())
2046 if (failed(sortUseListOrder(arg
)))
2047 return WalkResult::interrupt();
2048 return WalkResult::advance();
2051 auto resultWalk
= topLevelOp
->walk([this](Operation
*op
) {
2052 for (auto result
: op
->getResults())
2053 if (failed(sortUseListOrder(result
)))
2054 return WalkResult::interrupt();
2055 return WalkResult::advance();
2058 return failure(blockWalk
.wasInterrupted() || resultWalk
.wasInterrupted());
2061 //===----------------------------------------------------------------------===//
2065 BytecodeReader::Impl::parseIRSection(ArrayRef
<uint8_t> sectionData
,
2067 EncodingReader
reader(sectionData
, fileLoc
);
2069 // A stack of operation regions currently being read from the bytecode.
2070 std::vector
<RegionReadState
> regionStack
;
2072 // Parse the top-level block using a temporary module operation.
2073 OwningOpRef
<ModuleOp
> moduleOp
= ModuleOp::create(fileLoc
);
2074 regionStack
.emplace_back(*moduleOp
, &reader
, /*isIsolatedFromAbove=*/true);
2075 regionStack
.back().curBlocks
.push_back(moduleOp
->getBody());
2076 regionStack
.back().curBlock
= regionStack
.back().curRegion
->begin();
2077 if (failed(parseBlockHeader(reader
, regionStack
.back())))
2079 valueScopes
.emplace_back();
2080 valueScopes
.back().push(regionStack
.back());
2082 // Iteratively parse regions until everything has been resolved.
2083 while (!regionStack
.empty())
2084 if (failed(parseRegions(regionStack
, regionStack
.back())))
2086 if (!forwardRefOps
.empty()) {
2087 return reader
.emitError(
2088 "not all forward unresolved forward operand references");
2091 // Sort use-lists according to what specified in bytecode.
2092 if (failed(processUseLists(*moduleOp
)))
2093 return reader
.emitError(
2094 "parsed use-list orders were invalid and could not be applied");
2096 // Resolve dialect version.
2097 for (const std::unique_ptr
<BytecodeDialect
> &byteCodeDialect
: dialects
) {
2098 // Parsing is complete, give an opportunity to each dialect to visit the
2099 // IR and perform upgrades.
2100 if (!byteCodeDialect
->loadedVersion
)
2102 if (byteCodeDialect
->interface
&&
2103 failed(byteCodeDialect
->interface
->upgradeFromVersion(
2104 *moduleOp
, *byteCodeDialect
->loadedVersion
)))
2108 // Verify that the parsed operations are valid.
2109 if (config
.shouldVerifyAfterParse() && failed(verify(*moduleOp
)))
2112 // Splice the parsed operations over to the provided top-level block.
2113 auto &parsedOps
= moduleOp
->getBody()->getOperations();
2114 auto &destOps
= block
->getOperations();
2115 destOps
.splice(destOps
.end(), parsedOps
, parsedOps
.begin(), parsedOps
.end());
2120 BytecodeReader::Impl::parseRegions(std::vector
<RegionReadState
> ®ionStack
,
2121 RegionReadState
&readState
) {
2122 // Process regions, blocks, and operations until the end or if a nested
2123 // region is encountered. In this case we push a new state in regionStack and
2124 // return, the processing of the current region will resume afterward.
2125 for (; readState
.curRegion
!= readState
.endRegion
; ++readState
.curRegion
) {
2126 // If the current block hasn't been setup yet, parse the header for this
2127 // region. The current block is already setup when this function was
2128 // interrupted to recurse down in a nested region and we resume the current
2129 // block after processing the nested region.
2130 if (readState
.curBlock
== Region::iterator()) {
2131 if (failed(parseRegion(readState
)))
2134 // If the region is empty, there is nothing to more to do.
2135 if (readState
.curRegion
->empty())
2139 // Parse the blocks within the region.
2140 EncodingReader
&reader
= *readState
.reader
;
2142 while (readState
.numOpsRemaining
--) {
2143 // Read in the next operation. We don't read its regions directly, we
2144 // handle those afterwards as necessary.
2145 bool isIsolatedFromAbove
= false;
2146 FailureOr
<Operation
*> op
=
2147 parseOpWithoutRegions(reader
, readState
, isIsolatedFromAbove
);
2151 // If the op has regions, add it to the stack for processing and return:
2152 // we stop the processing of the current region and resume it after the
2153 // inner one is completed. Unless LazyLoading is activated in which case
2154 // nested region parsing is delayed.
2155 if ((*op
)->getNumRegions()) {
2156 RegionReadState
childState(*op
, &reader
, isIsolatedFromAbove
);
2158 // Isolated regions are encoded as a section in version 2 and above.
2159 if (version
>= bytecode::kLazyLoading
&& isIsolatedFromAbove
) {
2160 bytecode::Section::ID sectionID
;
2161 ArrayRef
<uint8_t> sectionData
;
2162 if (failed(reader
.parseSection(sectionID
, sectionData
)))
2164 if (sectionID
!= bytecode::Section::kIR
)
2165 return emitError(fileLoc
, "expected IR section for region");
2166 childState
.owningReader
=
2167 std::make_unique
<EncodingReader
>(sectionData
, fileLoc
);
2168 childState
.reader
= childState
.owningReader
.get();
2170 // If the user has a callback set, they have the opportunity to
2171 // control lazyloading as we go.
2172 if (lazyLoading
&& (!lazyOpsCallback
|| !lazyOpsCallback(*op
))) {
2173 lazyLoadableOps
.emplace_back(*op
, std::move(childState
));
2174 lazyLoadableOpsMap
.try_emplace(*op
,
2175 std::prev(lazyLoadableOps
.end()));
2179 regionStack
.push_back(std::move(childState
));
2181 // If the op is isolated from above, push a new value scope.
2182 if (isIsolatedFromAbove
)
2183 valueScopes
.emplace_back();
2188 // Move to the next block of the region.
2189 if (++readState
.curBlock
== readState
.curRegion
->end())
2191 if (failed(parseBlockHeader(reader
, readState
)))
2195 // Reset the current block and any values reserved for this region.
2196 readState
.curBlock
= {};
2197 valueScopes
.back().pop(readState
);
2200 // When the regions have been fully parsed, pop them off of the read stack. If
2201 // the regions were isolated from above, we also pop the last value scope.
2202 if (readState
.isIsolatedFromAbove
) {
2203 assert(!valueScopes
.empty() && "Expect a valueScope after reading region");
2204 valueScopes
.pop_back();
2206 assert(!regionStack
.empty() && "Expect a regionStack after reading region");
2207 regionStack
.pop_back();
2211 FailureOr
<Operation
*>
2212 BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader
&reader
,
2213 RegionReadState
&readState
,
2214 bool &isIsolatedFromAbove
) {
2215 // Parse the name of the operation.
2216 std::optional
<bool> wasRegistered
;
2217 FailureOr
<OperationName
> opName
= parseOpName(reader
, wasRegistered
);
2221 // Parse the operation mask, which indicates which components of the operation
2224 if (failed(reader
.parseByte(opMask
)))
2227 /// Parse the location.
2229 if (failed(parseAttribute(reader
, opLoc
)))
2232 // With the location and name resolved, we can start building the operation
2234 OperationState
opState(opLoc
, *opName
);
2236 // Parse the attributes of the operation.
2237 if (opMask
& bytecode::OpEncodingMask::kHasAttrs
) {
2238 DictionaryAttr dictAttr
;
2239 if (failed(parseAttribute(reader
, dictAttr
)))
2241 opState
.attributes
= dictAttr
;
2244 if (opMask
& bytecode::OpEncodingMask::kHasProperties
) {
2245 // kHasProperties wasn't emitted in older bytecode, we should never get
2246 // there without also having the `wasRegistered` flag available.
2248 return emitError(fileLoc
,
2249 "Unexpected missing `wasRegistered` opname flag at "
2250 "bytecode version ")
2251 << version
<< " with properties.";
2252 // When an operation is emitted without being registered, the properties are
2253 // stored as an attribute. Otherwise the op must implement the bytecode
2254 // interface and control the serialization.
2255 if (wasRegistered
) {
2256 DialectReader
dialectReader(attrTypeReader
, stringReader
, resourceReader
,
2257 dialectsMap
, reader
, version
);
2259 propertiesReader
.read(fileLoc
, dialectReader
, &*opName
, opState
)))
2262 // If the operation wasn't registered when it was emitted, the properties
2263 // was serialized as an attribute.
2264 if (failed(parseAttribute(reader
, opState
.propertiesAttr
)))
2269 /// Parse the results of the operation.
2270 if (opMask
& bytecode::OpEncodingMask::kHasResults
) {
2271 uint64_t numResults
;
2272 if (failed(reader
.parseVarInt(numResults
)))
2274 opState
.types
.resize(numResults
);
2275 for (int i
= 0, e
= numResults
; i
< e
; ++i
)
2276 if (failed(parseType(reader
, opState
.types
[i
])))
2280 /// Parse the operands of the operation.
2281 if (opMask
& bytecode::OpEncodingMask::kHasOperands
) {
2282 uint64_t numOperands
;
2283 if (failed(reader
.parseVarInt(numOperands
)))
2285 opState
.operands
.resize(numOperands
);
2286 for (int i
= 0, e
= numOperands
; i
< e
; ++i
)
2287 if (!(opState
.operands
[i
] = parseOperand(reader
)))
2291 /// Parse the successors of the operation.
2292 if (opMask
& bytecode::OpEncodingMask::kHasSuccessors
) {
2294 if (failed(reader
.parseVarInt(numSuccs
)))
2296 opState
.successors
.resize(numSuccs
);
2297 for (int i
= 0, e
= numSuccs
; i
< e
; ++i
) {
2298 if (failed(parseEntry(reader
, readState
.curBlocks
, opState
.successors
[i
],
2304 /// Parse the use-list orders for the results of the operation. Use-list
2305 /// orders are available since version 3 of the bytecode.
2306 std::optional
<UseListMapT
> resultIdxToUseListMap
= std::nullopt
;
2307 if (version
>= bytecode::kUseListOrdering
&&
2308 (opMask
& bytecode::OpEncodingMask::kHasUseListOrders
)) {
2309 size_t numResults
= opState
.types
.size();
2310 auto parseResult
= parseUseListOrderForRange(reader
, numResults
);
2311 if (failed(parseResult
))
2313 resultIdxToUseListMap
= std::move(*parseResult
);
2316 /// Parse the regions of the operation.
2317 if (opMask
& bytecode::OpEncodingMask::kHasInlineRegions
) {
2318 uint64_t numRegions
;
2319 if (failed(reader
.parseVarIntWithFlag(numRegions
, isIsolatedFromAbove
)))
2322 opState
.regions
.reserve(numRegions
);
2323 for (int i
= 0, e
= numRegions
; i
< e
; ++i
)
2324 opState
.regions
.push_back(std::make_unique
<Region
>());
2327 // Create the operation at the back of the current block.
2328 Operation
*op
= Operation::create(opState
);
2329 readState
.curBlock
->push_back(op
);
2331 // If the operation had results, update the value references. We don't need to
2332 // do this if the current value scope is empty. That is, the op was not
2333 // encoded within a parent region.
2334 if (readState
.numValues
&& op
->getNumResults() &&
2335 failed(defineValues(reader
, op
->getResults())))
2338 /// Store a map for every value that received a custom use-list order from the
2340 if (resultIdxToUseListMap
.has_value()) {
2341 for (size_t idx
= 0; idx
< op
->getNumResults(); idx
++) {
2342 if (resultIdxToUseListMap
->contains(idx
)) {
2343 valueToUseListMap
.try_emplace(op
->getResult(idx
).getAsOpaquePointer(),
2344 resultIdxToUseListMap
->at(idx
));
2351 LogicalResult
BytecodeReader::Impl::parseRegion(RegionReadState
&readState
) {
2352 EncodingReader
&reader
= *readState
.reader
;
2354 // Parse the number of blocks in the region.
2356 if (failed(reader
.parseVarInt(numBlocks
)))
2359 // If the region is empty, there is nothing else to do.
2363 // Parse the number of values defined in this region.
2365 if (failed(reader
.parseVarInt(numValues
)))
2367 readState
.numValues
= numValues
;
2369 // Create the blocks within this region. We do this before processing so that
2370 // we can rely on the blocks existing when creating operations.
2371 readState
.curBlocks
.clear();
2372 readState
.curBlocks
.reserve(numBlocks
);
2373 for (uint64_t i
= 0; i
< numBlocks
; ++i
) {
2374 readState
.curBlocks
.push_back(new Block());
2375 readState
.curRegion
->push_back(readState
.curBlocks
.back());
2378 // Prepare the current value scope for this region.
2379 valueScopes
.back().push(readState
);
2381 // Parse the entry block of the region.
2382 readState
.curBlock
= readState
.curRegion
->begin();
2383 return parseBlockHeader(reader
, readState
);
2387 BytecodeReader::Impl::parseBlockHeader(EncodingReader
&reader
,
2388 RegionReadState
&readState
) {
2390 if (failed(reader
.parseVarIntWithFlag(readState
.numOpsRemaining
, hasArgs
)))
2393 // Parse the arguments of the block.
2394 if (hasArgs
&& failed(parseBlockArguments(reader
, &*readState
.curBlock
)))
2397 // Uselist orders are available since version 3 of the bytecode.
2398 if (version
< bytecode::kUseListOrdering
)
2401 uint8_t hasUseListOrders
= 0;
2402 if (hasArgs
&& failed(reader
.parseByte(hasUseListOrders
)))
2405 if (!hasUseListOrders
)
2408 Block
&blk
= *readState
.curBlock
;
2409 auto argIdxToUseListMap
=
2410 parseUseListOrderForRange(reader
, blk
.getNumArguments());
2411 if (failed(argIdxToUseListMap
) || argIdxToUseListMap
->empty())
2414 for (size_t idx
= 0; idx
< blk
.getNumArguments(); idx
++)
2415 if (argIdxToUseListMap
->contains(idx
))
2416 valueToUseListMap
.try_emplace(blk
.getArgument(idx
).getAsOpaquePointer(),
2417 argIdxToUseListMap
->at(idx
));
2419 // We don't parse the operations of the block here, that's done elsewhere.
2423 LogicalResult
BytecodeReader::Impl::parseBlockArguments(EncodingReader
&reader
,
2425 // Parse the value ID for the first argument, and the number of arguments.
2427 if (failed(reader
.parseVarInt(numArgs
)))
2430 SmallVector
<Type
> argTypes
;
2431 SmallVector
<Location
> argLocs
;
2432 argTypes
.reserve(numArgs
);
2433 argLocs
.reserve(numArgs
);
2435 Location unknownLoc
= UnknownLoc::get(config
.getContext());
2438 LocationAttr argLoc
= unknownLoc
;
2439 if (version
>= bytecode::kElideUnknownBlockArgLocation
) {
2440 // Parse the type with hasLoc flag to determine if it has type.
2443 if (failed(reader
.parseVarIntWithFlag(typeIdx
, hasLoc
)) ||
2444 !(argType
= attrTypeReader
.resolveType(typeIdx
)))
2446 if (hasLoc
&& failed(parseAttribute(reader
, argLoc
)))
2449 // All args has type and location.
2450 if (failed(parseType(reader
, argType
)) ||
2451 failed(parseAttribute(reader
, argLoc
)))
2454 argTypes
.push_back(argType
);
2455 argLocs
.push_back(argLoc
);
2457 block
->addArguments(argTypes
, argLocs
);
2458 return defineValues(reader
, block
->getArguments());
2461 //===----------------------------------------------------------------------===//
2464 Value
BytecodeReader::Impl::parseOperand(EncodingReader
&reader
) {
2465 std::vector
<Value
> &values
= valueScopes
.back().values
;
2466 Value
*value
= nullptr;
2467 if (failed(parseEntry(reader
, values
, value
, "value")))
2470 // Create a new forward reference if necessary.
2472 *value
= createForwardRef();
2476 LogicalResult
BytecodeReader::Impl::defineValues(EncodingReader
&reader
,
2477 ValueRange newValues
) {
2478 ValueScope
&valueScope
= valueScopes
.back();
2479 std::vector
<Value
> &values
= valueScope
.values
;
2481 unsigned &valueID
= valueScope
.nextValueIDs
.back();
2482 unsigned valueIDEnd
= valueID
+ newValues
.size();
2483 if (valueIDEnd
> values
.size()) {
2484 return reader
.emitError(
2485 "value index range was outside of the expected range for "
2486 "the parent region, got [",
2487 valueID
, ", ", valueIDEnd
, "), but the maximum index was ",
2491 // Assign the values and update any forward references.
2492 for (unsigned i
= 0, e
= newValues
.size(); i
!= e
; ++i
, ++valueID
) {
2493 Value newValue
= newValues
[i
];
2495 // Check to see if a definition for this value already exists.
2496 if (Value oldValue
= std::exchange(values
[valueID
], newValue
)) {
2497 Operation
*forwardRefOp
= oldValue
.getDefiningOp();
2499 // Assert that this is a forward reference operation. Given how we compute
2500 // definition ids (incrementally as we parse), it shouldn't be possible
2501 // for the value to be defined any other way.
2502 assert(forwardRefOp
&& forwardRefOp
->getBlock() == &forwardRefOps
&&
2503 "value index was already defined?");
2505 oldValue
.replaceAllUsesWith(newValue
);
2506 forwardRefOp
->moveBefore(&openForwardRefOps
, openForwardRefOps
.end());
2512 Value
BytecodeReader::Impl::createForwardRef() {
2513 // Check for an available existing operation to use. Otherwise, create a new
2514 // fake operation to use for the reference.
2515 if (!openForwardRefOps
.empty()) {
2516 Operation
*op
= &openForwardRefOps
.back();
2517 op
->moveBefore(&forwardRefOps
, forwardRefOps
.end());
2519 forwardRefOps
.push_back(Operation::create(forwardRefOpState
));
2521 return forwardRefOps
.back().getResult(0);
2524 //===----------------------------------------------------------------------===//
2526 //===----------------------------------------------------------------------===//
2528 BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); }
2530 BytecodeReader::BytecodeReader(
2531 llvm::MemoryBufferRef buffer
, const ParserConfig
&config
, bool lazyLoading
,
2532 const std::shared_ptr
<llvm::SourceMgr
> &bufferOwnerRef
) {
2533 Location sourceFileLoc
=
2534 FileLineColLoc::get(config
.getContext(), buffer
.getBufferIdentifier(),
2535 /*line=*/0, /*column=*/0);
2536 impl
= std::make_unique
<Impl
>(sourceFileLoc
, config
, lazyLoading
, buffer
,
2540 LogicalResult
BytecodeReader::readTopLevel(
2541 Block
*block
, llvm::function_ref
<bool(Operation
*)> lazyOpsCallback
) {
2542 return impl
->read(block
, lazyOpsCallback
);
2545 int64_t BytecodeReader::getNumOpsToMaterialize() const {
2546 return impl
->getNumOpsToMaterialize();
2549 bool BytecodeReader::isMaterializable(Operation
*op
) {
2550 return impl
->isMaterializable(op
);
2553 LogicalResult
BytecodeReader::materialize(
2554 Operation
*op
, llvm::function_ref
<bool(Operation
*)> lazyOpsCallback
) {
2555 return impl
->materialize(op
, lazyOpsCallback
);
2559 BytecodeReader::finalize(function_ref
<bool(Operation
*)> shouldMaterialize
) {
2560 return impl
->finalize(shouldMaterialize
);
2563 bool mlir::isBytecode(llvm::MemoryBufferRef buffer
) {
2564 return buffer
.getBuffer().starts_with("ML\xefR");
2567 /// Read the bytecode from the provided memory buffer reference.
2568 /// `bufferOwnerRef` if provided is the owning source manager for the buffer,
2569 /// and may be used to extend the lifetime of the buffer.
2570 static LogicalResult
2571 readBytecodeFileImpl(llvm::MemoryBufferRef buffer
, Block
*block
,
2572 const ParserConfig
&config
,
2573 const std::shared_ptr
<llvm::SourceMgr
> &bufferOwnerRef
) {
2574 Location sourceFileLoc
=
2575 FileLineColLoc::get(config
.getContext(), buffer
.getBufferIdentifier(),
2576 /*line=*/0, /*column=*/0);
2577 if (!isBytecode(buffer
)) {
2578 return emitError(sourceFileLoc
,
2579 "input buffer is not an MLIR bytecode file");
2582 BytecodeReader::Impl
reader(sourceFileLoc
, config
, /*lazyLoading=*/false,
2583 buffer
, bufferOwnerRef
);
2584 return reader
.read(block
, /*lazyOpsCallback=*/nullptr);
2587 LogicalResult
mlir::readBytecodeFile(llvm::MemoryBufferRef buffer
, Block
*block
,
2588 const ParserConfig
&config
) {
2589 return readBytecodeFileImpl(buffer
, block
, config
, /*bufferOwnerRef=*/{});
2592 mlir::readBytecodeFile(const std::shared_ptr
<llvm::SourceMgr
> &sourceMgr
,
2593 Block
*block
, const ParserConfig
&config
) {
2594 return readBytecodeFileImpl(
2595 *sourceMgr
->getMemoryBuffer(sourceMgr
->getMainFileID()), block
, config
,