[mlir][sparse] cleanup of enums header (#71090)
[llvm-project.git] / mlir / include / mlir / Dialect / SparseTensor / IR / Enums.h
blob9c277a0b23633d8d89bbad7a19338abbc07af48a
1 //===- Enums.h - Enums for the SparseTensor dialect -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Typedefs and enums shared between MLIR code for manipulating the
10 // IR, and the lightweight runtime support library for sparse tensor
11 // manipulations. That is, all the enums are used to define the API
12 // of the runtime library and hence are also needed when generating
13 // calls into the runtime library. Moveover, the `DimLevelType` enum
14 // is also used as the internal IR encoding of dimension level types,
15 // to avoid code duplication (e.g., for the predicates).
17 // This file also defines x-macros <https://en.wikipedia.org/wiki/X_Macro>
18 // so that we can generate variations of the public functions for each
19 // supported primary- and/or overhead-type.
21 // Because this file defines a library which is a dependency of the
22 // runtime library itself, this file must not depend on any MLIR internals
23 // (e.g., operators, attributes, ArrayRefs, etc) lest the runtime library
24 // inherit those dependencies.
26 //===----------------------------------------------------------------------===//
28 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
29 #define MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H
31 // NOTE: Client code will need to include "mlir/ExecutionEngine/Float16bits.h"
32 // if they want to use the `MLIR_SPARSETENSOR_FOREVERY_V` macro.
34 #include <cassert>
35 #include <cinttypes>
36 #include <complex>
37 #include <optional>
39 namespace mlir {
40 namespace sparse_tensor {
42 /// This type is used in the public API at all places where MLIR expects
43 /// values with the built-in type "index". For now, we simply assume that
44 /// type is 64-bit, but targets with different "index" bitwidths should
45 /// link with an alternatively built runtime support library.
46 using index_type = uint64_t;
48 /// Encoding of overhead types (both position overhead and coordinate
49 /// overhead), for "overloading" @newSparseTensor.
50 enum class OverheadType : uint32_t {
51 kIndex = 0,
52 kU64 = 1,
53 kU32 = 2,
54 kU16 = 3,
55 kU8 = 4
58 // This x-macro calls its argument on every overhead type which has
59 // fixed-width. It excludes `index_type` because that type is often
60 // handled specially (e.g., by translating it into the architecture-dependent
61 // equivalent fixed-width overhead type).
62 #define MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \
63 DO(64, uint64_t) \
64 DO(32, uint32_t) \
65 DO(16, uint16_t) \
66 DO(8, uint8_t)
68 // This x-macro calls its argument on every overhead type, including
69 // `index_type`.
70 #define MLIR_SPARSETENSOR_FOREVERY_O(DO) \
71 MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \
72 DO(0, index_type)
74 // These are not just shorthands but indicate the particular
75 // implementation used (e.g., as opposed to C99's `complex double`,
76 // or MLIR's `ComplexType`).
77 using complex64 = std::complex<double>;
78 using complex32 = std::complex<float>;
80 /// Encoding of the elemental type, for "overloading" @newSparseTensor.
81 enum class PrimaryType : uint32_t {
82 kF64 = 1,
83 kF32 = 2,
84 kF16 = 3,
85 kBF16 = 4,
86 kI64 = 5,
87 kI32 = 6,
88 kI16 = 7,
89 kI8 = 8,
90 kC64 = 9,
91 kC32 = 10
94 // This x-macro includes all `V` types.
95 #define MLIR_SPARSETENSOR_FOREVERY_V(DO) \
96 DO(F64, double) \
97 DO(F32, float) \
98 DO(F16, f16) \
99 DO(BF16, bf16) \
100 DO(I64, int64_t) \
101 DO(I32, int32_t) \
102 DO(I16, int16_t) \
103 DO(I8, int8_t) \
104 DO(C64, complex64) \
105 DO(C32, complex32)
107 // This x-macro includes all `V` types and supports variadic arguments.
108 #define MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, ...) \
109 DO(F64, double, __VA_ARGS__) \
110 DO(F32, float, __VA_ARGS__) \
111 DO(F16, f16, __VA_ARGS__) \
112 DO(BF16, bf16, __VA_ARGS__) \
113 DO(I64, int64_t, __VA_ARGS__) \
114 DO(I32, int32_t, __VA_ARGS__) \
115 DO(I16, int16_t, __VA_ARGS__) \
116 DO(I8, int8_t, __VA_ARGS__) \
117 DO(C64, complex64, __VA_ARGS__) \
118 DO(C32, complex32, __VA_ARGS__)
120 // This x-macro calls its argument on every pair of overhead and `V` types.
121 #define MLIR_SPARSETENSOR_FOREVERY_V_O(DO) \
122 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 64, uint64_t) \
123 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 32, uint32_t) \
124 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 16, uint16_t) \
125 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 8, uint8_t) \
126 MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 0, index_type)
128 constexpr bool isFloatingPrimaryType(PrimaryType valTy) {
129 return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kBF16;
132 constexpr bool isIntegralPrimaryType(PrimaryType valTy) {
133 return PrimaryType::kI64 <= valTy && valTy <= PrimaryType::kI8;
136 constexpr bool isRealPrimaryType(PrimaryType valTy) {
137 return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kI8;
140 constexpr bool isComplexPrimaryType(PrimaryType valTy) {
141 return PrimaryType::kC64 <= valTy && valTy <= PrimaryType::kC32;
144 /// The actions performed by @newSparseTensor.
145 enum class Action : uint32_t {
146 kEmpty = 0,
147 kEmptyForward = 1,
148 kFromCOO = 2,
149 kFromReader = 4,
150 kToCOO = 5,
151 kPack = 7,
152 kSortCOOInPlace = 8,
155 /// This enum defines all the sparse representations supportable by
156 /// the SparseTensor dialect. We use a lightweight encoding to encode
157 /// both the "format" per se (dense, compressed, singleton, loose_compressed,
158 /// two-out-of-four) as well as the "properties" (ordered, unique). The
159 /// encoding is chosen for performance of the runtime library, and thus may
160 /// change in future versions; consequently, client code should use the
161 /// predicate functions defined below, rather than relying on knowledge
162 /// about the particular binary encoding.
164 /// The `Undef` "format" is a special value used internally for cases
165 /// where we need to store an undefined or indeterminate `DimLevelType`.
166 /// It should not be used externally, since it does not indicate an
167 /// actual/representable format.
168 enum class DimLevelType : uint8_t {
169 Undef = 0, // 0b00000_00
170 Dense = 4, // 0b00001_00
171 Compressed = 8, // 0b00010_00
172 CompressedNu = 9, // 0b00010_01
173 CompressedNo = 10, // 0b00010_10
174 CompressedNuNo = 11, // 0b00010_11
175 Singleton = 16, // 0b00100_00
176 SingletonNu = 17, // 0b00100_01
177 SingletonNo = 18, // 0b00100_10
178 SingletonNuNo = 19, // 0b00100_11
179 LooseCompressed = 32, // 0b01000_00
180 LooseCompressedNu = 33, // 0b01000_01
181 LooseCompressedNo = 34, // 0b01000_10
182 LooseCompressedNuNo = 35, // 0b01000_11
183 TwoOutOfFour = 64, // 0b10000_00
186 /// This enum defines all supported storage format without the level properties.
187 enum class LevelFormat : uint8_t {
188 Dense = 4, // 0b00001_00
189 Compressed = 8, // 0b00010_00
190 Singleton = 16, // 0b00100_00
191 LooseCompressed = 32, // 0b01000_00
192 TwoOutOfFour = 64, // 0b10000_00
195 /// This enum defines all the nondefault properties for storage formats.
196 enum class LevelNondefaultProperty : uint8_t {
197 Nonunique = 1, // 0b00000_01
198 Nonordered = 2, // 0b00000_10
201 /// Returns string representation of the given dimension level type.
202 constexpr const char *toMLIRString(DimLevelType dlt) {
203 switch (dlt) {
204 case DimLevelType::Undef:
205 return "undef";
206 case DimLevelType::Dense:
207 return "dense";
208 case DimLevelType::Compressed:
209 return "compressed";
210 case DimLevelType::CompressedNu:
211 return "compressed(nonunique)";
212 case DimLevelType::CompressedNo:
213 return "compressed(nonordered)";
214 case DimLevelType::CompressedNuNo:
215 return "compressed(nonunique, nonordered)";
216 case DimLevelType::Singleton:
217 return "singleton";
218 case DimLevelType::SingletonNu:
219 return "singleton(nonunique)";
220 case DimLevelType::SingletonNo:
221 return "singleton(nonordered)";
222 case DimLevelType::SingletonNuNo:
223 return "singleton(nonunique, nonordered)";
224 case DimLevelType::LooseCompressed:
225 return "loose_compressed";
226 case DimLevelType::LooseCompressedNu:
227 return "loose_compressed(nonunique)";
228 case DimLevelType::LooseCompressedNo:
229 return "loose_compressed(nonordered)";
230 case DimLevelType::LooseCompressedNuNo:
231 return "loose_compressed(nonunique, nonordered)";
232 case DimLevelType::TwoOutOfFour:
233 return "block2_4";
235 return "";
238 /// Check that the `DimLevelType` contains a valid (possibly undefined) value.
239 constexpr bool isValidDLT(DimLevelType dlt) {
240 const uint8_t formatBits = static_cast<uint8_t>(dlt) >> 2;
241 const uint8_t propertyBits = static_cast<uint8_t>(dlt) & 3;
242 // If undefined or dense, then must be unique and ordered.
243 // Otherwise, the format must be one of the known ones.
244 return (formatBits <= 1 || formatBits == 16)
245 ? (propertyBits == 0)
246 : (formatBits == 2 || formatBits == 4 || formatBits == 8);
249 /// Check if the `DimLevelType` is the special undefined value.
250 constexpr bool isUndefDLT(DimLevelType dlt) {
251 return dlt == DimLevelType::Undef;
254 /// Check if the `DimLevelType` is dense (regardless of properties).
255 constexpr bool isDenseDLT(DimLevelType dlt) {
256 return (static_cast<uint8_t>(dlt) & ~3) ==
257 static_cast<uint8_t>(DimLevelType::Dense);
260 /// Check if the `DimLevelType` is compressed (regardless of properties).
261 constexpr bool isCompressedDLT(DimLevelType dlt) {
262 return (static_cast<uint8_t>(dlt) & ~3) ==
263 static_cast<uint8_t>(DimLevelType::Compressed);
266 /// Check if the `DimLevelType` is singleton (regardless of properties).
267 constexpr bool isSingletonDLT(DimLevelType dlt) {
268 return (static_cast<uint8_t>(dlt) & ~3) ==
269 static_cast<uint8_t>(DimLevelType::Singleton);
272 /// Check if the `DimLevelType` is loose compressed (regardless of properties).
273 constexpr bool isLooseCompressedDLT(DimLevelType dlt) {
274 return (static_cast<uint8_t>(dlt) & ~3) ==
275 static_cast<uint8_t>(DimLevelType::LooseCompressed);
278 /// Check if the `DimLevelType` is 2OutOf4 (regardless of properties).
279 constexpr bool is2OutOf4DLT(DimLevelType dlt) {
280 return (static_cast<uint8_t>(dlt) & ~3) ==
281 static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
284 /// Check if the `DimLevelType` needs positions array.
285 constexpr bool isDLTWithPos(DimLevelType dlt) {
286 return isCompressedDLT(dlt) || isLooseCompressedDLT(dlt);
289 /// Check if the `DimLevelType` needs coordinates array.
290 constexpr bool isDLTWithCrd(DimLevelType dlt) {
291 return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
292 isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
295 /// Check if the `DimLevelType` is ordered (regardless of storage format).
296 constexpr bool isOrderedDLT(DimLevelType dlt) {
297 return !(static_cast<uint8_t>(dlt) & 2);
300 /// Check if the `DimLevelType` is unique (regardless of storage format).
301 constexpr bool isUniqueDLT(DimLevelType dlt) {
302 return !(static_cast<uint8_t>(dlt) & 1);
305 /// Convert a DimLevelType to its corresponding LevelFormat.
306 /// Returns std::nullopt when input dlt is Undef.
307 constexpr std::optional<LevelFormat> getLevelFormat(DimLevelType dlt) {
308 if (dlt == DimLevelType::Undef)
309 return std::nullopt;
310 return static_cast<LevelFormat>(static_cast<uint8_t>(dlt) & ~3);
313 /// Convert a LevelFormat to its corresponding DimLevelType with the given
314 /// properties. Returns std::nullopt when the properties are not applicable for
315 /// the input level format.
316 /// TODO: factor out a new LevelProperties type so we can add new properties
317 /// without changing this function's signature
318 constexpr std::optional<DimLevelType>
319 buildLevelType(LevelFormat lf, bool ordered, bool unique) {
320 auto dlt = static_cast<DimLevelType>(static_cast<uint8_t>(lf) |
321 (ordered ? 0 : 2) | (unique ? 0 : 1));
322 return isValidDLT(dlt) ? std::optional(dlt) : std::nullopt;
326 // Ensure the above methods work as indended.
329 static_assert(
330 (getLevelFormat(DimLevelType::Undef) == std::nullopt &&
331 *getLevelFormat(DimLevelType::Dense) == LevelFormat::Dense &&
332 *getLevelFormat(DimLevelType::Compressed) == LevelFormat::Compressed &&
333 *getLevelFormat(DimLevelType::CompressedNu) == LevelFormat::Compressed &&
334 *getLevelFormat(DimLevelType::CompressedNo) == LevelFormat::Compressed &&
335 *getLevelFormat(DimLevelType::CompressedNuNo) == LevelFormat::Compressed &&
336 *getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
337 *getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
338 *getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
339 *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton &&
340 *getLevelFormat(DimLevelType::LooseCompressed) ==
341 LevelFormat::LooseCompressed &&
342 *getLevelFormat(DimLevelType::LooseCompressedNu) ==
343 LevelFormat::LooseCompressed &&
344 *getLevelFormat(DimLevelType::LooseCompressedNo) ==
345 LevelFormat::LooseCompressed &&
346 *getLevelFormat(DimLevelType::LooseCompressedNuNo) ==
347 LevelFormat::LooseCompressed &&
348 *getLevelFormat(DimLevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
349 "getLevelFormat conversion is broken");
351 static_assert(
352 (buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
353 buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
354 buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
355 *buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
356 *buildLevelType(LevelFormat::Compressed, true, true) ==
357 DimLevelType::Compressed &&
358 *buildLevelType(LevelFormat::Compressed, true, false) ==
359 DimLevelType::CompressedNu &&
360 *buildLevelType(LevelFormat::Compressed, false, true) ==
361 DimLevelType::CompressedNo &&
362 *buildLevelType(LevelFormat::Compressed, false, false) ==
363 DimLevelType::CompressedNuNo &&
364 *buildLevelType(LevelFormat::Singleton, true, true) ==
365 DimLevelType::Singleton &&
366 *buildLevelType(LevelFormat::Singleton, true, false) ==
367 DimLevelType::SingletonNu &&
368 *buildLevelType(LevelFormat::Singleton, false, true) ==
369 DimLevelType::SingletonNo &&
370 *buildLevelType(LevelFormat::Singleton, false, false) ==
371 DimLevelType::SingletonNuNo &&
372 *buildLevelType(LevelFormat::LooseCompressed, true, true) ==
373 DimLevelType::LooseCompressed &&
374 *buildLevelType(LevelFormat::LooseCompressed, true, false) ==
375 DimLevelType::LooseCompressedNu &&
376 *buildLevelType(LevelFormat::LooseCompressed, false, true) ==
377 DimLevelType::LooseCompressedNo &&
378 *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
379 DimLevelType::LooseCompressedNuNo &&
380 buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
381 buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
382 buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
383 *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
384 DimLevelType::TwoOutOfFour),
385 "buildLevelType conversion is broken");
387 static_assert((isValidDLT(DimLevelType::Undef) &&
388 isValidDLT(DimLevelType::Dense) &&
389 isValidDLT(DimLevelType::Compressed) &&
390 isValidDLT(DimLevelType::CompressedNu) &&
391 isValidDLT(DimLevelType::CompressedNo) &&
392 isValidDLT(DimLevelType::CompressedNuNo) &&
393 isValidDLT(DimLevelType::Singleton) &&
394 isValidDLT(DimLevelType::SingletonNu) &&
395 isValidDLT(DimLevelType::SingletonNo) &&
396 isValidDLT(DimLevelType::SingletonNuNo) &&
397 isValidDLT(DimLevelType::LooseCompressed) &&
398 isValidDLT(DimLevelType::LooseCompressedNu) &&
399 isValidDLT(DimLevelType::LooseCompressedNo) &&
400 isValidDLT(DimLevelType::LooseCompressedNuNo) &&
401 isValidDLT(DimLevelType::TwoOutOfFour)),
402 "isValidDLT definition is broken");
404 static_assert((isDenseDLT(DimLevelType::Dense) &&
405 !isDenseDLT(DimLevelType::Compressed) &&
406 !isDenseDLT(DimLevelType::CompressedNu) &&
407 !isDenseDLT(DimLevelType::CompressedNo) &&
408 !isDenseDLT(DimLevelType::CompressedNuNo) &&
409 !isDenseDLT(DimLevelType::Singleton) &&
410 !isDenseDLT(DimLevelType::SingletonNu) &&
411 !isDenseDLT(DimLevelType::SingletonNo) &&
412 !isDenseDLT(DimLevelType::SingletonNuNo) &&
413 !isDenseDLT(DimLevelType::LooseCompressed) &&
414 !isDenseDLT(DimLevelType::LooseCompressedNu) &&
415 !isDenseDLT(DimLevelType::LooseCompressedNo) &&
416 !isDenseDLT(DimLevelType::LooseCompressedNuNo) &&
417 !isDenseDLT(DimLevelType::TwoOutOfFour)),
418 "isDenseDLT definition is broken");
420 static_assert((!isCompressedDLT(DimLevelType::Dense) &&
421 isCompressedDLT(DimLevelType::Compressed) &&
422 isCompressedDLT(DimLevelType::CompressedNu) &&
423 isCompressedDLT(DimLevelType::CompressedNo) &&
424 isCompressedDLT(DimLevelType::CompressedNuNo) &&
425 !isCompressedDLT(DimLevelType::Singleton) &&
426 !isCompressedDLT(DimLevelType::SingletonNu) &&
427 !isCompressedDLT(DimLevelType::SingletonNo) &&
428 !isCompressedDLT(DimLevelType::SingletonNuNo) &&
429 !isCompressedDLT(DimLevelType::LooseCompressed) &&
430 !isCompressedDLT(DimLevelType::LooseCompressedNu) &&
431 !isCompressedDLT(DimLevelType::LooseCompressedNo) &&
432 !isCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
433 !isCompressedDLT(DimLevelType::TwoOutOfFour)),
434 "isCompressedDLT definition is broken");
436 static_assert((!isSingletonDLT(DimLevelType::Dense) &&
437 !isSingletonDLT(DimLevelType::Compressed) &&
438 !isSingletonDLT(DimLevelType::CompressedNu) &&
439 !isSingletonDLT(DimLevelType::CompressedNo) &&
440 !isSingletonDLT(DimLevelType::CompressedNuNo) &&
441 isSingletonDLT(DimLevelType::Singleton) &&
442 isSingletonDLT(DimLevelType::SingletonNu) &&
443 isSingletonDLT(DimLevelType::SingletonNo) &&
444 isSingletonDLT(DimLevelType::SingletonNuNo) &&
445 !isSingletonDLT(DimLevelType::LooseCompressed) &&
446 !isSingletonDLT(DimLevelType::LooseCompressedNu) &&
447 !isSingletonDLT(DimLevelType::LooseCompressedNo) &&
448 !isSingletonDLT(DimLevelType::LooseCompressedNuNo) &&
449 !isSingletonDLT(DimLevelType::TwoOutOfFour)),
450 "isSingletonDLT definition is broken");
452 static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
453 !isLooseCompressedDLT(DimLevelType::Compressed) &&
454 !isLooseCompressedDLT(DimLevelType::CompressedNu) &&
455 !isLooseCompressedDLT(DimLevelType::CompressedNo) &&
456 !isLooseCompressedDLT(DimLevelType::CompressedNuNo) &&
457 !isLooseCompressedDLT(DimLevelType::Singleton) &&
458 !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
459 !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
460 !isLooseCompressedDLT(DimLevelType::SingletonNuNo) &&
461 isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
462 isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
463 isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
464 isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
465 !isLooseCompressedDLT(DimLevelType::TwoOutOfFour)),
466 "isLooseCompressedDLT definition is broken");
468 static_assert((!is2OutOf4DLT(DimLevelType::Dense) &&
469 !is2OutOf4DLT(DimLevelType::Compressed) &&
470 !is2OutOf4DLT(DimLevelType::CompressedNu) &&
471 !is2OutOf4DLT(DimLevelType::CompressedNo) &&
472 !is2OutOf4DLT(DimLevelType::CompressedNuNo) &&
473 !is2OutOf4DLT(DimLevelType::Singleton) &&
474 !is2OutOf4DLT(DimLevelType::SingletonNu) &&
475 !is2OutOf4DLT(DimLevelType::SingletonNo) &&
476 !is2OutOf4DLT(DimLevelType::SingletonNuNo) &&
477 !is2OutOf4DLT(DimLevelType::LooseCompressed) &&
478 !is2OutOf4DLT(DimLevelType::LooseCompressedNu) &&
479 !is2OutOf4DLT(DimLevelType::LooseCompressedNo) &&
480 !is2OutOf4DLT(DimLevelType::LooseCompressedNuNo) &&
481 is2OutOf4DLT(DimLevelType::TwoOutOfFour)),
482 "is2OutOf4DLT definition is broken");
484 static_assert((isOrderedDLT(DimLevelType::Dense) &&
485 isOrderedDLT(DimLevelType::Compressed) &&
486 isOrderedDLT(DimLevelType::CompressedNu) &&
487 !isOrderedDLT(DimLevelType::CompressedNo) &&
488 !isOrderedDLT(DimLevelType::CompressedNuNo) &&
489 isOrderedDLT(DimLevelType::Singleton) &&
490 isOrderedDLT(DimLevelType::SingletonNu) &&
491 !isOrderedDLT(DimLevelType::SingletonNo) &&
492 !isOrderedDLT(DimLevelType::SingletonNuNo) &&
493 isOrderedDLT(DimLevelType::LooseCompressed) &&
494 isOrderedDLT(DimLevelType::LooseCompressedNu) &&
495 !isOrderedDLT(DimLevelType::LooseCompressedNo) &&
496 !isOrderedDLT(DimLevelType::LooseCompressedNuNo) &&
497 isOrderedDLT(DimLevelType::TwoOutOfFour)),
498 "isOrderedDLT definition is broken");
500 static_assert((isUniqueDLT(DimLevelType::Dense) &&
501 isUniqueDLT(DimLevelType::Compressed) &&
502 !isUniqueDLT(DimLevelType::CompressedNu) &&
503 isUniqueDLT(DimLevelType::CompressedNo) &&
504 !isUniqueDLT(DimLevelType::CompressedNuNo) &&
505 isUniqueDLT(DimLevelType::Singleton) &&
506 !isUniqueDLT(DimLevelType::SingletonNu) &&
507 isUniqueDLT(DimLevelType::SingletonNo) &&
508 !isUniqueDLT(DimLevelType::SingletonNuNo) &&
509 isUniqueDLT(DimLevelType::LooseCompressed) &&
510 !isUniqueDLT(DimLevelType::LooseCompressedNu) &&
511 isUniqueDLT(DimLevelType::LooseCompressedNo) &&
512 !isUniqueDLT(DimLevelType::LooseCompressedNuNo) &&
513 isUniqueDLT(DimLevelType::TwoOutOfFour)),
514 "isUniqueDLT definition is broken");
516 /// Bit manipulations for affine encoding.
518 /// Note that because the indices in the mappings refer to dimensions
519 /// and levels (and *not* the sizes of these dimensions and levels), the
520 /// 64-bit encoding gives ample room for a compact encoding of affine
521 /// operations in the higher bits. Pure permutations still allow for
522 /// 60-bit indices. But non-permutations reserve 20-bits for the
523 /// potential three components (index i, constant, index ii).
525 /// The compact encoding is as follows:
527 /// 0xffffffffffffffff
528 /// |0000 | 60-bit idx| e.g. i
529 /// |0001 floor| 20-bit const|20-bit idx| e.g. i floor c
530 /// |0010 mod | 20-bit const|20-bit idx| e.g. i mod c
531 /// |0011 mul |20-bit idx|20-bit const|20-bit idx| e.g. i + c * ii
533 /// This encoding provides sufficient generality for currently supported
534 /// sparse tensor types. To generalize this more, we will need to provide
535 /// a broader encoding scheme for affine functions. Also, the library
536 /// encoding may be replaced with pure "direct-IR" code in the future.
538 constexpr uint64_t encodeDim(uint64_t i, uint64_t cf, uint64_t cm) {
539 if (cf != 0) {
540 assert(cf <= 0xfffff && cm == 0 && i <= 0xfffff);
541 return (0x01L << 60) | (cf << 20) | i;
543 if (cm != 0) {
544 assert(cm <= 0xfffff && i <= 0xfffff);
545 return (0x02L << 60) | (cm << 20) | i;
547 assert(i <= 0x0fffffffffffffffu);
548 return i;
550 constexpr uint64_t encodeLvl(uint64_t i, uint64_t c, uint64_t ii) {
551 if (c != 0) {
552 assert(c <= 0xfffff && ii <= 0xfffff && i <= 0xfffff);
553 return (0x03L << 60) | (c << 20) | (ii << 40) | i;
555 assert(i <= 0x0fffffffffffffffu);
556 return i;
558 constexpr bool isEncodedFloor(uint64_t v) { return (v >> 60) == 0x01; }
559 constexpr bool isEncodedMod(uint64_t v) { return (v >> 60) == 0x02; }
560 constexpr bool isEncodedMul(uint64_t v) { return (v >> 60) == 0x03; }
561 constexpr uint64_t decodeIndex(uint64_t v) { return v & 0xfffffu; }
562 constexpr uint64_t decodeConst(uint64_t v) { return (v >> 20) & 0xfffffu; }
563 constexpr uint64_t decodeMulc(uint64_t v) { return (v >> 20) & 0xfffffu; }
564 constexpr uint64_t decodeMuli(uint64_t v) { return (v >> 40) & 0xfffffu; }
566 } // namespace sparse_tensor
567 } // namespace mlir
569 #endif // MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H