[mlir][sparse] fixed naming consistency (#73053)
[llvm-project.git] / mlir / lib / Dialect / SparseTensor / IR / SparseTensorDialect.cpp
blobfb2e70482a1978b2e082fff7d46ac677bf419eb8
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <utility>
11 #include "Detail/DimLvlMapParser.h"
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Utils/StaticValueUtils.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
28 #define GET_ATTRDEF_CLASSES
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
38 //===----------------------------------------------------------------------===//
39 // Local convenience methods.
40 //===----------------------------------------------------------------------===//
42 static constexpr bool acceptBitWidth(unsigned bitWidth) {
43 switch (bitWidth) {
44 case 0:
45 case 8:
46 case 16:
47 case 32:
48 case 64:
49 return true;
50 default:
51 return false;
55 //===----------------------------------------------------------------------===//
56 // SparseTensorDialect StorageLayout.
57 //===----------------------------------------------------------------------===//
59 static constexpr Level kInvalidLevel = -1u;
60 static constexpr Level kInvalidFieldIndex = -1u;
61 static constexpr FieldIndex kDataFieldStartingIdx = 0;
63 void StorageLayout::foreachField(
64 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
65 DimLevelType)>
66 callback) const {
67 const auto lvlTypes = enc.getLvlTypes();
68 const Level lvlRank = enc.getLvlRank();
69 const Level cooStart = getCOOStart(enc);
70 const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
71 FieldIndex fieldIdx = kDataFieldStartingIdx;
72 // Per-level storage.
73 for (Level l = 0; l < end; l++) {
74 const auto dlt = lvlTypes[l];
75 if (isWithPosDLT(dlt)) {
76 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt)))
77 return;
79 if (isWithCrdDLT(dlt)) {
80 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt)))
81 return;
84 // The values array.
85 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
86 DimLevelType::Undef)))
87 return;
88 // Put metadata at the end.
89 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
90 DimLevelType::Undef)))
91 return;
94 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
95 SparseTensorType stt,
96 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
97 DimLevelType)>
98 callback) {
99 assert(stt.hasEncoding());
100 // Construct the basic types.
101 const Type crdType = stt.getCrdType();
102 const Type posType = stt.getPosType();
103 const Type eltType = stt.getElementType();
105 const Type specType = StorageSpecifierType::get(stt.getEncoding());
106 // memref<? x pos> positions
107 const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
108 // memref<? x crd> coordinates
109 const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
110 // memref<? x eltType> values
111 const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
113 StorageLayout(stt).foreachField(
114 [specType, posMemType, crdMemType, valMemType,
115 callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
116 Level lvl, DimLevelType dlt) -> bool {
117 switch (fieldKind) {
118 case SparseTensorFieldKind::StorageSpec:
119 return callback(specType, fieldIdx, fieldKind, lvl, dlt);
120 case SparseTensorFieldKind::PosMemRef:
121 return callback(posMemType, fieldIdx, fieldKind, lvl, dlt);
122 case SparseTensorFieldKind::CrdMemRef:
123 return callback(crdMemType, fieldIdx, fieldKind, lvl, dlt);
124 case SparseTensorFieldKind::ValMemRef:
125 return callback(valMemType, fieldIdx, fieldKind, lvl, dlt);
127 llvm_unreachable("unrecognized field kind");
131 unsigned StorageLayout::getNumFields() const {
132 unsigned numFields = 0;
133 foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
134 DimLevelType) -> bool {
135 numFields++;
136 return true;
138 return numFields;
141 unsigned StorageLayout::getNumDataFields() const {
142 unsigned numFields = 0; // one value memref
143 foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
144 DimLevelType) -> bool {
145 if (fidx >= kDataFieldStartingIdx)
146 numFields++;
147 return true;
149 numFields -= 1; // the last field is StorageSpecifier
150 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
151 return numFields;
154 std::pair<FieldIndex, unsigned>
155 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
156 std::optional<Level> lvl) const {
157 FieldIndex fieldIdx = kInvalidFieldIndex;
158 unsigned stride = 1;
159 if (kind == SparseTensorFieldKind::CrdMemRef) {
160 assert(lvl.has_value());
161 const Level cooStart = getCOOStart(enc);
162 const Level lvlRank = enc.getLvlRank();
163 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
164 lvl = cooStart;
165 stride = lvlRank - cooStart;
168 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
169 SparseTensorFieldKind fKind, Level fLvl,
170 DimLevelType dlt) -> bool {
171 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
172 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
173 fieldIdx = fIdx;
174 // Returns false to break the iteration.
175 return false;
177 return true;
179 assert(fieldIdx != kInvalidFieldIndex);
180 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
183 //===----------------------------------------------------------------------===//
184 // SparseTensorDialect Attribute Methods.
185 //===----------------------------------------------------------------------===//
187 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
188 return isDynamic(v) ? std::nullopt
189 : std::make_optional(static_cast<uint64_t>(v));
192 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
193 return getStatic(getOffset());
196 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
197 return getStatic(getStride());
200 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
201 return getStatic(getSize());
204 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
205 return isDynamic(getOffset()) && isDynamic(getStride()) &&
206 isDynamic(getSize());
209 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
210 return isDynamic(v) ? "?" : std::to_string(v);
213 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
214 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
215 os << '(';
216 os << getStaticString(getOffset());
217 os << ", ";
218 os << getStaticString(getSize());
219 os << ", ";
220 os << getStaticString(getStride());
221 os << ')';
224 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
225 print(printer.getStream());
228 static ParseResult parseOptionalStaticSlice(int64_t &result,
229 AsmParser &parser) {
230 auto parseResult = parser.parseOptionalInteger(result);
231 if (parseResult.has_value()) {
232 if (parseResult.value().succeeded() && result < 0) {
233 parser.emitError(
234 parser.getCurrentLocation(),
235 "expect positive value or ? for slice offset/size/stride");
236 return failure();
238 return parseResult.value();
241 // Else, and '?' which represented dynamic slice
242 result = SparseTensorDimSliceAttr::kDynamic;
243 return parser.parseQuestion();
246 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
247 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
249 if (failed(parser.parseLParen()) ||
250 failed(parseOptionalStaticSlice(offset, parser)) ||
251 failed(parser.parseComma()) ||
252 failed(parseOptionalStaticSlice(size, parser)) ||
253 failed(parser.parseComma()) ||
254 failed(parseOptionalStaticSlice(stride, parser)) ||
255 failed(parser.parseRParen()))
256 return {};
258 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
259 offset, size, stride);
262 LogicalResult
263 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
264 int64_t offset, int64_t size, int64_t stride) {
265 if (!isDynamic(offset) && offset < 0)
266 return emitError() << "expect non-negative value or ? for slice offset";
267 if (!isDynamic(size) && size <= 0)
268 return emitError() << "expect positive value or ? for slice size";
269 if (!isDynamic(stride) && stride <= 0)
270 return emitError() << "expect positive value or ? for slice stride";
271 return success();
274 SparseTensorEncodingAttr
275 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
276 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
277 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
278 AffineMap(), getPosWidth(),
279 getCrdWidth());
282 SparseTensorEncodingAttr
283 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
284 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
287 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
288 return withDimToLvl(AffineMap());
291 SparseTensorEncodingAttr
292 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
293 unsigned crdWidth) const {
294 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
295 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
296 getDimToLvl(), getLvlToDim(), posWidth,
297 crdWidth);
300 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
301 return withBitWidths(0, 0);
304 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
305 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
306 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
307 getDimToLvl(), getLvlToDim(),
308 getPosWidth(), getCrdWidth(), dimSlices);
311 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
312 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
315 bool SparseTensorEncodingAttr::isAllDense() const {
316 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
319 bool SparseTensorEncodingAttr::isCOO() const {
320 return getImpl() && isCOOType(*this, 0, true);
323 bool SparseTensorEncodingAttr::isAllOrdered() const {
324 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedDLT);
327 bool SparseTensorEncodingAttr::isIdentity() const {
328 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
331 bool SparseTensorEncodingAttr::isPermutation() const {
332 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
335 Dimension SparseTensorEncodingAttr::getDimRank() const {
336 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
337 const auto dimToLvl = getDimToLvl();
338 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
341 Level SparseTensorEncodingAttr::getLvlRank() const {
342 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
343 return getLvlTypes().size();
346 DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
347 if (!getImpl())
348 return DimLevelType::Dense;
349 assert(l < getLvlRank() && "Level is out of bounds");
350 return getLvlTypes()[l];
353 bool SparseTensorEncodingAttr::isSlice() const {
354 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
355 return !getDimSlices().empty();
358 SparseTensorDimSliceAttr
359 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
360 assert(isSlice() && "Is not a slice");
361 const auto dimSlices = getDimSlices();
362 assert(dim < dimSlices.size() && "Dimension is out of bounds");
363 return dimSlices[dim];
366 std::optional<uint64_t>
367 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
368 return getDimSlice(dim).getStaticOffset();
371 std::optional<uint64_t>
372 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
373 return getDimSlice(dim).getStaticStride();
376 std::optional<uint64_t>
377 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
378 // FIXME: `toOrigDim` is deprecated.
379 return getStaticDimSliceOffset(toOrigDim(*this, lvl));
382 std::optional<uint64_t>
383 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
384 // FIXME: `toOrigDim` is deprecated.
385 return getStaticDimSliceStride(toOrigDim(*this, lvl));
388 SmallVector<int64_t>
389 SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
390 CrdTransDirectionKind dir) const {
391 if (isIdentity())
392 return SmallVector<int64_t>(srcShape);
394 SmallVector<int64_t> ret;
395 unsigned rank =
396 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
397 ret.reserve(rank);
399 if (isPermutation()) {
400 for (unsigned r = 0; r < rank; r++) {
401 // FIXME: `toOrigDim` and `toStoredDim` are deprecated.
402 unsigned trans = dir == CrdTransDirectionKind::dim2lvl
403 ? toOrigDim(*this, r)
404 : toStoredDim(*this, r);
405 ret.push_back(srcShape[trans]);
407 return ret;
410 // Handle non-permutation maps.
411 AffineMap transMap =
412 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
414 SmallVector<AffineExpr> dimRep;
415 dimRep.reserve(srcShape.size());
416 for (int64_t sz : srcShape) {
417 if (!ShapedType::isDynamic(sz)) {
418 // Push back the max coordinate for the given dimension/level size.
419 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
420 } else {
421 // A dynamic size, use a AffineDimExpr to symbolize the value.
422 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
426 for (AffineExpr exp : transMap.getResults()) {
427 // Do constant propagation on the affine map.
428 AffineExpr evalExp =
429 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
430 // use llvm namespace here to avoid ambiguity
431 if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) {
432 ret.push_back(c.getValue() + 1);
433 } else {
434 if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp);
435 mod && mod.getKind() == AffineExprKind::Mod) {
436 // We can still infer a static bound for expressions in form
437 // "d % constant" since d % constant \in [0, constant).
438 if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) {
439 ret.push_back(bound.getValue());
440 continue;
443 ret.push_back(ShapedType::kDynamic);
446 assert(ret.size() == rank);
447 return ret;
450 ValueRange
451 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
452 ValueRange crds,
453 CrdTransDirectionKind dir) const {
454 if (!getImpl())
455 return crds;
457 SmallVector<Type> retType(
458 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
459 builder.getIndexType());
460 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
461 return transOp.getOutCrds();
464 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
465 // Open "<{" part.
466 if (failed(parser.parseLess()))
467 return {};
468 if (failed(parser.parseLBrace()))
469 return {};
471 // Process the data from the parsed dictionary value into struct-like data.
472 SmallVector<DimLevelType> lvlTypes;
473 SmallVector<SparseTensorDimSliceAttr> dimSlices;
474 AffineMap dimToLvl = {};
475 AffineMap lvlToDim = {};
476 unsigned posWidth = 0;
477 unsigned crdWidth = 0;
478 StringRef attrName;
479 SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"};
480 while (succeeded(parser.parseOptionalKeyword(&attrName))) {
481 // Detect admissible keyword.
482 auto *it = find(keys, attrName);
483 if (it == keys.end()) {
484 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
485 return {};
487 unsigned keyWordIndex = it - keys.begin();
488 // Consume the `=` after keys
489 if (failed(parser.parseEqual()))
490 return {};
491 // Dispatch on keyword.
492 switch (keyWordIndex) {
493 case 0: { // map
494 ir_detail::DimLvlMapParser cParser(parser);
495 auto res = cParser.parseDimLvlMap();
496 if (failed(res))
497 return {};
498 const auto &dlm = *res;
500 const Level lvlRank = dlm.getLvlRank();
501 for (Level lvl = 0; lvl < lvlRank; lvl++)
502 lvlTypes.push_back(dlm.getLvlType(lvl));
504 const Dimension dimRank = dlm.getDimRank();
505 for (Dimension dim = 0; dim < dimRank; dim++)
506 dimSlices.push_back(dlm.getDimSlice(dim));
507 // NOTE: the old syntax requires an all-or-nothing approach to
508 // `dimSlices`; therefore, if any slice actually exists then we need
509 // to convert null-DSA into default/nop DSA.
510 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
511 return static_cast<bool>(slice.getImpl());
513 if (llvm::any_of(dimSlices, isDefined)) {
514 const auto defaultSlice =
515 SparseTensorDimSliceAttr::get(parser.getContext());
516 for (Dimension dim = 0; dim < dimRank; dim++)
517 if (!isDefined(dimSlices[dim]))
518 dimSlices[dim] = defaultSlice;
519 } else {
520 dimSlices.clear();
523 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
524 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
525 break;
527 case 1: { // posWidth
528 Attribute attr;
529 if (failed(parser.parseAttribute(attr)))
530 return {};
531 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
532 if (!intAttr) {
533 parser.emitError(parser.getNameLoc(),
534 "expected an integral position bitwidth");
535 return {};
537 posWidth = intAttr.getInt();
538 break;
540 case 2: { // crdWidth
541 Attribute attr;
542 if (failed(parser.parseAttribute(attr)))
543 return {};
544 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
545 if (!intAttr) {
546 parser.emitError(parser.getNameLoc(),
547 "expected an integral index bitwidth");
548 return {};
550 crdWidth = intAttr.getInt();
551 break;
553 } // switch
554 // Only last item can omit the comma.
555 if (parser.parseOptionalComma().failed())
556 break;
559 // Close "}>" part.
560 if (failed(parser.parseRBrace()))
561 return {};
562 if (failed(parser.parseGreater()))
563 return {};
565 // Construct struct-like storage for attribute.
566 if (!lvlToDim || lvlToDim.isEmpty()) {
567 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
569 return parser.getChecked<SparseTensorEncodingAttr>(
570 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
571 dimSlices);
574 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
575 auto map = static_cast<AffineMap>(getDimToLvl());
576 // Empty affine map indicates identity map
577 if (!map)
578 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
579 printer << "<{ map = ";
580 printSymbols(map, printer);
581 printer << '(';
582 printDimensions(map, printer, getDimSlices());
583 printer << ") -> (";
584 printLevels(map, printer, getLvlTypes());
585 printer << ')';
586 // Print remaining members only for non-default values.
587 if (getPosWidth())
588 printer << ", posWidth = " << getPosWidth();
589 if (getCrdWidth())
590 printer << ", crdWidth = " << getCrdWidth();
591 printer << " }>";
594 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
595 AsmPrinter &printer) const {
596 if (map.getNumSymbols() == 0)
597 return;
598 printer << '[';
599 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
600 printer << 's' << i << ", ";
601 if (map.getNumSymbols() >= 1)
602 printer << 's' << map.getNumSymbols() - 1;
603 printer << ']';
606 void SparseTensorEncodingAttr::printDimensions(
607 AffineMap &map, AsmPrinter &printer,
608 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
609 if (!dimSlices.empty()) {
610 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
611 printer << 'd' << i << " : " << dimSlices[i] << ", ";
612 if (map.getNumDims() >= 1) {
613 printer << 'd' << map.getNumDims() - 1 << " : "
614 << dimSlices[map.getNumDims() - 1];
616 } else {
617 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
618 printer << 'd' << i << ", ";
619 if (map.getNumDims() >= 1)
620 printer << 'd' << map.getNumDims() - 1;
624 void SparseTensorEncodingAttr::printLevels(
625 AffineMap &map, AsmPrinter &printer,
626 ArrayRef<DimLevelType> lvlTypes) const {
627 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
628 map.getResult(i).print(printer.getStream());
629 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
631 if (map.getNumResults() >= 1) {
632 auto lastIndex = map.getNumResults() - 1;
633 map.getResult(lastIndex).print(printer.getStream());
634 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
638 LogicalResult
639 SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
640 ArrayRef<DimLevelType> lvlTypes,
641 AffineMap dimToLvl, AffineMap lvlToDim,
642 unsigned posWidth, unsigned crdWidth,
643 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
644 if (!acceptBitWidth(posWidth))
645 return emitError() << "unexpected position bitwidth: " << posWidth;
646 if (!acceptBitWidth(crdWidth))
647 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
648 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonDLT);
649 it != std::end(lvlTypes)) {
650 if (it == lvlTypes.begin() ||
651 (!isCompressedDLT(*(it - 1)) && !isLooseCompressedDLT(*(it - 1))))
652 return emitError() << "expected compressed or loose_compressed level "
653 "before singleton level";
654 if (!std::all_of(it, lvlTypes.end(),
655 [](DimLevelType i) { return isSingletonDLT(i); }))
656 return emitError() << "expected all singleton lvlTypes "
657 "following a singleton level";
659 // Before we can check that the level-rank is consistent/coherent
660 // across all fields, we need to define it. The source-of-truth for
661 // the `getLvlRank` method is the length of the level-types array,
662 // since it must always be provided and have full rank; therefore we
663 // use that same source-of-truth here.
664 const Level lvlRank = lvlTypes.size();
665 if (lvlRank == 0)
666 return emitError() << "expected a non-empty array for lvlTypes";
667 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
668 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
669 if (dimToLvl) {
670 if (dimToLvl.getNumResults() != lvlRank)
671 return emitError()
672 << "level-rank mismatch between dimToLvl and lvlTypes: "
673 << dimToLvl.getNumResults() << " != " << lvlRank;
674 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
675 // Symbols can't be inferred but are acceptable.
676 if (!inferRes && dimToLvl.getNumSymbols() == 0)
677 return emitError() << "failed to infer lvlToDim from dimToLvl";
678 if (lvlToDim && (inferRes != lvlToDim))
679 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
680 if (dimRank > lvlRank)
681 return emitError() << "unexpected dimToLvl mapping from " << dimRank
682 << " to " << lvlRank;
684 if (!dimSlices.empty()) {
685 if (dimSlices.size() != dimRank)
686 return emitError()
687 << "dimension-rank mismatch between dimSlices and dimToLvl: "
688 << dimSlices.size() << " != " << dimRank;
689 // Compiler support for `dimSlices` currently requires that the two
690 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
691 if (dimRank != lvlRank)
692 return emitError()
693 << "dimSlices expected dimension-rank to match level-rank: "
694 << dimRank << " != " << lvlRank;
696 return success();
699 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
700 ArrayRef<Size> dimShape, Type elementType,
701 function_ref<InFlightDiagnostic()> emitError) const {
702 // Check structural integrity. In particular, this ensures that the
703 // level-rank is coherent across all the fields.
704 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
705 getPosWidth(), getCrdWidth(), getDimSlices())))
706 return failure();
707 // Check integrity with tensor type specifics. In particular, we
708 // need only check that the dimension-rank of the tensor agrees with
709 // the dimension-rank of the encoding.
710 const Dimension dimRank = dimShape.size();
711 if (dimRank == 0)
712 return emitError() << "expected non-scalar sparse tensor";
713 if (getDimRank() != dimRank)
714 return emitError()
715 << "dimension-rank mismatch between encoding and tensor shape: "
716 << getDimRank() << " != " << dimRank;
717 return success();
720 //===----------------------------------------------------------------------===//
721 // Convenience methods.
722 //===----------------------------------------------------------------------===//
724 SparseTensorEncodingAttr
725 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
726 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
727 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
728 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
729 return mdtp.getEncoding();
730 return nullptr;
733 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
734 MLIRContext *context) {
735 auto map = static_cast<AffineMap>(dimToLvl);
736 AffineMap lvlToDim;
737 // Return an empty lvlToDim when inference is not successful.
738 if (!map || map.getNumSymbols() != 0) {
739 lvlToDim = AffineMap();
740 } else if (map.isPermutation()) {
741 lvlToDim = inversePermutation(map);
742 } else if (isBlockSparsity(map)) {
743 lvlToDim = inverseBlockSparsity(map, context);
745 return lvlToDim;
748 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
749 MLIRContext *context) {
750 SmallVector<AffineExpr> lvlExprs;
751 auto numLvls = dimToLvl.getNumResults();
752 lvlExprs.reserve(numLvls);
753 // lvlExprComponents stores information of the floordiv and mod operations
754 // applied to the same dimension, so as to build the lvlToDim map.
755 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
756 for (unsigned i = 0, n = numLvls; i < n; i++) {
757 auto result = dimToLvl.getResult(i);
758 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
759 if (result.getKind() == AffineExprKind::FloorDiv) {
760 // Position of the dimension in dimToLvl.
761 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
762 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
763 "expected only one floordiv for each dimension");
764 SmallVector<AffineExpr, 3> components;
765 // Level variable for floordiv.
766 components.push_back(getAffineDimExpr(i, context));
767 // Multiplier.
768 components.push_back(binOp.getRHS());
769 // Map key is the position of the dimension.
770 lvlExprComponents[pos] = components;
771 } else if (result.getKind() == AffineExprKind::Mod) {
772 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
773 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
774 "expected floordiv before mod");
775 // Add level variable for mod to the same vector
776 // of the corresponding floordiv.
777 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
778 } else {
779 assert(false && "expected floordiv or mod");
781 } else {
782 lvlExprs.push_back(getAffineDimExpr(i, context));
785 // Build lvlExprs from lvlExprComponents.
786 // For example, for il = i floordiv 2 and ii = i mod 2, the components
787 // would be [il, 2, ii]. It could be used to build the AffineExpr
788 // i = il * 2 + ii in lvlToDim.
789 for (auto &components : lvlExprComponents) {
790 assert(components.second.size() == 3 &&
791 "expected 3 components to build lvlExprs");
792 auto mulOp = getAffineBinaryOpExpr(
793 AffineExprKind::Mul, components.second[0], components.second[1]);
794 auto addOp =
795 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
796 lvlExprs.push_back(addOp);
798 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
801 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
802 assert(isBlockSparsity(dimToLvl) &&
803 "expected dimToLvl to be block sparsity for calling getBlockSize");
804 SmallVector<unsigned> blockSize;
805 for (auto result : dimToLvl.getResults()) {
806 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
807 if (result.getKind() == AffineExprKind::Mod) {
808 blockSize.push_back(
809 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue());
811 } else {
812 blockSize.push_back(0);
815 return blockSize;
818 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
819 if (!dimToLvl)
820 return false;
821 std::map<unsigned, int64_t> coeffientMap;
822 for (auto result : dimToLvl.getResults()) {
823 if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
824 auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
825 if (result.getKind() == AffineExprKind::FloorDiv) {
826 // Expect only one floordiv for each dimension.
827 if (coeffientMap.find(pos) != coeffientMap.end())
828 return false;
829 coeffientMap[pos] =
830 dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue();
831 } else if (result.getKind() == AffineExprKind::Mod) {
832 // Expect floordiv before mod.
833 if (coeffientMap.find(pos) == coeffientMap.end())
834 return false;
835 // Expect mod to have the same coefficient as floordiv.
836 if (dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue() !=
837 coeffientMap[pos]) {
838 return false;
840 } else {
841 return false;
845 return !coeffientMap.empty();
848 bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
849 Level startLvl, bool isUnique) {
850 if (!enc ||
851 !(enc.isCompressedLvl(startLvl) || enc.isLooseCompressedLvl(startLvl)))
852 return false;
853 const Level lvlRank = enc.getLvlRank();
854 for (Level l = startLvl + 1; l < lvlRank; ++l)
855 if (!enc.isSingletonLvl(l))
856 return false;
857 // If isUnique is true, then make sure that the last level is unique,
858 // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
859 // (unique on the last singleton).
860 return !isUnique || enc.isUniqueLvl(lvlRank - 1);
863 bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
864 return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
867 bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
868 auto hasNonIdentityMap = [](Value v) {
869 auto stt = tryGetSparseTensorType(v);
870 return stt && !stt->isIdentity();
873 return llvm::any_of(op->getOperands(), hasNonIdentityMap) ||
874 llvm::any_of(op->getResults(), hasNonIdentityMap);
877 Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
878 // We only consider COO region with at least two levels for the purpose
879 // of AOS storage optimization.
880 const Level lvlRank = enc.getLvlRank();
881 if (lvlRank > 1)
882 for (Level l = 0; l < lvlRank - 1; l++)
883 if (isCOOType(enc, l, /*isUnique=*/false))
884 return l;
885 return lvlRank;
888 // Helpers to setup a COO type.
889 RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
890 AffineMap lvlPerm,
891 bool ordered) {
892 const SparseTensorType src(rtt);
893 const Level lvlRank = src.getLvlRank();
894 SmallVector<DimLevelType> lvlTypes;
895 lvlTypes.reserve(lvlRank);
897 // An unordered and non-unique compressed level at beginning.
898 // If this is also the last level, then it is unique.
899 lvlTypes.push_back(
900 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
901 if (lvlRank > 1) {
902 // TODO: it is actually ordered at the level for ordered input.
903 // Followed by unordered non-unique n-2 singleton levels.
904 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
905 *buildLevelType(LevelFormat::Singleton, ordered, false));
906 // Ends by a unique singleton level unless the lvlRank is 1.
907 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
910 // TODO: Maybe pick the bitwidth based on input/output tensors (probably the
911 // largest one among them) in the original operation instead of using the
912 // default value.
913 unsigned posWidth = src.getPosWidth();
914 unsigned crdWidth = src.getCrdWidth();
915 AffineMap invPerm = src.getLvlToDim();
916 auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
917 invPerm, posWidth, crdWidth);
918 return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
921 RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
922 bool ordered) {
923 return getCOOFromTypeWithOrdering(
924 src, AffineMap::getMultiDimIdentityMap(src.getRank(), src.getContext()),
925 ordered);
928 // TODO: Remove this definition once all use-sites have been fixed to
929 // properly handle non-permutations.
930 Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
931 Level l) {
932 if (enc) {
933 if (const auto dimToLvl = enc.getDimToLvl()) {
934 assert(enc.isPermutation());
935 return dimToLvl.getDimPosition(l);
938 return l;
941 // TODO: Remove this definition once all use-sites have been fixed to
942 // properly handle non-permutations.
943 Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
944 Dimension d) {
945 if (enc) {
946 if (const auto dimToLvl = enc.getDimToLvl()) {
947 assert(enc.isPermutation());
948 auto maybePos =
949 dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
950 assert(maybePos.has_value());
951 return *maybePos;
954 return d;
957 /// We normalized sparse tensor encoding attribute by always using
958 /// ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
959 /// as other variants) lead to the same storage specifier type, and stripping
960 /// irrelevant fields that do not alter the sparse tensor memory layout.
961 static SparseTensorEncodingAttr
962 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
963 SmallVector<DimLevelType> dlts;
964 for (auto dlt : enc.getLvlTypes())
965 dlts.push_back(*buildLevelType(*getLevelFormat(dlt), true, true));
967 return SparseTensorEncodingAttr::get(
968 enc.getContext(), dlts,
969 AffineMap(), // dimToLvl (irrelevant to storage specifier)
970 AffineMap(), // lvlToDim (irrelevant to storage specifier)
971 // Always use `index` for memSize and lvlSize instead of reusing
972 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
973 // value for different bitwidth, it also avoids casting between index and
974 // integer (returned by DimOp)
975 0, 0, enc.getDimSlices());
978 StorageSpecifierType
979 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
980 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
983 //===----------------------------------------------------------------------===//
984 // SparseTensorDialect Operations.
985 //===----------------------------------------------------------------------===//
987 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
988 return success(lvl < getSparseTensorType(tensor).getLvlRank());
991 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
992 const Type etp = getMemRefType(mem).getElementType();
993 return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
996 static LogicalResult verifySparsifierGetterSetter(
997 StorageSpecifierKind mdKind, std::optional<Level> lvl,
998 TypedValue<StorageSpecifierType> md, Operation *op) {
999 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1000 return op->emitError(
1001 "redundant level argument for querying value memory size");
1004 const auto enc = md.getType().getEncoding();
1005 const Level lvlRank = enc.getLvlRank();
1007 if (mdKind == StorageSpecifierKind::DimOffset ||
1008 mdKind == StorageSpecifierKind::DimStride)
1009 if (!enc.isSlice())
1010 return op->emitError("requested slice data on non-slice tensor");
1012 if (mdKind != StorageSpecifierKind::ValMemSize) {
1013 if (!lvl)
1014 return op->emitError("missing level argument");
1016 const Level l = lvl.value();
1017 if (l >= lvlRank)
1018 return op->emitError("requested level is out of bounds");
1020 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1021 return op->emitError(
1022 "requested position memory size on a singleton level");
1024 return success();
1027 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1028 switch (kind) {
1029 case SparseTensorFieldKind::CrdMemRef:
1030 return stt.getCrdType();
1031 case SparseTensorFieldKind::PosMemRef:
1032 return stt.getPosType();
1033 case SparseTensorFieldKind::ValMemRef:
1034 return stt.getElementType();
1035 case SparseTensorFieldKind::StorageSpec:
1036 return nullptr;
1038 llvm_unreachable("Unrecognizable FieldKind");
1041 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1042 SparseTensorType stt,
1043 RankedTensorType valTp,
1044 TypeRange lvlTps) {
1045 if (requiresStaticShape && !stt.hasStaticDimShape())
1046 return op->emitError("the sparse-tensor must have static shape");
1047 if (!stt.hasEncoding())
1048 return op->emitError("the sparse-tensor must have an encoding attribute");
1049 if (!stt.isIdentity())
1050 return op->emitError("the sparse-tensor must have the identity mapping");
1052 // Verifies the trailing COO.
1053 Level cooStartLvl = getCOOStart(stt.getEncoding());
1054 if (cooStartLvl < stt.getLvlRank()) {
1055 // We only supports trailing COO for now, must be the last input.
1056 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1057 // The coordinates should be in shape of <? x rank>
1058 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1059 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1060 op->emitError("input/output trailing COO level-ranks don't match");
1064 // Verifies that all types match.
1065 StorageLayout layout(stt.getEncoding());
1066 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1067 return op->emitError("inconsistent number of fields between input/output");
1069 unsigned idx = 0;
1070 bool misMatch = false;
1071 layout.foreachField([&idx, &misMatch, stt, valTp,
1072 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1073 Level lvl, DimLevelType dlt) -> bool {
1074 if (fKind == SparseTensorFieldKind::StorageSpec)
1075 return true;
1077 Type inputTp = nullptr;
1078 if (fKind == SparseTensorFieldKind::ValMemRef) {
1079 inputTp = valTp;
1080 } else {
1081 assert(fid == idx && stt.getLvlType(lvl) == dlt);
1082 inputTp = lvlTps[idx++];
1084 // The input element type and expected element type should match.
1085 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1086 Type expElemTp = getFieldElemType(stt, fKind);
1087 if (inpElemTp != expElemTp) {
1088 misMatch = true;
1089 return false; // to terminate the iteration
1091 return true;
1094 if (misMatch)
1095 return op->emitError("input/output element-types don't match");
1096 return success();
1099 LogicalResult AssembleOp::verify() {
1100 const auto valuesTp = getRankedTensorType(getValues());
1101 const auto lvlsTp = getLevels().getTypes();
1102 const auto resTp = getSparseTensorType(getResult());
1103 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1106 LogicalResult DisassembleOp::verify() {
1107 if (getOutValues().getType() != getRetValues().getType())
1108 return emitError("output values and return value type mismatch");
1110 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1111 if (ot.getType() != rt.getType())
1112 return emitError("output levels and return levels type mismatch");
1114 const auto valuesTp = getRankedTensorType(getRetValues());
1115 const auto lvlsTp = getRetLevels().getTypes();
1116 const auto srcTp = getSparseTensorType(getTensor());
1117 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1120 LogicalResult ConvertOp::verify() {
1121 if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1122 if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1123 if (tp1.getRank() != tp2.getRank())
1124 return emitError("unexpected conversion mismatch in rank");
1125 auto dstEnc =
1126 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1127 if (dstEnc && dstEnc.isSlice())
1128 return emitError("cannot convert to a sparse tensor slice");
1130 auto shape1 = tp1.getShape();
1131 auto shape2 = tp2.getShape();
1132 // Accept size matches between the source and the destination type
1133 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1134 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1135 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1136 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1137 return emitError("unexpected conversion mismatch in dimension ") << d;
1138 return success();
1141 return emitError("unexpected type in convert");
1144 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1145 if (getType() == getSource().getType())
1146 return getSource();
1147 return {};
1150 bool ConvertOp::needsExtraSort() {
1151 SparseTensorType srcStt = getSparseTensorType(getSource());
1152 SparseTensorType dstStt = getSparseTensorType(getDest());
1154 // We do not need an extra sort when returning unordered sparse tensors or
1155 // dense tensor since dense tensor support random access.
1156 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1157 return false;
1159 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1160 srcStt.hasSameDimToLvl(dstStt)) {
1161 return false;
1164 // Source and dest tensors are ordered in different ways. We only do direct
1165 // dense to sparse conversion when the dense input is defined by a sparse
1166 // constant. Note that we can theoretically always directly convert from dense
1167 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1168 // performance.
1169 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1170 if (isa<SparseElementsAttr>(constOp.getValue()))
1171 return false;
1173 return true;
1176 LogicalResult CrdTranslateOp::verify() {
1177 uint64_t inRank = getEncoder().getLvlRank();
1178 uint64_t outRank = getEncoder().getDimRank();
1180 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1181 std::swap(inRank, outRank);
1183 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1184 return emitError("Coordinate rank mismatch with encoding");
1186 return success();
1189 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1190 SmallVectorImpl<OpFoldResult> &results) {
1191 if (getEncoder().isIdentity()) {
1192 results.assign(getInCrds().begin(), getInCrds().end());
1193 return success();
1195 if (getEncoder().isPermutation()) {
1196 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1197 ? getEncoder().getDimToLvl()
1198 : getEncoder().getLvlToDim();
1199 for (AffineExpr exp : perm.getResults())
1200 results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]);
1201 return success();
1204 // Fuse dim2lvl/lvl2dim pairs.
1205 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1206 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1207 return v.getDefiningOp() == def;
1209 if (!sameDef)
1210 return failure();
1212 bool oppositeDir = def.getDirection() != getDirection();
1213 bool sameOracle =
1214 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1215 bool sameCount = def.getNumResults() == getInCrds().size();
1216 if (!oppositeDir || !sameOracle || !sameCount)
1217 return failure();
1219 // The definition produces the coordinates in the same order as the input
1220 // coordinates.
1221 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1222 [](auto valuePair) {
1223 auto [lhs, rhs] = valuePair;
1224 return lhs == rhs;
1227 if (!sameOrder)
1228 return failure();
1229 // l1 = dim2lvl (lvl2dim l0)
1230 // ==> l0
1231 results.append(def.getInCrds().begin(), def.getInCrds().end());
1232 return success();
1235 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1236 int64_t index) {
1237 Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1238 return build(builder, state, source, val);
1241 LogicalResult LvlOp::verify() {
1242 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1243 auto stt = getSparseTensorType(getSource());
1244 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1245 emitError("Level index exceeds the rank of the input sparse tensor");
1247 return success();
1250 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1251 return getConstantIntValue(getIndex());
1254 Speculation::Speculatability LvlOp::getSpeculatability() {
1255 auto constantIndex = getConstantLvlIndex();
1256 if (!constantIndex)
1257 return Speculation::NotSpeculatable;
1259 assert(constantIndex <
1260 cast<RankedTensorType>(getSource().getType()).getRank());
1261 return Speculation::Speculatable;
1264 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1265 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1266 if (!lvlIndex)
1267 return {};
1269 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1270 auto stt = getSparseTensorType(getSource());
1271 if (lvl >= stt.getLvlRank()) {
1272 // Follows the same convention used by tensor.dim operation. Out of bound
1273 // indices produce undefined behavior but are still valid IR. Don't choke on
1274 // them.
1275 return {};
1278 // Helper lambda to build an IndexAttr.
1279 auto getIndexAttr = [this](int64_t lvlSz) {
1280 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1283 SmallVector<Size> lvlShape = stt.getLvlShape();
1284 if (!ShapedType::isDynamic(lvlShape[lvl]))
1285 return getIndexAttr(lvlShape[lvl]);
1287 return {};
1290 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1291 SparseTensorEncodingAttr dstEnc, Value source) {
1292 auto srcStt = getSparseTensorType(source);
1293 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1294 SmallVector<int64_t> dstDimShape =
1295 dstEnc.tranlateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1296 auto dstTp =
1297 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1298 return build(odsBuilder, odsState, dstTp, source);
1301 LogicalResult ReinterpretMapOp::verify() {
1302 auto srcStt = getSparseTensorType(getSource());
1303 auto dstStt = getSparseTensorType(getDest());
1304 ArrayRef<DimLevelType> srcLvlTps = srcStt.getLvlTypes();
1305 ArrayRef<DimLevelType> dstLvlTps = dstStt.getLvlTypes();
1307 if (srcLvlTps.size() != dstLvlTps.size())
1308 return emitError("Level rank mismatch between source/dest tensors");
1310 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1311 if (srcLvlTp != dstLvlTp)
1312 return emitError("Level type mismatch between source/dest tensors");
1314 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1315 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1316 return emitError("Crd/Pos width mismatch between source/dest tensors");
1319 if (srcStt.getElementType() != dstStt.getElementType())
1320 return emitError("Element type mismatch between source/dest tensors");
1322 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1323 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1324 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1325 if (srcLvlSz != dstLvlSz) {
1326 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1327 // compatible to <3x4>? For now, we require all the level sizes to be
1328 // *exactly* matched for simplicity.
1329 return emitError("Level size mismatch between source/dest tensors");
1333 return success();
1336 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1337 if (getSource().getType() == getDest().getType())
1338 return getSource();
1340 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1341 // A -> B, B -> A ==> A
1342 if (def.getSource().getType() == getDest().getType())
1343 return def.getSource();
1345 return {};
1348 LogicalResult ToPositionsOp::verify() {
1349 auto e = getSparseTensorEncoding(getTensor().getType());
1350 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1351 return emitError("requested level is out of bounds");
1352 if (failed(isMatchingWidth(getResult(), e.getPosWidth())))
1353 return emitError("unexpected type for positions");
1354 return success();
1357 LogicalResult ToCoordinatesOp::verify() {
1358 auto e = getSparseTensorEncoding(getTensor().getType());
1359 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1360 return emitError("requested level is out of bounds");
1361 if (failed(isMatchingWidth(getResult(), e.getCrdWidth())))
1362 return emitError("unexpected type for coordinates");
1363 return success();
1366 LogicalResult ToCoordinatesBufferOp::verify() {
1367 auto e = getSparseTensorEncoding(getTensor().getType());
1368 if (getCOOStart(e) >= e.getLvlRank())
1369 return emitError("expected sparse tensor with a COO region");
1370 return success();
1373 LogicalResult ToValuesOp::verify() {
1374 auto ttp = getRankedTensorType(getTensor());
1375 auto mtp = getMemRefType(getResult());
1376 if (ttp.getElementType() != mtp.getElementType())
1377 return emitError("unexpected mismatch in element types");
1378 return success();
1381 LogicalResult ToSliceOffsetOp::verify() {
1382 auto rank = getRankedTensorType(getSlice()).getRank();
1383 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1384 return emitError("requested dimension out of bound");
1385 return success();
1388 LogicalResult ToSliceStrideOp::verify() {
1389 auto rank = getRankedTensorType(getSlice()).getRank();
1390 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1391 return emitError("requested dimension out of bound");
1392 return success();
1395 LogicalResult GetStorageSpecifierOp::verify() {
1396 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1397 getSpecifier(), getOperation());
1400 template <typename SpecifierOp>
1401 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1402 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1405 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1406 const StorageSpecifierKind kind = getSpecifierKind();
1407 const auto lvl = getLevel();
1408 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1409 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1410 return op.getValue();
1411 return {};
1414 LogicalResult SetStorageSpecifierOp::verify() {
1415 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1416 getSpecifier(), getOperation());
1419 template <class T>
1420 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1421 const char *regionName,
1422 TypeRange inputTypes, Type outputType) {
1423 unsigned numArgs = region.getNumArguments();
1424 unsigned expectedNum = inputTypes.size();
1425 if (numArgs != expectedNum)
1426 return op->emitError() << regionName << " region must have exactly "
1427 << expectedNum << " arguments";
1429 for (unsigned i = 0; i < numArgs; i++) {
1430 Type typ = region.getArgument(i).getType();
1431 if (typ != inputTypes[i])
1432 return op->emitError() << regionName << " region argument " << (i + 1)
1433 << " type mismatch";
1435 Operation *term = region.front().getTerminator();
1436 YieldOp yield = dyn_cast<YieldOp>(term);
1437 if (!yield)
1438 return op->emitError() << regionName
1439 << " region must end with sparse_tensor.yield";
1440 if (!yield.getResult() || yield.getResult().getType() != outputType)
1441 return op->emitError() << regionName << " region yield type mismatch";
1443 return success();
1446 LogicalResult BinaryOp::verify() {
1447 NamedAttrList attrs = (*this)->getAttrs();
1448 Type leftType = getX().getType();
1449 Type rightType = getY().getType();
1450 Type outputType = getOutput().getType();
1451 Region &overlap = getOverlapRegion();
1452 Region &left = getLeftRegion();
1453 Region &right = getRightRegion();
1455 // Check correct number of block arguments and return type for each
1456 // non-empty region.
1457 if (!overlap.empty()) {
1458 if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1459 TypeRange{leftType, rightType}, outputType)))
1460 return failure();
1462 if (!left.empty()) {
1463 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1464 outputType)))
1465 return failure();
1466 } else if (getLeftIdentity()) {
1467 if (leftType != outputType)
1468 return emitError("left=identity requires first argument to have the same "
1469 "type as the output");
1471 if (!right.empty()) {
1472 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1473 outputType)))
1474 return failure();
1475 } else if (getRightIdentity()) {
1476 if (rightType != outputType)
1477 return emitError("right=identity requires second argument to have the "
1478 "same type as the output");
1480 return success();
1483 LogicalResult UnaryOp::verify() {
1484 Type inputType = getX().getType();
1485 Type outputType = getOutput().getType();
1487 // Check correct number of block arguments and return type for each
1488 // non-empty region.
1489 Region &present = getPresentRegion();
1490 if (!present.empty()) {
1491 if (failed(verifyNumBlockArgs(this, present, "present",
1492 TypeRange{inputType}, outputType)))
1493 return failure();
1495 Region &absent = getAbsentRegion();
1496 if (!absent.empty()) {
1497 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1498 outputType)))
1499 return failure();
1500 // Absent branch can only yield invariant values.
1501 Block *absentBlock = &absent.front();
1502 Block *parent = getOperation()->getBlock();
1503 Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
1504 if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1505 if (arg.getOwner() == parent)
1506 return emitError("absent region cannot yield linalg argument");
1507 } else if (Operation *def = absentVal.getDefiningOp()) {
1508 if (!isa<arith::ConstantOp>(def) &&
1509 (def->getBlock() == absentBlock || def->getBlock() == parent))
1510 return emitError("absent region cannot yield locally computed value");
1513 return success();
1516 bool ConcatenateOp::needsExtraSort() {
1517 SparseTensorType dstStt = getSparseTensorType(*this);
1518 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1519 return false;
1521 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1522 return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1524 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1525 // in all input/output buffers, and all input/output buffers have the same
1526 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1527 // CSC matrices along column).
1528 bool directLowerable =
1529 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1530 return !directLowerable;
1533 LogicalResult ConcatenateOp::verify() {
1534 const auto dstTp = getSparseTensorType(*this);
1535 const Dimension concatDim = getDimension();
1536 const Dimension dimRank = dstTp.getDimRank();
1538 if (getInputs().size() <= 1)
1539 return emitError("Need at least two tensors to concatenate.");
1541 if (concatDim >= dimRank)
1542 return emitError(llvm::formatv(
1543 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1544 concatDim, dimRank));
1546 for (const auto &it : llvm::enumerate(getInputs())) {
1547 const auto i = it.index();
1548 const auto srcTp = getSparseTensorType(it.value());
1549 if (srcTp.hasDynamicDimShape())
1550 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1551 const Dimension srcDimRank = srcTp.getDimRank();
1552 if (srcDimRank != dimRank)
1553 return emitError(
1554 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1555 "from the output tensor (rank={2}).",
1556 i, srcDimRank, dimRank));
1559 for (Dimension d = 0; d < dimRank; d++) {
1560 const Size dstSh = dstTp.getDimShape()[d];
1561 if (d == concatDim) {
1562 if (!ShapedType::isDynamic(dstSh)) {
1563 // If we reach here, then all inputs have static shapes. So we
1564 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1565 // to avoid redundant assertions in the loop.
1566 Size sumSz = 0;
1567 for (const auto src : getInputs())
1568 sumSz += getSparseTensorType(src).getDimShape()[d];
1569 // If all dimension are statically known, the sum of all the input
1570 // dimensions should be equal to the output dimension.
1571 if (sumSz != dstSh)
1572 return emitError(
1573 "The concatenation dimension of the output tensor should be the "
1574 "sum of all the concatenation dimensions of the input tensors.");
1576 } else {
1577 Size prev = dstSh;
1578 for (const auto src : getInputs()) {
1579 const auto sh = getSparseTensorType(src).getDimShape()[d];
1580 if (!ShapedType::isDynamic(prev) && sh != prev)
1581 return emitError("All dimensions (expect for the concatenating one) "
1582 "should be equal.");
1583 prev = sh;
1588 return success();
1591 LogicalResult InsertOp::verify() {
1592 const auto stt = getSparseTensorType(getTensor());
1593 if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
1594 return emitOpError("incorrect number of coordinates");
1595 return success();
1598 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1599 Value curSize, Value inBuffer, Value value) {
1600 build(builder, result, curSize, inBuffer, value, Value());
1603 LogicalResult PushBackOp::verify() {
1604 if (Value n = getN()) {
1605 std::optional<int64_t> nValue = getConstantIntValue(n);
1606 if (nValue && nValue.value() < 1)
1607 return emitOpError("n must be not less than 1");
1609 return success();
1612 LogicalResult CompressOp::verify() {
1613 const auto stt = getSparseTensorType(getTensor());
1614 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1615 return emitOpError("incorrect number of coordinates");
1616 return success();
1619 void ForeachOp::build(
1620 OpBuilder &builder, OperationState &result, Value tensor,
1621 ValueRange initArgs, AffineMapAttr order,
1622 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1623 bodyBuilder) {
1624 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1625 // Builds foreach body.
1626 if (!bodyBuilder)
1627 return;
1628 const auto stt = getSparseTensorType(tensor);
1629 const Dimension dimRank = stt.getDimRank();
1631 // Starts with `dimRank`-many coordinates.
1632 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1633 // Followed by one value.
1634 blockArgTypes.push_back(stt.getElementType());
1635 // Followed by the reduction variables.
1636 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1638 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1640 OpBuilder::InsertionGuard guard(builder);
1641 auto &region = *result.regions.front();
1642 Block *bodyBlock =
1643 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1644 bodyBuilder(builder, result.location,
1645 bodyBlock->getArguments().slice(0, dimRank),
1646 bodyBlock->getArguments()[dimRank],
1647 bodyBlock->getArguments().drop_front(dimRank + 1));
1650 LogicalResult ForeachOp::verify() {
1651 const auto t = getSparseTensorType(getTensor());
1652 const Dimension dimRank = t.getDimRank();
1653 const auto args = getBody()->getArguments();
1655 if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank())
1656 return emitError("Level traverse order does not match tensor's level rank");
1658 if (dimRank + 1 + getInitArgs().size() != args.size())
1659 return emitError("Unmatched number of arguments in the block");
1661 if (getNumResults() != getInitArgs().size())
1662 return emitError("Mismatch in number of init arguments and results");
1664 if (getResultTypes() != getInitArgs().getTypes())
1665 return emitError("Mismatch in types of init arguments and results");
1667 // Cannot mark this const, because the getters aren't.
1668 auto yield = cast<YieldOp>(getBody()->getTerminator());
1669 if (yield.getNumOperands() != getNumResults() ||
1670 yield.getOperands().getTypes() != getResultTypes())
1671 return emitError("Mismatch in types of yield values and results");
1673 const auto iTp = IndexType::get(getContext());
1674 for (Dimension d = 0; d < dimRank; d++)
1675 if (args[d].getType() != iTp)
1676 emitError(
1677 llvm::formatv("Expecting Index type for argument at index {0}", d));
1679 const auto elemTp = t.getElementType();
1680 const auto valueTp = args[dimRank].getType();
1681 if (elemTp != valueTp)
1682 emitError(llvm::formatv("Unmatched element type between input tensor and "
1683 "block argument, expected:{0}, got: {1}",
1684 elemTp, valueTp));
1685 return success();
1688 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1689 if (getSparseTensorEncoding(getInputCoo().getType()) ==
1690 getSparseTensorEncoding(getResultCoo().getType()))
1691 return getInputCoo();
1693 return {};
1696 LogicalResult ReorderCOOOp::verify() {
1697 SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1698 SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1700 if (!srcStt.hasSameDimToLvl(dstStt))
1701 emitError("Unmatched dim2lvl map between input and result COO");
1703 if (srcStt.getPosType() != dstStt.getPosType() ||
1704 srcStt.getCrdType() != dstStt.getCrdType() ||
1705 srcStt.getElementType() != dstStt.getElementType()) {
1706 emitError("Unmatched storage format between input and result COO");
1708 return success();
1711 LogicalResult ReduceOp::verify() {
1712 Type inputType = getX().getType();
1713 Region &formula = getRegion();
1714 return verifyNumBlockArgs(this, formula, "reduce",
1715 TypeRange{inputType, inputType}, inputType);
1718 LogicalResult SelectOp::verify() {
1719 Builder b(getContext());
1720 Type inputType = getX().getType();
1721 Type boolType = b.getI1Type();
1722 Region &formula = getRegion();
1723 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
1724 boolType);
1727 LogicalResult SortOp::verify() {
1728 AffineMap xPerm = getPermMap();
1729 uint64_t nx = xPerm.getNumDims();
1730 if (nx < 1)
1731 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
1733 if (!xPerm.isPermutation())
1734 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
1736 // We can't check the size of the buffers when n or buffer dimensions aren't
1737 // compile-time constants.
1738 std::optional<int64_t> cn = getConstantIntValue(getN());
1739 if (!cn)
1740 return success();
1742 // Verify dimensions.
1743 const auto checkDim = [&](Value v, Size minSize, const char *message) {
1744 const Size sh = getMemRefType(v).getShape()[0];
1745 if (!ShapedType::isDynamic(sh) && sh < minSize)
1746 emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
1748 uint64_t n = cn.value();
1749 uint64_t ny = 0;
1750 if (auto nyAttr = getNyAttr())
1751 ny = nyAttr.getInt();
1752 checkDim(getXy(), n * (nx + ny),
1753 "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1754 for (Value opnd : getYs())
1755 checkDim(opnd, n, "Expected dimension(y) >= n");
1757 return success();
1760 LogicalResult YieldOp::verify() {
1761 // Check for compatible parent.
1762 auto *parentOp = (*this)->getParentOp();
1763 if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
1764 isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
1765 isa<ForeachOp>(parentOp))
1766 return success();
1768 return emitOpError("expected parent op to be sparse_tensor unary, binary, "
1769 "reduce, select or foreach");
1772 /// Materialize a single constant operation from a given attribute value with
1773 /// the desired resultant type.
1774 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
1775 Attribute value, Type type,
1776 Location loc) {
1777 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
1778 return op;
1779 return nullptr;
1782 namespace {
1783 struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface {
1784 using OpAsmDialectInterface::OpAsmDialectInterface;
1786 AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
1787 if (attr.isa<SparseTensorEncodingAttr>()) {
1788 os << "sparse";
1789 return AliasResult::OverridableAlias;
1791 return AliasResult::NoAlias;
1794 } // namespace
1796 void SparseTensorDialect::initialize() {
1797 addInterface<SparseTensorAsmDialectInterface>();
1798 addAttributes<
1799 #define GET_ATTRDEF_LIST
1800 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
1801 >();
1802 addTypes<
1803 #define GET_TYPEDEF_LIST
1804 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
1805 >();
1806 addOperations<
1807 #define GET_OP_LIST
1808 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1809 >();
1812 #define GET_OP_CLASSES
1813 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1815 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"