1 //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // A pass that converts sparse tensor types and primitives to actual compiler
10 // visible buffers and actual compiler IR that implements these primitives on
11 // the selected sparse tensor storage schemes. This pass provides an alternative
12 // to the SparseTensorConversion pass, eliminating the dependence on a runtime
13 // support library, and providing much more opportunities for subsequent
14 // compiler optimization of the generated code.
16 //===----------------------------------------------------------------------===//
18 #include "CodegenUtils.h"
19 #include "SparseTensorDescriptor.h"
21 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/Dialect/Arith/Utils/Utils.h"
24 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/Linalg/Utils/Utils.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
31 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/Transforms/DialectConversion.h"
38 using namespace mlir::sparse_tensor
;
42 using FuncGeneratorType
=
43 function_ref
<void(OpBuilder
&, ModuleOp
, func::FuncOp
, RankedTensorType
)>;
45 //===----------------------------------------------------------------------===//
47 //===----------------------------------------------------------------------===//
49 /// Flatten a list of operands that may contain sparse tensors.
50 static void flattenOperands(ValueRange operands
,
51 SmallVectorImpl
<Value
> &flattened
) {
53 // sparse_tensor, c, sparse_tensor
55 // memref ..., c, memref ...
56 for (auto operand
: operands
) {
57 if (getSparseTensorEncoding(operand
.getType())) {
58 auto tuple
= getTuple(operand
);
59 // An unrealized_conversion_cast will be inserted by type converter to
60 // inter-mix the gap between 1:N conversion between sparse tensors and
61 // fields. In this case, take the operands in the cast and replace the
62 // sparse tensor output with the flattened type array.
63 flattened
.append(tuple
.getOperands().begin(), tuple
.getOperands().end());
65 flattened
.push_back(operand
);
70 /// Generates a load with proper `index` typing.
71 static Value
genLoad(OpBuilder
&builder
, Location loc
, Value mem
, Value idx
) {
72 idx
= genCast(builder
, loc
, idx
, builder
.getIndexType());
73 return builder
.create
<memref::LoadOp
>(loc
, mem
, idx
);
76 /// Generates a store with proper `index` typing and proper value.
77 static void genStore(OpBuilder
&builder
, Location loc
, Value val
, Value mem
,
79 idx
= genCast(builder
, loc
, idx
, builder
.getIndexType());
80 val
= genCast(builder
, loc
, val
,
81 cast
<ShapedType
>(mem
.getType()).getElementType());
82 builder
.create
<memref::StoreOp
>(loc
, val
, mem
, idx
);
85 /// Creates a straightforward counting for-loop.
86 static scf::ForOp
createFor(OpBuilder
&builder
, Location loc
, Value upper
,
87 MutableArrayRef
<Value
> fields
,
88 Value lower
= Value()) {
89 Type indexType
= builder
.getIndexType();
91 lower
= constantZero(builder
, loc
, indexType
);
92 Value one
= constantOne(builder
, loc
, indexType
);
93 scf::ForOp forOp
= builder
.create
<scf::ForOp
>(loc
, lower
, upper
, one
, fields
);
94 for (unsigned i
= 0, e
= fields
.size(); i
< e
; i
++)
95 fields
[i
] = forOp
.getRegionIterArg(i
);
96 builder
.setInsertionPointToStart(forOp
.getBody());
100 static void createPushback(OpBuilder
&builder
, Location loc
,
101 MutSparseTensorDescriptor desc
,
102 SparseTensorFieldKind kind
, std::optional
<Level
> lvl
,
103 Value value
, Value repeat
= Value()) {
104 Type etp
= desc
.getMemRefElementType(kind
, lvl
);
105 Value field
= desc
.getMemRefField(kind
, lvl
);
106 StorageSpecifierKind specFieldKind
= toSpecifierKind(kind
);
108 auto pushBackOp
= builder
.create
<PushBackOp
>(
109 loc
, desc
.getSpecifierField(builder
, loc
, specFieldKind
, lvl
), field
,
110 genCast(builder
, loc
, value
, etp
), repeat
);
112 desc
.setMemRefField(kind
, lvl
, pushBackOp
.getOutBuffer());
113 desc
.setSpecifierField(builder
, loc
, specFieldKind
, lvl
,
114 pushBackOp
.getNewSize());
117 /// Generates code that allocates a sparse storage scheme for given rank.
118 static void allocSchemeForRank(OpBuilder
&builder
, Location loc
,
119 MutSparseTensorDescriptor desc
, Level startLvl
) {
120 const SparseTensorType
stt(desc
.getRankedTensorType());
121 Value linear
= constantIndex(builder
, loc
, 1);
122 const Level lvlRank
= stt
.getLvlRank();
123 for (Level l
= startLvl
; l
< lvlRank
; l
++) {
124 const auto dlt
= stt
.getLvlType(l
);
125 if (isCompressedDLT(dlt
)) {
126 // Append linear x positions, initialized to zero. Since each compressed
127 // dimension initially already has a single zero entry, this maintains
128 // the desired "linear + 1" length property at all times.
129 Value posZero
= constantZero(builder
, loc
, stt
.getPosType());
130 createPushback(builder
, loc
, desc
, SparseTensorFieldKind::PosMemRef
, l
,
134 if (isSingletonDLT(dlt
)) {
135 return; // nothing to do
137 // Keep compounding the size, but nothing needs to be initialized
138 // at this level. We will eventually reach a compressed level or
139 // otherwise the values array for the from-here "all-dense" case.
140 assert(isDenseDLT(dlt
));
141 Value size
= desc
.getLvlSize(builder
, loc
, l
);
142 linear
= builder
.create
<arith::MulIOp
>(loc
, linear
, size
);
144 // Reached values array so prepare for an insertion.
145 Value valZero
= constantZero(builder
, loc
, stt
.getElementType());
146 createPushback(builder
, loc
, desc
, SparseTensorFieldKind::ValMemRef
,
147 std::nullopt
, valZero
, linear
);
150 /// Creates allocation operation.
151 static Value
createAllocation(OpBuilder
&builder
, Location loc
,
152 MemRefType memRefType
, Value sz
,
154 Value buffer
= builder
.create
<memref::AllocOp
>(loc
, memRefType
, sz
);
155 Type elemType
= memRefType
.getElementType();
157 Value fillValue
= constantZero(builder
, loc
, elemType
);
158 builder
.create
<linalg::FillOp
>(loc
, fillValue
, buffer
);
163 /// Creates allocation for each field in sparse tensor type. Note that
164 /// for all dynamic memrefs, the memory size is really the capacity of
165 /// the "vector", while the actual size resides in the sizes array.
167 /// TODO: for efficiency, we will need heuristics to make educated guesses
168 /// on the required capacities (see heuristic variable).
170 static void createAllocFields(OpBuilder
&builder
, Location loc
,
171 SparseTensorType stt
, ValueRange dynSizes
,
172 bool enableInit
, SmallVectorImpl
<Value
> &fields
,
174 // Build original sizes.
175 assert((dynSizes
.size() == static_cast<size_t>(stt
.getNumDynamicDims())) &&
176 "Got wrong number of dynamic sizes");
177 const Dimension dimRank
= stt
.getDimRank();
178 SmallVector
<Value
> dimSizes
;
179 dimSizes
.reserve(dimRank
);
180 unsigned i
= 0; // cumulative index into `dynSizes`.
181 for (const Size sh
: stt
.getDimShape())
182 dimSizes
.push_back(ShapedType::isDynamic(sh
)
184 : constantIndex(builder
, loc
, sh
));
186 // Set up some heuristic sizes. We try to set the initial
187 // size based on available information. Otherwise we just
188 // initialize a few elements to start the reallocation chain.
190 Value posHeuristic
, crdHeuristic
, valHeuristic
;
191 if (stt
.isAllDense()) {
192 valHeuristic
= dimSizes
[0];
193 for (const Value sz
: ArrayRef
<Value
>{dimSizes
}.drop_front())
194 valHeuristic
= builder
.create
<arith::MulIOp
>(loc
, valHeuristic
, sz
);
195 } else if (sizeHint
) {
196 if (getCOOStart(stt
.getEncoding()) == 0) {
197 posHeuristic
= constantIndex(builder
, loc
, 2);
198 crdHeuristic
= builder
.create
<arith::MulIOp
>(
199 loc
, constantIndex(builder
, loc
, dimRank
), sizeHint
); // AOS
200 } else if (dimRank
== 2 && stt
.isDenseLvl(0) && stt
.isCompressedLvl(1)) {
201 posHeuristic
= builder
.create
<arith::AddIOp
>(
202 loc
, sizeHint
, constantIndex(builder
, loc
, 1));
203 crdHeuristic
= sizeHint
;
205 posHeuristic
= crdHeuristic
= constantIndex(builder
, loc
, 16);
207 valHeuristic
= sizeHint
;
209 posHeuristic
= crdHeuristic
= valHeuristic
=
210 constantIndex(builder
, loc
, 16);
213 foreachFieldAndTypeInSparseTensor(
215 [&builder
, &fields
, stt
, loc
, posHeuristic
, crdHeuristic
, valHeuristic
,
216 enableInit
](Type fType
, FieldIndex fIdx
, SparseTensorFieldKind fKind
,
217 Level
/*lvl*/, DimLevelType
/*dlt*/) -> bool {
218 assert(fields
.size() == fIdx
);
221 case SparseTensorFieldKind::StorageSpec
:
222 field
= SparseTensorSpecifier::getInitValue(builder
, loc
, stt
);
224 case SparseTensorFieldKind::PosMemRef
:
225 case SparseTensorFieldKind::CrdMemRef
:
226 case SparseTensorFieldKind::ValMemRef
:
227 field
= createAllocation(
228 builder
, loc
, cast
<MemRefType
>(fType
),
229 (fKind
== SparseTensorFieldKind::PosMemRef
) ? posHeuristic
230 : (fKind
== SparseTensorFieldKind::CrdMemRef
) ? crdHeuristic
236 fields
.push_back(field
);
237 // Returns true to continue the iteration.
241 MutSparseTensorDescriptor
desc(stt
, fields
);
243 // Initialize the storage scheme to an empty tensor. Initialized memSizes
244 // to all zeros, sets the dimSizes to known values and gives all position
245 // fields an initial zero entry, so that it is easier to maintain the
246 // "linear + 1" length property.
247 Value posZero
= constantZero(builder
, loc
, stt
.getPosType());
248 for (Level lvlRank
= stt
.getLvlRank(), l
= 0; l
< lvlRank
; l
++) {
249 // Fills dim sizes array.
250 // FIXME: `toOrigDim` is deprecated.
251 desc
.setLvlSize(builder
, loc
, l
, dimSizes
[toOrigDim(stt
, l
)]);
252 // Pushes a leading zero to positions memref.
253 if (stt
.isCompressedLvl(l
))
254 createPushback(builder
, loc
, desc
, SparseTensorFieldKind::PosMemRef
, l
,
257 allocSchemeForRank(builder
, loc
, desc
, /*rank=*/0);
260 /// Helper method that generates block specific to compressed case:
262 /// // given: parentPos = posCursor[lvl-1]
263 /// pstart = desc.positions[lvl][parentPos]
264 /// pstop = desc.positions[lvl][parentPos+1]
265 /// plast = pstop - 1
266 /// msz = desc.coordinates[lvl].size()
267 /// if (pstart < pstop) {
268 /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
269 /// } else { // first insertion
270 /// isPresent = false
271 /// desc.positions[lvl][parentPos] = msz
273 /// if (isPresent) { // coordinate is already present
276 /// desc.coordinates[lvl].push_back(lvlCoords[lvl])
277 /// desc.positions[lvl][parentPos+1] = msz+1
279 /// <prepare level lvl+1>
281 /// posCursor[lvl] = pnext
282 static Value
genCompressed(OpBuilder
&builder
, Location loc
,
283 MutSparseTensorDescriptor desc
, ValueRange lvlCoords
,
284 Value
/*unused*/, Value parentPos
, Level lvl
) {
285 const SparseTensorType
stt(desc
.getRankedTensorType());
286 const Level lvlRank
= stt
.getLvlRank();
287 assert(lvl
< lvlRank
&& "Level is out of bounds");
288 assert(lvlCoords
.size() == static_cast<size_t>(lvlRank
) &&
289 "Level-rank mismatch");
290 SmallVector
<Type
> types
;
291 Type indexType
= builder
.getIndexType();
292 Type boolType
= builder
.getIntegerType(1);
295 std::tie(crdFidx
, crdStride
) = desc
.getCrdMemRefIndexAndStride(lvl
);
296 const Value one
= constantIndex(builder
, loc
, 1);
297 const Value pp1
= builder
.create
<arith::AddIOp
>(loc
, parentPos
, one
);
298 const Value positionsAtLvl
= desc
.getPosMemRef(lvl
);
299 const Value pstart
= genLoad(builder
, loc
, positionsAtLvl
, parentPos
);
300 const Value pstop
= genLoad(builder
, loc
, positionsAtLvl
, pp1
);
301 const Value crdMsz
= desc
.getCrdMemSize(builder
, loc
, lvl
);
302 const Value crdStrideC
=
303 crdStride
> 1 ? constantIndex(builder
, loc
, crdStride
) : Value();
305 crdStrideC
? builder
.create
<arith::DivUIOp
>(loc
, crdMsz
, crdStrideC
)
307 const Value plast
= builder
.create
<arith::SubIOp
>(
308 loc
, genCast(builder
, loc
, pstop
, indexType
), one
);
309 // Conditional expression.
310 Value lt
= builder
.create
<arith::CmpIOp
>(loc
, arith::CmpIPredicate::ult
,
312 types
.push_back(boolType
);
313 scf::IfOp ifOp1
= builder
.create
<scf::IfOp
>(loc
, types
, lt
, /*else*/ true);
315 builder
.setInsertionPointToStart(&ifOp1
.getThenRegion().front());
317 genLoad(builder
, loc
, desc
.getMemRefField(crdFidx
),
318 crdStrideC
? builder
.create
<arith::MulIOp
>(loc
, plast
, crdStrideC
)
320 Value eq
= builder
.create
<arith::CmpIOp
>(
321 loc
, arith::CmpIPredicate::eq
, genCast(builder
, loc
, crd
, indexType
),
323 builder
.create
<scf::YieldOp
>(loc
, eq
);
324 builder
.setInsertionPointToStart(&ifOp1
.getElseRegion().front());
326 genStore(builder
, loc
, msz
, positionsAtLvl
, parentPos
);
327 builder
.create
<scf::YieldOp
>(loc
, constantI1(builder
, loc
, false));
328 builder
.setInsertionPointAfter(ifOp1
);
329 // If present construct. Note that for a non-unique dimension level, we
330 // simply set the condition to false and rely on CSE/DCE to clean up the IR.
332 // TODO: generate less temporary IR?
334 for (unsigned i
= 0, e
= desc
.getNumFields(); i
< e
; i
++)
335 types
.push_back(desc
.getField(i
).getType());
336 types
.push_back(indexType
);
337 const Value p
= stt
.isUniqueLvl(lvl
) ? ifOp1
.getResult(0)
338 : constantI1(builder
, loc
, false);
339 scf::IfOp ifOp2
= builder
.create
<scf::IfOp
>(loc
, types
, p
, /*else*/ true);
340 // If present (fields unaffected, update pnext to plast).
341 builder
.setInsertionPointToStart(&ifOp2
.getThenRegion().front());
343 // FIXME: This does not looks like a clean way, but probably the most
345 desc
.getFields().push_back(plast
);
346 builder
.create
<scf::YieldOp
>(loc
, desc
.getFields());
347 desc
.getFields().pop_back();
349 // If !present (changes fields, update pnext).
350 builder
.setInsertionPointToStart(&ifOp2
.getElseRegion().front());
351 Value mszp1
= builder
.create
<arith::AddIOp
>(loc
, msz
, one
);
352 genStore(builder
, loc
, mszp1
, positionsAtLvl
, pp1
);
353 createPushback(builder
, loc
, desc
, SparseTensorFieldKind::CrdMemRef
, lvl
,
355 // Prepare the next level "as needed".
356 if ((lvl
+ 1) < lvlRank
)
357 allocSchemeForRank(builder
, loc
, desc
, lvl
+ 1);
359 desc
.getFields().push_back(msz
);
360 builder
.create
<scf::YieldOp
>(loc
, desc
.getFields());
361 desc
.getFields().pop_back();
363 // Update fields and return next pos.
364 builder
.setInsertionPointAfter(ifOp2
);
366 for (unsigned i
= 0, e
= desc
.getNumFields(); i
< e
; i
++)
367 desc
.setField(i
, ifOp2
.getResult(o
++));
368 return ifOp2
.getResult(o
);
371 /// Helper class to help lowering sparse_tensor.insert operation.
372 class SparseInsertGenerator
373 : public FuncCallOrInlineGenerator
<SparseInsertGenerator
> {
375 SparseInsertGenerator(TensorType rtp
, TypeRange retTypes
, ValueRange params
,
377 : FuncCallOrInlineGenerator(retTypes
, params
, genCall
), rtp(rtp
){};
379 /// Generates code along an insertion path without the need for a "cursor".
380 /// This current insertion strategy comes at the expense of some testing
381 /// overhead for each insertion. The strategy will be optimized later for
382 /// common insertion patterns. The current insertion strategy also assumes
383 /// insertions occur in "a reasonable order" that enables building the
384 /// storage scheme in an appending/inserting kind of fashion (i.e. no
385 /// in-between insertions that need data movement). The implementation
386 /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
388 /// TODO: better unord/not-unique; also generalize, optimize, specialize!
389 SmallVector
<Value
> genImplementation(TypeRange retTypes
, ValueRange args
,
390 OpBuilder
&builder
, Location loc
) {
391 const SparseTensorType
stt(llvm::cast
<RankedTensorType
>(rtp
));
392 const Level lvlRank
= stt
.getLvlRank();
393 // Extract fields and coordinates from args.
394 SmallVector
<Value
> fields
= llvm::to_vector(args
.drop_back(lvlRank
+ 1));
395 MutSparseTensorDescriptor
desc(stt
, fields
);
396 const SmallVector
<Value
> coords
=
397 llvm::to_vector(args
.take_back(lvlRank
+ 1).drop_back());
398 Value value
= args
.back();
399 Value parentPos
= constantZero(builder
, loc
, builder
.getIndexType());
400 // Generate code for every level.
401 for (Level l
= 0; l
< lvlRank
; l
++) {
402 const auto dlt
= stt
.getLvlType(l
);
403 if (isCompressedDLT(dlt
)) {
406 // coordinates[l].push_back(coords[l])
407 // <update positions and prepare level l + 1>
409 // positions[l] = coordinates.size() - 1
410 // <insert @ positions[l] at next level l + 1>
412 genCompressed(builder
, loc
, desc
, coords
, value
, parentPos
, l
);
413 } else if (isSingletonDLT(dlt
)) {
415 // coordinates[l].push_back(coords[l])
416 // positions[l] = positions[l-1]
417 // <insert @ positions[l] at next level l + 1>
418 createPushback(builder
, loc
, desc
, SparseTensorFieldKind::CrdMemRef
, l
,
421 assert(isDenseDLT(dlt
));
422 // Construct the new position as:
423 // positions[l] = size * positions[l-1] + coords[l]
424 // <insert @ positions[l] at next level l + 1>
425 Value size
= desc
.getLvlSize(builder
, loc
, l
);
426 Value mult
= builder
.create
<arith::MulIOp
>(loc
, size
, parentPos
);
427 parentPos
= builder
.create
<arith::AddIOp
>(loc
, mult
, coords
[l
]);
430 // Reached the actual value append/insert.
431 if (!stt
.isDenseLvl(lvlRank
- 1))
432 createPushback(builder
, loc
, desc
, SparseTensorFieldKind::ValMemRef
,
433 std::nullopt
, value
);
435 genStore(builder
, loc
, value
, desc
.getValMemRef(), parentPos
);
439 std::string
getMangledFuncName() {
440 // The mangled name of the function has this format:
441 // <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
442 constexpr const char kInsertFuncNamePrefix
[] = "_insert_";
443 const SparseTensorType
stt(llvm::cast
<RankedTensorType
>(rtp
));
445 SmallString
<32> nameBuffer
;
446 llvm::raw_svector_ostream
nameOstream(nameBuffer
);
447 nameOstream
<< kInsertFuncNamePrefix
;
448 const Level lvlRank
= stt
.getLvlRank();
449 for (Level l
= 0; l
< lvlRank
; l
++) {
450 std::string lvlType
= toMLIRString(stt
.getLvlType(l
));
451 // Replace/remove punctuations in level properties.
453 lvlType
.begin(), lvlType
.end(),
454 [](char c
) { return c
== '(' || c
== ','; }, '_');
455 llvm::erase_if(lvlType
, [](char c
) { return c
== ')' || c
== ' '; });
456 nameOstream
<< lvlType
<< "_";
458 // Static dim sizes are used in the generated code while dynamic sizes are
459 // loaded from the dimSizes buffer. This is the reason for adding the shape
460 // to the function name.
461 for (const auto sh
: stt
.getDimShape())
462 nameOstream
<< sh
<< "_";
463 // Permutation information is also used in generating insertion.
464 if (!stt
.isIdentity())
465 nameOstream
<< stt
.getDimToLvl() << "_";
466 nameOstream
<< stt
.getElementType() << "_";
467 nameOstream
<< stt
.getCrdWidth() << "_" << stt
.getPosWidth();
468 return nameOstream
.str().str();
475 /// Generations insertion finalization code.
476 static void genEndInsert(OpBuilder
&builder
, Location loc
,
477 SparseTensorDescriptor desc
) {
478 const SparseTensorType
stt(desc
.getRankedTensorType());
479 const Level lvlRank
= stt
.getLvlRank();
480 for (Level l
= 0; l
< lvlRank
; l
++) {
481 const auto dlt
= stt
.getLvlType(l
);
482 if (isLooseCompressedDLT(dlt
))
483 llvm_unreachable("TODO: Not yet implemented");
484 if (isCompressedDLT(dlt
)) {
485 // Compressed dimensions need a position cleanup for all entries
486 // that were not visited during the insertion pass.
488 // TODO: avoid cleanup and keep compressed scheme consistent at all
492 Type posType
= stt
.getPosType();
493 Value posMemRef
= desc
.getPosMemRef(l
);
494 Value hi
= desc
.getPosMemSize(builder
, loc
, l
);
495 Value zero
= constantIndex(builder
, loc
, 0);
496 Value one
= constantIndex(builder
, loc
, 1);
497 // Vector of only one, but needed by createFor's prototype.
498 SmallVector
<Value
, 1> inits
{genLoad(builder
, loc
, posMemRef
, zero
)};
499 scf::ForOp loop
= createFor(builder
, loc
, hi
, inits
, one
);
500 Value i
= loop
.getInductionVar();
501 Value oldv
= loop
.getRegionIterArg(0);
502 Value newv
= genLoad(builder
, loc
, posMemRef
, i
);
503 Value posZero
= constantZero(builder
, loc
, posType
);
504 Value cond
= builder
.create
<arith::CmpIOp
>(
505 loc
, arith::CmpIPredicate::eq
, newv
, posZero
);
506 scf::IfOp ifOp
= builder
.create
<scf::IfOp
>(loc
, TypeRange(posType
),
507 cond
, /*else*/ true);
508 builder
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
509 genStore(builder
, loc
, oldv
, posMemRef
, i
);
510 builder
.create
<scf::YieldOp
>(loc
, oldv
);
511 builder
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
512 builder
.create
<scf::YieldOp
>(loc
, newv
);
513 builder
.setInsertionPointAfter(ifOp
);
514 builder
.create
<scf::YieldOp
>(loc
, ifOp
.getResult(0));
515 builder
.setInsertionPointAfter(loop
);
518 assert(isDenseDLT(dlt
) || isSingletonDLT(dlt
));
523 static Value
genSliceToSize(OpBuilder
&builder
, Location loc
, Value mem
,
525 auto elemTp
= llvm::cast
<MemRefType
>(mem
.getType()).getElementType();
527 .create
<memref::SubViewOp
>(
528 loc
, MemRefType::get({ShapedType::kDynamic
}, elemTp
), mem
,
529 ValueRange
{}, ValueRange
{sz
}, ValueRange
{},
530 ArrayRef
<int64_t>{0}, // static offset
531 ArrayRef
<int64_t>{ShapedType::kDynamic
}, // dynamic size
532 ArrayRef
<int64_t>{1}) // static stride
536 static ReassociationIndices
getReassociationForFlattening(ShapedType srcTp
) {
537 ReassociationIndices reassociation
;
538 for (int i
= 0, e
= srcTp
.getRank(); i
< e
; i
++)
539 reassociation
.push_back(i
);
540 return reassociation
;
543 static Value
genScalarToTensor(OpBuilder
&builder
, Location loc
, Value elem
,
545 if (auto rtp
= dstTp
.dyn_cast
<RankedTensorType
>()) {
546 // Scalars can only be converted to 0-ranked tensors.
547 if (rtp
.getRank() != 0)
549 elem
= genCast(builder
, loc
, elem
, rtp
.getElementType());
550 return builder
.create
<tensor::FromElementsOp
>(loc
, rtp
, elem
);
552 return genCast(builder
, loc
, elem
, dstTp
);
555 //===----------------------------------------------------------------------===//
557 //===----------------------------------------------------------------------===//
559 /// Sparse tensor storage conversion rule for returns.
560 class SparseReturnConverter
: public OpConversionPattern
<func::ReturnOp
> {
562 using OpConversionPattern::OpConversionPattern
;
564 matchAndRewrite(func::ReturnOp op
, OpAdaptor adaptor
,
565 ConversionPatternRewriter
&rewriter
) const override
{
566 SmallVector
<Value
> flattened
;
567 flattenOperands(adaptor
.getOperands(), flattened
);
568 // Create a return with the flattened value extracted from sparse tensors.
569 rewriter
.replaceOpWithNewOp
<func::ReturnOp
>(op
, flattened
);
574 /// Sparse tensor storage conversion rule for calls.
575 class SparseCallConverter
: public OpConversionPattern
<func::CallOp
> {
577 // The default CallOp converter can not handle 1:N type conversion.
578 using OpConversionPattern::OpConversionPattern
;
580 matchAndRewrite(func::CallOp op
, OpAdaptor adaptor
,
581 ConversionPatternRewriter
&rewriter
) const override
{
582 Location loc
= op
.getLoc();
584 // sparse_tensor, f, sparse_tensor = call @foo(...)
586 // memref..., f, memref = call @foo(...) replace with
587 // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
588 SmallVector
<Type
> finalRetTy
;
589 if (failed(typeConverter
->convertTypes(op
.getResultTypes(), finalRetTy
)))
592 // (1) Generates new call with flattened return value.
593 SmallVector
<Value
> flattened
;
594 flattenOperands(adaptor
.getOperands(), flattened
);
595 auto newCall
= rewriter
.create
<func::CallOp
>(loc
, op
.getCallee(),
596 finalRetTy
, flattened
);
597 // (2) Create cast operation for sparse tensor returns.
598 SmallVector
<Value
> castedRet
;
599 // Tracks the offset of current return value (of the original call)
600 // relative to the new call (after sparse tensor flattening);
601 unsigned retOffset
= 0;
602 // Temporal buffer to hold the flattened list of type for
604 SmallVector
<Type
> sparseFlat
;
605 for (auto ret
: op
.getResults()) {
606 assert(retOffset
< newCall
.getNumResults());
607 auto retType
= ret
.getType();
608 if (failed(typeConverter
->convertType(retType
, sparseFlat
)))
609 // This should never happen.
610 llvm_unreachable("Failed to convert type in sparse tensor codegen");
612 // Converted types can not be empty when the type conversion succeed.
613 assert(!sparseFlat
.empty());
614 if (sparseFlat
.size() > 1) {
615 auto flatSize
= sparseFlat
.size();
616 ValueRange
fields(iterator_range
<ResultRange::iterator
>(
617 newCall
.result_begin() + retOffset
,
618 newCall
.result_begin() + retOffset
+ flatSize
));
619 castedRet
.push_back(genTuple(rewriter
, loc
, retType
, fields
));
620 retOffset
+= flatSize
;
622 // If this is an 1:1 conversion, no need for casting.
623 castedRet
.push_back(newCall
.getResult(retOffset
));
629 assert(castedRet
.size() == op
.getNumResults());
630 rewriter
.replaceOp(op
, castedRet
);
635 /// Sparse codegen rule for level accesses.
636 class SparseLvlOpConverter
: public OpConversionPattern
<LvlOp
> {
638 using OpConversionPattern::OpConversionPattern
;
640 matchAndRewrite(LvlOp op
, OpAdaptor adaptor
,
641 ConversionPatternRewriter
&rewriter
) const override
{
642 std::optional
<int64_t> lvl
= op
.getConstantLvlIndex();
643 if (!lvl
|| !getSparseTensorEncoding(adaptor
.getSource().getType()))
646 auto desc
= getDescriptorFromTensorTuple(adaptor
.getSource());
647 auto sz
= desc
.getLvlSize(rewriter
, op
.getLoc(), *lvl
);
649 rewriter
.replaceOp(op
, sz
);
654 // TODO: use a new SortCOO operation here instead of reusing convert op.
655 struct SparseReorderCOOConverter
: public OpConversionPattern
<ReorderCOOOp
> {
656 using OpConversionPattern::OpConversionPattern
;
658 matchAndRewrite(ReorderCOOOp op
, ReorderCOOOpAdaptor adaptor
,
659 ConversionPatternRewriter
&rewriter
) const override
{
660 Location loc
= op
.getLoc();
661 MLIRContext
*ctx
= op
.getContext();
663 SparseTensorType srcStt
= getSparseTensorType(op
.getInputCoo());
664 SparseTensorType dstStt
= getSparseTensorType(op
.getResultCoo());
666 // Should have been verified.
667 assert(dstStt
.isAllOrdered() && !srcStt
.isAllOrdered() &&
668 isUniqueCOOType(srcStt
.getRankedTensorType()) &&
669 isUniqueCOOType(dstStt
.getRankedTensorType()));
670 assert(dstStt
.hasSameDimToLvl(srcStt
));
672 // We don't need a mutable descriptor here as we perform sorting in-place.
673 auto nnz
= genValMemSize(rewriter
, op
.getLoc(), adaptor
.getInputCoo());
674 auto desc
= getDescriptorFromTensorTuple(adaptor
.getInputCoo());
675 auto crd
= desc
.getAOSMemRef();
676 auto val
= desc
.getValMemRef();
678 // Otherwise we need another data shuffle and a non-identity map.
679 assert(dstStt
.hasSameDimToLvl(srcStt
));
680 (void)dstStt
; // to silence warning when assertion is disabled
682 auto id
= AffineMap::getMultiDimIdentityMap(srcStt
.getLvlRank(), ctx
);
684 rewriter
.create
<SortOp
>(loc
, nnz
, crd
, ValueRange
{val
}, id
,
685 rewriter
.getIndexAttr(0), op
.getAlgorithm());
687 // Since we do in-place sorting, the destinate tensor will have the same set
688 // of memrefs as the source tensor.
689 rewriter
.replaceOp(op
, adaptor
.getInputCoo());
694 template <typename Op
, StorageSpecifierKind kind
>
695 class SparseSliceGetterOpConverter
: public OpConversionPattern
<Op
> {
697 using OpConversionPattern
<Op
>::OpConversionPattern
;
699 matchAndRewrite(Op op
, typename
Op::Adaptor adaptor
,
700 ConversionPatternRewriter
&rewriter
) const override
{
701 // Simply lowers to specifer.get <field> operation.
702 auto desc
= getDescriptorFromTensorTuple(adaptor
.getSlice());
703 auto v
= desc
.getSpecifierField(rewriter
, op
.getLoc(), kind
,
704 op
.getDim().getZExtValue());
706 rewriter
.replaceOp(op
, v
);
711 /// Sparse codegen rule for trivial tensor casts.
712 class SparseCastConverter
: public OpConversionPattern
<tensor::CastOp
> {
714 using OpConversionPattern::OpConversionPattern
;
716 matchAndRewrite(tensor::CastOp op
, OpAdaptor adaptor
,
717 ConversionPatternRewriter
&rewriter
) const override
{
718 // Only rewrite identically annotated source/dest.
719 auto encDst
= getSparseTensorEncoding(op
.getType());
720 auto encSrc
= getSparseTensorEncoding(op
.getSource().getType());
721 if (!encDst
|| encDst
!= encSrc
)
723 rewriter
.replaceOp(op
, adaptor
.getOperands());
728 class SparseReMapConverter
: public OpConversionPattern
<ReinterpretMapOp
> {
730 using OpConversionPattern::OpConversionPattern
;
732 matchAndRewrite(ReinterpretMapOp op
, OpAdaptor adaptor
,
733 ConversionPatternRewriter
&rewriter
) const override
{
734 // Simply fold the operation.
735 rewriter
.replaceOp(op
, adaptor
.getSource());
740 /// Sparse codegen rule for the alloc operator.
741 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
742 class SparseTensorAllocConverter
743 : public OpConversionPattern
<bufferization::AllocTensorOp
> {
745 using OpConversionPattern::OpConversionPattern
;
746 SparseTensorAllocConverter(TypeConverter
&typeConverter
, MLIRContext
*context
,
748 : OpConversionPattern(typeConverter
, context
),
749 enableBufferInitialization(enableInit
) {}
752 matchAndRewrite(bufferization::AllocTensorOp op
, OpAdaptor adaptor
,
753 ConversionPatternRewriter
&rewriter
) const override
{
754 const auto resType
= getSparseTensorType(op
);
755 if (!resType
.hasEncoding())
758 // Construct allocation for each field.
759 const Location loc
= op
.getLoc();
761 auto desc
= getDescriptorFromTensorTuple(adaptor
.getCopy());
762 SmallVector
<Value
> fields
;
763 fields
.reserve(desc
.getNumFields());
764 // Memcpy on memref fields.
765 for (auto field
: desc
.getMemRefFields()) {
766 auto memrefTp
= cast
<MemRefType
>(field
.getType());
767 auto size
= rewriter
.create
<memref::DimOp
>(loc
, field
, 0);
769 rewriter
.create
<memref::AllocOp
>(loc
, memrefTp
, ValueRange
{size
});
770 rewriter
.create
<memref::CopyOp
>(loc
, field
, copied
);
771 fields
.push_back(copied
);
774 fields
.push_back(desc
.getSpecifier());
775 assert(fields
.size() == desc
.getNumFields());
776 rewriter
.replaceOp(op
, genTuple(rewriter
, loc
, resType
, fields
));
780 const Value sizeHint
= op
.getSizeHint();
781 const ValueRange dynSizes
= adaptor
.getDynamicSizes();
782 const size_t found
= dynSizes
.size();
783 const int64_t expected
= resType
.getNumDynamicDims();
784 if (found
!= static_cast<size_t>(expected
))
785 return rewriter
.notifyMatchFailure(
787 "Got wrong number of dynamic sizes: Found={0}, Expected={1}",
789 SmallVector
<Value
> fields
;
790 createAllocFields(rewriter
, loc
, resType
, dynSizes
,
791 enableBufferInitialization
, fields
, sizeHint
);
792 // Replace operation with resulting memrefs.
793 rewriter
.replaceOp(op
, genTuple(rewriter
, loc
, resType
, fields
));
798 bool enableBufferInitialization
;
801 /// Sparse codegen rule for the empty tensor operator.
802 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
803 class SparseTensorEmptyConverter
: public OpConversionPattern
<tensor::EmptyOp
> {
805 using OpConversionPattern::OpConversionPattern
;
806 SparseTensorEmptyConverter(TypeConverter
&typeConverter
, MLIRContext
*context
,
808 : OpConversionPattern(typeConverter
, context
),
809 enableBufferInitialization(enableInit
) {}
812 matchAndRewrite(tensor::EmptyOp op
, OpAdaptor adaptor
,
813 ConversionPatternRewriter
&rewriter
) const override
{
814 const auto resType
= getSparseTensorType(op
);
815 if (!resType
.hasEncoding())
818 // Construct allocation for each field.
819 const Location loc
= op
.getLoc();
820 const Value sizeHint
; // none
821 const ValueRange dynSizes
= adaptor
.getDynamicSizes();
822 const size_t found
= dynSizes
.size();
823 const int64_t expected
= resType
.getNumDynamicDims();
824 if (found
!= static_cast<size_t>(expected
))
825 return rewriter
.notifyMatchFailure(
827 "Got wrong number of dynamic sizes: Found={0}, Expected={1}",
829 SmallVector
<Value
> fields
;
830 createAllocFields(rewriter
, loc
, resType
, dynSizes
,
831 enableBufferInitialization
, fields
, sizeHint
);
832 // Replace operation with resulting memrefs.
833 rewriter
.replaceOp(op
, genTuple(rewriter
, loc
, resType
, fields
));
838 bool enableBufferInitialization
;
841 /// Sparse codegen rule for the dealloc operator.
842 class SparseTensorDeallocConverter
843 : public OpConversionPattern
<bufferization::DeallocTensorOp
> {
845 using OpConversionPattern::OpConversionPattern
;
846 SparseTensorDeallocConverter(TypeConverter
&typeConverter
,
847 MLIRContext
*context
, bool createDeallocs
)
848 : OpConversionPattern(typeConverter
, context
),
849 createDeallocs(createDeallocs
) {}
852 matchAndRewrite(bufferization::DeallocTensorOp op
, OpAdaptor adaptor
,
853 ConversionPatternRewriter
&rewriter
) const override
{
854 auto enc
= getSparseTensorEncoding(op
.getTensor().getType());
858 // If user requests not to deallocate sparse tensors, simply erase the
860 if (createDeallocs
) {
861 // Replace the sparse tensor deallocation with field deallocations.
862 Location loc
= op
.getLoc();
863 auto desc
= getDescriptorFromTensorTuple(adaptor
.getTensor());
864 for (auto input
: desc
.getMemRefFields())
865 // Deallocate every buffer used to store the sparse tensor handler.
866 rewriter
.create
<memref::DeallocOp
>(loc
, input
);
868 rewriter
.eraseOp(op
);
873 const bool createDeallocs
;
876 /// Sparse codegen rule for tensor rematerialization.
877 class SparseTensorLoadConverter
: public OpConversionPattern
<LoadOp
> {
879 using OpConversionPattern::OpConversionPattern
;
881 matchAndRewrite(LoadOp op
, OpAdaptor adaptor
,
882 ConversionPatternRewriter
&rewriter
) const override
{
883 // Prepare descriptor.
884 auto desc
= getDescriptorFromTensorTuple(adaptor
.getTensor());
885 // Generate optional insertion finalization code.
886 if (op
.getHasInserts())
887 genEndInsert(rewriter
, op
.getLoc(), desc
);
888 // Replace operation with resulting memrefs.
889 rewriter
.replaceOp(op
, genTuple(rewriter
, op
.getLoc(), desc
));
894 /// Sparse codegen rule for the expand op.
895 class SparseExpandConverter
: public OpConversionPattern
<ExpandOp
> {
897 using OpConversionPattern::OpConversionPattern
;
899 matchAndRewrite(ExpandOp op
, OpAdaptor adaptor
,
900 ConversionPatternRewriter
&rewriter
) const override
{
901 if (!getSparseTensorEncoding(op
.getTensor().getType()))
903 Location loc
= op
->getLoc();
904 auto desc
= getDescriptorFromTensorTuple(adaptor
.getTensor());
905 const auto srcType
= getSparseTensorType(op
.getTensor());
906 Type eltType
= srcType
.getElementType();
907 Type boolType
= rewriter
.getIntegerType(1);
908 Type idxType
= rewriter
.getIndexType();
909 // All initialization should be done on entry of the loop nest.
910 rewriter
.setInsertionPointAfter(op
.getTensor().getDefiningOp());
912 // Determine the size for access expansion (always the innermost stored
914 const auto sz
= desc
.getLvlSize(rewriter
, loc
, srcType
.getLvlRank() - 1);
915 // Generate a memref for `sz` elements of type `t`.
916 const auto genAlloc
= [&](Type t
) {
917 const auto memTp
= MemRefType::get({ShapedType::kDynamic
}, t
);
918 return rewriter
.create
<memref::AllocOp
>(loc
, memTp
, ValueRange
{sz
});
920 // Allocate temporary buffers for values/filled-switch and added.
921 // We do not use stack buffers for this, since the expanded size may
922 // be rather large (as it envelops a single expanded dense dimension).
923 Value values
= genAlloc(eltType
);
924 Value filled
= genAlloc(boolType
);
925 Value added
= genAlloc(idxType
);
926 Value zero
= constantZero(rewriter
, loc
, idxType
);
927 // Reset the values/filled-switch to all-zero/false. Note that this
928 // introduces an O(N) operation into the computation, but this reset
929 // operation is amortized over the innermost loops for the access
930 // pattern expansion. As noted in the operation doc, we would like
931 // to amortize this setup cost even between kernels.
932 rewriter
.create
<linalg::FillOp
>(
933 loc
, ValueRange
{constantZero(rewriter
, loc
, eltType
)},
935 rewriter
.create
<linalg::FillOp
>(
936 loc
, ValueRange
{constantZero(rewriter
, loc
, boolType
)},
938 // Replace expansion op with these buffers and initial coordinate.
939 assert(op
.getNumResults() == 4);
940 rewriter
.replaceOp(op
, {values
, filled
, added
, zero
});
945 /// Sparse codegen rule for the compress operator.
946 class SparseCompressConverter
: public OpConversionPattern
<CompressOp
> {
948 using OpConversionPattern::OpConversionPattern
;
950 matchAndRewrite(CompressOp op
, OpAdaptor adaptor
,
951 ConversionPatternRewriter
&rewriter
) const override
{
952 Location loc
= op
->getLoc();
953 SmallVector
<Value
> fields
;
954 auto desc
= getMutDescriptorFromTensorTuple(adaptor
.getTensor(), fields
);
955 Value values
= adaptor
.getValues();
956 Value filled
= adaptor
.getFilled();
957 Value added
= adaptor
.getAdded();
958 Value count
= adaptor
.getCount();
959 const SparseTensorType
dstType(desc
.getRankedTensorType());
960 Type eltType
= dstType
.getElementType();
962 // If the innermost level is ordered, we need to sort the coordinates
963 // in the "added" array prior to applying the compression.
964 if (dstType
.isOrderedLvl(dstType
.getLvlRank() - 1))
965 rewriter
.create
<SortOp
>(
966 loc
, count
, added
, ValueRange
{}, rewriter
.getMultiDimIdentityMap(1),
967 rewriter
.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort
);
968 // While performing the insertions, we also need to reset the elements
969 // of the values/filled-switch by only iterating over the set elements,
970 // to ensure that the runtime complexity remains proportional to the
971 // sparsity of the expanded access pattern.
974 // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
976 // value = values[crd];
977 // insert({lvlCoords, crd}, value);
978 // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
980 // filled[crd] = false;
983 scf::ForOp loop
= createFor(rewriter
, loc
, count
, desc
.getFields());
984 Value i
= loop
.getInductionVar();
986 Value crd
= genLoad(rewriter
, loc
, added
, i
);
987 Value value
= genLoad(rewriter
, loc
, values
, crd
);
988 SmallVector
<Value
> params(desc
.getFields().begin(), desc
.getFields().end());
989 SmallVector
<Type
> flatSpTensorTps
= llvm::to_vector(
990 llvm::map_range(desc
.getFields(), [](Value v
) { return v
.getType(); }));
991 params
.append(adaptor
.getLvlCoords().begin(), adaptor
.getLvlCoords().end());
992 params
.push_back(crd
);
993 params
.push_back(value
);
994 SparseInsertGenerator
insertGen(op
.getTensor().getType(), flatSpTensorTps
,
995 params
, /*genCall=*/true);
996 SmallVector
<Value
> insertRet
= insertGen
.genCallOrInline(rewriter
, loc
);
997 genStore(rewriter
, loc
, constantZero(rewriter
, loc
, eltType
), values
, crd
);
998 genStore(rewriter
, loc
, constantI1(rewriter
, loc
, false), filled
, crd
);
999 rewriter
.create
<scf::YieldOp
>(loc
, insertRet
);
1001 rewriter
.setInsertionPointAfter(loop
);
1002 Value result
= genTuple(rewriter
, loc
, dstType
, loop
->getResults());
1003 // Deallocate the buffers on exit of the full loop nest.
1004 Operation
*parent
= getTop(op
);
1005 rewriter
.setInsertionPointAfter(parent
);
1006 rewriter
.create
<memref::DeallocOp
>(loc
, values
);
1007 rewriter
.create
<memref::DeallocOp
>(loc
, filled
);
1008 rewriter
.create
<memref::DeallocOp
>(loc
, added
);
1009 // Replace operation with resulting memrefs.
1010 rewriter
.replaceOp(op
, result
);
1015 /// Sparse codegen rule for the insert operator.
1016 class SparseInsertConverter
: public OpConversionPattern
<InsertOp
> {
1018 using OpConversionPattern::OpConversionPattern
;
1020 matchAndRewrite(InsertOp op
, OpAdaptor adaptor
,
1021 ConversionPatternRewriter
&rewriter
) const override
{
1022 Location loc
= op
.getLoc();
1023 auto desc
= getDescriptorFromTensorTuple(adaptor
.getTensor());
1024 TypeRange flatSpTensorTps
= desc
.getFields().getTypes();
1025 SmallVector
<Value
> params
= llvm::to_vector(desc
.getFields());
1026 params
.append(adaptor
.getLvlCoords().begin(), adaptor
.getLvlCoords().end());
1027 params
.push_back(adaptor
.getValue());
1028 SparseInsertGenerator
insertGen(op
.getTensor().getType(), flatSpTensorTps
,
1029 params
, /*genCall=*/true);
1030 SmallVector
<Value
> ret
= insertGen
.genCallOrInline(rewriter
, loc
);
1031 // Replace operation with resulting memrefs.
1032 rewriter
.replaceOp(op
,
1033 genTuple(rewriter
, loc
, op
.getTensor().getType(), ret
));
1038 /// Sparse codegen rule for position accesses.
1039 class SparseToPositionsConverter
: public OpConversionPattern
<ToPositionsOp
> {
1041 using OpAdaptor
= typename
ToPositionsOp::Adaptor
;
1042 using OpConversionPattern
<ToPositionsOp
>::OpConversionPattern
;
1044 matchAndRewrite(ToPositionsOp op
, OpAdaptor adaptor
,
1045 ConversionPatternRewriter
&rewriter
) const override
{
1046 // Replace the requested position access with corresponding field.
1047 // The cast_op is inserted by type converter to intermix 1:N type
1049 auto desc
= getDescriptorFromTensorTuple(adaptor
.getTensor());
1050 rewriter
.replaceOp(op
, desc
.getPosMemRef(op
.getLevel()));
1055 /// Sparse codegen rule for accessing the coordinates arrays.
1056 class SparseToCoordinatesConverter
1057 : public OpConversionPattern
<ToCoordinatesOp
> {
1059 using OpAdaptor
= typename
ToCoordinatesOp::Adaptor
;
1060 using OpConversionPattern
<ToCoordinatesOp
>::OpConversionPattern
;
1062 matchAndRewrite(ToCoordinatesOp op
, OpAdaptor adaptor
,
1063 ConversionPatternRewriter
&rewriter
) const override
{
1064 // Replace the requested coordinates access with corresponding field.
1065 // The cast_op is inserted by type converter to intermix 1:N type
1067 Location loc
= op
.getLoc();
1068 auto desc
= getDescriptorFromTensorTuple(adaptor
.getTensor());
1069 Value field
= desc
.getCrdMemRefOrView(rewriter
, loc
, op
.getLevel());
1071 // Insert a cast to bridge the actual type to the user expected type. If the
1072 // actual type and the user expected type aren't compatible, the compiler or
1073 // the runtime will issue an error.
1074 Type resType
= op
.getResult().getType();
1075 if (resType
!= field
.getType())
1076 field
= rewriter
.create
<memref::CastOp
>(loc
, resType
, field
);
1077 rewriter
.replaceOp(op
, field
);
1083 /// Sparse codegen rule for accessing the linear coordinates buffer.
1084 class SparseToCoordinatesBufferConverter
1085 : public OpConversionPattern
<ToCoordinatesBufferOp
> {
1087 using OpAdaptor
= typename
ToCoordinatesBufferOp::Adaptor
;
1088 using OpConversionPattern
<ToCoordinatesBufferOp
>::OpConversionPattern
;
1090 matchAndRewrite(ToCoordinatesBufferOp op
, OpAdaptor adaptor
,
1091 ConversionPatternRewriter
&rewriter
) const override
{
1092 // Replace the requested coordinates access with corresponding field.
1093 // The cast_op is inserted by type converter to intermix 1:N type
1095 auto desc
= getDescriptorFromTensorTuple(adaptor
.getTensor());
1096 rewriter
.replaceOp(op
, desc
.getAOSMemRef());
1102 /// Sparse codegen rule for value accesses.
1103 class SparseToValuesConverter
: public OpConversionPattern
<ToValuesOp
> {
1105 using OpAdaptor
= typename
ToValuesOp::Adaptor
;
1106 using OpConversionPattern
<ToValuesOp
>::OpConversionPattern
;
1108 matchAndRewrite(ToValuesOp op
, OpAdaptor adaptor
,
1109 ConversionPatternRewriter
&rewriter
) const override
{
1110 // Replace the requested values access with corresponding field.
1111 // The cast_op is inserted by type converter to intermix 1:N type
1113 auto desc
= getDescriptorFromTensorTuple(adaptor
.getTensor());
1114 rewriter
.replaceOp(op
, desc
.getValMemRef());
1119 /// Sparse codegen rule for the convert operator.
1120 class SparseConvertConverter
: public OpConversionPattern
<ConvertOp
> {
1122 using OpConversionPattern::OpConversionPattern
;
1124 matchAndRewrite(ConvertOp op
, OpAdaptor adaptor
,
1125 ConversionPatternRewriter
&rewriter
) const override
{
1126 SparseTensorEncodingAttr encDst
= getSparseTensorEncoding(op
.getType());
1127 SparseTensorEncodingAttr encSrc
=
1128 getSparseTensorEncoding(op
.getSource().getType());
1129 // The output tensor can not be a slice and those cases should have been
1130 // rejected by ConvertOp::verify() already.
1131 assert(!encDst
.isSlice() && "Cannot convert to a sparse tensor slices.");
1132 // Different encoding (except for different bitwidth) should be handled by
1134 // We need further rewrites if the input tensor is a slice too.
1135 if (encDst
.withoutBitWidths() != encSrc
.withoutBitWidths() ||
1140 Type retElemTp
= op
.getResult().getType().getElementType();
1141 Type srcElemTp
= op
.getSource().getType().getElementType();
1142 // Fold the trivial cases.
1143 if (retElemTp
== srcElemTp
&& encDst
== encSrc
) {
1144 rewriter
.replaceOp(op
, adaptor
.getSource());
1148 // Do element-wise type conversion without using InsertOp.
1150 // for each memref in srcTensor:
1151 // dst = memref.alloc
1152 // if srcMemRefType != dstMemRefType:
1153 // for every dst[i] = cast(src[i])
1155 // dst = memref.copy(src)
1156 Location loc
= op
.getLoc();
1157 auto srcDesc
= getDescriptorFromTensorTuple(adaptor
.getSource());
1158 SmallVector
<Value
> fields
;
1159 foreachFieldAndTypeInSparseTensor(
1160 SparseTensorType(cast
<RankedTensorType
>(op
.getResult().getType())),
1161 [&rewriter
, &fields
, srcDesc
,
1162 loc
](Type fTp
, FieldIndex fIdx
, SparseTensorFieldKind fKind
, Level lvl
,
1163 DimLevelType
/*dlt*/) -> bool {
1164 // Simply reuses the storage specifier as it is an SSA value.
1165 if (fKind
== SparseTensorFieldKind::StorageSpec
) {
1166 fields
.push_back(srcDesc
.getSpecifier());
1168 // Allocates new memrefs
1169 Value srcMem
= srcDesc
.getMemRefField(fIdx
);
1170 // TODO: We can instead use the actual memSize in specifier, that
1171 // would require a subViewOp to avoid overflow when copying
1173 Value sz
= linalg::createOrFoldDimOp(rewriter
, loc
, srcMem
, 0);
1174 auto dstMem
= rewriter
.create
<memref::AllocOp
>(
1175 loc
, cast
<MemRefType
>(fTp
), sz
);
1176 if (fTp
!= srcMem
.getType()) {
1177 // Converts elements type.
1179 rewriter
, loc
, constantIndex(rewriter
, loc
, 0), sz
,
1180 constantIndex(rewriter
, loc
, 1),
1181 [srcMem
, &dstMem
](OpBuilder
&builder
, Location loc
,
1183 Value v
= builder
.create
<memref::LoadOp
>(loc
, srcMem
, ivs
);
1184 Value casted
= genCast(builder
, loc
, v
,
1185 dstMem
.getType().getElementType());
1186 builder
.create
<memref::StoreOp
>(loc
, casted
, dstMem
, ivs
);
1189 // TODO: We can even reuse the same memref for the new tensor,
1190 // but that requires a `ref-counting` based memory management
1191 // for shared memrefs between multiple sparse tensors.
1192 rewriter
.create
<memref::CopyOp
>(loc
, srcMem
, dstMem
);
1194 fields
.push_back(dstMem
);
1200 op
, genTuple(rewriter
, loc
, op
.getResult().getType(), fields
));
1205 class SparseExtractSliceConverter
1206 : public OpConversionPattern
<tensor::ExtractSliceOp
> {
1208 using OpConversionPattern::OpConversionPattern
;
1210 matchAndRewrite(tensor::ExtractSliceOp op
, OpAdaptor adaptor
,
1211 ConversionPatternRewriter
&rewriter
) const override
{
1212 Location loc
= op
.getLoc();
1213 MLIRContext
*ctx
= op
.getContext();
1214 auto srcEnc
= getSparseTensorEncoding(op
.getSourceType());
1215 auto dstEnc
= getSparseTensorEncoding(op
.getResult().getType());
1216 // TODO: We should check these in ExtractSliceOp::verify.
1217 if (!srcEnc
|| !dstEnc
|| !dstEnc
.isSlice())
1219 assert(srcEnc
.withoutDimSlices() == dstEnc
.withoutDimSlices());
1221 SmallVector
<Value
> fields
;
1222 auto desc
= getMutDescriptorFromTensorTuple(adaptor
.getSource(), fields
);
1224 auto newSpec
= rewriter
.create
<StorageSpecifierInitOp
>(
1225 loc
, StorageSpecifierType::get(ctx
, dstEnc
), desc
.getSpecifier());
1226 desc
.setSpecifier(newSpec
);
1228 // Fills in slice information.
1229 for (auto [idx
, offset
, size
, stride
] : llvm::enumerate(
1230 op
.getMixedOffsets(), op
.getMixedSizes(), op
.getMixedStrides())) {
1231 Dimension dim
= idx
;
1233 Value offsetV
= getValueOrCreateConstantIndexOp(rewriter
, loc
, offset
);
1234 Value sizeV
= getValueOrCreateConstantIndexOp(rewriter
, loc
, size
);
1235 Value strideV
= getValueOrCreateConstantIndexOp(rewriter
, loc
, stride
);
1236 // TODO: We could probably only set dynamic value here. But it would
1237 // requires us to fill the hole when casting a static slice to dynamic
1239 desc
.setSpecifierField(rewriter
, loc
, StorageSpecifierKind::DimOffset
,
1242 // FIXME: we need to distinguish level sizes and dimension size for slices
1243 // here. Maybe we should store slice level sizes in a different array
1244 // instead of reusing it.
1245 assert(srcEnc
.isIdentity());
1246 desc
.setSpecifierField(rewriter
, loc
, StorageSpecifierKind::LvlSize
, dim
,
1248 desc
.setSpecifierField(rewriter
, loc
, StorageSpecifierKind::DimStride
,
1252 // NOTE: we can not generate tuples directly from descriptor here, as the
1253 // descriptor is holding the original type, yet we want the slice type
1254 // here (they shared every memref but with an updated specifier).
1255 rewriter
.replaceOp(op
, genTuple(rewriter
, loc
, op
.getResult().getType(),
1261 /// Sparse codegen rule for number of entries operator.
1262 class SparseNumberOfEntriesConverter
1263 : public OpConversionPattern
<NumberOfEntriesOp
> {
1265 using OpConversionPattern::OpConversionPattern
;
1267 matchAndRewrite(NumberOfEntriesOp op
, OpAdaptor adaptor
,
1268 ConversionPatternRewriter
&rewriter
) const override
{
1269 // Query memSizes for the actually stored values.
1270 // FIXME: the nse value computed in this way might be wrong when there is
1271 // any "loose_compressed" level.
1273 op
, genValMemSize(rewriter
, op
.getLoc(), adaptor
.getTensor()));
1278 struct SparseAssembleOpConverter
: public OpConversionPattern
<AssembleOp
> {
1279 using OpConversionPattern::OpConversionPattern
;
1281 matchAndRewrite(AssembleOp op
, OpAdaptor adaptor
,
1282 ConversionPatternRewriter
&rewriter
) const override
{
1283 Location loc
= op
.getLoc();
1284 const auto stt
= getSparseTensorType(op
.getResult());
1286 SmallVector
<Value
> fields
;
1288 foreachFieldAndTypeInSparseTensor(
1290 [&rewriter
, &fields
, &op
, &stt
,
1291 loc
](Type fType
, FieldIndex fIdx
, SparseTensorFieldKind fKind
,
1292 Level
/*lvl*/, DimLevelType dlt
) -> bool {
1293 assert(fields
.size() == fIdx
);
1294 if (fKind
== SparseTensorFieldKind::StorageSpec
) {
1296 SparseTensorSpecifier::getInitValue(rewriter
, loc
, stt
));
1298 // Else simply takes the inputs.
1299 Value tensor
= fKind
== SparseTensorFieldKind::ValMemRef
1301 : op
.getLevels()[fIdx
];
1303 TypedValue
<BaseMemRefType
> mem
= genToMemref(rewriter
, loc
, tensor
);
1304 if (mem
.getType().getRank() > 1) {
1305 // Flattens the buffer to rank 1.
1306 auto reassoc
= getReassociationForFlattening(mem
.getType());
1307 mem
= rewriter
.create
<memref::CastOp
>(
1309 rewriter
.create
<memref::CollapseShapeOp
>(loc
, mem
, reassoc
));
1311 mem
= rewriter
.create
<memref::CastOp
>(loc
, fType
, mem
);
1313 fields
.push_back(mem
);
1318 MutSparseTensorDescriptor
desc(stt
, fields
);
1319 Value c0
= constantIndex(rewriter
, loc
, 0);
1320 Value c1
= constantIndex(rewriter
, loc
, 1);
1321 Value c2
= constantIndex(rewriter
, loc
, 2);
1322 Value posBack
= c0
; // index to the last value in the position array
1323 Value memSize
= c1
; // memory size for current array
1325 Level trailCOOStart
= getCOOStart(stt
.getEncoding());
1326 Level trailCOORank
= stt
.getLvlRank() - trailCOOStart
;
1327 // Sets up SparseTensorSpecifier.
1328 for (Level lvl
= 0, lvlRank
= stt
.getLvlRank(); lvl
< lvlRank
; lvl
++) {
1329 assert(!ShapedType::isDynamic(stt
.getDimShape()[lvl
]));
1331 // FIXME: dim/lvl confusion!
1332 // Sets up the level size.
1333 auto lvlSize
= constantIndex(rewriter
, loc
, stt
.getDimShape()[lvl
]);
1334 desc
.setLvlSize(rewriter
, loc
, lvl
, lvlSize
);
1335 // We use a single AOS array to store the trailing COO, so there is only
1336 // one memory size to set for the entire COO section.
1337 if (lvl
> trailCOOStart
)
1340 // Sets up the memory size by reading the last value in position array.
1341 DimLevelType dlt
= stt
.getLvlType(lvl
);
1342 // Simply forwards the position index when this is a dense level.
1343 if (isDenseDLT(dlt
)) {
1344 memSize
= rewriter
.create
<arith::MulIOp
>(loc
, lvlSize
, memSize
);
1345 posBack
= rewriter
.create
<arith::SubIOp
>(loc
, memSize
, c1
);
1349 if (isDLTWithPos(dlt
)) {
1350 assert(isCompressedDLT(dlt
) || isLooseCompressedDLT(dlt
));
1351 if (isLooseCompressedDLT(dlt
)) {
1352 memSize
= rewriter
.create
<arith::MulIOp
>(loc
, memSize
, c2
);
1353 posBack
= rewriter
.create
<arith::SubIOp
>(loc
, memSize
, c1
);
1355 assert(isCompressedDLT(dlt
));
1357 memSize
= rewriter
.create
<arith::AddIOp
>(loc
, memSize
, c1
);
1359 desc
.setPosMemSize(rewriter
, loc
, lvl
, memSize
);
1360 // The last value in position array is the memory size for next level.
1361 memSize
= genIndexLoad(rewriter
, loc
, desc
.getPosMemRef(lvl
), posBack
);
1362 posBack
= rewriter
.create
<arith::SubIOp
>(loc
, posBack
, c1
);
1364 assert(isDLTWithCrd(dlt
) && lvl
<= trailCOOStart
);
1365 // FIXME: This seems to be unnecessarily complex, can we simplify it?
1366 if (lvl
== trailCOOStart
) {
1367 Value cooSz
= rewriter
.create
<arith::MulIOp
>(
1368 loc
, memSize
, constantIndex(rewriter
, loc
, trailCOORank
));
1369 desc
.setCrdMemSize(rewriter
, loc
, lvl
, cooSz
);
1371 desc
.setCrdMemSize(rewriter
, loc
, lvl
, memSize
);
1374 desc
.setValMemSize(rewriter
, loc
, memSize
);
1376 rewriter
.replaceOp(op
, genTuple(rewriter
, loc
, desc
));
1381 struct SparseDisassembleOpConverter
1382 : public OpConversionPattern
<DisassembleOp
> {
1383 using OpConversionPattern::OpConversionPattern
;
1384 SparseDisassembleOpConverter(TypeConverter
&typeConverter
,
1385 MLIRContext
*context
)
1386 : OpConversionPattern(typeConverter
, context
) {}
1389 matchAndRewrite(DisassembleOp op
, OpAdaptor adaptor
,
1390 ConversionPatternRewriter
&rewriter
) const override
{
1391 auto desc
= getDescriptorFromTensorTuple(adaptor
.getTensor());
1392 Location loc
= op
.getLoc();
1393 SmallVector
<Value
> retMem
;
1394 SmallVector
<Value
> retLen
;
1395 desc
.getLayout().foreachField([desc
, loc
, &rewriter
, &op
, &retMem
, &retLen
](
1397 SparseTensorFieldKind fKind
, Level lvl
,
1398 DimLevelType dlt
) -> bool {
1399 if (fKind
== SparseTensorFieldKind::StorageSpec
)
1401 SparseTensorType
stt(desc
.getRankedTensorType());
1403 TypedValue
<BaseMemRefType
> dst
;
1404 if (fKind
== SparseTensorFieldKind::ValMemRef
) {
1405 sz
= desc
.getValMemSize(rewriter
, loc
);
1406 src
= desc
.getValMemRef();
1407 dst
= genToMemref(rewriter
, loc
, op
.getOutValues());
1408 // Values is the last field in descriptor, but it is the first
1409 // operand in unpack operation.
1410 // TODO: maybe change unpack/pack operation instead to be
1412 retMem
.insert(retMem
.begin(), dst
);
1413 Type valLenTp
= op
.getValLen().getType();
1414 retLen
.insert(retLen
.begin(),
1415 genScalarToTensor(rewriter
, loc
, sz
, valLenTp
));
1417 assert(fKind
== SparseTensorFieldKind::PosMemRef
||
1418 fKind
== SparseTensorFieldKind::CrdMemRef
);
1420 sz
= fKind
== SparseTensorFieldKind::PosMemRef
1421 ? desc
.getPosMemSize(rewriter
, loc
, lvl
)
1422 : desc
.getCrdMemSize(rewriter
, loc
, lvl
);
1423 src
= desc
.getMemRefField(fid
);
1424 dst
= genToMemref(rewriter
, loc
, op
.getOutLevels()[fid
]);
1425 retMem
.push_back(dst
);
1426 // Retrieves the corresponding level length type.
1427 Type lvlLenTp
= op
.getLvlLens().getTypes()[retLen
.size()];
1428 retLen
.push_back(genScalarToTensor(rewriter
, loc
, sz
, lvlLenTp
));
1430 Value flatOut
= dst
;
1431 if (dst
.getType().getRank() != 1) {
1432 auto reassoc
= getReassociationForFlattening(dst
.getType());
1433 flatOut
= rewriter
.create
<memref::CollapseShapeOp
>(loc
, dst
, reassoc
);
1435 Value dstMem
= genSliceToSize(rewriter
, loc
, flatOut
, sz
);
1436 Value srcMem
= genSliceToSize(rewriter
, loc
, src
, sz
);
1437 rewriter
.create
<memref::CopyOp
>(loc
, srcMem
, dstMem
);
1441 // Converts MemRefs back to Tensors.
1442 SmallVector
<Value
> retValues
= llvm::to_vector(
1443 llvm::map_range(retMem
, [&rewriter
, loc
](Value v
) -> Value
{
1444 return rewriter
.create
<bufferization::ToTensorOp
>(loc
, v
);
1446 // Appends the actual memory length used in each buffer returned.
1447 retValues
.append(retLen
.begin(), retLen
.end());
1448 rewriter
.replaceOp(op
, retValues
);
1453 struct SparseNewConverter
: public OpConversionPattern
<NewOp
> {
1454 using OpConversionPattern::OpConversionPattern
;
1456 matchAndRewrite(NewOp op
, OpAdaptor adaptor
,
1457 ConversionPatternRewriter
&rewriter
) const override
{
1458 Location loc
= op
.getLoc();
1459 const auto dstTp
= getSparseTensorType(op
.getResult());
1460 // Creating COO with NewOp is handled by direct IR codegen. All other cases
1461 // are handled by rewriting.
1462 if (!dstTp
.hasEncoding() || getCOOStart(dstTp
.getEncoding()) != 0)
1465 // Implement as follows:
1466 // %reader = @createCheckedSparseTensorReader(%filename)
1467 // %nse = @getSparseTensorNSE(%reader)
1468 // %coo = bufferization.alloc_tensor an ordered COO with
1469 // dst dim ordering, size_hint = %nse
1470 // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1471 // %values = sparse_tensor.values(%coo)
1472 // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1473 // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1474 // update storage specifier
1475 // @delSparseTensorReader(%reader)
1476 SmallVector
<Value
> dimShapesValues
;
1477 Value dimSizesBuffer
;
1478 Value reader
= genReader(rewriter
, loc
, dstTp
, adaptor
.getOperands()[0],
1479 dimShapesValues
, dimSizesBuffer
);
1481 // Get the number of stored entries.
1482 const Type indexTp
= rewriter
.getIndexType();
1483 Value nse
= createFuncCall(rewriter
, loc
, "getSparseTensorReaderNSE",
1484 {indexTp
}, {reader
}, EmitCInterface::Off
)
1487 // Construct allocation for each field.
1488 SmallVector
<Value
> dynSizes
;
1489 if (dstTp
.hasDynamicDimShape()) {
1490 for (const auto &d
: llvm::enumerate(dstTp
.getDimShape()))
1491 if (ShapedType::isDynamic(d
.value()))
1492 dynSizes
.push_back(rewriter
.create
<memref::LoadOp
>(
1493 loc
, dimSizesBuffer
, constantIndex(rewriter
, loc
, d
.index())));
1495 SmallVector
<Value
> fields
;
1496 createAllocFields(rewriter
, loc
, dstTp
, dynSizes
, /*enableInit=*/false,
1498 MutSparseTensorDescriptor
desc(dstTp
, fields
);
1500 // Now construct the dim2lvl and lvl2dim buffers.
1501 Value dim2lvlBuffer
;
1502 Value lvl2dimBuffer
;
1503 genMapBuffers(rewriter
, loc
, dstTp
, dimShapesValues
, dimSizesBuffer
,
1504 dim2lvlBuffer
, lvl2dimBuffer
);
1506 // Read the COO tensor data.
1507 Value xs
= desc
.getAOSMemRef();
1508 Value ys
= desc
.getValMemRef();
1509 const Type boolTp
= rewriter
.getIntegerType(1);
1510 const Type elemTp
= dstTp
.getElementType();
1511 const Type crdTp
= dstTp
.getCrdType();
1512 SmallString
<32> readToBuffersFuncName
{"getSparseTensorReaderReadToBuffers",
1513 overheadTypeFunctionSuffix(crdTp
),
1514 primaryTypeFunctionSuffix(elemTp
)};
1516 createFuncCall(rewriter
, loc
, readToBuffersFuncName
, {boolTp
},
1517 {reader
, dim2lvlBuffer
, lvl2dimBuffer
, xs
, ys
},
1521 // If the destination tensor is a sorted COO, we need to sort the COO tensor
1522 // data if the input elements aren't sorted yet.
1523 const Level lvlRank
= dstTp
.getLvlRank();
1524 if (dstTp
.isOrderedLvl(lvlRank
- 1)) {
1525 Value kFalse
= constantI1(rewriter
, loc
, false);
1526 Value notSorted
= rewriter
.create
<arith::CmpIOp
>(
1527 loc
, arith::CmpIPredicate::eq
, isSorted
, kFalse
);
1529 rewriter
.create
<scf::IfOp
>(loc
, notSorted
, /*else*/ false);
1530 rewriter
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
1531 auto xPerm
= rewriter
.getMultiDimIdentityMap(lvlRank
);
1532 rewriter
.create
<SortOp
>(loc
, nse
, xs
, ValueRange
{ys
}, xPerm
,
1533 rewriter
.getIndexAttr(0),
1534 SparseTensorSortKind::HybridQuickSort
);
1535 rewriter
.setInsertionPointAfter(ifOp
);
1538 // Set PosMemRef0[1] = nse.
1539 const Value c1
= constantIndex(rewriter
, loc
, 1);
1540 const Value posMemref0
= desc
.getPosMemRef(0);
1541 const Type posTp
= dstTp
.getPosType();
1542 const Value posNse
= genCast(rewriter
, loc
, nse
, posTp
);
1543 rewriter
.create
<memref::StoreOp
>(loc
, posNse
, posMemref0
, c1
);
1545 // Update storage specifier.
1546 Value coordinatesSize
= rewriter
.create
<arith::MulIOp
>(
1547 loc
, nse
, constantIndex(rewriter
, loc
, lvlRank
));
1548 desc
.setSpecifierField(rewriter
, loc
, StorageSpecifierKind::CrdMemSize
, 0,
1550 desc
.setSpecifierField(rewriter
, loc
, StorageSpecifierKind::ValMemSize
,
1553 // Release the sparse tensor reader.
1554 createFuncCall(rewriter
, loc
, "delSparseTensorReader", {}, {reader
},
1555 EmitCInterface::Off
);
1557 // Replace operation with resulting memrefs.
1558 rewriter
.replaceOp(op
, genTuple(rewriter
, loc
, dstTp
, fields
));
1565 //===----------------------------------------------------------------------===//
1566 // Public method for populating conversion rules.
1567 //===----------------------------------------------------------------------===//
1569 /// Populates the given patterns list with conversion rules required for
1570 /// the sparsification of linear algebra operations.
1571 void mlir::populateSparseTensorCodegenPatterns(
1572 TypeConverter
&typeConverter
, RewritePatternSet
&patterns
,
1573 bool createSparseDeallocs
, bool enableBufferInitialization
) {
1574 patterns
.add
<SparseAssembleOpConverter
, SparseDisassembleOpConverter
,
1575 SparseReturnConverter
, SparseCallConverter
, SparseLvlOpConverter
,
1576 SparseCastConverter
, SparseExtractSliceConverter
,
1577 SparseTensorLoadConverter
, SparseExpandConverter
,
1578 SparseCompressConverter
, SparseInsertConverter
,
1579 SparseReorderCOOConverter
, SparseReMapConverter
,
1580 SparseSliceGetterOpConverter
<ToSliceOffsetOp
,
1581 StorageSpecifierKind::DimOffset
>,
1582 SparseSliceGetterOpConverter
<ToSliceStrideOp
,
1583 StorageSpecifierKind::DimStride
>,
1584 SparseToPositionsConverter
, SparseToCoordinatesConverter
,
1585 SparseToCoordinatesBufferConverter
, SparseToValuesConverter
,
1586 SparseConvertConverter
, SparseNewConverter
,
1587 SparseNumberOfEntriesConverter
>(typeConverter
,
1588 patterns
.getContext());
1589 patterns
.add
<SparseTensorDeallocConverter
>(
1590 typeConverter
, patterns
.getContext(), createSparseDeallocs
);
1591 patterns
.add
<SparseTensorAllocConverter
, SparseTensorEmptyConverter
>(
1592 typeConverter
, patterns
.getContext(), enableBufferInitialization
);