[mlir][sparse] cleanup of enums header (#71090)
[llvm-project.git] / mlir / lib / Dialect / SparseTensor / IR / SparseTensorDialect.cpp
blobc727b8d05c26d7dc046ecb245546de1c9bd390c3
1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <utility>
11 #include "Detail/DimLvlMapParser.h"
13 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
14 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
15 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Utils/StaticValueUtils.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/FormatVariadic.h"
28 #define GET_ATTRDEF_CLASSES
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
32 #define GET_TYPEDEF_CLASSES
33 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
35 using namespace mlir;
36 using namespace mlir::sparse_tensor;
38 //===----------------------------------------------------------------------===//
39 // Local convenience methods.
40 //===----------------------------------------------------------------------===//
42 static constexpr bool acceptBitWidth(unsigned bitWidth) {
43 switch (bitWidth) {
44 case 0:
45 case 8:
46 case 16:
47 case 32:
48 case 64:
49 return true;
50 default:
51 return false;
55 //===----------------------------------------------------------------------===//
56 // SparseTensorDialect StorageLayout.
57 //===----------------------------------------------------------------------===//
59 static constexpr Level kInvalidLevel = -1u;
60 static constexpr Level kInvalidFieldIndex = -1u;
61 static constexpr FieldIndex kDataFieldStartingIdx = 0;
63 void StorageLayout::foreachField(
64 llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
65 DimLevelType)>
66 callback) const {
67 const auto lvlTypes = enc.getLvlTypes();
68 const Level lvlRank = enc.getLvlRank();
69 const Level cooStart = getCOOStart(enc);
70 const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
71 FieldIndex fieldIdx = kDataFieldStartingIdx;
72 // Per-level storage.
73 for (Level l = 0; l < end; l++) {
74 const auto dlt = lvlTypes[l];
75 if (isDLTWithPos(dlt)) {
76 if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt)))
77 return;
79 if (isDLTWithCrd(dlt)) {
80 if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt)))
81 return;
84 // The values array.
85 if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
86 DimLevelType::Undef)))
87 return;
88 // Put metadata at the end.
89 if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
90 DimLevelType::Undef)))
91 return;
94 void sparse_tensor::foreachFieldAndTypeInSparseTensor(
95 SparseTensorType stt,
96 llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level,
97 DimLevelType)>
98 callback) {
99 assert(stt.hasEncoding());
100 // Construct the basic types.
101 const Type crdType = stt.getCrdType();
102 const Type posType = stt.getPosType();
103 const Type eltType = stt.getElementType();
105 const Type specType = StorageSpecifierType::get(stt.getEncoding());
106 // memref<? x pos> positions
107 const Type posMemType = MemRefType::get({ShapedType::kDynamic}, posType);
108 // memref<? x crd> coordinates
109 const Type crdMemType = MemRefType::get({ShapedType::kDynamic}, crdType);
110 // memref<? x eltType> values
111 const Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
113 StorageLayout(stt).foreachField(
114 [specType, posMemType, crdMemType, valMemType,
115 callback](FieldIndex fieldIdx, SparseTensorFieldKind fieldKind,
116 Level lvl, DimLevelType dlt) -> bool {
117 switch (fieldKind) {
118 case SparseTensorFieldKind::StorageSpec:
119 return callback(specType, fieldIdx, fieldKind, lvl, dlt);
120 case SparseTensorFieldKind::PosMemRef:
121 return callback(posMemType, fieldIdx, fieldKind, lvl, dlt);
122 case SparseTensorFieldKind::CrdMemRef:
123 return callback(crdMemType, fieldIdx, fieldKind, lvl, dlt);
124 case SparseTensorFieldKind::ValMemRef:
125 return callback(valMemType, fieldIdx, fieldKind, lvl, dlt);
127 llvm_unreachable("unrecognized field kind");
131 unsigned StorageLayout::getNumFields() const {
132 unsigned numFields = 0;
133 foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level,
134 DimLevelType) -> bool {
135 numFields++;
136 return true;
138 return numFields;
141 unsigned StorageLayout::getNumDataFields() const {
142 unsigned numFields = 0; // one value memref
143 foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level,
144 DimLevelType) -> bool {
145 if (fidx >= kDataFieldStartingIdx)
146 numFields++;
147 return true;
149 numFields -= 1; // the last field is StorageSpecifier
150 assert(numFields == getNumFields() - kDataFieldStartingIdx - 1);
151 return numFields;
154 std::pair<FieldIndex, unsigned>
155 StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
156 std::optional<Level> lvl) const {
157 FieldIndex fieldIdx = kInvalidFieldIndex;
158 unsigned stride = 1;
159 if (kind == SparseTensorFieldKind::CrdMemRef) {
160 assert(lvl.has_value());
161 const Level cooStart = getCOOStart(enc);
162 const Level lvlRank = enc.getLvlRank();
163 if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
164 lvl = cooStart;
165 stride = lvlRank - cooStart;
168 foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx,
169 SparseTensorFieldKind fKind, Level fLvl,
170 DimLevelType dlt) -> bool {
171 if ((lvl && fLvl == lvl.value() && kind == fKind) ||
172 (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) {
173 fieldIdx = fIdx;
174 // Returns false to break the iteration.
175 return false;
177 return true;
179 assert(fieldIdx != kInvalidFieldIndex);
180 return std::pair<FieldIndex, unsigned>(fieldIdx, stride);
183 //===----------------------------------------------------------------------===//
184 // SparseTensorDialect Attribute Methods.
185 //===----------------------------------------------------------------------===//
187 std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
188 return isDynamic(v) ? std::nullopt
189 : std::make_optional(static_cast<uint64_t>(v));
192 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const {
193 return getStatic(getOffset());
196 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const {
197 return getStatic(getStride());
200 std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const {
201 return getStatic(getSize());
204 bool SparseTensorDimSliceAttr::isCompletelyDynamic() const {
205 return isDynamic(getOffset()) && isDynamic(getStride()) &&
206 isDynamic(getSize());
209 std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) {
210 return isDynamic(v) ? "?" : std::to_string(v);
213 void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const {
214 assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr");
215 os << '(';
216 os << getStaticString(getOffset());
217 os << ", ";
218 os << getStaticString(getSize());
219 os << ", ";
220 os << getStaticString(getStride());
221 os << ')';
224 void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const {
225 print(printer.getStream());
228 static ParseResult parseOptionalStaticSlice(int64_t &result,
229 AsmParser &parser) {
230 auto parseResult = parser.parseOptionalInteger(result);
231 if (parseResult.has_value()) {
232 if (parseResult.value().succeeded() && result < 0) {
233 parser.emitError(
234 parser.getCurrentLocation(),
235 "expect positive value or ? for slice offset/size/stride");
236 return failure();
238 return parseResult.value();
241 // Else, and '?' which represented dynamic slice
242 result = SparseTensorDimSliceAttr::kDynamic;
243 return parser.parseQuestion();
246 Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) {
247 int64_t offset = kDynamic, size = kDynamic, stride = kDynamic;
249 if (failed(parser.parseLParen()) ||
250 failed(parseOptionalStaticSlice(offset, parser)) ||
251 failed(parser.parseComma()) ||
252 failed(parseOptionalStaticSlice(size, parser)) ||
253 failed(parser.parseComma()) ||
254 failed(parseOptionalStaticSlice(stride, parser)) ||
255 failed(parser.parseRParen()))
256 return {};
258 return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(),
259 offset, size, stride);
262 LogicalResult
263 SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
264 int64_t offset, int64_t size, int64_t stride) {
265 if (!isDynamic(offset) && offset < 0)
266 return emitError() << "expect non-negative value or ? for slice offset";
267 if (!isDynamic(size) && size <= 0)
268 return emitError() << "expect positive value or ? for slice size";
269 if (!isDynamic(stride) && stride <= 0)
270 return emitError() << "expect positive value or ? for slice stride";
271 return success();
274 SparseTensorEncodingAttr
275 SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
276 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
277 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(), dimToLvl,
278 AffineMap(), getPosWidth(),
279 getCrdWidth());
282 SparseTensorEncodingAttr
283 SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const {
284 return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap());
287 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const {
288 return withDimToLvl(AffineMap());
291 SparseTensorEncodingAttr
292 SparseTensorEncodingAttr::withBitWidths(unsigned posWidth,
293 unsigned crdWidth) const {
294 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
295 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
296 getDimToLvl(), getLvlToDim(), posWidth,
297 crdWidth);
300 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const {
301 return withBitWidths(0, 0);
304 SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices(
305 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
306 return SparseTensorEncodingAttr::get(getContext(), getLvlTypes(),
307 getDimToLvl(), getLvlToDim(),
308 getPosWidth(), getCrdWidth(), dimSlices);
311 SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const {
312 return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{});
315 bool SparseTensorEncodingAttr::isAllDense() const {
316 return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
319 bool SparseTensorEncodingAttr::isCOO() const {
320 return getImpl() && isCOOType(*this, 0, true);
323 bool SparseTensorEncodingAttr::isAllOrdered() const {
324 return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedDLT);
327 bool SparseTensorEncodingAttr::isIdentity() const {
328 return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
331 bool SparseTensorEncodingAttr::isPermutation() const {
332 return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation();
335 Dimension SparseTensorEncodingAttr::getDimRank() const {
336 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
337 const auto dimToLvl = getDimToLvl();
338 return dimToLvl ? dimToLvl.getNumDims() : getLvlRank();
341 Level SparseTensorEncodingAttr::getLvlRank() const {
342 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
343 return getLvlTypes().size();
346 DimLevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
347 if (!getImpl())
348 return DimLevelType::Dense;
349 assert(l < getLvlRank() && "Level is out of bounds");
350 return getLvlTypes()[l];
353 bool SparseTensorEncodingAttr::isSlice() const {
354 assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
355 return !getDimSlices().empty();
358 SparseTensorDimSliceAttr
359 SparseTensorEncodingAttr::getDimSlice(Dimension dim) const {
360 assert(isSlice() && "Is not a slice");
361 const auto dimSlices = getDimSlices();
362 assert(dim < dimSlices.size() && "Dimension is out of bounds");
363 return dimSlices[dim];
366 std::optional<uint64_t>
367 SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const {
368 return getDimSlice(dim).getStaticOffset();
371 std::optional<uint64_t>
372 SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const {
373 return getDimSlice(dim).getStaticSize();
376 std::optional<uint64_t>
377 SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const {
378 return getDimSlice(dim).getStaticStride();
381 std::optional<uint64_t>
382 SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const {
383 // FIXME: `toOrigDim` is deprecated.
384 return getStaticDimSliceOffset(toOrigDim(*this, lvl));
387 std::optional<uint64_t>
388 SparseTensorEncodingAttr::getStaticLvlSliceSize(Level lvl) const {
389 // FIXME: `toOrigDim` is deprecated.
390 return getStaticDimSliceSize(toOrigDim(*this, lvl));
393 std::optional<uint64_t>
394 SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
395 // FIXME: `toOrigDim` is deprecated.
396 return getStaticDimSliceStride(toOrigDim(*this, lvl));
399 SmallVector<int64_t>
400 SparseTensorEncodingAttr::tranlateShape(ArrayRef<int64_t> srcShape,
401 CrdTransDirectionKind dir) const {
402 if (isIdentity())
403 return SmallVector<int64_t>(srcShape);
405 SmallVector<int64_t> ret;
406 unsigned rank =
407 dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank();
408 ret.reserve(rank);
410 if (isPermutation()) {
411 for (unsigned r = 0; r < rank; r++) {
412 unsigned trans = dir == CrdTransDirectionKind::dim2lvl
413 ? toOrigDim(*this, r)
414 : toStoredDim(*this, r);
415 ret.push_back(srcShape[trans]);
417 return ret;
420 // Handle non-permutation maps.
421 AffineMap transMap =
422 dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim();
424 SmallVector<AffineExpr> dimRep;
425 dimRep.reserve(srcShape.size());
426 for (int64_t sz : srcShape) {
427 if (!ShapedType::isDynamic(sz)) {
428 // Push back the max coordinate for the given dimension/level size.
429 dimRep.push_back(getAffineConstantExpr(sz - 1, getContext()));
430 } else {
431 // A dynamic size, use a AffineDimExpr to symbolize the value.
432 dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext()));
436 for (AffineExpr exp : transMap.getResults()) {
437 // Do constant propagation on the affine map.
438 AffineExpr evalExp =
439 simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0);
440 if (auto c = evalExp.dyn_cast<AffineConstantExpr>()) {
441 ret.push_back(c.getValue() + 1);
442 } else {
443 if (auto mod = evalExp.dyn_cast<AffineBinaryOpExpr>();
444 mod && mod.getKind() == AffineExprKind::Mod) {
445 // We can still infer a static bound for expressions in form
446 // "d % constant" since d % constant \in [0, constant).
447 if (auto bound = mod.getRHS().dyn_cast<AffineConstantExpr>()) {
448 ret.push_back(bound.getValue());
449 continue;
452 ret.push_back(ShapedType::kDynamic);
455 assert(ret.size() == rank);
456 return ret;
459 ValueRange
460 SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
461 ValueRange crds,
462 CrdTransDirectionKind dir) const {
463 if (!getImpl())
464 return crds;
466 SmallVector<Type> retType(
467 dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
468 builder.getIndexType());
469 auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
470 return transOp.getOutCrds();
473 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
474 // Open "<{" part.
475 if (failed(parser.parseLess()))
476 return {};
477 if (failed(parser.parseLBrace()))
478 return {};
480 // Process the data from the parsed dictionary value into struct-like data.
481 SmallVector<DimLevelType> lvlTypes;
482 SmallVector<SparseTensorDimSliceAttr> dimSlices;
483 AffineMap dimToLvl = {};
484 AffineMap lvlToDim = {};
485 unsigned posWidth = 0;
486 unsigned crdWidth = 0;
487 StringRef attrName;
488 SmallVector<StringRef, 3> keys = {"map", "posWidth", "crdWidth"};
489 while (succeeded(parser.parseOptionalKeyword(&attrName))) {
490 // Detect admissible keyword.
491 auto *it = find(keys, attrName);
492 if (it == keys.end()) {
493 parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName;
494 return {};
496 unsigned keyWordIndex = it - keys.begin();
497 // Consume the `=` after keys
498 if (failed(parser.parseEqual()))
499 return {};
500 // Dispatch on keyword.
501 switch (keyWordIndex) {
502 case 0: { // map
503 ir_detail::DimLvlMapParser cParser(parser);
504 auto res = cParser.parseDimLvlMap();
505 if (failed(res))
506 return {};
507 const auto &dlm = *res;
509 const Level lvlRank = dlm.getLvlRank();
510 for (Level lvl = 0; lvl < lvlRank; lvl++)
511 lvlTypes.push_back(dlm.getLvlType(lvl));
513 const Dimension dimRank = dlm.getDimRank();
514 for (Dimension dim = 0; dim < dimRank; dim++)
515 dimSlices.push_back(dlm.getDimSlice(dim));
516 // NOTE: the old syntax requires an all-or-nothing approach to
517 // `dimSlices`; therefore, if any slice actually exists then we need
518 // to convert null-DSA into default/nop DSA.
519 const auto isDefined = [](SparseTensorDimSliceAttr slice) {
520 return static_cast<bool>(slice.getImpl());
522 if (llvm::any_of(dimSlices, isDefined)) {
523 const auto defaultSlice =
524 SparseTensorDimSliceAttr::get(parser.getContext());
525 for (Dimension dim = 0; dim < dimRank; dim++)
526 if (!isDefined(dimSlices[dim]))
527 dimSlices[dim] = defaultSlice;
528 } else {
529 dimSlices.clear();
532 dimToLvl = dlm.getDimToLvlMap(parser.getContext());
533 lvlToDim = dlm.getLvlToDimMap(parser.getContext());
534 break;
536 case 1: { // posWidth
537 Attribute attr;
538 if (failed(parser.parseAttribute(attr)))
539 return {};
540 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
541 if (!intAttr) {
542 parser.emitError(parser.getNameLoc(),
543 "expected an integral position bitwidth");
544 return {};
546 posWidth = intAttr.getInt();
547 break;
549 case 2: { // crdWidth
550 Attribute attr;
551 if (failed(parser.parseAttribute(attr)))
552 return {};
553 auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
554 if (!intAttr) {
555 parser.emitError(parser.getNameLoc(),
556 "expected an integral index bitwidth");
557 return {};
559 crdWidth = intAttr.getInt();
560 break;
562 } // switch
563 // Only last item can omit the comma.
564 if (parser.parseOptionalComma().failed())
565 break;
568 // Close "}>" part.
569 if (failed(parser.parseRBrace()))
570 return {};
571 if (failed(parser.parseGreater()))
572 return {};
574 // Construct struct-like storage for attribute.
575 if (!lvlToDim || lvlToDim.isEmpty()) {
576 lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
578 return parser.getChecked<SparseTensorEncodingAttr>(
579 parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
580 dimSlices);
583 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
584 auto map = static_cast<AffineMap>(getDimToLvl());
585 // Empty affine map indicates identity map
586 if (!map)
587 map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext());
588 printer << "<{ map = ";
589 printSymbols(map, printer);
590 printer << '(';
591 printDimensions(map, printer, getDimSlices());
592 printer << ") -> (";
593 printLevels(map, printer, getLvlTypes());
594 printer << ')';
595 // Print remaining members only for non-default values.
596 if (getPosWidth())
597 printer << ", posWidth = " << getPosWidth();
598 if (getCrdWidth())
599 printer << ", crdWidth = " << getCrdWidth();
600 printer << " }>";
603 void SparseTensorEncodingAttr::printSymbols(AffineMap &map,
604 AsmPrinter &printer) const {
605 if (map.getNumSymbols() == 0)
606 return;
607 printer << '[';
608 for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++)
609 printer << 's' << i << ", ";
610 if (map.getNumSymbols() >= 1)
611 printer << 's' << map.getNumSymbols() - 1;
612 printer << ']';
615 void SparseTensorEncodingAttr::printDimensions(
616 AffineMap &map, AsmPrinter &printer,
617 ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
618 if (!dimSlices.empty()) {
619 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
620 printer << 'd' << i << " : " << dimSlices[i] << ", ";
621 if (map.getNumDims() >= 1) {
622 printer << 'd' << map.getNumDims() - 1 << " : "
623 << dimSlices[map.getNumDims() - 1];
625 } else {
626 for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++)
627 printer << 'd' << i << ", ";
628 if (map.getNumDims() >= 1)
629 printer << 'd' << map.getNumDims() - 1;
633 void SparseTensorEncodingAttr::printLevels(
634 AffineMap &map, AsmPrinter &printer,
635 ArrayRef<DimLevelType> lvlTypes) const {
636 for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
637 map.getResult(i).print(printer.getStream());
638 printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
640 if (map.getNumResults() >= 1) {
641 auto lastIndex = map.getNumResults() - 1;
642 map.getResult(lastIndex).print(printer.getStream());
643 printer << " : " << toMLIRString(lvlTypes[lastIndex]);
647 LogicalResult
648 SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
649 ArrayRef<DimLevelType> lvlTypes,
650 AffineMap dimToLvl, AffineMap lvlToDim,
651 unsigned posWidth, unsigned crdWidth,
652 ArrayRef<SparseTensorDimSliceAttr> dimSlices) {
653 if (!acceptBitWidth(posWidth))
654 return emitError() << "unexpected position bitwidth: " << posWidth;
655 if (!acceptBitWidth(crdWidth))
656 return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
657 if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonDLT);
658 it != std::end(lvlTypes)) {
659 if (it == lvlTypes.begin() ||
660 (!isCompressedDLT(*(it - 1)) && !isLooseCompressedDLT(*(it - 1))))
661 return emitError() << "expected compressed or loose_compressed level "
662 "before singleton level";
663 if (!std::all_of(it, lvlTypes.end(),
664 [](DimLevelType i) { return isSingletonDLT(i); }))
665 return emitError() << "expected all singleton lvlTypes "
666 "following a singleton level";
668 // Before we can check that the level-rank is consistent/coherent
669 // across all fields, we need to define it. The source-of-truth for
670 // the `getLvlRank` method is the length of the level-types array,
671 // since it must always be provided and have full rank; therefore we
672 // use that same source-of-truth here.
673 const Level lvlRank = lvlTypes.size();
674 if (lvlRank == 0)
675 return emitError() << "expected a non-empty array for lvlTypes";
676 // We save `dimRank` here because we'll also need it to verify `dimSlices`.
677 const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank;
678 if (dimToLvl) {
679 if (dimToLvl.getNumResults() != lvlRank)
680 return emitError()
681 << "level-rank mismatch between dimToLvl and lvlTypes: "
682 << dimToLvl.getNumResults() << " != " << lvlRank;
683 auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
684 // Symbols can't be inferred but are acceptable.
685 if (!inferRes && dimToLvl.getNumSymbols() == 0)
686 return emitError() << "failed to infer lvlToDim from dimToLvl";
687 if (lvlToDim && (inferRes != lvlToDim))
688 return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
689 if (dimRank > lvlRank)
690 return emitError() << "unexpected dimToLvl mapping from " << dimRank
691 << " to " << lvlRank;
693 if (!dimSlices.empty()) {
694 if (dimSlices.size() != dimRank)
695 return emitError()
696 << "dimension-rank mismatch between dimSlices and dimToLvl: "
697 << dimSlices.size() << " != " << dimRank;
698 // Compiler support for `dimSlices` currently requires that the two
699 // ranks agree. (However, it does allow `dimToLvl` to be a permutation.)
700 if (dimRank != lvlRank)
701 return emitError()
702 << "dimSlices expected dimension-rank to match level-rank: "
703 << dimRank << " != " << lvlRank;
705 return success();
708 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
709 ArrayRef<Size> dimShape, Type elementType,
710 function_ref<InFlightDiagnostic()> emitError) const {
711 // Check structural integrity. In particular, this ensures that the
712 // level-rank is coherent across all the fields.
713 if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
714 getPosWidth(), getCrdWidth(), getDimSlices())))
715 return failure();
716 // Check integrity with tensor type specifics. In particular, we
717 // need only check that the dimension-rank of the tensor agrees with
718 // the dimension-rank of the encoding.
719 const Dimension dimRank = dimShape.size();
720 if (dimRank == 0)
721 return emitError() << "expected non-scalar sparse tensor";
722 if (getDimRank() != dimRank)
723 return emitError()
724 << "dimension-rank mismatch between encoding and tensor shape: "
725 << getDimRank() << " != " << dimRank;
726 return success();
729 //===----------------------------------------------------------------------===//
730 // Convenience methods.
731 //===----------------------------------------------------------------------===//
733 SparseTensorEncodingAttr
734 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
735 if (auto ttp = llvm::dyn_cast<RankedTensorType>(type))
736 return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding());
737 if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type))
738 return mdtp.getEncoding();
739 return nullptr;
742 AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
743 MLIRContext *context) {
744 auto map = static_cast<AffineMap>(dimToLvl);
745 AffineMap lvlToDim;
746 // Return an empty lvlToDim when inference is not successful.
747 if (!map || map.getNumSymbols() != 0) {
748 lvlToDim = AffineMap();
749 } else if (map.isPermutation()) {
750 lvlToDim = inversePermutation(map);
751 } else if (isBlockSparsity(map)) {
752 lvlToDim = inverseBlockSparsity(map, context);
754 return lvlToDim;
757 AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
758 MLIRContext *context) {
759 SmallVector<AffineExpr> lvlExprs;
760 auto numLvls = dimToLvl.getNumResults();
761 lvlExprs.reserve(numLvls);
762 // lvlExprComponents stores information of the floordiv and mod operations
763 // applied to the same dimension, so as to build the lvlToDim map.
764 std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents;
765 for (unsigned i = 0, n = numLvls; i < n; i++) {
766 auto result = dimToLvl.getResult(i);
767 if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
768 if (result.getKind() == AffineExprKind::FloorDiv) {
769 // Position of the dimension in dimToLvl.
770 auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
771 assert(lvlExprComponents.find(pos) == lvlExprComponents.end() &&
772 "expected only one floordiv for each dimension");
773 SmallVector<AffineExpr, 3> components;
774 // Level variable for floordiv.
775 components.push_back(getAffineDimExpr(i, context));
776 // Multiplier.
777 components.push_back(binOp.getRHS());
778 // Map key is the position of the dimension.
779 lvlExprComponents[pos] = components;
780 } else if (result.getKind() == AffineExprKind::Mod) {
781 auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
782 assert(lvlExprComponents.find(pos) != lvlExprComponents.end() &&
783 "expected floordiv before mod");
784 // Add level variable for mod to the same vector
785 // of the corresponding floordiv.
786 lvlExprComponents[pos].push_back(getAffineDimExpr(i, context));
787 } else {
788 assert(false && "expected floordiv or mod");
790 } else {
791 lvlExprs.push_back(getAffineDimExpr(i, context));
794 // Build lvlExprs from lvlExprComponents.
795 // For example, for il = i floordiv 2 and ii = i mod 2, the components
796 // would be [il, 2, ii]. It could be used to build the AffineExpr
797 // i = il * 2 + ii in lvlToDim.
798 for (auto &components : lvlExprComponents) {
799 assert(components.second.size() == 3 &&
800 "expected 3 components to build lvlExprs");
801 auto mulOp = getAffineBinaryOpExpr(
802 AffineExprKind::Mul, components.second[0], components.second[1]);
803 auto addOp =
804 getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]);
805 lvlExprs.push_back(addOp);
807 return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
810 SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
811 assert(isBlockSparsity(dimToLvl) &&
812 "expected dimToLvl to be block sparsity for calling getBlockSize");
813 SmallVector<unsigned> blockSize;
814 for (auto result : dimToLvl.getResults()) {
815 if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
816 if (result.getKind() == AffineExprKind::Mod) {
817 blockSize.push_back(
818 binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue());
820 } else {
821 blockSize.push_back(0);
824 return blockSize;
827 bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
828 if (!dimToLvl)
829 return false;
830 std::map<unsigned, int64_t> coeffientMap;
831 for (auto result : dimToLvl.getResults()) {
832 if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
833 auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
834 if (result.getKind() == AffineExprKind::FloorDiv) {
835 // Expect only one floordiv for each dimension.
836 if (coeffientMap.find(pos) != coeffientMap.end())
837 return false;
838 coeffientMap[pos] =
839 binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue();
840 } else if (result.getKind() == AffineExprKind::Mod) {
841 // Expect floordiv before mod.
842 if (coeffientMap.find(pos) == coeffientMap.end())
843 return false;
844 // Expect mod to have the same coefficient as floordiv.
845 if (binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue() !=
846 coeffientMap[pos]) {
847 return false;
849 } else {
850 return false;
854 return !coeffientMap.empty();
857 bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
858 Level startLvl, bool isUnique) {
859 if (!enc ||
860 !(enc.isCompressedLvl(startLvl) || enc.isLooseCompressedLvl(startLvl)))
861 return false;
862 const Level lvlRank = enc.getLvlRank();
863 for (Level l = startLvl + 1; l < lvlRank; ++l)
864 if (!enc.isSingletonLvl(l))
865 return false;
866 // If isUnique is true, then make sure that the last level is unique,
867 // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
868 // (unique on the last singleton).
869 return !isUnique || enc.isUniqueLvl(lvlRank - 1);
872 bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
873 return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
876 Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
877 // We only consider COO region with at least two levels for the purpose
878 // of AOS storage optimization.
879 const Level lvlRank = enc.getLvlRank();
880 if (lvlRank > 1)
881 for (Level l = 0; l < lvlRank - 1; l++)
882 if (isCOOType(enc, l, /*isUnique=*/false))
883 return l;
884 return lvlRank;
887 // Helpers to setup a COO type.
888 RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
889 AffineMap lvlPerm,
890 bool ordered) {
891 const SparseTensorType src(rtt);
892 const Level lvlRank = src.getLvlRank();
893 SmallVector<DimLevelType> lvlTypes;
894 lvlTypes.reserve(lvlRank);
896 // An unordered and non-unique compressed level at beginning.
897 // If this is also the last level, then it is unique.
898 lvlTypes.push_back(
899 *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
900 if (lvlRank > 1) {
901 // TODO: it is actually ordered at the level for ordered input.
902 // Followed by unordered non-unique n-2 singleton levels.
903 std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
904 *buildLevelType(LevelFormat::Singleton, ordered, false));
905 // Ends by a unique singleton level unless the lvlRank is 1.
906 lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
909 // TODO: Maybe pick the bitwidth based on input/output tensors (probably the
910 // largest one among them) in the original operation instead of using the
911 // default value.
912 unsigned posWidth = src.getPosWidth();
913 unsigned crdWidth = src.getCrdWidth();
914 AffineMap invPerm = src.getLvlToDim();
915 auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
916 invPerm, posWidth, crdWidth);
917 return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
920 RankedTensorType sparse_tensor::getCOOFromType(RankedTensorType src,
921 bool ordered) {
922 return getCOOFromTypeWithOrdering(
923 src, AffineMap::getMultiDimIdentityMap(src.getRank(), src.getContext()),
924 ordered);
927 // TODO: Remove this definition once all use-sites have been fixed to
928 // properly handle non-permutations.
929 Dimension mlir::sparse_tensor::toOrigDim(SparseTensorEncodingAttr enc,
930 Level l) {
931 if (enc) {
932 if (const auto dimToLvl = enc.getDimToLvl()) {
933 assert(enc.isPermutation());
934 return dimToLvl.getDimPosition(l);
937 return l;
940 // TODO: Remove this definition once all use-sites have been fixed to
941 // properly handle non-permutations.
942 Level mlir::sparse_tensor::toStoredDim(SparseTensorEncodingAttr enc,
943 Dimension d) {
944 if (enc) {
945 if (const auto dimToLvl = enc.getDimToLvl()) {
946 assert(enc.isPermutation());
947 auto maybePos =
948 dimToLvl.getResultPosition(getAffineDimExpr(d, enc.getContext()));
949 assert(maybePos.has_value());
950 return *maybePos;
953 return d;
956 // TODO: Remove this definition once all use-sites have been fixed to
957 // properly handle non-permutations.
958 Dimension mlir::sparse_tensor::toOrigDim(RankedTensorType type, Level l) {
959 const auto enc = getSparseTensorEncoding(type);
960 assert(!enc || l < enc.getLvlRank());
961 return toOrigDim(enc, l);
964 // TODO: Remove this definition once all use-sites have been fixed to
965 // properly handle non-permutations.
966 Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
967 assert(d < static_cast<Dimension>(type.getRank()));
968 return toStoredDim(getSparseTensorEncoding(type), d);
971 /// We normalized sparse tensor encoding attribute by always using
972 /// ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
973 /// as other variants) lead to the same storage specifier type, and stripping
974 /// irrelevant fields that do not alter the sparse tensor memory layout.
975 static SparseTensorEncodingAttr
976 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
977 SmallVector<DimLevelType> dlts;
978 for (auto dlt : enc.getLvlTypes())
979 dlts.push_back(*buildLevelType(*getLevelFormat(dlt), true, true));
981 return SparseTensorEncodingAttr::get(
982 enc.getContext(), dlts,
983 AffineMap(), // dimToLvl (irrelevant to storage specifier)
984 AffineMap(), // lvlToDim (irrelevant to storage specifier)
985 // Always use `index` for memSize and lvlSize instead of reusing
986 // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA
987 // value for different bitwidth, it also avoids casting between index and
988 // integer (returned by DimOp)
989 0, 0, enc.getDimSlices());
992 StorageSpecifierType
993 StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
994 return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
997 //===----------------------------------------------------------------------===//
998 // SparseTensorDialect Operations.
999 //===----------------------------------------------------------------------===//
1001 static LogicalResult lvlIsInBounds(Level lvl, Value tensor) {
1002 return success(lvl < getSparseTensorType(tensor).getLvlRank());
1005 static LogicalResult isMatchingWidth(Value mem, unsigned width) {
1006 const Type etp = getMemRefType(mem).getElementType();
1007 return success(width == 0 ? etp.isIndex() : etp.isInteger(width));
1010 static LogicalResult verifySparsifierGetterSetter(
1011 StorageSpecifierKind mdKind, std::optional<Level> lvl,
1012 TypedValue<StorageSpecifierType> md, Operation *op) {
1013 if (mdKind == StorageSpecifierKind::ValMemSize && lvl) {
1014 return op->emitError(
1015 "redundant level argument for querying value memory size");
1018 const auto enc = md.getType().getEncoding();
1019 const Level lvlRank = enc.getLvlRank();
1021 if (mdKind == StorageSpecifierKind::DimOffset ||
1022 mdKind == StorageSpecifierKind::DimStride)
1023 if (!enc.isSlice())
1024 return op->emitError("requested slice data on non-slice tensor");
1026 if (mdKind != StorageSpecifierKind::ValMemSize) {
1027 if (!lvl)
1028 return op->emitError("missing level argument");
1030 const Level l = lvl.value();
1031 if (l >= lvlRank)
1032 return op->emitError("requested level is out of bounds");
1034 if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l))
1035 return op->emitError(
1036 "requested position memory size on a singleton level");
1038 return success();
1041 static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) {
1042 switch (kind) {
1043 case SparseTensorFieldKind::CrdMemRef:
1044 return stt.getCrdType();
1045 case SparseTensorFieldKind::PosMemRef:
1046 return stt.getPosType();
1047 case SparseTensorFieldKind::ValMemRef:
1048 return stt.getElementType();
1049 case SparseTensorFieldKind::StorageSpec:
1050 return nullptr;
1052 llvm_unreachable("Unrecognizable FieldKind");
1055 static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1056 SparseTensorType stt,
1057 RankedTensorType valTp,
1058 TypeRange lvlTps) {
1059 if (requiresStaticShape && !stt.hasStaticDimShape())
1060 return op->emitError("the sparse-tensor must have static shape");
1061 if (!stt.hasEncoding())
1062 return op->emitError("the sparse-tensor must have an encoding attribute");
1063 if (!stt.isIdentity())
1064 return op->emitError("the sparse-tensor must have the identity mapping");
1066 // Verifies the trailing COO.
1067 Level cooStartLvl = getCOOStart(stt.getEncoding());
1068 if (cooStartLvl < stt.getLvlRank()) {
1069 // We only supports trailing COO for now, must be the last input.
1070 auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
1071 // The coordinates should be in shape of <? x rank>
1072 unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
1073 if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1074 op->emitError("input/output trailing COO level-ranks don't match");
1078 // Verifies that all types match.
1079 StorageLayout layout(stt.getEncoding());
1080 if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref
1081 return op->emitError("inconsistent number of fields between input/output");
1083 unsigned idx = 0;
1084 bool misMatch = false;
1085 layout.foreachField([&idx, &misMatch, stt, valTp,
1086 lvlTps](FieldIndex fid, SparseTensorFieldKind fKind,
1087 Level lvl, DimLevelType dlt) -> bool {
1088 if (fKind == SparseTensorFieldKind::StorageSpec)
1089 return true;
1091 Type inputTp = nullptr;
1092 if (fKind == SparseTensorFieldKind::ValMemRef) {
1093 inputTp = valTp;
1094 } else {
1095 assert(fid == idx && stt.getLvlType(lvl) == dlt);
1096 inputTp = lvlTps[idx++];
1098 // The input element type and expected element type should match.
1099 Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType();
1100 Type expElemTp = getFieldElemType(stt, fKind);
1101 if (inpElemTp != expElemTp) {
1102 misMatch = true;
1103 return false; // to terminate the iteration
1105 return true;
1108 if (misMatch)
1109 return op->emitError("input/output element-types don't match");
1110 return success();
1113 LogicalResult AssembleOp::verify() {
1114 const auto valuesTp = getRankedTensorType(getValues());
1115 const auto lvlsTp = getLevels().getTypes();
1116 const auto resTp = getSparseTensorType(getResult());
1117 return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
1120 LogicalResult DisassembleOp::verify() {
1121 if (getOutValues().getType() != getRetValues().getType())
1122 return emitError("output values and return value type mismatch");
1124 for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels()))
1125 if (ot.getType() != rt.getType())
1126 return emitError("output levels and return levels type mismatch");
1128 const auto valuesTp = getRankedTensorType(getRetValues());
1129 const auto lvlsTp = getRetLevels().getTypes();
1130 const auto srcTp = getSparseTensorType(getTensor());
1131 return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
1134 LogicalResult ConvertOp::verify() {
1135 if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1136 if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1137 if (tp1.getRank() != tp2.getRank())
1138 return emitError("unexpected conversion mismatch in rank");
1139 auto dstEnc =
1140 llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1141 if (dstEnc && dstEnc.isSlice())
1142 return emitError("cannot convert to a sparse tensor slice");
1144 auto shape1 = tp1.getShape();
1145 auto shape2 = tp2.getShape();
1146 // Accept size matches between the source and the destination type
1147 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1148 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1149 for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1150 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1151 return emitError("unexpected conversion mismatch in dimension ") << d;
1152 return success();
1155 return emitError("unexpected type in convert");
1158 OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
1159 if (getType() == getSource().getType())
1160 return getSource();
1161 return {};
1164 bool ConvertOp::needsExtraSort() {
1165 SparseTensorType srcStt = getSparseTensorType(getSource());
1166 SparseTensorType dstStt = getSparseTensorType(getDest());
1168 // We do not need an extra sort when returning unordered sparse tensors or
1169 // dense tensor since dense tensor support random access.
1170 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1171 return false;
1173 if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
1174 srcStt.hasSameDimToLvl(dstStt)) {
1175 return false;
1178 // Source and dest tensors are ordered in different ways. We only do direct
1179 // dense to sparse conversion when the dense input is defined by a sparse
1180 // constant. Note that we can theoretically always directly convert from dense
1181 // inputs by rotating dense loops but it leads to bad cache locality and hurt
1182 // performance.
1183 if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
1184 if (isa<SparseElementsAttr>(constOp.getValue()))
1185 return false;
1187 return true;
1190 LogicalResult CrdTranslateOp::verify() {
1191 uint64_t inRank = getEncoder().getLvlRank();
1192 uint64_t outRank = getEncoder().getDimRank();
1194 if (getDirection() == CrdTransDirectionKind::dim2lvl)
1195 std::swap(inRank, outRank);
1197 if (inRank != getInCrds().size() || outRank != getOutCrds().size())
1198 return emitError("Coordinate rank mismatch with encoding");
1200 return success();
1203 LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
1204 SmallVectorImpl<OpFoldResult> &results) {
1205 if (getEncoder().isIdentity()) {
1206 results.assign(getInCrds().begin(), getInCrds().end());
1207 return success();
1209 if (getEncoder().isPermutation()) {
1210 AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
1211 ? getEncoder().getDimToLvl()
1212 : getEncoder().getLvlToDim();
1213 for (AffineExpr exp : perm.getResults())
1214 results.push_back(getInCrds()[exp.cast<AffineDimExpr>().getPosition()]);
1215 return success();
1218 // Fuse dim2lvl/lvl2dim pairs.
1219 auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>();
1220 bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) {
1221 return v.getDefiningOp() == def;
1223 if (!sameDef)
1224 return failure();
1226 bool oppositeDir = def.getDirection() != getDirection();
1227 bool sameOracle =
1228 def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl();
1229 bool sameCount = def.getNumResults() == getInCrds().size();
1230 if (!oppositeDir || !sameOracle || !sameCount)
1231 return failure();
1233 // The definition produces the coordinates in the same order as the input
1234 // coordinates.
1235 bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()),
1236 [](auto valuePair) {
1237 auto [lhs, rhs] = valuePair;
1238 return lhs == rhs;
1241 if (!sameOrder)
1242 return failure();
1243 // l1 = dim2lvl (lvl2dim l0)
1244 // ==> l0
1245 results.append(def.getInCrds().begin(), def.getInCrds().end());
1246 return success();
1249 void LvlOp::build(OpBuilder &builder, OperationState &state, Value source,
1250 int64_t index) {
1251 Value val = builder.create<arith::ConstantIndexOp>(state.location, index);
1252 return build(builder, state, source, val);
1255 LogicalResult LvlOp::verify() {
1256 if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
1257 auto stt = getSparseTensorType(getSource());
1258 if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1259 emitError("Level index exceeds the rank of the input sparse tensor");
1261 return success();
1264 std::optional<uint64_t> LvlOp::getConstantLvlIndex() {
1265 return getConstantIntValue(getIndex());
1268 Speculation::Speculatability LvlOp::getSpeculatability() {
1269 auto constantIndex = getConstantLvlIndex();
1270 if (!constantIndex)
1271 return Speculation::NotSpeculatable;
1273 assert(constantIndex <
1274 cast<RankedTensorType>(getSource().getType()).getRank());
1275 return Speculation::Speculatable;
1278 OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
1279 auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1280 if (!lvlIndex)
1281 return {};
1283 Level lvl = lvlIndex.getAPSInt().getZExtValue();
1284 auto stt = getSparseTensorType(getSource());
1285 if (lvl >= stt.getLvlRank()) {
1286 // Follows the same convention used by tensor.dim operation. Out of bound
1287 // indices produce undefined behavior but are still valid IR. Don't choke on
1288 // them.
1289 return {};
1292 // Helper lambda to build an IndexAttr.
1293 auto getIndexAttr = [this](int64_t lvlSz) {
1294 return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz));
1297 SmallVector<Size> lvlShape = stt.getLvlShape();
1298 if (!ShapedType::isDynamic(lvlShape[lvl]))
1299 return getIndexAttr(lvlShape[lvl]);
1301 return {};
1304 void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1305 SparseTensorEncodingAttr dstEnc, Value source) {
1306 auto srcStt = getSparseTensorType(source);
1307 SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape();
1308 SmallVector<int64_t> dstDimShape =
1309 dstEnc.tranlateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim);
1310 auto dstTp =
1311 RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc);
1312 return build(odsBuilder, odsState, dstTp, source);
1315 LogicalResult ReinterpretMapOp::verify() {
1316 auto srcStt = getSparseTensorType(getSource());
1317 auto dstStt = getSparseTensorType(getDest());
1318 ArrayRef<DimLevelType> srcLvlTps = srcStt.getLvlTypes();
1319 ArrayRef<DimLevelType> dstLvlTps = dstStt.getLvlTypes();
1321 if (srcLvlTps.size() != dstLvlTps.size())
1322 return emitError("Level rank mismatch between source/dest tensors");
1324 for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps))
1325 if (srcLvlTp != dstLvlTp)
1326 return emitError("Level type mismatch between source/dest tensors");
1328 if (srcStt.getPosWidth() != dstStt.getPosWidth() ||
1329 srcStt.getCrdWidth() != dstStt.getCrdWidth()) {
1330 return emitError("Crd/Pos width mismatch between source/dest tensors");
1333 if (srcStt.getElementType() != dstStt.getElementType())
1334 return emitError("Element type mismatch between source/dest tensors");
1336 SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1337 SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
1338 for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
1339 if (srcLvlSz != dstLvlSz) {
1340 // Should we allow one side to be dynamic size, e.g., <?x?> should be
1341 // compatible to <3x4>? For now, we require all the level sizes to be
1342 // *exactly* matched for simplicity.
1343 return emitError("Level size mismatch between source/dest tensors");
1347 return success();
1350 OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1351 if (getSource().getType() == getDest().getType())
1352 return getSource();
1354 if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) {
1355 // A -> B, B -> A ==> A
1356 if (def.getSource().getType() == getDest().getType())
1357 return def.getSource();
1359 return {};
1362 LogicalResult ToPositionsOp::verify() {
1363 auto e = getSparseTensorEncoding(getTensor().getType());
1364 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1365 return emitError("requested level is out of bounds");
1366 if (failed(isMatchingWidth(getResult(), e.getPosWidth())))
1367 return emitError("unexpected type for positions");
1368 return success();
1371 LogicalResult ToCoordinatesOp::verify() {
1372 auto e = getSparseTensorEncoding(getTensor().getType());
1373 if (failed(lvlIsInBounds(getLevel(), getTensor())))
1374 return emitError("requested level is out of bounds");
1375 if (failed(isMatchingWidth(getResult(), e.getCrdWidth())))
1376 return emitError("unexpected type for coordinates");
1377 return success();
1380 LogicalResult ToCoordinatesBufferOp::verify() {
1381 auto e = getSparseTensorEncoding(getTensor().getType());
1382 if (getCOOStart(e) >= e.getLvlRank())
1383 return emitError("expected sparse tensor with a COO region");
1384 return success();
1387 LogicalResult ToValuesOp::verify() {
1388 auto ttp = getRankedTensorType(getTensor());
1389 auto mtp = getMemRefType(getResult());
1390 if (ttp.getElementType() != mtp.getElementType())
1391 return emitError("unexpected mismatch in element types");
1392 return success();
1395 LogicalResult ToSliceOffsetOp::verify() {
1396 auto rank = getRankedTensorType(getSlice()).getRank();
1397 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1398 return emitError("requested dimension out of bound");
1399 return success();
1402 LogicalResult ToSliceStrideOp::verify() {
1403 auto rank = getRankedTensorType(getSlice()).getRank();
1404 if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
1405 return emitError("requested dimension out of bound");
1406 return success();
1409 LogicalResult GetStorageSpecifierOp::verify() {
1410 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1411 getSpecifier(), getOperation());
1414 template <typename SpecifierOp>
1415 static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) {
1416 return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>();
1419 OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1420 const StorageSpecifierKind kind = getSpecifierKind();
1421 const auto lvl = getLevel();
1422 for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op))
1423 if (kind == op.getSpecifierKind() && lvl == op.getLevel())
1424 return op.getValue();
1425 return {};
1428 LogicalResult SetStorageSpecifierOp::verify() {
1429 return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1430 getSpecifier(), getOperation());
1433 template <class T>
1434 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
1435 const char *regionName,
1436 TypeRange inputTypes, Type outputType) {
1437 unsigned numArgs = region.getNumArguments();
1438 unsigned expectedNum = inputTypes.size();
1439 if (numArgs != expectedNum)
1440 return op->emitError() << regionName << " region must have exactly "
1441 << expectedNum << " arguments";
1443 for (unsigned i = 0; i < numArgs; i++) {
1444 Type typ = region.getArgument(i).getType();
1445 if (typ != inputTypes[i])
1446 return op->emitError() << regionName << " region argument " << (i + 1)
1447 << " type mismatch";
1449 Operation *term = region.front().getTerminator();
1450 YieldOp yield = dyn_cast<YieldOp>(term);
1451 if (!yield)
1452 return op->emitError() << regionName
1453 << " region must end with sparse_tensor.yield";
1454 if (!yield.getResult() || yield.getResult().getType() != outputType)
1455 return op->emitError() << regionName << " region yield type mismatch";
1457 return success();
1460 LogicalResult BinaryOp::verify() {
1461 NamedAttrList attrs = (*this)->getAttrs();
1462 Type leftType = getX().getType();
1463 Type rightType = getY().getType();
1464 Type outputType = getOutput().getType();
1465 Region &overlap = getOverlapRegion();
1466 Region &left = getLeftRegion();
1467 Region &right = getRightRegion();
1469 // Check correct number of block arguments and return type for each
1470 // non-empty region.
1471 if (!overlap.empty()) {
1472 if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1473 TypeRange{leftType, rightType}, outputType)))
1474 return failure();
1476 if (!left.empty()) {
1477 if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1478 outputType)))
1479 return failure();
1480 } else if (getLeftIdentity()) {
1481 if (leftType != outputType)
1482 return emitError("left=identity requires first argument to have the same "
1483 "type as the output");
1485 if (!right.empty()) {
1486 if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1487 outputType)))
1488 return failure();
1489 } else if (getRightIdentity()) {
1490 if (rightType != outputType)
1491 return emitError("right=identity requires second argument to have the "
1492 "same type as the output");
1494 return success();
1497 LogicalResult UnaryOp::verify() {
1498 Type inputType = getX().getType();
1499 Type outputType = getOutput().getType();
1501 // Check correct number of block arguments and return type for each
1502 // non-empty region.
1503 Region &present = getPresentRegion();
1504 if (!present.empty()) {
1505 if (failed(verifyNumBlockArgs(this, present, "present",
1506 TypeRange{inputType}, outputType)))
1507 return failure();
1509 Region &absent = getAbsentRegion();
1510 if (!absent.empty()) {
1511 if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1512 outputType)))
1513 return failure();
1514 // Absent branch can only yield invariant values.
1515 Block *absentBlock = &absent.front();
1516 Block *parent = getOperation()->getBlock();
1517 Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
1518 if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1519 if (arg.getOwner() == parent)
1520 return emitError("absent region cannot yield linalg argument");
1521 } else if (Operation *def = absentVal.getDefiningOp()) {
1522 if (!isa<arith::ConstantOp>(def) &&
1523 (def->getBlock() == absentBlock || def->getBlock() == parent))
1524 return emitError("absent region cannot yield locally computed value");
1527 return success();
1530 bool ConcatenateOp::needsExtraSort() {
1531 SparseTensorType dstStt = getSparseTensorType(*this);
1532 if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1533 return false;
1535 bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1536 return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1538 // TODO: When conDim != 0, as long as conDim corresponding to the first level
1539 // in all input/output buffers, and all input/output buffers have the same
1540 // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1541 // CSC matrices along column).
1542 bool directLowerable =
1543 allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1544 return !directLowerable;
1547 LogicalResult ConcatenateOp::verify() {
1548 const auto dstTp = getSparseTensorType(*this);
1549 const Dimension concatDim = getDimension();
1550 const Dimension dimRank = dstTp.getDimRank();
1552 if (getInputs().size() <= 1)
1553 return emitError("Need at least two tensors to concatenate.");
1555 if (concatDim >= dimRank)
1556 return emitError(llvm::formatv(
1557 "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})",
1558 concatDim, dimRank));
1560 for (const auto &it : llvm::enumerate(getInputs())) {
1561 const auto i = it.index();
1562 const auto srcTp = getSparseTensorType(it.value());
1563 if (srcTp.hasDynamicDimShape())
1564 return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i));
1565 const Dimension srcDimRank = srcTp.getDimRank();
1566 if (srcDimRank != dimRank)
1567 return emitError(
1568 llvm::formatv("Input tensor ${0} has a different rank (rank={1}) "
1569 "from the output tensor (rank={2}).",
1570 i, srcDimRank, dimRank));
1573 for (Dimension d = 0; d < dimRank; d++) {
1574 const Size dstSh = dstTp.getDimShape()[d];
1575 if (d == concatDim) {
1576 if (!ShapedType::isDynamic(dstSh)) {
1577 // If we reach here, then all inputs have static shapes. So we
1578 // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
1579 // to avoid redundant assertions in the loop.
1580 Size sumSz = 0;
1581 for (const auto src : getInputs())
1582 sumSz += getSparseTensorType(src).getDimShape()[d];
1583 // If all dimension are statically known, the sum of all the input
1584 // dimensions should be equal to the output dimension.
1585 if (sumSz != dstSh)
1586 return emitError(
1587 "The concatenation dimension of the output tensor should be the "
1588 "sum of all the concatenation dimensions of the input tensors.");
1590 } else {
1591 Size prev = dstSh;
1592 for (const auto src : getInputs()) {
1593 const auto sh = getSparseTensorType(src).getDimShape()[d];
1594 if (!ShapedType::isDynamic(prev) && sh != prev)
1595 return emitError("All dimensions (expect for the concatenating one) "
1596 "should be equal.");
1597 prev = sh;
1602 return success();
1605 LogicalResult InsertOp::verify() {
1606 const auto stt = getSparseTensorType(getTensor());
1607 if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
1608 return emitOpError("incorrect number of coordinates");
1609 return success();
1612 void PushBackOp::build(OpBuilder &builder, OperationState &result,
1613 Value curSize, Value inBuffer, Value value) {
1614 build(builder, result, curSize, inBuffer, value, Value());
1617 LogicalResult PushBackOp::verify() {
1618 if (Value n = getN()) {
1619 std::optional<int64_t> nValue = getConstantIntValue(n);
1620 if (nValue && nValue.value() < 1)
1621 return emitOpError("n must be not less than 1");
1623 return success();
1626 LogicalResult CompressOp::verify() {
1627 const auto stt = getSparseTensorType(getTensor());
1628 if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size()))
1629 return emitOpError("incorrect number of coordinates");
1630 return success();
1633 void ForeachOp::build(
1634 OpBuilder &builder, OperationState &result, Value tensor,
1635 ValueRange initArgs, AffineMapAttr order,
1636 function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)>
1637 bodyBuilder) {
1638 build(builder, result, initArgs.getTypes(), tensor, initArgs, order);
1639 // Builds foreach body.
1640 if (!bodyBuilder)
1641 return;
1642 const auto stt = getSparseTensorType(tensor);
1643 const Dimension dimRank = stt.getDimRank();
1645 // Starts with `dimRank`-many coordinates.
1646 SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType());
1647 // Followed by one value.
1648 blockArgTypes.push_back(stt.getElementType());
1649 // Followed by the reduction variables.
1650 blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
1652 SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc());
1654 OpBuilder::InsertionGuard guard(builder);
1655 auto &region = *result.regions.front();
1656 Block *bodyBlock =
1657 builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1658 bodyBuilder(builder, result.location,
1659 bodyBlock->getArguments().slice(0, dimRank),
1660 bodyBlock->getArguments()[dimRank],
1661 bodyBlock->getArguments().drop_front(dimRank + 1));
1664 LogicalResult ForeachOp::verify() {
1665 const auto t = getSparseTensorType(getTensor());
1666 const Dimension dimRank = t.getDimRank();
1667 const auto args = getBody()->getArguments();
1669 if (getOrder().has_value() &&
1670 (t.getEncoding() || !getOrder()->isPermutation()))
1671 return emitError("Only support permuted order on non encoded dense tensor");
1673 if (static_cast<size_t>(dimRank) + 1 + getInitArgs().size() != args.size())
1674 return emitError("Unmatched number of arguments in the block");
1676 if (getNumResults() != getInitArgs().size())
1677 return emitError("Mismatch in number of init arguments and results");
1679 if (getResultTypes() != getInitArgs().getTypes())
1680 return emitError("Mismatch in types of init arguments and results");
1682 // Cannot mark this const, because the getters aren't.
1683 auto yield = cast<YieldOp>(getBody()->getTerminator());
1684 if (yield.getNumOperands() != getNumResults() ||
1685 yield.getOperands().getTypes() != getResultTypes())
1686 return emitError("Mismatch in types of yield values and results");
1688 const auto iTp = IndexType::get(getContext());
1689 for (Dimension d = 0; d < dimRank; d++)
1690 if (args[d].getType() != iTp)
1691 emitError(
1692 llvm::formatv("Expecting Index type for argument at index {0}", d));
1694 const auto elemTp = t.getElementType();
1695 const auto valueTp = args[dimRank].getType();
1696 if (elemTp != valueTp)
1697 emitError(llvm::formatv("Unmatched element type between input tensor and "
1698 "block argument, expected:{0}, got: {1}",
1699 elemTp, valueTp));
1700 return success();
1703 OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1704 if (getSparseTensorEncoding(getInputCoo().getType()) ==
1705 getSparseTensorEncoding(getResultCoo().getType()))
1706 return getInputCoo();
1708 return {};
1711 LogicalResult ReorderCOOOp::verify() {
1712 SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1713 SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1715 if (!srcStt.hasSameDimToLvl(dstStt))
1716 emitError("Unmatched dim2lvl map between input and result COO");
1718 if (srcStt.getPosType() != dstStt.getPosType() ||
1719 srcStt.getCrdType() != dstStt.getCrdType() ||
1720 srcStt.getElementType() != dstStt.getElementType()) {
1721 emitError("Unmatched storage format between input and result COO");
1723 return success();
1726 LogicalResult ReduceOp::verify() {
1727 Type inputType = getX().getType();
1728 Region &formula = getRegion();
1729 return verifyNumBlockArgs(this, formula, "reduce",
1730 TypeRange{inputType, inputType}, inputType);
1733 LogicalResult SelectOp::verify() {
1734 Builder b(getContext());
1735 Type inputType = getX().getType();
1736 Type boolType = b.getI1Type();
1737 Region &formula = getRegion();
1738 return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
1739 boolType);
1742 LogicalResult SortOp::verify() {
1743 AffineMap xPerm = getPermMap();
1744 uint64_t nx = xPerm.getNumDims();
1745 if (nx < 1)
1746 emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
1748 if (!xPerm.isPermutation())
1749 emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
1751 std::optional<int64_t> cn = getConstantIntValue(getN());
1752 // We can't check the size of the buffers when n or buffer dimensions aren't
1753 // compile-time constants.
1754 if (!cn)
1755 return success();
1757 uint64_t n = cn.value();
1758 uint64_t ny = 0;
1759 if (auto nyAttr = getNyAttr()) {
1760 ny = nyAttr.getInt();
1763 // FIXME: update the types of variables used in expressions bassed as
1764 // the `minSize` argument, to avoid implicit casting at the callsites
1765 // of this lambda.
1766 const auto checkDim = [&](Value v, Size minSize, const char *message) {
1767 const Size sh = getMemRefType(v).getShape()[0];
1768 if (!ShapedType::isDynamic(sh) && sh < minSize)
1769 emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
1772 checkDim(getXy(), n * (nx + ny),
1773 "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
1775 for (Value opnd : getYs()) {
1776 checkDim(opnd, n, "Expected dimension(y) >= n");
1779 return success();
1782 LogicalResult YieldOp::verify() {
1783 // Check for compatible parent.
1784 auto *parentOp = (*this)->getParentOp();
1785 if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
1786 isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp) ||
1787 isa<ForeachOp>(parentOp))
1788 return success();
1790 return emitOpError("expected parent op to be sparse_tensor unary, binary, "
1791 "reduce, select or foreach");
1794 /// Materialize a single constant operation from a given attribute value with
1795 /// the desired resultant type.
1796 Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
1797 Attribute value, Type type,
1798 Location loc) {
1799 if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
1800 return op;
1801 return nullptr;
1804 void SparseTensorDialect::initialize() {
1805 addAttributes<
1806 #define GET_ATTRDEF_LIST
1807 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
1808 >();
1809 addTypes<
1810 #define GET_TYPEDEF_LIST
1811 #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
1812 >();
1813 addOperations<
1814 #define GET_OP_LIST
1815 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1816 >();
1819 #define GET_OP_CLASSES
1820 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
1822 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"