1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "Detail/DimLvlMapParser.h"
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Utils/StaticValueUtils.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
28 #define GET_ATTRDEF_CLASSES
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
36 using namespace mlir::sparse_tensor
;
38 //===----------------------------------------------------------------------===//
39 // Local convenience methods.
40 //===----------------------------------------------------------------------===//
42 static constexpr bool acceptBitWidth(unsigned bitWidth
) {
55 //===----------------------------------------------------------------------===//
56 // SparseTensorDialect StorageLayout.
57 //===----------------------------------------------------------------------===//
59 static constexpr Level kInvalidLevel
= -1u;
60 static constexpr Level kInvalidFieldIndex
= -1u;
61 static constexpr FieldIndex kDataFieldStartingIdx
= 0;
63 void StorageLayout::foreachField(
64 llvm::function_ref
<bool(FieldIndex
, SparseTensorFieldKind
, Level
,
67 const auto lvlTypes
= enc
.getLvlTypes();
68 const Level lvlRank
= enc
.getLvlRank();
69 const Level cooStart
= getCOOStart(enc
);
70 const Level end
= cooStart
== lvlRank
? cooStart
: cooStart
+ 1;
71 FieldIndex fieldIdx
= kDataFieldStartingIdx
;
73 for (Level l
= 0; l
< end
; l
++) {
74 const auto dlt
= lvlTypes
[l
];
75 if (isDLTWithPos(dlt
)) {
76 if (!(callback(fieldIdx
++, SparseTensorFieldKind::PosMemRef
, l
, dlt
)))
79 if (isDLTWithCrd(dlt
)) {
80 if (!(callback(fieldIdx
++, SparseTensorFieldKind::CrdMemRef
, l
, dlt
)))
85 if (!(callback(fieldIdx
++, SparseTensorFieldKind::ValMemRef
, kInvalidLevel
,
86 DimLevelType::Undef
)))
88 // Put metadata at the end.
89 if (!(callback(fieldIdx
++, SparseTensorFieldKind::StorageSpec
, kInvalidLevel
,
90 DimLevelType::Undef
)))
94 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
96 llvm::function_ref
<bool(Type
, FieldIndex
, SparseTensorFieldKind
, Level
,
99 assert(stt
.hasEncoding());
100 // Construct the basic types.
101 const Type crdType
= stt
.getCrdType();
102 const Type posType
= stt
.getPosType();
103 const Type eltType
= stt
.getElementType();
105 const Type specType
= StorageSpecifierType::get(stt
.getEncoding());
106 // memref<? x pos> positions
107 const Type posMemType
= MemRefType::get({ShapedType::kDynamic
}, posType
);
108 // memref<? x crd> coordinates
109 const Type crdMemType
= MemRefType::get({ShapedType::kDynamic
}, crdType
);
110 // memref<? x eltType> values
111 const Type valMemType
= MemRefType::get({ShapedType::kDynamic
}, eltType
);
113 StorageLayout(stt
).foreachField(
114 [specType
, posMemType
, crdMemType
, valMemType
,
115 callback
](FieldIndex fieldIdx
, SparseTensorFieldKind fieldKind
,
116 Level lvl
, DimLevelType dlt
) -> bool {
118 case SparseTensorFieldKind::StorageSpec
:
119 return callback(specType
, fieldIdx
, fieldKind
, lvl
, dlt
);
120 case SparseTensorFieldKind::PosMemRef
:
121 return callback(posMemType
, fieldIdx
, fieldKind
, lvl
, dlt
);
122 case SparseTensorFieldKind::CrdMemRef
:
123 return callback(crdMemType
, fieldIdx
, fieldKind
, lvl
, dlt
);
124 case SparseTensorFieldKind::ValMemRef
:
125 return callback(valMemType
, fieldIdx
, fieldKind
, lvl
, dlt
);
127 llvm_unreachable("unrecognized field kind");
131 unsigned StorageLayout::getNumFields() const {
132 unsigned numFields
= 0;
133 foreachField([&numFields
](FieldIndex
, SparseTensorFieldKind
, Level
,
134 DimLevelType
) -> bool {
141 unsigned StorageLayout::getNumDataFields() const {
142 unsigned numFields
= 0; // one value memref
143 foreachField([&numFields
](FieldIndex fidx
, SparseTensorFieldKind
, Level
,
144 DimLevelType
) -> bool {
145 if (fidx
>= kDataFieldStartingIdx
)
149 numFields
-= 1; // the last field is StorageSpecifier
150 assert(numFields
== getNumFields() - kDataFieldStartingIdx
- 1);
154 std::pair
<FieldIndex
, unsigned>
155 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind
,
156 std::optional
<Level
> lvl
) const {
157 FieldIndex fieldIdx
= kInvalidFieldIndex
;
159 if (kind
== SparseTensorFieldKind::CrdMemRef
) {
160 assert(lvl
.has_value());
161 const Level cooStart
= getCOOStart(enc
);
162 const Level lvlRank
= enc
.getLvlRank();
163 if (lvl
.value() >= cooStart
&& lvl
.value() < lvlRank
) {
165 stride
= lvlRank
- cooStart
;
168 foreachField([lvl
, kind
, &fieldIdx
](FieldIndex fIdx
,
169 SparseTensorFieldKind fKind
, Level fLvl
,
170 DimLevelType dlt
) -> bool {
171 if ((lvl
&& fLvl
== lvl
.value() && kind
== fKind
) ||
172 (kind
== fKind
&& fKind
== SparseTensorFieldKind::ValMemRef
)) {
174 // Returns false to break the iteration.
179 assert(fieldIdx
!= kInvalidFieldIndex
);
180 return std::pair
<FieldIndex
, unsigned>(fieldIdx
, stride
);
183 //===----------------------------------------------------------------------===//
184 // SparseTensorDialect Attribute Methods.
185 //===----------------------------------------------------------------------===//
187 std::optional
<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v
) {
188 return isDynamic(v
) ? std::nullopt
189 : std::make_optional(static_cast<uint64_t>(v
));
192 std::optional
<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
193 return getStatic(getOffset());
196 std::optional
<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
197 return getStatic(getStride());
200 std::optional
<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
201 return getStatic(getSize());
204 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
205 return isDynamic(getOffset()) && isDynamic(getStride()) &&
206 isDynamic(getSize());
209 std::string
SparseTensorDimSliceAttr::getStaticString(int64_t v
) {
210 return isDynamic(v
) ? "?" : std::to_string(v
);
213 void SparseTensorDimSliceAttr::print(llvm::raw_ostream
&os
) const {
214 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
216 os
<< getStaticString(getOffset());
218 os
<< getStaticString(getSize());
220 os
<< getStaticString(getStride());
224 void SparseTensorDimSliceAttr::print(AsmPrinter
&printer
) const {
225 print(printer
.getStream());
228 static ParseResult
parseOptionalStaticSlice(int64_t &result
,
230 auto parseResult
= parser
.parseOptionalInteger(result
);
231 if (parseResult
.has_value()) {
232 if (parseResult
.value().succeeded() && result
< 0) {
234 parser
.getCurrentLocation(),
235 "expect positive value or ? for slice offset/size/stride");
238 return parseResult
.value();
241 // Else, and '?' which represented dynamic slice
242 result
= SparseTensorDimSliceAttr::kDynamic
;
243 return parser
.parseQuestion();
246 Attribute
SparseTensorDimSliceAttr::parse(AsmParser
&parser
, Type type
) {
247 int64_t offset
= kDynamic
, size
= kDynamic
, stride
= kDynamic
;
249 if (failed(parser
.parseLParen()) ||
250 failed(parseOptionalStaticSlice(offset
, parser
)) ||
251 failed(parser
.parseComma()) ||
252 failed(parseOptionalStaticSlice(size
, parser
)) ||
253 failed(parser
.parseComma()) ||
254 failed(parseOptionalStaticSlice(stride
, parser
)) ||
255 failed(parser
.parseRParen()))
258 return parser
.getChecked
<SparseTensorDimSliceAttr
>(parser
.getContext(),
259 offset
, size
, stride
);
263 SparseTensorDimSliceAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
264 int64_t offset
, int64_t size
, int64_t stride
) {
265 if (!isDynamic(offset
) && offset
< 0)
266 return emitError() << "expect non-negative value or ? for slice offset";
267 if (!isDynamic(size
) && size
<= 0)
268 return emitError() << "expect positive value or ? for slice size";
269 if (!isDynamic(stride
) && stride
<= 0)
270 return emitError() << "expect positive value or ? for slice stride";
274 SparseTensorEncodingAttr
275 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl
) const {
276 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
277 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl
,
278 AffineMap(), getPosWidth(),
282 SparseTensorEncodingAttr
283 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc
) const {
284 return withDimToLvl(enc
? enc
.getDimToLvl() : AffineMap());
287 SparseTensorEncodingAttr
SparseTensorEncodingAttr::withoutDimToLvl() const {
288 return withDimToLvl(AffineMap());
291 SparseTensorEncodingAttr
292 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth
,
293 unsigned crdWidth
) const {
294 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
295 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
296 getDimToLvl(), getLvlToDim(), posWidth
,
300 SparseTensorEncodingAttr
SparseTensorEncodingAttr::withoutBitWidths() const {
301 return withBitWidths(0, 0);
304 SparseTensorEncodingAttr
SparseTensorEncodingAttr::withDimSlices(
305 ArrayRef
<SparseTensorDimSliceAttr
> dimSlices
) const {
306 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
307 getDimToLvl(), getLvlToDim(),
308 getPosWidth(), getCrdWidth(), dimSlices
);
311 SparseTensorEncodingAttr
SparseTensorEncodingAttr::withoutDimSlices() const {
312 return withDimSlices(ArrayRef
<SparseTensorDimSliceAttr
>{});
315 bool SparseTensorEncodingAttr::isAllDense() const {
316 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT
);
319 bool SparseTensorEncodingAttr::isCOO() const {
320 return getImpl() && isCOOType(*this, 0, true);
323 bool SparseTensorEncodingAttr::isAllOrdered() const {
324 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedDLT
);
327 bool SparseTensorEncodingAttr::isIdentity() const {
328 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
331 bool SparseTensorEncodingAttr::isPermutation() const {
332 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
335 Dimension
SparseTensorEncodingAttr::getDimRank() const {
336 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
337 const auto dimToLvl
= getDimToLvl();
338 return dimToLvl
? dimToLvl
.getNumDims() : getLvlRank();
341 Level
SparseTensorEncodingAttr::getLvlRank() const {
342 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
343 return getLvlTypes().size();
346 DimLevelType
SparseTensorEncodingAttr::getLvlType(Level l
) const {
348 return DimLevelType::Dense
;
349 assert(l
< getLvlRank() && "Level is out of bounds");
350 return getLvlTypes()[l
];
353 bool SparseTensorEncodingAttr::isSlice() const {
354 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
355 return !getDimSlices().empty();
358 SparseTensorDimSliceAttr
359 SparseTensorEncodingAttr::getDimSlice(Dimension dim
) const {
360 assert(isSlice() && "Is not a slice");
361 const auto dimSlices
= getDimSlices();
362 assert(dim
< dimSlices
.size() && "Dimension is out of bounds");
363 return dimSlices
[dim
];
366 std::optional
<uint64_t>
367 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim
) const {
368 return getDimSlice(dim
).getStaticOffset();
371 std::optional
<uint64_t>
372 SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim
) const {
373 return getDimSlice(dim
).getStaticSize();
376 std::optional
<uint64_t>
377 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim
) const {
378 return getDimSlice(dim
).getStaticStride();
381 std::optional
<uint64_t>
382 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl
) const {
383 // FIXME: `toOrigDim` is deprecated.
384 return getStaticDimSliceOffset(toOrigDim(*this, lvl
));
387 std::optional
<uint64_t>
388 SparseTensorEncodingAttr::getStaticLvlSliceSize(Level lvl
) const {
389 // FIXME: `toOrigDim` is deprecated.
390 return getStaticDimSliceSize(toOrigDim(*this, lvl
));
393 std::optional
<uint64_t>
394 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl
) const {
395 // FIXME: `toOrigDim` is deprecated.
396 return getStaticDimSliceStride(toOrigDim(*this, lvl
));
400 SparseTensorEncodingAttr::tranlateShape(ArrayRef
<int64_t> srcShape
,
401 CrdTransDirectionKind dir
) const {
403 return SmallVector
<int64_t>(srcShape
);
405 SmallVector
<int64_t> ret
;
407 dir
== CrdTransDirectionKind::dim2lvl
? getLvlRank() : getDimRank();
410 if (isPermutation()) {
411 for (unsigned r
= 0; r
< rank
; r
++) {
412 unsigned trans
= dir
== CrdTransDirectionKind::dim2lvl
413 ? toOrigDim(*this, r
)
414 : toStoredDim(*this, r
);
415 ret
.push_back(srcShape
[trans
]);
420 // Handle non-permutation maps.
422 dir
== CrdTransDirectionKind::dim2lvl
? getDimToLvl() : getLvlToDim();
424 SmallVector
<AffineExpr
> dimRep
;
425 dimRep
.reserve(srcShape
.size());
426 for (int64_t sz
: srcShape
) {
427 if (!ShapedType::isDynamic(sz
)) {
428 // Push back the max coordinate for the given dimension/level size.
429 dimRep
.push_back(getAffineConstantExpr(sz
- 1, getContext()));
431 // A dynamic size, use a AffineDimExpr to symbolize the value.
432 dimRep
.push_back(getAffineDimExpr(dimRep
.size(), getContext()));
436 for (AffineExpr exp
: transMap
.getResults()) {
437 // Do constant propagation on the affine map.
439 simplifyAffineExpr(exp
.replaceDims(dimRep
), srcShape
.size(), 0);
440 if (auto c
= evalExp
.dyn_cast
<AffineConstantExpr
>()) {
441 ret
.push_back(c
.getValue() + 1);
443 if (auto mod
= evalExp
.dyn_cast
<AffineBinaryOpExpr
>();
444 mod
&& mod
.getKind() == AffineExprKind::Mod
) {
445 // We can still infer a static bound for expressions in form
446 // "d % constant" since d % constant \in [0, constant).
447 if (auto bound
= mod
.getRHS().dyn_cast
<AffineConstantExpr
>()) {
448 ret
.push_back(bound
.getValue());
452 ret
.push_back(ShapedType::kDynamic
);
455 assert(ret
.size() == rank
);
460 SparseTensorEncodingAttr::translateCrds(OpBuilder
&builder
, Location loc
,
462 CrdTransDirectionKind dir
) const {
466 SmallVector
<Type
> retType(
467 dir
== CrdTransDirectionKind::lvl2dim
? getDimRank() : getLvlRank(),
468 builder
.getIndexType());
469 auto transOp
= builder
.create
<CrdTranslateOp
>(loc
, retType
, crds
, dir
, *this);
470 return transOp
.getOutCrds();
473 Attribute
SparseTensorEncodingAttr::parse(AsmParser
&parser
, Type type
) {
475 if (failed(parser
.parseLess()))
477 if (failed(parser
.parseLBrace()))
480 // Process the data from the parsed dictionary value into struct-like data.
481 SmallVector
<DimLevelType
> lvlTypes
;
482 SmallVector
<SparseTensorDimSliceAttr
> dimSlices
;
483 AffineMap dimToLvl
= {};
484 AffineMap lvlToDim
= {};
485 unsigned posWidth
= 0;
486 unsigned crdWidth
= 0;
488 SmallVector
<StringRef
, 3> keys
= {"map", "posWidth", "crdWidth"};
489 while (succeeded(parser
.parseOptionalKeyword(&attrName
))) {
490 // Detect admissible keyword.
491 auto *it
= find(keys
, attrName
);
492 if (it
== keys
.end()) {
493 parser
.emitError(parser
.getNameLoc(), "unexpected key: ") << attrName
;
496 unsigned keyWordIndex
= it
- keys
.begin();
497 // Consume the `=` after keys
498 if (failed(parser
.parseEqual()))
500 // Dispatch on keyword.
501 switch (keyWordIndex
) {
503 ir_detail::DimLvlMapParser
cParser(parser
);
504 auto res
= cParser
.parseDimLvlMap();
507 const auto &dlm
= *res
;
509 const Level lvlRank
= dlm
.getLvlRank();
510 for (Level lvl
= 0; lvl
< lvlRank
; lvl
++)
511 lvlTypes
.push_back(dlm
.getLvlType(lvl
));
513 const Dimension dimRank
= dlm
.getDimRank();
514 for (Dimension dim
= 0; dim
< dimRank
; dim
++)
515 dimSlices
.push_back(dlm
.getDimSlice(dim
));
516 // NOTE: the old syntax requires an all-or-nothing approach to
517 // `dimSlices`; therefore, if any slice actually exists then we need
518 // to convert null-DSA into default/nop DSA.
519 const auto isDefined
= [](SparseTensorDimSliceAttr slice
) {
520 return static_cast<bool>(slice
.getImpl());
522 if (llvm::any_of(dimSlices
, isDefined
)) {
523 const auto defaultSlice
=
524 SparseTensorDimSliceAttr::get(parser
.getContext());
525 for (Dimension dim
= 0; dim
< dimRank
; dim
++)
526 if (!isDefined(dimSlices
[dim
]))
527 dimSlices
[dim
] = defaultSlice
;
532 dimToLvl
= dlm
.getDimToLvlMap(parser
.getContext());
533 lvlToDim
= dlm
.getLvlToDimMap(parser
.getContext());
536 case 1: { // posWidth
538 if (failed(parser
.parseAttribute(attr
)))
540 auto intAttr
= llvm::dyn_cast
<IntegerAttr
>(attr
);
542 parser
.emitError(parser
.getNameLoc(),
543 "expected an integral position bitwidth");
546 posWidth
= intAttr
.getInt();
549 case 2: { // crdWidth
551 if (failed(parser
.parseAttribute(attr
)))
553 auto intAttr
= llvm::dyn_cast
<IntegerAttr
>(attr
);
555 parser
.emitError(parser
.getNameLoc(),
556 "expected an integral index bitwidth");
559 crdWidth
= intAttr
.getInt();
563 // Only last item can omit the comma.
564 if (parser
.parseOptionalComma().failed())
569 if (failed(parser
.parseRBrace()))
571 if (failed(parser
.parseGreater()))
574 // Construct struct-like storage for attribute.
575 if (!lvlToDim
|| lvlToDim
.isEmpty()) {
576 lvlToDim
= inferLvlToDim(dimToLvl
, parser
.getContext());
578 return parser
.getChecked
<SparseTensorEncodingAttr
>(
579 parser
.getContext(), lvlTypes
, dimToLvl
, lvlToDim
, posWidth
, crdWidth
,
583 void SparseTensorEncodingAttr::print(AsmPrinter
&printer
) const {
584 auto map
= static_cast<AffineMap
>(getDimToLvl());
585 // Empty affine map indicates identity map
587 map
= AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
588 printer
<< "<{ map = ";
589 printSymbols(map
, printer
);
591 printDimensions(map
, printer
, getDimSlices());
593 printLevels(map
, printer
, getLvlTypes());
595 // Print remaining members only for non-default values.
597 printer
<< ", posWidth = " << getPosWidth();
599 printer
<< ", crdWidth = " << getCrdWidth();
603 void SparseTensorEncodingAttr::printSymbols(AffineMap
&map
,
604 AsmPrinter
&printer
) const {
605 if (map
.getNumSymbols() == 0)
608 for (unsigned i
= 0, n
= map
.getNumSymbols() - 1; i
< n
; i
++)
609 printer
<< 's' << i
<< ", ";
610 if (map
.getNumSymbols() >= 1)
611 printer
<< 's' << map
.getNumSymbols() - 1;
615 void SparseTensorEncodingAttr::printDimensions(
616 AffineMap
&map
, AsmPrinter
&printer
,
617 ArrayRef
<SparseTensorDimSliceAttr
> dimSlices
) const {
618 if (!dimSlices
.empty()) {
619 for (unsigned i
= 0, n
= map
.getNumDims() - 1; i
< n
; i
++)
620 printer
<< 'd' << i
<< " : " << dimSlices
[i
] << ", ";
621 if (map
.getNumDims() >= 1) {
622 printer
<< 'd' << map
.getNumDims() - 1 << " : "
623 << dimSlices
[map
.getNumDims() - 1];
626 for (unsigned i
= 0, n
= map
.getNumDims() - 1; i
< n
; i
++)
627 printer
<< 'd' << i
<< ", ";
628 if (map
.getNumDims() >= 1)
629 printer
<< 'd' << map
.getNumDims() - 1;
633 void SparseTensorEncodingAttr::printLevels(
634 AffineMap
&map
, AsmPrinter
&printer
,
635 ArrayRef
<DimLevelType
> lvlTypes
) const {
636 for (unsigned i
= 0, n
= map
.getNumResults() - 1; i
< n
; i
++) {
637 map
.getResult(i
).print(printer
.getStream());
638 printer
<< " : " << toMLIRString(lvlTypes
[i
]) << ", ";
640 if (map
.getNumResults() >= 1) {
641 auto lastIndex
= map
.getNumResults() - 1;
642 map
.getResult(lastIndex
).print(printer
.getStream());
643 printer
<< " : " << toMLIRString(lvlTypes
[lastIndex
]);
648 SparseTensorEncodingAttr::verify(function_ref
<InFlightDiagnostic()> emitError
,
649 ArrayRef
<DimLevelType
> lvlTypes
,
650 AffineMap dimToLvl
, AffineMap lvlToDim
,
651 unsigned posWidth
, unsigned crdWidth
,
652 ArrayRef
<SparseTensorDimSliceAttr
> dimSlices
) {
653 if (!acceptBitWidth(posWidth
))
654 return emitError() << "unexpected position bitwidth: " << posWidth
;
655 if (!acceptBitWidth(crdWidth
))
656 return emitError() << "unexpected coordinate bitwidth: " << crdWidth
;
657 if (auto it
= std::find_if(lvlTypes
.begin(), lvlTypes
.end(), isSingletonDLT
);
658 it
!= std::end(lvlTypes
)) {
659 if (it
== lvlTypes
.begin() ||
660 (!isCompressedDLT(*(it
- 1)) && !isLooseCompressedDLT(*(it
- 1))))
661 return emitError() << "expected compressed or loose_compressed level "
662 "before singleton level";
663 if (!std::all_of(it
, lvlTypes
.end(),
664 [](DimLevelType i
) { return isSingletonDLT(i
); }))
665 return emitError() << "expected all singleton lvlTypes "
666 "following a singleton level";
668 // Before we can check that the level-rank is consistent/coherent
669 // across all fields, we need to define it. The source-of-truth for
670 // the `getLvlRank` method is the length of the level-types array,
671 // since it must always be provided and have full rank; therefore we
672 // use that same source-of-truth here.
673 const Level lvlRank
= lvlTypes
.size();
675 return emitError() << "expected a non-empty array for lvlTypes";
676 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
677 const Dimension dimRank
= dimToLvl
? dimToLvl
.getNumDims() : lvlRank
;
679 if (dimToLvl
.getNumResults() != lvlRank
)
681 << "level-rank mismatch between dimToLvl and lvlTypes: "
682 << dimToLvl
.getNumResults() << " != " << lvlRank
;
683 auto inferRes
= inferLvlToDim(dimToLvl
, dimToLvl
.getContext());
684 // Symbols can't be inferred but are acceptable.
685 if (!inferRes
&& dimToLvl
.getNumSymbols() == 0)
686 return emitError() << "failed to infer lvlToDim from dimToLvl";
687 if (lvlToDim
&& (inferRes
!= lvlToDim
))
688 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
689 if (dimRank
> lvlRank
)
690 return emitError() << "unexpected dimToLvl mapping from " << dimRank
691 << " to " << lvlRank
;
693 if (!dimSlices
.empty()) {
694 if (dimSlices
.size() != dimRank
)
696 << "dimension-rank mismatch between dimSlices and dimToLvl: "
697 << dimSlices
.size() << " != " << dimRank
;
698 // Compiler support for `dimSlices` currently requires that the two
699 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
700 if (dimRank
!= lvlRank
)
702 << "dimSlices expected dimension-rank to match level-rank: "
703 << dimRank
<< " != " << lvlRank
;
708 LogicalResult
SparseTensorEncodingAttr::verifyEncoding(
709 ArrayRef
<Size
> dimShape
, Type elementType
,
710 function_ref
<InFlightDiagnostic()> emitError
) const {
711 // Check structural integrity. In particular, this ensures that the
712 // level-rank is coherent across all the fields.
713 if (failed(verify(emitError
, getLvlTypes(), getDimToLvl(), getLvlToDim(),
714 getPosWidth(), getCrdWidth(), getDimSlices())))
716 // Check integrity with tensor type specifics. In particular, we
717 // need only check that the dimension-rank of the tensor agrees with
718 // the dimension-rank of the encoding.
719 const Dimension dimRank
= dimShape
.size();
721 return emitError() << "expected non-scalar sparse tensor";
722 if (getDimRank() != dimRank
)
724 << "dimension-rank mismatch between encoding and tensor shape: "
725 << getDimRank() << " != " << dimRank
;
729 //===----------------------------------------------------------------------===//
730 // Convenience methods.
731 //===----------------------------------------------------------------------===//
733 SparseTensorEncodingAttr
734 mlir::sparse_tensor::getSparseTensorEncoding(Type type
) {
735 if (auto ttp
= llvm::dyn_cast
<RankedTensorType
>(type
))
736 return llvm::dyn_cast_or_null
<SparseTensorEncodingAttr
>(ttp
.getEncoding());
737 if (auto mdtp
= llvm::dyn_cast
<StorageSpecifierType
>(type
))
738 return mdtp
.getEncoding();
742 AffineMap
mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl
,
743 MLIRContext
*context
) {
744 auto map
= static_cast<AffineMap
>(dimToLvl
);
746 // Return an empty lvlToDim when inference is not successful.
747 if (!map
|| map
.getNumSymbols() != 0) {
748 lvlToDim
= AffineMap();
749 } else if (map
.isPermutation()) {
750 lvlToDim
= inversePermutation(map
);
751 } else if (isBlockSparsity(map
)) {
752 lvlToDim
= inverseBlockSparsity(map
, context
);
757 AffineMap
mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl
,
758 MLIRContext
*context
) {
759 SmallVector
<AffineExpr
> lvlExprs
;
760 auto numLvls
= dimToLvl
.getNumResults();
761 lvlExprs
.reserve(numLvls
);
762 // lvlExprComponents stores information of the floordiv and mod operations
763 // applied to the same dimension, so as to build the lvlToDim map.
764 std::map
<unsigned, SmallVector
<AffineExpr
, 3>> lvlExprComponents
;
765 for (unsigned i
= 0, n
= numLvls
; i
< n
; i
++) {
766 auto result
= dimToLvl
.getResult(i
);
767 if (auto binOp
= result
.dyn_cast
<AffineBinaryOpExpr
>()) {
768 if (result
.getKind() == AffineExprKind::FloorDiv
) {
769 // Position of the dimension in dimToLvl.
770 auto pos
= binOp
.getLHS().dyn_cast
<AffineDimExpr
>().getPosition();
771 assert(lvlExprComponents
.find(pos
) == lvlExprComponents
.end() &&
772 "expected only one floordiv for each dimension");
773 SmallVector
<AffineExpr
, 3> components
;
774 // Level variable for floordiv.
775 components
.push_back(getAffineDimExpr(i
, context
));
777 components
.push_back(binOp
.getRHS());
778 // Map key is the position of the dimension.
779 lvlExprComponents
[pos
] = components
;
780 } else if (result
.getKind() == AffineExprKind::Mod
) {
781 auto pos
= binOp
.getLHS().dyn_cast
<AffineDimExpr
>().getPosition();
782 assert(lvlExprComponents
.find(pos
) != lvlExprComponents
.end() &&
783 "expected floordiv before mod");
784 // Add level variable for mod to the same vector
785 // of the corresponding floordiv.
786 lvlExprComponents
[pos
].push_back(getAffineDimExpr(i
, context
));
788 assert(false && "expected floordiv or mod");
791 lvlExprs
.push_back(getAffineDimExpr(i
, context
));
794 // Build lvlExprs from lvlExprComponents.
795 // For example, for il = i floordiv 2 and ii = i mod 2, the components
796 // would be [il, 2, ii]. It could be used to build the AffineExpr
797 // i = il * 2 + ii in lvlToDim.
798 for (auto &components
: lvlExprComponents
) {
799 assert(components
.second
.size() == 3 &&
800 "expected 3 components to build lvlExprs");
801 auto mulOp
= getAffineBinaryOpExpr(
802 AffineExprKind::Mul
, components
.second
[0], components
.second
[1]);
804 getAffineBinaryOpExpr(AffineExprKind::Add
, mulOp
, components
.second
[2]);
805 lvlExprs
.push_back(addOp
);
807 return dimToLvl
.get(dimToLvl
.getNumResults(), 0, lvlExprs
, context
);
810 SmallVector
<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl
) {
811 assert(isBlockSparsity(dimToLvl
) &&
812 "expected dimToLvl to be block sparsity for calling getBlockSize");
813 SmallVector
<unsigned> blockSize
;
814 for (auto result
: dimToLvl
.getResults()) {
815 if (auto binOp
= result
.dyn_cast
<AffineBinaryOpExpr
>()) {
816 if (result
.getKind() == AffineExprKind::Mod
) {
818 binOp
.getRHS().dyn_cast
<AffineConstantExpr
>().getValue());
821 blockSize
.push_back(0);
827 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl
) {
830 std::map
<unsigned, int64_t> coeffientMap
;
831 for (auto result
: dimToLvl
.getResults()) {
832 if (auto binOp
= result
.dyn_cast
<AffineBinaryOpExpr
>()) {
833 auto pos
= binOp
.getLHS().dyn_cast
<AffineDimExpr
>().getPosition();
834 if (result
.getKind() == AffineExprKind::FloorDiv
) {
835 // Expect only one floordiv for each dimension.
836 if (coeffientMap
.find(pos
) != coeffientMap
.end())
839 binOp
.getRHS().dyn_cast
<AffineConstantExpr
>().getValue();
840 } else if (result
.getKind() == AffineExprKind::Mod
) {
841 // Expect floordiv before mod.
842 if (coeffientMap
.find(pos
) == coeffientMap
.end())
844 // Expect mod to have the same coefficient as floordiv.
845 if (binOp
.getRHS().dyn_cast
<AffineConstantExpr
>().getValue() !=
854 return !coeffientMap
.empty();
857 bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc
,
858 Level startLvl
, bool isUnique
) {
860 !(enc
.isCompressedLvl(startLvl
) || enc
.isLooseCompressedLvl(startLvl
)))
862 const Level lvlRank
= enc
.getLvlRank();
863 for (Level l
= startLvl
+ 1; l
< lvlRank
; ++l
)
864 if (!enc
.isSingletonLvl(l
))
866 // If isUnique is true, then make sure that the last level is unique,
867 // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
868 // (unique on the last singleton).
869 return !isUnique
|| enc
.isUniqueLvl(lvlRank
- 1);
872 bool mlir::sparse_tensor::isUniqueCOOType(Type tp
) {
873 return isCOOType(getSparseTensorEncoding(tp
), 0, /*isUnique=*/true);
876 Level
mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc
) {
877 // We only consider COO region with at least two levels for the purpose
878 // of AOS storage optimization.
879 const Level lvlRank
= enc
.getLvlRank();
881 for (Level l
= 0; l
< lvlRank
- 1; l
++)
882 if (isCOOType(enc
, l
, /*isUnique=*/false))
887 // Helpers to setup a COO type.
888 RankedTensorType
sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt
,
891 const SparseTensorType
src(rtt
);
892 const Level lvlRank
= src
.getLvlRank();
893 SmallVector
<DimLevelType
> lvlTypes
;
894 lvlTypes
.reserve(lvlRank
);
896 // An unordered and non-unique compressed level at beginning.
897 // If this is also the last level, then it is unique.
899 *buildLevelType(LevelFormat::Compressed
, ordered
, lvlRank
== 1));
901 // TODO: it is actually ordered at the level for ordered input.
902 // Followed by unordered non-unique n-2 singleton levels.
903 std::fill_n(std::back_inserter(lvlTypes
), lvlRank
- 2,
904 *buildLevelType(LevelFormat::Singleton
, ordered
, false));
905 // Ends by a unique singleton level unless the lvlRank is 1.
906 lvlTypes
.push_back(*buildLevelType(LevelFormat::Singleton
, ordered
, true));
909 // TODO: Maybe pick the bitwidth based on input/output tensors (probably the
910 // largest one among them) in the original operation instead of using the
912 unsigned posWidth
= src
.getPosWidth();
913 unsigned crdWidth
= src
.getCrdWidth();
914 AffineMap invPerm
= src
.getLvlToDim();
915 auto enc
= SparseTensorEncodingAttr::get(src
.getContext(), lvlTypes
, lvlPerm
,
916 invPerm
, posWidth
, crdWidth
);
917 return RankedTensorType::get(src
.getDimShape(), src
.getElementType(), enc
);
920 RankedTensorType
sparse_tensor::getCOOFromType(RankedTensorType src
,
922 return getCOOFromTypeWithOrdering(
923 src
, AffineMap::getMultiDimIdentityMap(src
.getRank(), src
.getContext()),
927 // TODO: Remove this definition once all use-sites have been fixed to
928 // properly handle non-permutations.
929 Dimension
mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc
,
932 if (const auto dimToLvl
= enc
.getDimToLvl()) {
933 assert(enc
.isPermutation());
934 return dimToLvl
.getDimPosition(l
);
940 // TODO: Remove this definition once all use-sites have been fixed to
941 // properly handle non-permutations.
942 Level
mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc
,
945 if (const auto dimToLvl
= enc
.getDimToLvl()) {
946 assert(enc
.isPermutation());
948 dimToLvl
.getResultPosition(getAffineDimExpr(d
, enc
.getContext()));
949 assert(maybePos
.has_value());
956 // TODO: Remove this definition once all use-sites have been fixed to
957 // properly handle non-permutations.
958 Dimension
mlir::sparse_tensor::toOrigDim(RankedTensorType type
, Level l
) {
959 const auto enc
= getSparseTensorEncoding(type
);
960 assert(!enc
|| l
< enc
.getLvlRank());
961 return toOrigDim(enc
, l
);
964 // TODO: Remove this definition once all use-sites have been fixed to
965 // properly handle non-permutations.
966 Level
mlir::sparse_tensor::toStoredDim(RankedTensorType type
, Dimension d
) {
967 assert(d
< static_cast<Dimension
>(type
.getRank()));
968 return toStoredDim(getSparseTensorEncoding(type
), d
);
971 /// We normalized sparse tensor encoding attribute by always using
972 /// ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
973 /// as other variants) lead to the same storage specifier type, and stripping
974 /// irrelevant fields that do not alter the sparse tensor memory layout.
975 static SparseTensorEncodingAttr
976 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc
) {
977 SmallVector
<DimLevelType
> dlts
;
978 for (auto dlt
: enc
.getLvlTypes())
979 dlts
.push_back(*buildLevelType(*getLevelFormat(dlt
), true, true));
981 return SparseTensorEncodingAttr::get(
982 enc
.getContext(), dlts
,
983 AffineMap(), // dimToLvl (irrelevant to storage specifier)
984 AffineMap(), // lvlToDim (irrelevant to storage specifier)
985 // Always use `index` for memSize and lvlSize instead of reusing
986 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
987 // value for different bitwidth, it also avoids casting between index and
988 // integer (returned by DimOp)
989 0, 0, enc
.getDimSlices());
993 StorageSpecifierType::get(MLIRContext
*ctx
, SparseTensorEncodingAttr encoding
) {
994 return Base::get(ctx
, getNormalizedEncodingForSpecifier(encoding
));
997 //===----------------------------------------------------------------------===//
998 // SparseTensorDialect Operations.
999 //===----------------------------------------------------------------------===//
1001 static LogicalResult
lvlIsInBounds(Level lvl
, Value tensor
) {
1002 return success(lvl
< getSparseTensorType(tensor
).getLvlRank());
1005 static LogicalResult
isMatchingWidth(Value mem
, unsigned width
) {
1006 const Type etp
= getMemRefType(mem
).getElementType();
1007 return success(width
== 0 ? etp
.isIndex() : etp
.isInteger(width
));
1010 static LogicalResult
verifySparsifierGetterSetter(
1011 StorageSpecifierKind mdKind
, std::optional
<Level
> lvl
,
1012 TypedValue
<StorageSpecifierType
> md
, Operation
*op
) {
1013 if (mdKind
== StorageSpecifierKind::ValMemSize
&& lvl
) {
1014 return op
->emitError(
1015 "redundant level argument for querying value memory size");
1018 const auto enc
= md
.getType().getEncoding();
1019 const Level lvlRank
= enc
.getLvlRank();
1021 if (mdKind
== StorageSpecifierKind::DimOffset
||
1022 mdKind
== StorageSpecifierKind::DimStride
)
1024 return op
->emitError("requested slice data on non-slice tensor");
1026 if (mdKind
!= StorageSpecifierKind::ValMemSize
) {
1028 return op
->emitError("missing level argument");
1030 const Level l
= lvl
.value();
1032 return op
->emitError("requested level is out of bounds");
1034 if (mdKind
== StorageSpecifierKind::PosMemSize
&& enc
.isSingletonLvl(l
))
1035 return op
->emitError(
1036 "requested position memory size on a singleton level");
1041 static Type
getFieldElemType(SparseTensorType stt
, SparseTensorFieldKind kind
) {
1043 case SparseTensorFieldKind::CrdMemRef
:
1044 return stt
.getCrdType();
1045 case SparseTensorFieldKind::PosMemRef
:
1046 return stt
.getPosType();
1047 case SparseTensorFieldKind::ValMemRef
:
1048 return stt
.getElementType();
1049 case SparseTensorFieldKind::StorageSpec
:
1052 llvm_unreachable("Unrecognizable FieldKind");
1055 static LogicalResult
verifyPackUnPack(Operation
*op
, bool requiresStaticShape
,
1056 SparseTensorType stt
,
1057 RankedTensorType valTp
,
1059 if (requiresStaticShape
&& !stt
.hasStaticDimShape())
1060 return op
->emitError("the sparse-tensor must have static shape");
1061 if (!stt
.hasEncoding())
1062 return op
->emitError("the sparse-tensor must have an encoding attribute");
1063 if (!stt
.isIdentity())
1064 return op
->emitError("the sparse-tensor must have the identity mapping");
1066 // Verifies the trailing COO.
1067 Level cooStartLvl
= getCOOStart(stt
.getEncoding());
1068 if (cooStartLvl
< stt
.getLvlRank()) {
1069 // We only supports trailing COO for now, must be the last input.
1070 auto cooTp
= llvm::cast
<ShapedType
>(lvlTps
.back());
1071 // The coordinates should be in shape of <? x rank>
1072 unsigned expCOORank
= stt
.getLvlRank() - cooStartLvl
;
1073 if (cooTp
.getRank() != 2 || expCOORank
!= cooTp
.getShape().back()) {
1074 op
->emitError("input/output trailing COO level-ranks don't match");
1078 // Verifies that all types match.
1079 StorageLayout
layout(stt
.getEncoding());
1080 if (layout
.getNumDataFields() != lvlTps
.size() + 1) // plus one value memref
1081 return op
->emitError("inconsistent number of fields between input/output");
1084 bool misMatch
= false;
1085 layout
.foreachField([&idx
, &misMatch
, stt
, valTp
,
1086 lvlTps
](FieldIndex fid
, SparseTensorFieldKind fKind
,
1087 Level lvl
, DimLevelType dlt
) -> bool {
1088 if (fKind
== SparseTensorFieldKind::StorageSpec
)
1091 Type inputTp
= nullptr;
1092 if (fKind
== SparseTensorFieldKind::ValMemRef
) {
1095 assert(fid
== idx
&& stt
.getLvlType(lvl
) == dlt
);
1096 inputTp
= lvlTps
[idx
++];
1098 // The input element type and expected element type should match.
1099 Type inpElemTp
= llvm::cast
<TensorType
>(inputTp
).getElementType();
1100 Type expElemTp
= getFieldElemType(stt
, fKind
);
1101 if (inpElemTp
!= expElemTp
) {
1103 return false; // to terminate the iteration
1109 return op
->emitError("input/output element-types don't match");
1113 LogicalResult
AssembleOp::verify() {
1114 const auto valuesTp
= getRankedTensorType(getValues());
1115 const auto lvlsTp
= getLevels().getTypes();
1116 const auto resTp
= getSparseTensorType(getResult());
1117 return verifyPackUnPack(*this, true, resTp
, valuesTp
, lvlsTp
);
1120 LogicalResult
DisassembleOp::verify() {
1121 if (getOutValues().getType() != getRetValues().getType())
1122 return emitError("output values and return value type mismatch");
1124 for (auto [ot
, rt
] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1125 if (ot
.getType() != rt
.getType())
1126 return emitError("output levels and return levels type mismatch");
1128 const auto valuesTp
= getRankedTensorType(getRetValues());
1129 const auto lvlsTp
= getRetLevels().getTypes();
1130 const auto srcTp
= getSparseTensorType(getTensor());
1131 return verifyPackUnPack(*this, false, srcTp
, valuesTp
, lvlsTp
);
1134 LogicalResult
ConvertOp::verify() {
1135 if (auto tp1
= llvm::dyn_cast
<RankedTensorType
>(getSource().getType())) {
1136 if (auto tp2
= llvm::dyn_cast
<RankedTensorType
>(getDest().getType())) {
1137 if (tp1
.getRank() != tp2
.getRank())
1138 return emitError("unexpected conversion mismatch in rank");
1140 llvm::dyn_cast_or_null
<SparseTensorEncodingAttr
>(tp2
.getEncoding());
1141 if (dstEnc
&& dstEnc
.isSlice())
1142 return emitError("cannot convert to a sparse tensor slice");
1144 auto shape1
= tp1
.getShape();
1145 auto shape2
= tp2
.getShape();
1146 // Accept size matches between the source and the destination type
1147 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1148 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1149 for (Dimension d
= 0, dimRank
= tp1
.getRank(); d
< dimRank
; d
++)
1150 if (shape1
[d
] != shape2
[d
] && shape2
[d
] != ShapedType::kDynamic
)
1151 return emitError("unexpected conversion mismatch in dimension ") << d
;
1155 return emitError("unexpected type in convert");
1158 OpFoldResult
ConvertOp::fold(FoldAdaptor adaptor
) {
1159 if (getType() == getSource().getType())
1164 bool ConvertOp::needsExtraSort() {
1165 SparseTensorType srcStt
= getSparseTensorType(getSource());
1166 SparseTensorType dstStt
= getSparseTensorType(getDest());
1168 // We do not need an extra sort when returning unordered sparse tensors or
1169 // dense tensor since dense tensor support random access.
1170 if (dstStt
.isAllDense() || !dstStt
.isAllOrdered())
1173 if (srcStt
.isAllOrdered() && dstStt
.isAllOrdered() &&
1174 srcStt
.hasSameDimToLvl(dstStt
)) {
1178 // Source and dest tensors are ordered in different ways. We only do direct
1179 // dense to sparse conversion when the dense input is defined by a sparse
1180 // constant. Note that we can theoretically always directly convert from dense
1181 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1183 if (auto constOp
= getSource().getDefiningOp
<arith::ConstantOp
>())
1184 if (isa
<SparseElementsAttr
>(constOp
.getValue()))
1190 LogicalResult
CrdTranslateOp::verify() {
1191 uint64_t inRank
= getEncoder().getLvlRank();
1192 uint64_t outRank
= getEncoder().getDimRank();
1194 if (getDirection() == CrdTransDirectionKind::dim2lvl
)
1195 std::swap(inRank
, outRank
);
1197 if (inRank
!= getInCrds().size() || outRank
!= getOutCrds().size())
1198 return emitError("Coordinate rank mismatch with encoding");
1203 LogicalResult
CrdTranslateOp::fold(FoldAdaptor adaptor
,
1204 SmallVectorImpl
<OpFoldResult
> &results
) {
1205 if (getEncoder().isIdentity()) {
1206 results
.assign(getInCrds().begin(), getInCrds().end());
1209 if (getEncoder().isPermutation()) {
1210 AffineMap perm
= getDirection() == CrdTransDirectionKind::dim2lvl
1211 ? getEncoder().getDimToLvl()
1212 : getEncoder().getLvlToDim();
1213 for (AffineExpr exp
: perm
.getResults())
1214 results
.push_back(getInCrds()[exp
.cast
<AffineDimExpr
>().getPosition()]);
1218 // Fuse dim2lvl/lvl2dim pairs.
1219 auto def
= getInCrds()[0].getDefiningOp
<CrdTranslateOp
>();
1220 bool sameDef
= def
&& llvm::all_of(getInCrds(), [def
](Value v
) {
1221 return v
.getDefiningOp() == def
;
1226 bool oppositeDir
= def
.getDirection() != getDirection();
1228 def
.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1229 bool sameCount
= def
.getNumResults() == getInCrds().size();
1230 if (!oppositeDir
|| !sameOracle
|| !sameCount
)
1233 // The definition produces the coordinates in the same order as the input
1235 bool sameOrder
= llvm::all_of(llvm::zip_equal(def
.getOutCrds(), getInCrds()),
1236 [](auto valuePair
) {
1237 auto [lhs
, rhs
] = valuePair
;
1243 // l1 = dim2lvl (lvl2dim l0)
1245 results
.append(def
.getInCrds().begin(), def
.getInCrds().end());
1249 void LvlOp::build(OpBuilder
&builder
, OperationState
&state
, Value source
,
1251 Value val
= builder
.create
<arith::ConstantIndexOp
>(state
.location
, index
);
1252 return build(builder
, state
, source
, val
);
1255 LogicalResult
LvlOp::verify() {
1256 if (std::optional
<uint64_t> lvl
= getConstantLvlIndex()) {
1257 auto stt
= getSparseTensorType(getSource());
1258 if (static_cast<uint64_t>(lvl
.value()) >= stt
.getLvlRank())
1259 emitError("Level index exceeds the rank of the input sparse tensor");
1264 std::optional
<uint64_t> LvlOp::getConstantLvlIndex() {
1265 return getConstantIntValue(getIndex());
1268 Speculation::Speculatability
LvlOp::getSpeculatability() {
1269 auto constantIndex
= getConstantLvlIndex();
1271 return Speculation::NotSpeculatable
;
1273 assert(constantIndex
<
1274 cast
<RankedTensorType
>(getSource().getType()).getRank());
1275 return Speculation::Speculatable
;
1278 OpFoldResult
LvlOp::fold(FoldAdaptor adaptor
) {
1279 auto lvlIndex
= llvm::dyn_cast_if_present
<IntegerAttr
>(adaptor
.getIndex());
1283 Level lvl
= lvlIndex
.getAPSInt().getZExtValue();
1284 auto stt
= getSparseTensorType(getSource());
1285 if (lvl
>= stt
.getLvlRank()) {
1286 // Follows the same convention used by tensor.dim operation. Out of bound
1287 // indices produce undefined behavior but are still valid IR. Don't choke on
1292 // Helper lambda to build an IndexAttr.
1293 auto getIndexAttr
= [this](int64_t lvlSz
) {
1294 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz
));
1297 SmallVector
<Size
> lvlShape
= stt
.getLvlShape();
1298 if (!ShapedType::isDynamic(lvlShape
[lvl
]))
1299 return getIndexAttr(lvlShape
[lvl
]);
1304 void ReinterpretMapOp::build(OpBuilder
&odsBuilder
, OperationState
&odsState
,
1305 SparseTensorEncodingAttr dstEnc
, Value source
) {
1306 auto srcStt
= getSparseTensorType(source
);
1307 SmallVector
<int64_t> srcLvlShape
= srcStt
.getLvlShape();
1308 SmallVector
<int64_t> dstDimShape
=
1309 dstEnc
.tranlateShape(srcLvlShape
, CrdTransDirectionKind::lvl2dim
);
1311 RankedTensorType::get(dstDimShape
, srcStt
.getElementType(), dstEnc
);
1312 return build(odsBuilder
, odsState
, dstTp
, source
);
1315 LogicalResult
ReinterpretMapOp::verify() {
1316 auto srcStt
= getSparseTensorType(getSource());
1317 auto dstStt
= getSparseTensorType(getDest());
1318 ArrayRef
<DimLevelType
> srcLvlTps
= srcStt
.getLvlTypes();
1319 ArrayRef
<DimLevelType
> dstLvlTps
= dstStt
.getLvlTypes();
1321 if (srcLvlTps
.size() != dstLvlTps
.size())
1322 return emitError("Level rank mismatch between source/dest tensors");
1324 for (auto [srcLvlTp
, dstLvlTp
] : llvm::zip(srcLvlTps
, dstLvlTps
))
1325 if (srcLvlTp
!= dstLvlTp
)
1326 return emitError("Level type mismatch between source/dest tensors");
1328 if (srcStt
.getPosWidth() != dstStt
.getPosWidth() ||
1329 srcStt
.getCrdWidth() != dstStt
.getCrdWidth()) {
1330 return emitError("Crd/Pos width mismatch between source/dest tensors");
1333 if (srcStt
.getElementType() != dstStt
.getElementType())
1334 return emitError("Element type mismatch between source/dest tensors");
1336 SmallVector
<Size
> srcLvlShape
= srcStt
.getLvlShape();
1337 SmallVector
<Size
> dstLvlShape
= dstStt
.getLvlShape();
1338 for (auto [srcLvlSz
, dstLvlSz
] : llvm::zip(srcLvlShape
, dstLvlShape
)) {
1339 if (srcLvlSz
!= dstLvlSz
) {
1340 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1341 // compatible to <3x4>? For now, we require all the level sizes to be
1342 // *exactly* matched for simplicity.
1343 return emitError("Level size mismatch between source/dest tensors");
1350 OpFoldResult
ReinterpretMapOp::fold(FoldAdaptor adaptor
) {
1351 if (getSource().getType() == getDest().getType())
1354 if (auto def
= getSource().getDefiningOp
<ReinterpretMapOp
>()) {
1355 // A -> B, B -> A ==> A
1356 if (def
.getSource().getType() == getDest().getType())
1357 return def
.getSource();
1362 LogicalResult
ToPositionsOp::verify() {
1363 auto e
= getSparseTensorEncoding(getTensor().getType());
1364 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1365 return emitError("requested level is out of bounds");
1366 if (failed(isMatchingWidth(getResult(), e
.getPosWidth())))
1367 return emitError("unexpected type for positions");
1371 LogicalResult
ToCoordinatesOp::verify() {
1372 auto e
= getSparseTensorEncoding(getTensor().getType());
1373 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1374 return emitError("requested level is out of bounds");
1375 if (failed(isMatchingWidth(getResult(), e
.getCrdWidth())))
1376 return emitError("unexpected type for coordinates");
1380 LogicalResult
ToCoordinatesBufferOp::verify() {
1381 auto e
= getSparseTensorEncoding(getTensor().getType());
1382 if (getCOOStart(e
) >= e
.getLvlRank())
1383 return emitError("expected sparse tensor with a COO region");
1387 LogicalResult
ToValuesOp::verify() {
1388 auto ttp
= getRankedTensorType(getTensor());
1389 auto mtp
= getMemRefType(getResult());
1390 if (ttp
.getElementType() != mtp
.getElementType())
1391 return emitError("unexpected mismatch in element types");
1395 LogicalResult
ToSliceOffsetOp::verify() {
1396 auto rank
= getRankedTensorType(getSlice()).getRank();
1397 if (rank
<= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1398 return emitError("requested dimension out of bound");
1402 LogicalResult
ToSliceStrideOp::verify() {
1403 auto rank
= getRankedTensorType(getSlice()).getRank();
1404 if (rank
<= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1405 return emitError("requested dimension out of bound");
1409 LogicalResult
GetStorageSpecifierOp::verify() {
1410 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1411 getSpecifier(), getOperation());
1414 template <typename SpecifierOp
>
1415 static SetStorageSpecifierOp
getSpecifierSetDef(SpecifierOp op
) {
1416 return op
.getSpecifier().template getDefiningOp
<SetStorageSpecifierOp
>();
1419 OpFoldResult
GetStorageSpecifierOp::fold(FoldAdaptor adaptor
) {
1420 const StorageSpecifierKind kind
= getSpecifierKind();
1421 const auto lvl
= getLevel();
1422 for (auto op
= getSpecifierSetDef(*this); op
; op
= getSpecifierSetDef(op
))
1423 if (kind
== op
.getSpecifierKind() && lvl
== op
.getLevel())
1424 return op
.getValue();
1428 LogicalResult
SetStorageSpecifierOp::verify() {
1429 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1430 getSpecifier(), getOperation());
1434 static LogicalResult
verifyNumBlockArgs(T
*op
, Region
®ion
,
1435 const char *regionName
,
1436 TypeRange inputTypes
, Type outputType
) {
1437 unsigned numArgs
= region
.getNumArguments();
1438 unsigned expectedNum
= inputTypes
.size();
1439 if (numArgs
!= expectedNum
)
1440 return op
->emitError() << regionName
<< " region must have exactly "
1441 << expectedNum
<< " arguments";
1443 for (unsigned i
= 0; i
< numArgs
; i
++) {
1444 Type typ
= region
.getArgument(i
).getType();
1445 if (typ
!= inputTypes
[i
])
1446 return op
->emitError() << regionName
<< " region argument " << (i
+ 1)
1447 << " type mismatch";
1449 Operation
*term
= region
.front().getTerminator();
1450 YieldOp yield
= dyn_cast
<YieldOp
>(term
);
1452 return op
->emitError() << regionName
1453 << " region must end with sparse_tensor.yield";
1454 if (!yield
.getResult() || yield
.getResult().getType() != outputType
)
1455 return op
->emitError() << regionName
<< " region yield type mismatch";
1460 LogicalResult
BinaryOp::verify() {
1461 NamedAttrList attrs
= (*this)->getAttrs();
1462 Type leftType
= getX().getType();
1463 Type rightType
= getY().getType();
1464 Type outputType
= getOutput().getType();
1465 Region
&overlap
= getOverlapRegion();
1466 Region
&left
= getLeftRegion();
1467 Region
&right
= getRightRegion();
1469 // Check correct number of block arguments and return type for each
1470 // non-empty region.
1471 if (!overlap
.empty()) {
1472 if (failed(verifyNumBlockArgs(this, overlap
, "overlap",
1473 TypeRange
{leftType
, rightType
}, outputType
)))
1476 if (!left
.empty()) {
1477 if (failed(verifyNumBlockArgs(this, left
, "left", TypeRange
{leftType
},
1480 } else if (getLeftIdentity()) {
1481 if (leftType
!= outputType
)
1482 return emitError("left=identity requires first argument to have the same "
1483 "type as the output");
1485 if (!right
.empty()) {
1486 if (failed(verifyNumBlockArgs(this, right
, "right", TypeRange
{rightType
},
1489 } else if (getRightIdentity()) {
1490 if (rightType
!= outputType
)
1491 return emitError("right=identity requires second argument to have the "
1492 "same type as the output");
1497 LogicalResult
UnaryOp::verify() {
1498 Type inputType
= getX().getType();
1499 Type outputType
= getOutput().getType();
1501 // Check correct number of block arguments and return type for each
1502 // non-empty region.
1503 Region
&present
= getPresentRegion();
1504 if (!present
.empty()) {
1505 if (failed(verifyNumBlockArgs(this, present
, "present",
1506 TypeRange
{inputType
}, outputType
)))
1509 Region
&absent
= getAbsentRegion();
1510 if (!absent
.empty()) {
1511 if (failed(verifyNumBlockArgs(this, absent
, "absent", TypeRange
{},
1514 // Absent branch can only yield invariant values.
1515 Block
*absentBlock
= &absent
.front();
1516 Block
*parent
= getOperation()->getBlock();
1517 Value absentVal
= cast
<YieldOp
>(absentBlock
->getTerminator()).getResult();
1518 if (auto arg
= dyn_cast
<BlockArgument
>(absentVal
)) {
1519 if (arg
.getOwner() == parent
)
1520 return emitError("absent region cannot yield linalg argument");
1521 } else if (Operation
*def
= absentVal
.getDefiningOp()) {
1522 if (!isa
<arith::ConstantOp
>(def
) &&
1523 (def
->getBlock() == absentBlock
|| def
->getBlock() == parent
))
1524 return emitError("absent region cannot yield locally computed value");
1530 bool ConcatenateOp::needsExtraSort() {
1531 SparseTensorType dstStt
= getSparseTensorType(*this);
1532 if (dstStt
.isAllDense() || !dstStt
.isAllOrdered())
1535 bool allSameOrdered
= llvm::all_of(getInputs(), [dstStt
](Value op
) {
1536 return getSparseTensorType(op
).hasSameDimToLvl(dstStt
);
1538 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1539 // in all input/output buffers, and all input/output buffers have the same
1540 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1541 // CSC matrices along column).
1542 bool directLowerable
=
1543 allSameOrdered
&& getDimension() == 0 && dstStt
.isIdentity();
1544 return !directLowerable
;
1547 LogicalResult
ConcatenateOp::verify() {
1548 const auto dstTp
= getSparseTensorType(*this);
1549 const Dimension concatDim
= getDimension();
1550 const Dimension dimRank
= dstTp
.getDimRank();
1552 if (getInputs().size() <= 1)
1553 return emitError("Need at least two tensors to concatenate.");
1555 if (concatDim
>= dimRank
)
1556 return emitError(llvm::formatv(
1557 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1558 concatDim
, dimRank
));
1560 for (const auto &it
: llvm::enumerate(getInputs())) {
1561 const auto i
= it
.index();
1562 const auto srcTp
= getSparseTensorType(it
.value());
1563 if (srcTp
.hasDynamicDimShape())
1564 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i
));
1565 const Dimension srcDimRank
= srcTp
.getDimRank();
1566 if (srcDimRank
!= dimRank
)
1568 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1569 "from the output tensor (rank={2}).",
1570 i
, srcDimRank
, dimRank
));
1573 for (Dimension d
= 0; d
< dimRank
; d
++) {
1574 const Size dstSh
= dstTp
.getDimShape()[d
];
1575 if (d
== concatDim
) {
1576 if (!ShapedType::isDynamic(dstSh
)) {
1577 // If we reach here, then all inputs have static shapes. So we
1578 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1579 // to avoid redundant assertions in the loop.
1581 for (const auto src
: getInputs())
1582 sumSz
+= getSparseTensorType(src
).getDimShape()[d
];
1583 // If all dimension are statically known, the sum of all the input
1584 // dimensions should be equal to the output dimension.
1587 "The concatenation dimension of the output tensor should be the "
1588 "sum of all the concatenation dimensions of the input tensors.");
1592 for (const auto src
: getInputs()) {
1593 const auto sh
= getSparseTensorType(src
).getDimShape()[d
];
1594 if (!ShapedType::isDynamic(prev
) && sh
!= prev
)
1595 return emitError("All dimensions (expect for the concatenating one) "
1596 "should be equal.");
1605 LogicalResult
InsertOp::verify() {
1606 const auto stt
= getSparseTensorType(getTensor());
1607 if (stt
.getLvlRank() != static_cast<Level
>(getLvlCoords().size()))
1608 return emitOpError("incorrect number of coordinates");
1612 void PushBackOp::build(OpBuilder
&builder
, OperationState
&result
,
1613 Value curSize
, Value inBuffer
, Value value
) {
1614 build(builder
, result
, curSize
, inBuffer
, value
, Value());
1617 LogicalResult
PushBackOp::verify() {
1618 if (Value n
= getN()) {
1619 std::optional
<int64_t> nValue
= getConstantIntValue(n
);
1620 if (nValue
&& nValue
.value() < 1)
1621 return emitOpError("n must be not less than 1");
1626 LogicalResult
CompressOp::verify() {
1627 const auto stt
= getSparseTensorType(getTensor());
1628 if (stt
.getLvlRank() != 1 + static_cast<Level
>(getLvlCoords().size()))
1629 return emitOpError("incorrect number of coordinates");
1633 void ForeachOp::build(
1634 OpBuilder
&builder
, OperationState
&result
, Value tensor
,
1635 ValueRange initArgs
, AffineMapAttr order
,
1636 function_ref
<void(OpBuilder
&, Location
, ValueRange
, Value
, ValueRange
)>
1638 build(builder
, result
, initArgs
.getTypes(), tensor
, initArgs
, order
);
1639 // Builds foreach body.
1642 const auto stt
= getSparseTensorType(tensor
);
1643 const Dimension dimRank
= stt
.getDimRank();
1645 // Starts with `dimRank`-many coordinates.
1646 SmallVector
<Type
> blockArgTypes(dimRank
, builder
.getIndexType());
1647 // Followed by one value.
1648 blockArgTypes
.push_back(stt
.getElementType());
1649 // Followed by the reduction variables.
1650 blockArgTypes
.append(initArgs
.getTypes().begin(), initArgs
.getTypes().end());
1652 SmallVector
<Location
> blockArgLocs(blockArgTypes
.size(), tensor
.getLoc());
1654 OpBuilder::InsertionGuard
guard(builder
);
1655 auto ®ion
= *result
.regions
.front();
1657 builder
.createBlock(®ion
, region
.end(), blockArgTypes
, blockArgLocs
);
1658 bodyBuilder(builder
, result
.location
,
1659 bodyBlock
->getArguments().slice(0, dimRank
),
1660 bodyBlock
->getArguments()[dimRank
],
1661 bodyBlock
->getArguments().drop_front(dimRank
+ 1));
1664 LogicalResult
ForeachOp::verify() {
1665 const auto t
= getSparseTensorType(getTensor());
1666 const Dimension dimRank
= t
.getDimRank();
1667 const auto args
= getBody()->getArguments();
1669 if (getOrder().has_value() &&
1670 (t
.getEncoding() || !getOrder()->isPermutation()))
1671 return emitError("Only support permuted order on non encoded dense tensor");
1673 if (static_cast<size_t>(dimRank
) + 1 + getInitArgs().size() != args
.size())
1674 return emitError("Unmatched number of arguments in the block");
1676 if (getNumResults() != getInitArgs().size())
1677 return emitError("Mismatch in number of init arguments and results");
1679 if (getResultTypes() != getInitArgs().getTypes())
1680 return emitError("Mismatch in types of init arguments and results");
1682 // Cannot mark this const, because the getters aren't.
1683 auto yield
= cast
<YieldOp
>(getBody()->getTerminator());
1684 if (yield
.getNumOperands() != getNumResults() ||
1685 yield
.getOperands().getTypes() != getResultTypes())
1686 return emitError("Mismatch in types of yield values and results");
1688 const auto iTp
= IndexType::get(getContext());
1689 for (Dimension d
= 0; d
< dimRank
; d
++)
1690 if (args
[d
].getType() != iTp
)
1692 llvm::formatv("Expecting Index type for argument at index {0}", d
));
1694 const auto elemTp
= t
.getElementType();
1695 const auto valueTp
= args
[dimRank
].getType();
1696 if (elemTp
!= valueTp
)
1697 emitError(llvm::formatv("Unmatched element type between input tensor and "
1698 "block argument, expected:{0}, got: {1}",
1703 OpFoldResult
ReorderCOOOp::fold(FoldAdaptor adaptor
) {
1704 if (getSparseTensorEncoding(getInputCoo().getType()) ==
1705 getSparseTensorEncoding(getResultCoo().getType()))
1706 return getInputCoo();
1711 LogicalResult
ReorderCOOOp::verify() {
1712 SparseTensorType srcStt
= getSparseTensorType(getInputCoo());
1713 SparseTensorType dstStt
= getSparseTensorType(getResultCoo());
1715 if (!srcStt
.hasSameDimToLvl(dstStt
))
1716 emitError("Unmatched dim2lvl map between input and result COO");
1718 if (srcStt
.getPosType() != dstStt
.getPosType() ||
1719 srcStt
.getCrdType() != dstStt
.getCrdType() ||
1720 srcStt
.getElementType() != dstStt
.getElementType()) {
1721 emitError("Unmatched storage format between input and result COO");
1726 LogicalResult
ReduceOp::verify() {
1727 Type inputType
= getX().getType();
1728 Region
&formula
= getRegion();
1729 return verifyNumBlockArgs(this, formula
, "reduce",
1730 TypeRange
{inputType
, inputType
}, inputType
);
1733 LogicalResult
SelectOp::verify() {
1734 Builder
b(getContext());
1735 Type inputType
= getX().getType();
1736 Type boolType
= b
.getI1Type();
1737 Region
&formula
= getRegion();
1738 return verifyNumBlockArgs(this, formula
, "select", TypeRange
{inputType
},
1742 LogicalResult
SortOp::verify() {
1743 AffineMap xPerm
= getPermMap();
1744 uint64_t nx
= xPerm
.getNumDims();
1746 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx
));
1748 if (!xPerm
.isPermutation())
1749 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm
));
1751 std::optional
<int64_t> cn
= getConstantIntValue(getN());
1752 // We can't check the size of the buffers when n or buffer dimensions aren't
1753 // compile-time constants.
1757 uint64_t n
= cn
.value();
1759 if (auto nyAttr
= getNyAttr()) {
1760 ny
= nyAttr
.getInt();
1763 // FIXME: update the types of variables used in expressions bassed as
1764 // the `minSize` argument, to avoid implicit casting at the callsites
1766 const auto checkDim
= [&](Value v
, Size minSize
, const char *message
) {
1767 const Size sh
= getMemRefType(v
).getShape()[0];
1768 if (!ShapedType::isDynamic(sh
) && sh
< minSize
)
1769 emitError(llvm::formatv("{0} got {1} < {2}", message
, sh
, minSize
));
1772 checkDim(getXy(), n
* (nx
+ ny
),
1773 "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1775 for (Value opnd
: getYs()) {
1776 checkDim(opnd
, n
, "Expected dimension(y) >= n");
1782 LogicalResult
YieldOp::verify() {
1783 // Check for compatible parent.
1784 auto *parentOp
= (*this)->getParentOp();
1785 if (isa
<BinaryOp
>(parentOp
) || isa
<UnaryOp
>(parentOp
) ||
1786 isa
<ReduceOp
>(parentOp
) || isa
<SelectOp
>(parentOp
) ||
1787 isa
<ForeachOp
>(parentOp
))
1790 return emitOpError("expected parent op to be sparse_tensor unary, binary, "
1791 "reduce, select or foreach");
1794 /// Materialize a single constant operation from a given attribute value with
1795 /// the desired resultant type.
1796 Operation
*SparseTensorDialect::materializeConstant(OpBuilder
&builder
,
1797 Attribute value
, Type type
,
1799 if (auto op
= arith::ConstantOp::materialize(builder
, value
, type
, loc
))
1804 void SparseTensorDialect::initialize() {
1806 #define GET_ATTRDEF_LIST
1807 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
1810 #define GET_TYPEDEF_LIST
1811 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
1815 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1819 #define GET_OP_CLASSES
1820 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1822 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"