1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "Detail/DimLvlMapParser.h"
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Utils/StaticValueUtils.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
28 #define GET_ATTRDEF_CLASSES
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
36 using namespace mlir::sparse_tensor
;
38 //===----------------------------------------------------------------------===//
39 // Local convenience methods.
40 //===----------------------------------------------------------------------===//
42 static constexpr bool acceptBitWidth(unsigned bitWidth
) {
55 //===----------------------------------------------------------------------===//
56 // SparseTensorDialect StorageLayout.
57 //===----------------------------------------------------------------------===//
59 static constexpr Level kInvalidLevel
= -1u;
60 static constexpr Level kInvalidFieldIndex
= -1u;
61 static constexpr FieldIndex kDataFieldStartingIdx
= 0;
63 void StorageLayout::foreachField(
64 llvm::function_ref
<bool(FieldIndex
, SparseTensorFieldKind
, Level
,
67 const auto lvlTypes
= enc
.getLvlTypes();
68 const Level lvlRank
= enc
.getLvlRank();
69 const Level cooStart
= getCOOStart(enc
);
70 const Level end
= cooStart
== lvlRank
? cooStart
: cooStart
+ 1;
71 FieldIndex fieldIdx
= kDataFieldStartingIdx
;
73 for (Level l
= 0; l
< end
; l
++) {
74 const auto dlt
= lvlTypes
[l
];
75 if (isWithPosDLT(dlt
)) {
76 if (!(callback(fieldIdx
++, SparseTensorFieldKind::PosMemRef
, l
, dlt
)))
79 if (isWithCrdDLT(dlt
)) {
80 if (!(callback(fieldIdx
++, SparseTensorFieldKind::CrdMemRef
, l
, dlt
)))
85 if (!(callback(fieldIdx
++, SparseTensorFieldKind::ValMemRef
, kInvalidLevel
,
86 DimLevelType::Undef
)))
88 // Put metadata at the end.
89 if (!(callback(fieldIdx
++, SparseTensorFieldKind::StorageSpec
, kInvalidLevel
,
90 DimLevelType::Undef
)))
94 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
96 llvm::function_ref
<bool(Type
, FieldIndex
, SparseTensorFieldKind
, Level
,
99 assert(stt
.hasEncoding());
100 // Construct the basic types.
101 const Type crdType
= stt
.getCrdType();
102 const Type posType
= stt
.getPosType();
103 const Type eltType
= stt
.getElementType();
105 const Type specType
= StorageSpecifierType::get(stt
.getEncoding());
106 // memref<? x pos> positions
107 const Type posMemType
= MemRefType::get({ShapedType::kDynamic
}, posType
);
108 // memref<? x crd> coordinates
109 const Type crdMemType
= MemRefType::get({ShapedType::kDynamic
}, crdType
);
110 // memref<? x eltType> values
111 const Type valMemType
= MemRefType::get({ShapedType::kDynamic
}, eltType
);
113 StorageLayout(stt
).foreachField(
114 [specType
, posMemType
, crdMemType
, valMemType
,
115 callback
](FieldIndex fieldIdx
, SparseTensorFieldKind fieldKind
,
116 Level lvl
, DimLevelType dlt
) -> bool {
118 case SparseTensorFieldKind::StorageSpec
:
119 return callback(specType
, fieldIdx
, fieldKind
, lvl
, dlt
);
120 case SparseTensorFieldKind::PosMemRef
:
121 return callback(posMemType
, fieldIdx
, fieldKind
, lvl
, dlt
);
122 case SparseTensorFieldKind::CrdMemRef
:
123 return callback(crdMemType
, fieldIdx
, fieldKind
, lvl
, dlt
);
124 case SparseTensorFieldKind::ValMemRef
:
125 return callback(valMemType
, fieldIdx
, fieldKind
, lvl
, dlt
);
127 llvm_unreachable("unrecognized field kind");
131 unsigned StorageLayout::getNumFields() const {
132 unsigned numFields
= 0;
133 foreachField([&numFields
](FieldIndex
, SparseTensorFieldKind
, Level
,
134 DimLevelType
) -> bool {
141 unsigned StorageLayout::getNumDataFields() const {
142 unsigned numFields
= 0; // one value memref
143 foreachField([&numFields
](FieldIndex fidx
, SparseTensorFieldKind
, Level
,
144 DimLevelType
) -> bool {
145 if (fidx
>= kDataFieldStartingIdx
)
149 numFields
-= 1; // the last field is StorageSpecifier
150 assert(numFields
== getNumFields() - kDataFieldStartingIdx
- 1);
154 std::pair
<FieldIndex
, unsigned>
155 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind
,
156 std::optional
<Level
> lvl
) const {
157 FieldIndex fieldIdx
= kInvalidFieldIndex
;
159 if (kind
== SparseTensorFieldKind::CrdMemRef
) {
160 assert(lvl
.has_value());
161 const Level cooStart
= getCOOStart(enc
);
162 const Level lvlRank
= enc
.getLvlRank();
163 if (lvl
.value() >= cooStart
&& lvl
.value() < lvlRank
) {
165 stride
= lvlRank
- cooStart
;
168 foreachField([lvl
, kind
, &fieldIdx
](FieldIndex fIdx
,
169 SparseTensorFieldKind fKind
, Level fLvl
,
170 DimLevelType dlt
) -> bool {
171 if ((lvl
&& fLvl
== lvl
.value() && kind
== fKind
) ||
172 (kind
== fKind
&& fKind
== SparseTensorFieldKind::ValMemRef
)) {
174 // Returns false to break the iteration.
179 assert(fieldIdx
!= kInvalidFieldIndex
);
180 return std::pair
<FieldIndex
, unsigned>(fieldIdx
, stride
);
183 //===----------------------------------------------------------------------===//
184 // SparseTensorDialect Attribute Methods.
185 //===----------------------------------------------------------------------===//
187 std::optional
<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v
) {
188 return isDynamic(v
) ? std::nullopt
189 : std::make_optional(static_cast<uint64_t>(v
));
192 std::optional
<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
193 return getStatic(getOffset());
196 std::optional
<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
197 return getStatic(getStride());
200 std::optional
<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
201 return getStatic(getSize());
204 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
205 return isDynamic(getOffset()) && isDynamic(getStride()) &&
206 isDynamic(getSize());
209 std::string
SparseTensorDimSliceAttr::getStaticString(int64_t v
) {
210 return isDynamic(v
) ? "?" : std::to_string(v
);
213 void SparseTensorDimSliceAttr::print(llvm::raw_ostream
&os
) const {
214 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
216 os
<< getStaticString(getOffset());
218 os
<< getStaticString(getSize());
220 os
<< getStaticString(getStride());
224 void SparseTensorDimSliceAttr::print(AsmPrinter
&printer
) const {
225 print(printer
.getStream());
228 static ParseResult
parseOptionalStaticSlice(int64_t &result
,
230 auto parseResult
= parser
.parseOptionalInteger(result
);
231 if (parseResult
.has_value()) {
232 if (parseResult
.value().succeeded() && result
< 0) {
234 parser
.getCurrentLocation(),
235 "expect positive value or ? for slice offset/size/stride");
238 return parseResult
.value();
241 // Else, and '?' which represented dynamic slice
242 result
= SparseTensorDimSliceAttr::kDynamic
;
243 return parser
.parseQuestion();
246 Attribute
SparseTensorDimSliceAttr::parse(AsmParser
&parser
, Type type
) {
247 int64_t offset
= kDynamic
, size
= kDynamic
, stride
= kDynamic
;
249 if (failed(parser
.parseLParen()) ||
250 failed(parseOptionalStaticSlice(offset
, parser
)) ||
251 failed(parser
.parseComma()) ||
252 failed(parseOptionalStaticSlice(size
, parser
)) ||
253 failed(parser
.parseComma()) ||
254 failed(parseOptionalStaticSlice(stride
, parser
)) ||
255 failed(parser
.parseRParen()))
258 return parser
.getChecked
<SparseTensorDimSliceAttr
>(parser
.getContext(),
259 offset
, size
, stride
);
263 SparseTensorDimSliceAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
264 int64_t offset
, int64_t size
, int64_t stride
) {
265 if (!isDynamic(offset
) && offset
< 0)
266 return emitError() << "expect non-negative value or ? for slice offset";
267 if (!isDynamic(size
) && size
<= 0)
268 return emitError() << "expect positive value or ? for slice size";
269 if (!isDynamic(stride
) && stride
<= 0)
270 return emitError() << "expect positive value or ? for slice stride";
274 SparseTensorEncodingAttr
275 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl
) const {
276 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
277 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl
,
278 AffineMap(), getPosWidth(),
282 SparseTensorEncodingAttr
283 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc
) const {
284 return withDimToLvl(enc
? enc
.getDimToLvl() : AffineMap());
287 SparseTensorEncodingAttr
SparseTensorEncodingAttr::withoutDimToLvl() const {
288 return withDimToLvl(AffineMap());
291 SparseTensorEncodingAttr
292 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth
,
293 unsigned crdWidth
) const {
294 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
295 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
296 getDimToLvl(), getLvlToDim(), posWidth
,
300 SparseTensorEncodingAttr
SparseTensorEncodingAttr::withoutBitWidths() const {
301 return withBitWidths(0, 0);
304 SparseTensorEncodingAttr
SparseTensorEncodingAttr::withDimSlices(
305 ArrayRef
<SparseTensorDimSliceAttr
> dimSlices
) const {
306 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
307 getDimToLvl(), getLvlToDim(),
308 getPosWidth(), getCrdWidth(), dimSlices
);
311 SparseTensorEncodingAttr
SparseTensorEncodingAttr::withoutDimSlices() const {
312 return withDimSlices(ArrayRef
<SparseTensorDimSliceAttr
>{});
315 bool SparseTensorEncodingAttr::isAllDense() const {
316 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT
);
319 bool SparseTensorEncodingAttr::isCOO() const {
320 return getImpl() && isCOOType(*this, 0, true);
323 bool SparseTensorEncodingAttr::isAllOrdered() const {
324 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedDLT
);
327 bool SparseTensorEncodingAttr::isIdentity() const {
328 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
331 bool SparseTensorEncodingAttr::isPermutation() const {
332 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
335 Dimension
SparseTensorEncodingAttr::getDimRank() const {
336 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
337 const auto dimToLvl
= getDimToLvl();
338 return dimToLvl
? dimToLvl
.getNumDims() : getLvlRank();
341 Level
SparseTensorEncodingAttr::getLvlRank() const {
342 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
343 return getLvlTypes().size();
346 DimLevelType
SparseTensorEncodingAttr::getLvlType(Level l
) const {
348 return DimLevelType::Dense
;
349 assert(l
< getLvlRank() && "Level is out of bounds");
350 return getLvlTypes()[l
];
353 bool SparseTensorEncodingAttr::isSlice() const {
354 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
355 return !getDimSlices().empty();
358 SparseTensorDimSliceAttr
359 SparseTensorEncodingAttr::getDimSlice(Dimension dim
) const {
360 assert(isSlice() && "Is not a slice");
361 const auto dimSlices
= getDimSlices();
362 assert(dim
< dimSlices
.size() && "Dimension is out of bounds");
363 return dimSlices
[dim
];
366 std::optional
<uint64_t>
367 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim
) const {
368 return getDimSlice(dim
).getStaticOffset();
371 std::optional
<uint64_t>
372 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim
) const {
373 return getDimSlice(dim
).getStaticStride();
376 std::optional
<uint64_t>
377 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl
) const {
378 // FIXME: `toOrigDim` is deprecated.
379 return getStaticDimSliceOffset(toOrigDim(*this, lvl
));
382 std::optional
<uint64_t>
383 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl
) const {
384 // FIXME: `toOrigDim` is deprecated.
385 return getStaticDimSliceStride(toOrigDim(*this, lvl
));
389 SparseTensorEncodingAttr::tranlateShape(ArrayRef
<int64_t> srcShape
,
390 CrdTransDirectionKind dir
) const {
392 return SmallVector
<int64_t>(srcShape
);
394 SmallVector
<int64_t> ret
;
396 dir
== CrdTransDirectionKind::dim2lvl
? getLvlRank() : getDimRank();
399 if (isPermutation()) {
400 for (unsigned r
= 0; r
< rank
; r
++) {
401 // FIXME: `toOrigDim` and `toStoredDim` are deprecated.
402 unsigned trans
= dir
== CrdTransDirectionKind::dim2lvl
403 ? toOrigDim(*this, r
)
404 : toStoredDim(*this, r
);
405 ret
.push_back(srcShape
[trans
]);
410 // Handle non-permutation maps.
412 dir
== CrdTransDirectionKind::dim2lvl
? getDimToLvl() : getLvlToDim();
414 SmallVector
<AffineExpr
> dimRep
;
415 dimRep
.reserve(srcShape
.size());
416 for (int64_t sz
: srcShape
) {
417 if (!ShapedType::isDynamic(sz
)) {
418 // Push back the max coordinate for the given dimension/level size.
419 dimRep
.push_back(getAffineConstantExpr(sz
- 1, getContext()));
421 // A dynamic size, use a AffineDimExpr to symbolize the value.
422 dimRep
.push_back(getAffineDimExpr(dimRep
.size(), getContext()));
426 for (AffineExpr exp
: transMap
.getResults()) {
427 // Do constant propagation on the affine map.
429 simplifyAffineExpr(exp
.replaceDims(dimRep
), srcShape
.size(), 0);
430 // use llvm namespace here to avoid ambiguity
431 if (auto c
= llvm::dyn_cast
<AffineConstantExpr
>(evalExp
)) {
432 ret
.push_back(c
.getValue() + 1);
434 if (auto mod
= llvm::dyn_cast
<AffineBinaryOpExpr
>(evalExp
);
435 mod
&& mod
.getKind() == AffineExprKind::Mod
) {
436 // We can still infer a static bound for expressions in form
437 // "d % constant" since d % constant \in [0, constant).
438 if (auto bound
= llvm::dyn_cast
<AffineConstantExpr
>(mod
.getRHS())) {
439 ret
.push_back(bound
.getValue());
443 ret
.push_back(ShapedType::kDynamic
);
446 assert(ret
.size() == rank
);
451 SparseTensorEncodingAttr::translateCrds(OpBuilder
&builder
, Location loc
,
453 CrdTransDirectionKind dir
) const {
457 SmallVector
<Type
> retType(
458 dir
== CrdTransDirectionKind::lvl2dim
? getDimRank() : getLvlRank(),
459 builder
.getIndexType());
460 auto transOp
= builder
.create
<CrdTranslateOp
>(loc
, retType
, crds
, dir
, *this);
461 return transOp
.getOutCrds();
464 Attribute
SparseTensorEncodingAttr::parse(AsmParser
&parser
, Type type
) {
466 if (failed(parser
.parseLess()))
468 if (failed(parser
.parseLBrace()))
471 // Process the data from the parsed dictionary value into struct-like data.
472 SmallVector
<DimLevelType
> lvlTypes
;
473 SmallVector
<SparseTensorDimSliceAttr
> dimSlices
;
474 AffineMap dimToLvl
= {};
475 AffineMap lvlToDim
= {};
476 unsigned posWidth
= 0;
477 unsigned crdWidth
= 0;
479 SmallVector
<StringRef
, 3> keys
= {"map", "posWidth", "crdWidth"};
480 while (succeeded(parser
.parseOptionalKeyword(&attrName
))) {
481 // Detect admissible keyword.
482 auto *it
= find(keys
, attrName
);
483 if (it
== keys
.end()) {
484 parser
.emitError(parser
.getNameLoc(), "unexpected key: ") << attrName
;
487 unsigned keyWordIndex
= it
- keys
.begin();
488 // Consume the `=` after keys
489 if (failed(parser
.parseEqual()))
491 // Dispatch on keyword.
492 switch (keyWordIndex
) {
494 ir_detail::DimLvlMapParser
cParser(parser
);
495 auto res
= cParser
.parseDimLvlMap();
498 const auto &dlm
= *res
;
500 const Level lvlRank
= dlm
.getLvlRank();
501 for (Level lvl
= 0; lvl
< lvlRank
; lvl
++)
502 lvlTypes
.push_back(dlm
.getLvlType(lvl
));
504 const Dimension dimRank
= dlm
.getDimRank();
505 for (Dimension dim
= 0; dim
< dimRank
; dim
++)
506 dimSlices
.push_back(dlm
.getDimSlice(dim
));
507 // NOTE: the old syntax requires an all-or-nothing approach to
508 // `dimSlices`; therefore, if any slice actually exists then we need
509 // to convert null-DSA into default/nop DSA.
510 const auto isDefined
= [](SparseTensorDimSliceAttr slice
) {
511 return static_cast<bool>(slice
.getImpl());
513 if (llvm::any_of(dimSlices
, isDefined
)) {
514 const auto defaultSlice
=
515 SparseTensorDimSliceAttr::get(parser
.getContext());
516 for (Dimension dim
= 0; dim
< dimRank
; dim
++)
517 if (!isDefined(dimSlices
[dim
]))
518 dimSlices
[dim
] = defaultSlice
;
523 dimToLvl
= dlm
.getDimToLvlMap(parser
.getContext());
524 lvlToDim
= dlm
.getLvlToDimMap(parser
.getContext());
527 case 1: { // posWidth
529 if (failed(parser
.parseAttribute(attr
)))
531 auto intAttr
= llvm::dyn_cast
<IntegerAttr
>(attr
);
533 parser
.emitError(parser
.getNameLoc(),
534 "expected an integral position bitwidth");
537 posWidth
= intAttr
.getInt();
540 case 2: { // crdWidth
542 if (failed(parser
.parseAttribute(attr
)))
544 auto intAttr
= llvm::dyn_cast
<IntegerAttr
>(attr
);
546 parser
.emitError(parser
.getNameLoc(),
547 "expected an integral index bitwidth");
550 crdWidth
= intAttr
.getInt();
554 // Only last item can omit the comma.
555 if (parser
.parseOptionalComma().failed())
560 if (failed(parser
.parseRBrace()))
562 if (failed(parser
.parseGreater()))
565 // Construct struct-like storage for attribute.
566 if (!lvlToDim
|| lvlToDim
.isEmpty()) {
567 lvlToDim
= inferLvlToDim(dimToLvl
, parser
.getContext());
569 return parser
.getChecked
<SparseTensorEncodingAttr
>(
570 parser
.getContext(), lvlTypes
, dimToLvl
, lvlToDim
, posWidth
, crdWidth
,
574 void SparseTensorEncodingAttr::print(AsmPrinter
&printer
) const {
575 auto map
= static_cast<AffineMap
>(getDimToLvl());
576 // Empty affine map indicates identity map
578 map
= AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
579 printer
<< "<{ map = ";
580 printSymbols(map
, printer
);
582 printDimensions(map
, printer
, getDimSlices());
584 printLevels(map
, printer
, getLvlTypes());
586 // Print remaining members only for non-default values.
588 printer
<< ", posWidth = " << getPosWidth();
590 printer
<< ", crdWidth = " << getCrdWidth();
594 void SparseTensorEncodingAttr::printSymbols(AffineMap
&map
,
595 AsmPrinter
&printer
) const {
596 if (map
.getNumSymbols() == 0)
599 for (unsigned i
= 0, n
= map
.getNumSymbols() - 1; i
< n
; i
++)
600 printer
<< 's' << i
<< ", ";
601 if (map
.getNumSymbols() >= 1)
602 printer
<< 's' << map
.getNumSymbols() - 1;
606 void SparseTensorEncodingAttr::printDimensions(
607 AffineMap
&map
, AsmPrinter
&printer
,
608 ArrayRef
<SparseTensorDimSliceAttr
> dimSlices
) const {
609 if (!dimSlices
.empty()) {
610 for (unsigned i
= 0, n
= map
.getNumDims() - 1; i
< n
; i
++)
611 printer
<< 'd' << i
<< " : " << dimSlices
[i
] << ", ";
612 if (map
.getNumDims() >= 1) {
613 printer
<< 'd' << map
.getNumDims() - 1 << " : "
614 << dimSlices
[map
.getNumDims() - 1];
617 for (unsigned i
= 0, n
= map
.getNumDims() - 1; i
< n
; i
++)
618 printer
<< 'd' << i
<< ", ";
619 if (map
.getNumDims() >= 1)
620 printer
<< 'd' << map
.getNumDims() - 1;
624 void SparseTensorEncodingAttr::printLevels(
625 AffineMap
&map
, AsmPrinter
&printer
,
626 ArrayRef
<DimLevelType
> lvlTypes
) const {
627 for (unsigned i
= 0, n
= map
.getNumResults() - 1; i
< n
; i
++) {
628 map
.getResult(i
).print(printer
.getStream());
629 printer
<< " : " << toMLIRString(lvlTypes
[i
]) << ", ";
631 if (map
.getNumResults() >= 1) {
632 auto lastIndex
= map
.getNumResults() - 1;
633 map
.getResult(lastIndex
).print(printer
.getStream());
634 printer
<< " : " << toMLIRString(lvlTypes
[lastIndex
]);
639 SparseTensorEncodingAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
640 ArrayRef
<DimLevelType
> lvlTypes
,
641 AffineMap dimToLvl
, AffineMap lvlToDim
,
642 unsigned posWidth
, unsigned crdWidth
,
643 ArrayRef
<SparseTensorDimSliceAttr
> dimSlices
) {
644 if (!acceptBitWidth(posWidth
))
645 return emitError() << "unexpected position bitwidth: " << posWidth
;
646 if (!acceptBitWidth(crdWidth
))
647 return emitError() << "unexpected coordinate bitwidth: " << crdWidth
;
648 if (auto it
= std::find_if(lvlTypes
.begin(), lvlTypes
.end(), isSingletonDLT
);
649 it
!= std::end(lvlTypes
)) {
650 if (it
== lvlTypes
.begin() ||
651 (!isCompressedDLT(*(it
- 1)) && !isLooseCompressedDLT(*(it
- 1))))
652 return emitError() << "expected compressed or loose_compressed level "
653 "before singleton level";
654 if (!std::all_of(it
, lvlTypes
.end(),
655 [](DimLevelType i
) { return isSingletonDLT(i
); }))
656 return emitError() << "expected all singleton lvlTypes "
657 "following a singleton level";
659 // Before we can check that the level-rank is consistent/coherent
660 // across all fields, we need to define it. The source-of-truth for
661 // the `getLvlRank` method is the length of the level-types array,
662 // since it must always be provided and have full rank; therefore we
663 // use that same source-of-truth here.
664 const Level lvlRank
= lvlTypes
.size();
666 return emitError() << "expected a non-empty array for lvlTypes";
667 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
668 const Dimension dimRank
= dimToLvl
? dimToLvl
.getNumDims() : lvlRank
;
670 if (dimToLvl
.getNumResults() != lvlRank
)
672 << "level-rank mismatch between dimToLvl and lvlTypes: "
673 << dimToLvl
.getNumResults() << " != " << lvlRank
;
674 auto inferRes
= inferLvlToDim(dimToLvl
, dimToLvl
.getContext());
675 // Symbols can't be inferred but are acceptable.
676 if (!inferRes
&& dimToLvl
.getNumSymbols() == 0)
677 return emitError() << "failed to infer lvlToDim from dimToLvl";
678 if (lvlToDim
&& (inferRes
!= lvlToDim
))
679 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
680 if (dimRank
> lvlRank
)
681 return emitError() << "unexpected dimToLvl mapping from " << dimRank
682 << " to " << lvlRank
;
684 if (!dimSlices
.empty()) {
685 if (dimSlices
.size() != dimRank
)
687 << "dimension-rank mismatch between dimSlices and dimToLvl: "
688 << dimSlices
.size() << " != " << dimRank
;
689 // Compiler support for `dimSlices` currently requires that the two
690 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
691 if (dimRank
!= lvlRank
)
693 << "dimSlices expected dimension-rank to match level-rank: "
694 << dimRank
<< " != " << lvlRank
;
699 LogicalResult
SparseTensorEncodingAttr::verifyEncoding(
700 ArrayRef
<Size
> dimShape
, Type elementType
,
701 function_ref
<InFlightDiagnostic()> emitError
) const {
702 // Check structural integrity. In particular, this ensures that the
703 // level-rank is coherent across all the fields.
704 if (failed(verify(emitError
, getLvlTypes(), getDimToLvl(), getLvlToDim(),
705 getPosWidth(), getCrdWidth(), getDimSlices())))
707 // Check integrity with tensor type specifics. In particular, we
708 // need only check that the dimension-rank of the tensor agrees with
709 // the dimension-rank of the encoding.
710 const Dimension dimRank
= dimShape
.size();
712 return emitError() << "expected non-scalar sparse tensor";
713 if (getDimRank() != dimRank
)
715 << "dimension-rank mismatch between encoding and tensor shape: "
716 << getDimRank() << " != " << dimRank
;
720 //===----------------------------------------------------------------------===//
721 // Convenience methods.
722 //===----------------------------------------------------------------------===//
724 SparseTensorEncodingAttr
725 mlir::sparse_tensor::getSparseTensorEncoding(Type type
) {
726 if (auto ttp
= llvm::dyn_cast
<RankedTensorType
>(type
))
727 return llvm::dyn_cast_or_null
<SparseTensorEncodingAttr
>(ttp
.getEncoding());
728 if (auto mdtp
= llvm::dyn_cast
<StorageSpecifierType
>(type
))
729 return mdtp
.getEncoding();
733 AffineMap
mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl
,
734 MLIRContext
*context
) {
735 auto map
= static_cast<AffineMap
>(dimToLvl
);
737 // Return an empty lvlToDim when inference is not successful.
738 if (!map
|| map
.getNumSymbols() != 0) {
739 lvlToDim
= AffineMap();
740 } else if (map
.isPermutation()) {
741 lvlToDim
= inversePermutation(map
);
742 } else if (isBlockSparsity(map
)) {
743 lvlToDim
= inverseBlockSparsity(map
, context
);
748 AffineMap
mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl
,
749 MLIRContext
*context
) {
750 SmallVector
<AffineExpr
> lvlExprs
;
751 auto numLvls
= dimToLvl
.getNumResults();
752 lvlExprs
.reserve(numLvls
);
753 // lvlExprComponents stores information of the floordiv and mod operations
754 // applied to the same dimension, so as to build the lvlToDim map.
755 std::map
<unsigned, SmallVector
<AffineExpr
, 3>> lvlExprComponents
;
756 for (unsigned i
= 0, n
= numLvls
; i
< n
; i
++) {
757 auto result
= dimToLvl
.getResult(i
);
758 if (auto binOp
= dyn_cast
<AffineBinaryOpExpr
>(result
)) {
759 if (result
.getKind() == AffineExprKind::FloorDiv
) {
760 // Position of the dimension in dimToLvl.
761 auto pos
= dyn_cast
<AffineDimExpr
>(binOp
.getLHS()).getPosition();
762 assert(lvlExprComponents
.find(pos
) == lvlExprComponents
.end() &&
763 "expected only one floordiv for each dimension");
764 SmallVector
<AffineExpr
, 3> components
;
765 // Level variable for floordiv.
766 components
.push_back(getAffineDimExpr(i
, context
));
768 components
.push_back(binOp
.getRHS());
769 // Map key is the position of the dimension.
770 lvlExprComponents
[pos
] = components
;
771 } else if (result
.getKind() == AffineExprKind::Mod
) {
772 auto pos
= dyn_cast
<AffineDimExpr
>(binOp
.getLHS()).getPosition();
773 assert(lvlExprComponents
.find(pos
) != lvlExprComponents
.end() &&
774 "expected floordiv before mod");
775 // Add level variable for mod to the same vector
776 // of the corresponding floordiv.
777 lvlExprComponents
[pos
].push_back(getAffineDimExpr(i
, context
));
779 assert(false && "expected floordiv or mod");
782 lvlExprs
.push_back(getAffineDimExpr(i
, context
));
785 // Build lvlExprs from lvlExprComponents.
786 // For example, for il = i floordiv 2 and ii = i mod 2, the components
787 // would be [il, 2, ii]. It could be used to build the AffineExpr
788 // i = il * 2 + ii in lvlToDim.
789 for (auto &components
: lvlExprComponents
) {
790 assert(components
.second
.size() == 3 &&
791 "expected 3 components to build lvlExprs");
792 auto mulOp
= getAffineBinaryOpExpr(
793 AffineExprKind::Mul
, components
.second
[0], components
.second
[1]);
795 getAffineBinaryOpExpr(AffineExprKind::Add
, mulOp
, components
.second
[2]);
796 lvlExprs
.push_back(addOp
);
798 return dimToLvl
.get(dimToLvl
.getNumResults(), 0, lvlExprs
, context
);
801 SmallVector
<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl
) {
802 assert(isBlockSparsity(dimToLvl
) &&
803 "expected dimToLvl to be block sparsity for calling getBlockSize");
804 SmallVector
<unsigned> blockSize
;
805 for (auto result
: dimToLvl
.getResults()) {
806 if (auto binOp
= dyn_cast
<AffineBinaryOpExpr
>(result
)) {
807 if (result
.getKind() == AffineExprKind::Mod
) {
809 dyn_cast
<AffineConstantExpr
>(binOp
.getRHS()).getValue());
812 blockSize
.push_back(0);
818 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl
) {
821 std::map
<unsigned, int64_t> coeffientMap
;
822 for (auto result
: dimToLvl
.getResults()) {
823 if (auto binOp
= dyn_cast
<AffineBinaryOpExpr
>(result
)) {
824 auto pos
= dyn_cast
<AffineDimExpr
>(binOp
.getLHS()).getPosition();
825 if (result
.getKind() == AffineExprKind::FloorDiv
) {
826 // Expect only one floordiv for each dimension.
827 if (coeffientMap
.find(pos
) != coeffientMap
.end())
830 dyn_cast
<AffineConstantExpr
>(binOp
.getRHS()).getValue();
831 } else if (result
.getKind() == AffineExprKind::Mod
) {
832 // Expect floordiv before mod.
833 if (coeffientMap
.find(pos
) == coeffientMap
.end())
835 // Expect mod to have the same coefficient as floordiv.
836 if (dyn_cast
<AffineConstantExpr
>(binOp
.getRHS()).getValue() !=
845 return !coeffientMap
.empty();
848 bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc
,
849 Level startLvl
, bool isUnique
) {
851 !(enc
.isCompressedLvl(startLvl
) || enc
.isLooseCompressedLvl(startLvl
)))
853 const Level lvlRank
= enc
.getLvlRank();
854 for (Level l
= startLvl
+ 1; l
< lvlRank
; ++l
)
855 if (!enc
.isSingletonLvl(l
))
857 // If isUnique is true, then make sure that the last level is unique,
858 // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
859 // (unique on the last singleton).
860 return !isUnique
|| enc
.isUniqueLvl(lvlRank
- 1);
863 bool mlir::sparse_tensor::isUniqueCOOType(Type tp
) {
864 return isCOOType(getSparseTensorEncoding(tp
), 0, /*isUnique=*/true);
867 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation
*op
) {
868 auto hasNonIdentityMap
= [](Value v
) {
869 auto stt
= tryGetSparseTensorType(v
);
870 return stt
&& !stt
->isIdentity();
873 return llvm::any_of(op
->getOperands(), hasNonIdentityMap
) ||
874 llvm::any_of(op
->getResults(), hasNonIdentityMap
);
877 Level
mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc
) {
878 // We only consider COO region with at least two levels for the purpose
879 // of AOS storage optimization.
880 const Level lvlRank
= enc
.getLvlRank();
882 for (Level l
= 0; l
< lvlRank
- 1; l
++)
883 if (isCOOType(enc
, l
, /*isUnique=*/false))
888 // Helpers to setup a COO type.
889 RankedTensorType
sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt
,
892 const SparseTensorType
src(rtt
);
893 const Level lvlRank
= src
.getLvlRank();
894 SmallVector
<DimLevelType
> lvlTypes
;
895 lvlTypes
.reserve(lvlRank
);
897 // An unordered and non-unique compressed level at beginning.
898 // If this is also the last level, then it is unique.
900 *buildLevelType(LevelFormat::Compressed
, ordered
, lvlRank
== 1));
902 // TODO: it is actually ordered at the level for ordered input.
903 // Followed by unordered non-unique n-2 singleton levels.
904 std::fill_n(std::back_inserter(lvlTypes
), lvlRank
- 2,
905 *buildLevelType(LevelFormat::Singleton
, ordered
, false));
906 // Ends by a unique singleton level unless the lvlRank is 1.
907 lvlTypes
.push_back(*buildLevelType(LevelFormat::Singleton
, ordered
, true));
910 // TODO: Maybe pick the bitwidth based on input/output tensors (probably the
911 // largest one among them) in the original operation instead of using the
913 unsigned posWidth
= src
.getPosWidth();
914 unsigned crdWidth
= src
.getCrdWidth();
915 AffineMap invPerm
= src
.getLvlToDim();
916 auto enc
= SparseTensorEncodingAttr::get(src
.getContext(), lvlTypes
, lvlPerm
,
917 invPerm
, posWidth
, crdWidth
);
918 return RankedTensorType::get(src
.getDimShape(), src
.getElementType(), enc
);
921 RankedTensorType
sparse_tensor::getCOOFromType(RankedTensorType src
,
923 return getCOOFromTypeWithOrdering(
924 src
, AffineMap::getMultiDimIdentityMap(src
.getRank(), src
.getContext()),
928 // TODO: Remove this definition once all use-sites have been fixed to
929 // properly handle non-permutations.
930 Dimension
mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc
,
933 if (const auto dimToLvl
= enc
.getDimToLvl()) {
934 assert(enc
.isPermutation());
935 return dimToLvl
.getDimPosition(l
);
941 // TODO: Remove this definition once all use-sites have been fixed to
942 // properly handle non-permutations.
943 Level
mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc
,
946 if (const auto dimToLvl
= enc
.getDimToLvl()) {
947 assert(enc
.isPermutation());
949 dimToLvl
.getResultPosition(getAffineDimExpr(d
, enc
.getContext()));
950 assert(maybePos
.has_value());
957 /// We normalized sparse tensor encoding attribute by always using
958 /// ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
959 /// as other variants) lead to the same storage specifier type, and stripping
960 /// irrelevant fields that do not alter the sparse tensor memory layout.
961 static SparseTensorEncodingAttr
962 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc
) {
963 SmallVector
<DimLevelType
> dlts
;
964 for (auto dlt
: enc
.getLvlTypes())
965 dlts
.push_back(*buildLevelType(*getLevelFormat(dlt
), true, true));
967 return SparseTensorEncodingAttr::get(
968 enc
.getContext(), dlts
,
969 AffineMap(), // dimToLvl (irrelevant to storage specifier)
970 AffineMap(), // lvlToDim (irrelevant to storage specifier)
971 // Always use `index` for memSize and lvlSize instead of reusing
972 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
973 // value for different bitwidth, it also avoids casting between index and
974 // integer (returned by DimOp)
975 0, 0, enc
.getDimSlices());
979 StorageSpecifierType::get(MLIRContext
*ctx
, SparseTensorEncodingAttr encoding
) {
980 return Base::get(ctx
, getNormalizedEncodingForSpecifier(encoding
));
983 //===----------------------------------------------------------------------===//
984 // SparseTensorDialect Operations.
985 //===----------------------------------------------------------------------===//
987 static LogicalResult
lvlIsInBounds(Level lvl
, Value tensor
) {
988 return success(lvl
< getSparseTensorType(tensor
).getLvlRank());
991 static LogicalResult
isMatchingWidth(Value mem
, unsigned width
) {
992 const Type etp
= getMemRefType(mem
).getElementType();
993 return success(width
== 0 ? etp
.isIndex() : etp
.isInteger(width
));
996 static LogicalResult
verifySparsifierGetterSetter(
997 StorageSpecifierKind mdKind
, std::optional
<Level
> lvl
,
998 TypedValue
<StorageSpecifierType
> md
, Operation
*op
) {
999 if (mdKind
== StorageSpecifierKind::ValMemSize
&& lvl
) {
1000 return op
->emitError(
1001 "redundant level argument for querying value memory size");
1004 const auto enc
= md
.getType().getEncoding();
1005 const Level lvlRank
= enc
.getLvlRank();
1007 if (mdKind
== StorageSpecifierKind::DimOffset
||
1008 mdKind
== StorageSpecifierKind::DimStride
)
1010 return op
->emitError("requested slice data on non-slice tensor");
1012 if (mdKind
!= StorageSpecifierKind::ValMemSize
) {
1014 return op
->emitError("missing level argument");
1016 const Level l
= lvl
.value();
1018 return op
->emitError("requested level is out of bounds");
1020 if (mdKind
== StorageSpecifierKind::PosMemSize
&& enc
.isSingletonLvl(l
))
1021 return op
->emitError(
1022 "requested position memory size on a singleton level");
1027 static Type
getFieldElemType(SparseTensorType stt
, SparseTensorFieldKind kind
) {
1029 case SparseTensorFieldKind::CrdMemRef
:
1030 return stt
.getCrdType();
1031 case SparseTensorFieldKind::PosMemRef
:
1032 return stt
.getPosType();
1033 case SparseTensorFieldKind::ValMemRef
:
1034 return stt
.getElementType();
1035 case SparseTensorFieldKind::StorageSpec
:
1038 llvm_unreachable("Unrecognizable FieldKind");
1041 static LogicalResult
verifyPackUnPack(Operation
*op
, bool requiresStaticShape
,
1042 SparseTensorType stt
,
1043 RankedTensorType valTp
,
1045 if (requiresStaticShape
&& !stt
.hasStaticDimShape())
1046 return op
->emitError("the sparse-tensor must have static shape");
1047 if (!stt
.hasEncoding())
1048 return op
->emitError("the sparse-tensor must have an encoding attribute");
1049 if (!stt
.isIdentity())
1050 return op
->emitError("the sparse-tensor must have the identity mapping");
1052 // Verifies the trailing COO.
1053 Level cooStartLvl
= getCOOStart(stt
.getEncoding());
1054 if (cooStartLvl
< stt
.getLvlRank()) {
1055 // We only supports trailing COO for now, must be the last input.
1056 auto cooTp
= llvm::cast
<ShapedType
>(lvlTps
.back());
1057 // The coordinates should be in shape of <? x rank>
1058 unsigned expCOORank
= stt
.getLvlRank() - cooStartLvl
;
1059 if (cooTp
.getRank() != 2 || expCOORank
!= cooTp
.getShape().back()) {
1060 op
->emitError("input/output trailing COO level-ranks don't match");
1064 // Verifies that all types match.
1065 StorageLayout
layout(stt
.getEncoding());
1066 if (layout
.getNumDataFields() != lvlTps
.size() + 1) // plus one value memref
1067 return op
->emitError("inconsistent number of fields between input/output");
1070 bool misMatch
= false;
1071 layout
.foreachField([&idx
, &misMatch
, stt
, valTp
,
1072 lvlTps
](FieldIndex fid
, SparseTensorFieldKind fKind
,
1073 Level lvl
, DimLevelType dlt
) -> bool {
1074 if (fKind
== SparseTensorFieldKind::StorageSpec
)
1077 Type inputTp
= nullptr;
1078 if (fKind
== SparseTensorFieldKind::ValMemRef
) {
1081 assert(fid
== idx
&& stt
.getLvlType(lvl
) == dlt
);
1082 inputTp
= lvlTps
[idx
++];
1084 // The input element type and expected element type should match.
1085 Type inpElemTp
= llvm::cast
<TensorType
>(inputTp
).getElementType();
1086 Type expElemTp
= getFieldElemType(stt
, fKind
);
1087 if (inpElemTp
!= expElemTp
) {
1089 return false; // to terminate the iteration
1095 return op
->emitError("input/output element-types don't match");
1099 LogicalResult
AssembleOp::verify() {
1100 const auto valuesTp
= getRankedTensorType(getValues());
1101 const auto lvlsTp
= getLevels().getTypes();
1102 const auto resTp
= getSparseTensorType(getResult());
1103 return verifyPackUnPack(*this, true, resTp
, valuesTp
, lvlsTp
);
1106 LogicalResult
DisassembleOp::verify() {
1107 if (getOutValues().getType() != getRetValues().getType())
1108 return emitError("output values and return value type mismatch");
1110 for (auto [ot
, rt
] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1111 if (ot
.getType() != rt
.getType())
1112 return emitError("output levels and return levels type mismatch");
1114 const auto valuesTp
= getRankedTensorType(getRetValues());
1115 const auto lvlsTp
= getRetLevels().getTypes();
1116 const auto srcTp
= getSparseTensorType(getTensor());
1117 return verifyPackUnPack(*this, false, srcTp
, valuesTp
, lvlsTp
);
1120 LogicalResult
ConvertOp::verify() {
1121 if (auto tp1
= llvm::dyn_cast
<RankedTensorType
>(getSource().getType())) {
1122 if (auto tp2
= llvm::dyn_cast
<RankedTensorType
>(getDest().getType())) {
1123 if (tp1
.getRank() != tp2
.getRank())
1124 return emitError("unexpected conversion mismatch in rank");
1126 llvm::dyn_cast_or_null
<SparseTensorEncodingAttr
>(tp2
.getEncoding());
1127 if (dstEnc
&& dstEnc
.isSlice())
1128 return emitError("cannot convert to a sparse tensor slice");
1130 auto shape1
= tp1
.getShape();
1131 auto shape2
= tp2
.getShape();
1132 // Accept size matches between the source and the destination type
1133 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1134 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1135 for (Dimension d
= 0, dimRank
= tp1
.getRank(); d
< dimRank
; d
++)
1136 if (shape1
[d
] != shape2
[d
] && shape2
[d
] != ShapedType::kDynamic
)
1137 return emitError("unexpected conversion mismatch in dimension ") << d
;
1141 return emitError("unexpected type in convert");
1144 OpFoldResult
ConvertOp::fold(FoldAdaptor adaptor
) {
1145 if (getType() == getSource().getType())
1150 bool ConvertOp::needsExtraSort() {
1151 SparseTensorType srcStt
= getSparseTensorType(getSource());
1152 SparseTensorType dstStt
= getSparseTensorType(getDest());
1154 // We do not need an extra sort when returning unordered sparse tensors or
1155 // dense tensor since dense tensor support random access.
1156 if (dstStt
.isAllDense() || !dstStt
.isAllOrdered())
1159 if (srcStt
.isAllOrdered() && dstStt
.isAllOrdered() &&
1160 srcStt
.hasSameDimToLvl(dstStt
)) {
1164 // Source and dest tensors are ordered in different ways. We only do direct
1165 // dense to sparse conversion when the dense input is defined by a sparse
1166 // constant. Note that we can theoretically always directly convert from dense
1167 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1169 if (auto constOp
= getSource().getDefiningOp
<arith::ConstantOp
>())
1170 if (isa
<SparseElementsAttr
>(constOp
.getValue()))
1176 LogicalResult
CrdTranslateOp::verify() {
1177 uint64_t inRank
= getEncoder().getLvlRank();
1178 uint64_t outRank
= getEncoder().getDimRank();
1180 if (getDirection() == CrdTransDirectionKind::dim2lvl
)
1181 std::swap(inRank
, outRank
);
1183 if (inRank
!= getInCrds().size() || outRank
!= getOutCrds().size())
1184 return emitError("Coordinate rank mismatch with encoding");
1189 LogicalResult
CrdTranslateOp::fold(FoldAdaptor adaptor
,
1190 SmallVectorImpl
<OpFoldResult
> &results
) {
1191 if (getEncoder().isIdentity()) {
1192 results
.assign(getInCrds().begin(), getInCrds().end());
1195 if (getEncoder().isPermutation()) {
1196 AffineMap perm
= getDirection() == CrdTransDirectionKind::dim2lvl
1197 ? getEncoder().getDimToLvl()
1198 : getEncoder().getLvlToDim();
1199 for (AffineExpr exp
: perm
.getResults())
1200 results
.push_back(getInCrds()[cast
<AffineDimExpr
>(exp
).getPosition()]);
1204 // Fuse dim2lvl/lvl2dim pairs.
1205 auto def
= getInCrds()[0].getDefiningOp
<CrdTranslateOp
>();
1206 bool sameDef
= def
&& llvm::all_of(getInCrds(), [def
](Value v
) {
1207 return v
.getDefiningOp() == def
;
1212 bool oppositeDir
= def
.getDirection() != getDirection();
1214 def
.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1215 bool sameCount
= def
.getNumResults() == getInCrds().size();
1216 if (!oppositeDir
|| !sameOracle
|| !sameCount
)
1219 // The definition produces the coordinates in the same order as the input
1221 bool sameOrder
= llvm::all_of(llvm::zip_equal(def
.getOutCrds(), getInCrds()),
1222 [](auto valuePair
) {
1223 auto [lhs
, rhs
] = valuePair
;
1229 // l1 = dim2lvl (lvl2dim l0)
1231 results
.append(def
.getInCrds().begin(), def
.getInCrds().end());
1235 void LvlOp::build(OpBuilder
&builder
, OperationState
&state
, Value source
,
1237 Value val
= builder
.create
<arith::ConstantIndexOp
>(state
.location
, index
);
1238 return build(builder
, state
, source
, val
);
1241 LogicalResult
LvlOp::verify() {
1242 if (std::optional
<uint64_t> lvl
= getConstantLvlIndex()) {
1243 auto stt
= getSparseTensorType(getSource());
1244 if (static_cast<uint64_t>(lvl
.value()) >= stt
.getLvlRank())
1245 emitError("Level index exceeds the rank of the input sparse tensor");
1250 std::optional
<uint64_t> LvlOp::getConstantLvlIndex() {
1251 return getConstantIntValue(getIndex());
1254 Speculation::Speculatability
LvlOp::getSpeculatability() {
1255 auto constantIndex
= getConstantLvlIndex();
1257 return Speculation::NotSpeculatable
;
1259 assert(constantIndex
<
1260 cast
<RankedTensorType
>(getSource().getType()).getRank());
1261 return Speculation::Speculatable
;
1264 OpFoldResult
LvlOp::fold(FoldAdaptor adaptor
) {
1265 auto lvlIndex
= llvm::dyn_cast_if_present
<IntegerAttr
>(adaptor
.getIndex());
1269 Level lvl
= lvlIndex
.getAPSInt().getZExtValue();
1270 auto stt
= getSparseTensorType(getSource());
1271 if (lvl
>= stt
.getLvlRank()) {
1272 // Follows the same convention used by tensor.dim operation. Out of bound
1273 // indices produce undefined behavior but are still valid IR. Don't choke on
1278 // Helper lambda to build an IndexAttr.
1279 auto getIndexAttr
= [this](int64_t lvlSz
) {
1280 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz
));
1283 SmallVector
<Size
> lvlShape
= stt
.getLvlShape();
1284 if (!ShapedType::isDynamic(lvlShape
[lvl
]))
1285 return getIndexAttr(lvlShape
[lvl
]);
1290 void ReinterpretMapOp::build(OpBuilder
&odsBuilder
, OperationState
&odsState
,
1291 SparseTensorEncodingAttr dstEnc
, Value source
) {
1292 auto srcStt
= getSparseTensorType(source
);
1293 SmallVector
<int64_t> srcLvlShape
= srcStt
.getLvlShape();
1294 SmallVector
<int64_t> dstDimShape
=
1295 dstEnc
.tranlateShape(srcLvlShape
, CrdTransDirectionKind::lvl2dim
);
1297 RankedTensorType::get(dstDimShape
, srcStt
.getElementType(), dstEnc
);
1298 return build(odsBuilder
, odsState
, dstTp
, source
);
1301 LogicalResult
ReinterpretMapOp::verify() {
1302 auto srcStt
= getSparseTensorType(getSource());
1303 auto dstStt
= getSparseTensorType(getDest());
1304 ArrayRef
<DimLevelType
> srcLvlTps
= srcStt
.getLvlTypes();
1305 ArrayRef
<DimLevelType
> dstLvlTps
= dstStt
.getLvlTypes();
1307 if (srcLvlTps
.size() != dstLvlTps
.size())
1308 return emitError("Level rank mismatch between source/dest tensors");
1310 for (auto [srcLvlTp
, dstLvlTp
] : llvm::zip(srcLvlTps
, dstLvlTps
))
1311 if (srcLvlTp
!= dstLvlTp
)
1312 return emitError("Level type mismatch between source/dest tensors");
1314 if (srcStt
.getPosWidth() != dstStt
.getPosWidth() ||
1315 srcStt
.getCrdWidth() != dstStt
.getCrdWidth()) {
1316 return emitError("Crd/Pos width mismatch between source/dest tensors");
1319 if (srcStt
.getElementType() != dstStt
.getElementType())
1320 return emitError("Element type mismatch between source/dest tensors");
1322 SmallVector
<Size
> srcLvlShape
= srcStt
.getLvlShape();
1323 SmallVector
<Size
> dstLvlShape
= dstStt
.getLvlShape();
1324 for (auto [srcLvlSz
, dstLvlSz
] : llvm::zip(srcLvlShape
, dstLvlShape
)) {
1325 if (srcLvlSz
!= dstLvlSz
) {
1326 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1327 // compatible to <3x4>? For now, we require all the level sizes to be
1328 // *exactly* matched for simplicity.
1329 return emitError("Level size mismatch between source/dest tensors");
1336 OpFoldResult
ReinterpretMapOp::fold(FoldAdaptor adaptor
) {
1337 if (getSource().getType() == getDest().getType())
1340 if (auto def
= getSource().getDefiningOp
<ReinterpretMapOp
>()) {
1341 // A -> B, B -> A ==> A
1342 if (def
.getSource().getType() == getDest().getType())
1343 return def
.getSource();
1348 LogicalResult
ToPositionsOp::verify() {
1349 auto e
= getSparseTensorEncoding(getTensor().getType());
1350 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1351 return emitError("requested level is out of bounds");
1352 if (failed(isMatchingWidth(getResult(), e
.getPosWidth())))
1353 return emitError("unexpected type for positions");
1357 LogicalResult
ToCoordinatesOp::verify() {
1358 auto e
= getSparseTensorEncoding(getTensor().getType());
1359 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1360 return emitError("requested level is out of bounds");
1361 if (failed(isMatchingWidth(getResult(), e
.getCrdWidth())))
1362 return emitError("unexpected type for coordinates");
1366 LogicalResult
ToCoordinatesBufferOp::verify() {
1367 auto e
= getSparseTensorEncoding(getTensor().getType());
1368 if (getCOOStart(e
) >= e
.getLvlRank())
1369 return emitError("expected sparse tensor with a COO region");
1373 LogicalResult
ToValuesOp::verify() {
1374 auto ttp
= getRankedTensorType(getTensor());
1375 auto mtp
= getMemRefType(getResult());
1376 if (ttp
.getElementType() != mtp
.getElementType())
1377 return emitError("unexpected mismatch in element types");
1381 LogicalResult
ToSliceOffsetOp::verify() {
1382 auto rank
= getRankedTensorType(getSlice()).getRank();
1383 if (rank
<= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1384 return emitError("requested dimension out of bound");
1388 LogicalResult
ToSliceStrideOp::verify() {
1389 auto rank
= getRankedTensorType(getSlice()).getRank();
1390 if (rank
<= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1391 return emitError("requested dimension out of bound");
1395 LogicalResult
GetStorageSpecifierOp::verify() {
1396 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1397 getSpecifier(), getOperation());
1400 template <typename SpecifierOp
>
1401 static SetStorageSpecifierOp
getSpecifierSetDef(SpecifierOp op
) {
1402 return op
.getSpecifier().template getDefiningOp
<SetStorageSpecifierOp
>();
1405 OpFoldResult
GetStorageSpecifierOp::fold(FoldAdaptor adaptor
) {
1406 const StorageSpecifierKind kind
= getSpecifierKind();
1407 const auto lvl
= getLevel();
1408 for (auto op
= getSpecifierSetDef(*this); op
; op
= getSpecifierSetDef(op
))
1409 if (kind
== op
.getSpecifierKind() && lvl
== op
.getLevel())
1410 return op
.getValue();
1414 LogicalResult
SetStorageSpecifierOp::verify() {
1415 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1416 getSpecifier(), getOperation());
1420 static LogicalResult
verifyNumBlockArgs(T
*op
, Region
®ion
,
1421 const char *regionName
,
1422 TypeRange inputTypes
, Type outputType
) {
1423 unsigned numArgs
= region
.getNumArguments();
1424 unsigned expectedNum
= inputTypes
.size();
1425 if (numArgs
!= expectedNum
)
1426 return op
->emitError() << regionName
<< " region must have exactly "
1427 << expectedNum
<< " arguments";
1429 for (unsigned i
= 0; i
< numArgs
; i
++) {
1430 Type typ
= region
.getArgument(i
).getType();
1431 if (typ
!= inputTypes
[i
])
1432 return op
->emitError() << regionName
<< " region argument " << (i
+ 1)
1433 << " type mismatch";
1435 Operation
*term
= region
.front().getTerminator();
1436 YieldOp yield
= dyn_cast
<YieldOp
>(term
);
1438 return op
->emitError() << regionName
1439 << " region must end with sparse_tensor.yield";
1440 if (!yield
.getResult() || yield
.getResult().getType() != outputType
)
1441 return op
->emitError() << regionName
<< " region yield type mismatch";
1446 LogicalResult
BinaryOp::verify() {
1447 NamedAttrList attrs
= (*this)->getAttrs();
1448 Type leftType
= getX().getType();
1449 Type rightType
= getY().getType();
1450 Type outputType
= getOutput().getType();
1451 Region
&overlap
= getOverlapRegion();
1452 Region
&left
= getLeftRegion();
1453 Region
&right
= getRightRegion();
1455 // Check correct number of block arguments and return type for each
1456 // non-empty region.
1457 if (!overlap
.empty()) {
1458 if (failed(verifyNumBlockArgs(this, overlap
, "overlap",
1459 TypeRange
{leftType
, rightType
}, outputType
)))
1462 if (!left
.empty()) {
1463 if (failed(verifyNumBlockArgs(this, left
, "left", TypeRange
{leftType
},
1466 } else if (getLeftIdentity()) {
1467 if (leftType
!= outputType
)
1468 return emitError("left=identity requires first argument to have the same "
1469 "type as the output");
1471 if (!right
.empty()) {
1472 if (failed(verifyNumBlockArgs(this, right
, "right", TypeRange
{rightType
},
1475 } else if (getRightIdentity()) {
1476 if (rightType
!= outputType
)
1477 return emitError("right=identity requires second argument to have the "
1478 "same type as the output");
1483 LogicalResult
UnaryOp::verify() {
1484 Type inputType
= getX().getType();
1485 Type outputType
= getOutput().getType();
1487 // Check correct number of block arguments and return type for each
1488 // non-empty region.
1489 Region
&present
= getPresentRegion();
1490 if (!present
.empty()) {
1491 if (failed(verifyNumBlockArgs(this, present
, "present",
1492 TypeRange
{inputType
}, outputType
)))
1495 Region
&absent
= getAbsentRegion();
1496 if (!absent
.empty()) {
1497 if (failed(verifyNumBlockArgs(this, absent
, "absent", TypeRange
{},
1500 // Absent branch can only yield invariant values.
1501 Block
*absentBlock
= &absent
.front();
1502 Block
*parent
= getOperation()->getBlock();
1503 Value absentVal
= cast
<YieldOp
>(absentBlock
->getTerminator()).getResult();
1504 if (auto arg
= dyn_cast
<BlockArgument
>(absentVal
)) {
1505 if (arg
.getOwner() == parent
)
1506 return emitError("absent region cannot yield linalg argument");
1507 } else if (Operation
*def
= absentVal
.getDefiningOp()) {
1508 if (!isa
<arith::ConstantOp
>(def
) &&
1509 (def
->getBlock() == absentBlock
|| def
->getBlock() == parent
))
1510 return emitError("absent region cannot yield locally computed value");
1516 bool ConcatenateOp::needsExtraSort() {
1517 SparseTensorType dstStt
= getSparseTensorType(*this);
1518 if (dstStt
.isAllDense() || !dstStt
.isAllOrdered())
1521 bool allSameOrdered
= llvm::all_of(getInputs(), [dstStt
](Value op
) {
1522 return getSparseTensorType(op
).hasSameDimToLvl(dstStt
);
1524 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1525 // in all input/output buffers, and all input/output buffers have the same
1526 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1527 // CSC matrices along column).
1528 bool directLowerable
=
1529 allSameOrdered
&& getDimension() == 0 && dstStt
.isIdentity();
1530 return !directLowerable
;
1533 LogicalResult
ConcatenateOp::verify() {
1534 const auto dstTp
= getSparseTensorType(*this);
1535 const Dimension concatDim
= getDimension();
1536 const Dimension dimRank
= dstTp
.getDimRank();
1538 if (getInputs().size() <= 1)
1539 return emitError("Need at least two tensors to concatenate.");
1541 if (concatDim
>= dimRank
)
1542 return emitError(llvm::formatv(
1543 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1544 concatDim
, dimRank
));
1546 for (const auto &it
: llvm::enumerate(getInputs())) {
1547 const auto i
= it
.index();
1548 const auto srcTp
= getSparseTensorType(it
.value());
1549 if (srcTp
.hasDynamicDimShape())
1550 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i
));
1551 const Dimension srcDimRank
= srcTp
.getDimRank();
1552 if (srcDimRank
!= dimRank
)
1554 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1555 "from the output tensor (rank={2}).",
1556 i
, srcDimRank
, dimRank
));
1559 for (Dimension d
= 0; d
< dimRank
; d
++) {
1560 const Size dstSh
= dstTp
.getDimShape()[d
];
1561 if (d
== concatDim
) {
1562 if (!ShapedType::isDynamic(dstSh
)) {
1563 // If we reach here, then all inputs have static shapes. So we
1564 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1565 // to avoid redundant assertions in the loop.
1567 for (const auto src
: getInputs())
1568 sumSz
+= getSparseTensorType(src
).getDimShape()[d
];
1569 // If all dimension are statically known, the sum of all the input
1570 // dimensions should be equal to the output dimension.
1573 "The concatenation dimension of the output tensor should be the "
1574 "sum of all the concatenation dimensions of the input tensors.");
1578 for (const auto src
: getInputs()) {
1579 const auto sh
= getSparseTensorType(src
).getDimShape()[d
];
1580 if (!ShapedType::isDynamic(prev
) && sh
!= prev
)
1581 return emitError("All dimensions (expect for the concatenating one) "
1582 "should be equal.");
1591 LogicalResult
InsertOp::verify() {
1592 const auto stt
= getSparseTensorType(getTensor());
1593 if (stt
.getLvlRank() != static_cast<Level
>(getLvlCoords().size()))
1594 return emitOpError("incorrect number of coordinates");
1598 void PushBackOp::build(OpBuilder
&builder
, OperationState
&result
,
1599 Value curSize
, Value inBuffer
, Value value
) {
1600 build(builder
, result
, curSize
, inBuffer
, value
, Value());
1603 LogicalResult
PushBackOp::verify() {
1604 if (Value n
= getN()) {
1605 std::optional
<int64_t> nValue
= getConstantIntValue(n
);
1606 if (nValue
&& nValue
.value() < 1)
1607 return emitOpError("n must be not less than 1");
1612 LogicalResult
CompressOp::verify() {
1613 const auto stt
= getSparseTensorType(getTensor());
1614 if (stt
.getLvlRank() != 1 + static_cast<Level
>(getLvlCoords().size()))
1615 return emitOpError("incorrect number of coordinates");
1619 void ForeachOp::build(
1620 OpBuilder
&builder
, OperationState
&result
, Value tensor
,
1621 ValueRange initArgs
, AffineMapAttr order
,
1622 function_ref
<void(OpBuilder
&, Location
, ValueRange
, Value
, ValueRange
)>
1624 build(builder
, result
, initArgs
.getTypes(), tensor
, initArgs
, order
);
1625 // Builds foreach body.
1628 const auto stt
= getSparseTensorType(tensor
);
1629 const Dimension dimRank
= stt
.getDimRank();
1631 // Starts with `dimRank`-many coordinates.
1632 SmallVector
<Type
> blockArgTypes(dimRank
, builder
.getIndexType());
1633 // Followed by one value.
1634 blockArgTypes
.push_back(stt
.getElementType());
1635 // Followed by the reduction variables.
1636 blockArgTypes
.append(initArgs
.getTypes().begin(), initArgs
.getTypes().end());
1638 SmallVector
<Location
> blockArgLocs(blockArgTypes
.size(), tensor
.getLoc());
1640 OpBuilder::InsertionGuard
guard(builder
);
1641 auto ®ion
= *result
.regions
.front();
1643 builder
.createBlock(®ion
, region
.end(), blockArgTypes
, blockArgLocs
);
1644 bodyBuilder(builder
, result
.location
,
1645 bodyBlock
->getArguments().slice(0, dimRank
),
1646 bodyBlock
->getArguments()[dimRank
],
1647 bodyBlock
->getArguments().drop_front(dimRank
+ 1));
1650 LogicalResult
ForeachOp::verify() {
1651 const auto t
= getSparseTensorType(getTensor());
1652 const Dimension dimRank
= t
.getDimRank();
1653 const auto args
= getBody()->getArguments();
1655 if (getOrder().has_value() && getOrder()->getNumDims() != t
.getLvlRank())
1656 return emitError("Level traverse order does not match tensor's level rank");
1658 if (dimRank
+ 1 + getInitArgs().size() != args
.size())
1659 return emitError("Unmatched number of arguments in the block");
1661 if (getNumResults() != getInitArgs().size())
1662 return emitError("Mismatch in number of init arguments and results");
1664 if (getResultTypes() != getInitArgs().getTypes())
1665 return emitError("Mismatch in types of init arguments and results");
1667 // Cannot mark this const, because the getters aren't.
1668 auto yield
= cast
<YieldOp
>(getBody()->getTerminator());
1669 if (yield
.getNumOperands() != getNumResults() ||
1670 yield
.getOperands().getTypes() != getResultTypes())
1671 return emitError("Mismatch in types of yield values and results");
1673 const auto iTp
= IndexType::get(getContext());
1674 for (Dimension d
= 0; d
< dimRank
; d
++)
1675 if (args
[d
].getType() != iTp
)
1677 llvm::formatv("Expecting Index type for argument at index {0}", d
));
1679 const auto elemTp
= t
.getElementType();
1680 const auto valueTp
= args
[dimRank
].getType();
1681 if (elemTp
!= valueTp
)
1682 emitError(llvm::formatv("Unmatched element type between input tensor and "
1683 "block argument, expected:{0}, got: {1}",
1688 OpFoldResult
ReorderCOOOp::fold(FoldAdaptor adaptor
) {
1689 if (getSparseTensorEncoding(getInputCoo().getType()) ==
1690 getSparseTensorEncoding(getResultCoo().getType()))
1691 return getInputCoo();
1696 LogicalResult
ReorderCOOOp::verify() {
1697 SparseTensorType srcStt
= getSparseTensorType(getInputCoo());
1698 SparseTensorType dstStt
= getSparseTensorType(getResultCoo());
1700 if (!srcStt
.hasSameDimToLvl(dstStt
))
1701 emitError("Unmatched dim2lvl map between input and result COO");
1703 if (srcStt
.getPosType() != dstStt
.getPosType() ||
1704 srcStt
.getCrdType() != dstStt
.getCrdType() ||
1705 srcStt
.getElementType() != dstStt
.getElementType()) {
1706 emitError("Unmatched storage format between input and result COO");
1711 LogicalResult
ReduceOp::verify() {
1712 Type inputType
= getX().getType();
1713 Region
&formula
= getRegion();
1714 return verifyNumBlockArgs(this, formula
, "reduce",
1715 TypeRange
{inputType
, inputType
}, inputType
);
1718 LogicalResult
SelectOp::verify() {
1719 Builder
b(getContext());
1720 Type inputType
= getX().getType();
1721 Type boolType
= b
.getI1Type();
1722 Region
&formula
= getRegion();
1723 return verifyNumBlockArgs(this, formula
, "select", TypeRange
{inputType
},
1727 LogicalResult
SortOp::verify() {
1728 AffineMap xPerm
= getPermMap();
1729 uint64_t nx
= xPerm
.getNumDims();
1731 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx
));
1733 if (!xPerm
.isPermutation())
1734 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm
));
1736 // We can't check the size of the buffers when n or buffer dimensions aren't
1737 // compile-time constants.
1738 std::optional
<int64_t> cn
= getConstantIntValue(getN());
1742 // Verify dimensions.
1743 const auto checkDim
= [&](Value v
, Size minSize
, const char *message
) {
1744 const Size sh
= getMemRefType(v
).getShape()[0];
1745 if (!ShapedType::isDynamic(sh
) && sh
< minSize
)
1746 emitError(llvm::formatv("{0} got {1} < {2}", message
, sh
, minSize
));
1748 uint64_t n
= cn
.value();
1750 if (auto nyAttr
= getNyAttr())
1751 ny
= nyAttr
.getInt();
1752 checkDim(getXy(), n
* (nx
+ ny
),
1753 "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1754 for (Value opnd
: getYs())
1755 checkDim(opnd
, n
, "Expected dimension(y) >= n");
1760 LogicalResult
YieldOp::verify() {
1761 // Check for compatible parent.
1762 auto *parentOp
= (*this)->getParentOp();
1763 if (isa
<BinaryOp
>(parentOp
) || isa
<UnaryOp
>(parentOp
) ||
1764 isa
<ReduceOp
>(parentOp
) || isa
<SelectOp
>(parentOp
) ||
1765 isa
<ForeachOp
>(parentOp
))
1768 return emitOpError("expected parent op to be sparse_tensor unary, binary, "
1769 "reduce, select or foreach");
1772 /// Materialize a single constant operation from a given attribute value with
1773 /// the desired resultant type.
1774 Operation
*SparseTensorDialect::materializeConstant(OpBuilder
&builder
,
1775 Attribute value
, Type type
,
1777 if (auto op
= arith::ConstantOp::materialize(builder
, value
, type
, loc
))
1783 struct SparseTensorAsmDialectInterface
: public OpAsmDialectInterface
{
1784 using OpAsmDialectInterface::OpAsmDialectInterface
;
1786 AliasResult
getAlias(Attribute attr
, raw_ostream
&os
) const override
{
1787 if (attr
.isa
<SparseTensorEncodingAttr
>()) {
1789 return AliasResult::OverridableAlias
;
1791 return AliasResult::NoAlias
;
1796 void SparseTensorDialect::initialize() {
1797 addInterface
<SparseTensorAsmDialectInterface
>();
1799 #define GET_ATTRDEF_LIST
1800 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
1803 #define GET_TYPEDEF_LIST
1804 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
1808 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1812 #define GET_OP_CLASSES
1813 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1815 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"