[mlir][sparse] simplify some header code (#70989)
[llvm-project.git] / mlir / lib / Dialect / SparseTensor / Transforms / SparseTensorCodegen.cpp
blob8c6312150f4c832e0565e9ddfecc49997fa7376d
1 //===- SparseTensorCodegen.cpp - Sparse tensor primitives conversion ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A pass that converts sparse tensor types and primitives to actual compiler
10 // visible buffers and actual compiler IR that implements these primitives on
11 // the selected sparse tensor storage schemes. This pass provides an alternative
12 // to the SparseTensorConversion pass, eliminating the dependence on a runtime
13 // support library, and providing much more opportunities for subsequent
14 // compiler optimization of the generated code.
16 //===----------------------------------------------------------------------===//
18 #include "CodegenUtils.h"
19 #include "SparseTensorDescriptor.h"
21 #include "llvm/Support/FormatVariadic.h"
23 #include "mlir/Dialect/Arith/Utils/Utils.h"
24 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/Linalg/Utils/Utils.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
29 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
30 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
31 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
32 #include "mlir/Dialect/Tensor/IR/Tensor.h"
33 #include "mlir/Transforms/DialectConversion.h"
35 #include <optional>
37 using namespace mlir;
38 using namespace mlir::sparse_tensor;
40 namespace {
42 using FuncGeneratorType =
43 function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, RankedTensorType)>;
45 //===----------------------------------------------------------------------===//
46 // Helper methods.
47 //===----------------------------------------------------------------------===//
49 /// Flatten a list of operands that may contain sparse tensors.
50 static void flattenOperands(ValueRange operands,
51 SmallVectorImpl<Value> &flattened) {
52 // In case of
53 // sparse_tensor, c, sparse_tensor
54 // ==>
55 // memref ..., c, memref ...
56 for (auto operand : operands) {
57 if (getSparseTensorEncoding(operand.getType())) {
58 auto tuple = getTuple(operand);
59 // An unrealized_conversion_cast will be inserted by type converter to
60 // inter-mix the gap between 1:N conversion between sparse tensors and
61 // fields. In this case, take the operands in the cast and replace the
62 // sparse tensor output with the flattened type array.
63 flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
64 } else {
65 flattened.push_back(operand);
70 /// Generates a load with proper `index` typing.
71 static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
72 idx = genCast(builder, loc, idx, builder.getIndexType());
73 return builder.create<memref::LoadOp>(loc, mem, idx);
76 /// Generates a store with proper `index` typing and proper value.
77 static void genStore(OpBuilder &builder, Location loc, Value val, Value mem,
78 Value idx) {
79 idx = genCast(builder, loc, idx, builder.getIndexType());
80 val = genCast(builder, loc, val,
81 cast<ShapedType>(mem.getType()).getElementType());
82 builder.create<memref::StoreOp>(loc, val, mem, idx);
85 /// Creates a straightforward counting for-loop.
86 static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
87 MutableArrayRef<Value> fields,
88 Value lower = Value()) {
89 Type indexType = builder.getIndexType();
90 if (!lower)
91 lower = constantZero(builder, loc, indexType);
92 Value one = constantOne(builder, loc, indexType);
93 scf::ForOp forOp = builder.create<scf::ForOp>(loc, lower, upper, one, fields);
94 for (unsigned i = 0, e = fields.size(); i < e; i++)
95 fields[i] = forOp.getRegionIterArg(i);
96 builder.setInsertionPointToStart(forOp.getBody());
97 return forOp;
100 static void createPushback(OpBuilder &builder, Location loc,
101 MutSparseTensorDescriptor desc,
102 SparseTensorFieldKind kind, std::optional<Level> lvl,
103 Value value, Value repeat = Value()) {
104 Type etp = desc.getMemRefElementType(kind, lvl);
105 Value field = desc.getMemRefField(kind, lvl);
106 StorageSpecifierKind specFieldKind = toSpecifierKind(kind);
108 auto pushBackOp = builder.create<PushBackOp>(
109 loc, desc.getSpecifierField(builder, loc, specFieldKind, lvl), field,
110 genCast(builder, loc, value, etp), repeat);
112 desc.setMemRefField(kind, lvl, pushBackOp.getOutBuffer());
113 desc.setSpecifierField(builder, loc, specFieldKind, lvl,
114 pushBackOp.getNewSize());
117 /// Generates code that allocates a sparse storage scheme for given rank.
118 static void allocSchemeForRank(OpBuilder &builder, Location loc,
119 MutSparseTensorDescriptor desc, Level startLvl) {
120 const SparseTensorType stt(desc.getRankedTensorType());
121 Value linear = constantIndex(builder, loc, 1);
122 const Level lvlRank = stt.getLvlRank();
123 for (Level l = startLvl; l < lvlRank; l++) {
124 const auto dlt = stt.getLvlType(l);
125 if (isCompressedDLT(dlt)) {
126 // Append linear x positions, initialized to zero. Since each compressed
127 // dimension initially already has a single zero entry, this maintains
128 // the desired "linear + 1" length property at all times.
129 Value posZero = constantZero(builder, loc, stt.getPosType());
130 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
131 posZero, linear);
132 return;
134 if (isSingletonDLT(dlt)) {
135 return; // nothing to do
137 // Keep compounding the size, but nothing needs to be initialized
138 // at this level. We will eventually reach a compressed level or
139 // otherwise the values array for the from-here "all-dense" case.
140 assert(isDenseDLT(dlt));
141 Value size = desc.getLvlSize(builder, loc, l);
142 linear = builder.create<arith::MulIOp>(loc, linear, size);
144 // Reached values array so prepare for an insertion.
145 Value valZero = constantZero(builder, loc, stt.getElementType());
146 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
147 std::nullopt, valZero, linear);
150 /// Creates allocation operation.
151 static Value createAllocation(OpBuilder &builder, Location loc,
152 MemRefType memRefType, Value sz,
153 bool enableInit) {
154 Value buffer = builder.create<memref::AllocOp>(loc, memRefType, sz);
155 Type elemType = memRefType.getElementType();
156 if (enableInit) {
157 Value fillValue = constantZero(builder, loc, elemType);
158 builder.create<linalg::FillOp>(loc, fillValue, buffer);
160 return buffer;
163 /// Creates allocation for each field in sparse tensor type. Note that
164 /// for all dynamic memrefs, the memory size is really the capacity of
165 /// the "vector", while the actual size resides in the sizes array.
167 /// TODO: for efficiency, we will need heuristics to make educated guesses
168 /// on the required capacities (see heuristic variable).
170 static void createAllocFields(OpBuilder &builder, Location loc,
171 SparseTensorType stt, ValueRange dynSizes,
172 bool enableInit, SmallVectorImpl<Value> &fields,
173 Value sizeHint) {
174 // Build original sizes.
175 assert((dynSizes.size() == static_cast<size_t>(stt.getNumDynamicDims())) &&
176 "Got wrong number of dynamic sizes");
177 const Dimension dimRank = stt.getDimRank();
178 SmallVector<Value> dimSizes;
179 dimSizes.reserve(dimRank);
180 unsigned i = 0; // cumulative index into `dynSizes`.
181 for (const Size sh : stt.getDimShape())
182 dimSizes.push_back(ShapedType::isDynamic(sh)
183 ? dynSizes[i++]
184 : constantIndex(builder, loc, sh));
186 // Set up some heuristic sizes. We try to set the initial
187 // size based on available information. Otherwise we just
188 // initialize a few elements to start the reallocation chain.
189 // TODO: refine this
190 Value posHeuristic, crdHeuristic, valHeuristic;
191 if (stt.isAllDense()) {
192 valHeuristic = dimSizes[0];
193 for (const Value sz : ArrayRef<Value>{dimSizes}.drop_front())
194 valHeuristic = builder.create<arith::MulIOp>(loc, valHeuristic, sz);
195 } else if (sizeHint) {
196 if (getCOOStart(stt.getEncoding()) == 0) {
197 posHeuristic = constantIndex(builder, loc, 2);
198 crdHeuristic = builder.create<arith::MulIOp>(
199 loc, constantIndex(builder, loc, dimRank), sizeHint); // AOS
200 } else if (dimRank == 2 && stt.isDenseLvl(0) && stt.isCompressedLvl(1)) {
201 posHeuristic = builder.create<arith::AddIOp>(
202 loc, sizeHint, constantIndex(builder, loc, 1));
203 crdHeuristic = sizeHint;
204 } else {
205 posHeuristic = crdHeuristic = constantIndex(builder, loc, 16);
207 valHeuristic = sizeHint;
208 } else {
209 posHeuristic = crdHeuristic = valHeuristic =
210 constantIndex(builder, loc, 16);
213 foreachFieldAndTypeInSparseTensor(
214 stt,
215 [&builder, &fields, stt, loc, posHeuristic, crdHeuristic, valHeuristic,
216 enableInit](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
217 Level /*lvl*/, DimLevelType /*dlt*/) -> bool {
218 assert(fields.size() == fIdx);
219 Value field;
220 switch (fKind) {
221 case SparseTensorFieldKind::StorageSpec:
222 field = SparseTensorSpecifier::getInitValue(builder, loc, stt);
223 break;
224 case SparseTensorFieldKind::PosMemRef:
225 case SparseTensorFieldKind::CrdMemRef:
226 case SparseTensorFieldKind::ValMemRef:
227 field = createAllocation(
228 builder, loc, cast<MemRefType>(fType),
229 (fKind == SparseTensorFieldKind::PosMemRef) ? posHeuristic
230 : (fKind == SparseTensorFieldKind::CrdMemRef) ? crdHeuristic
231 : valHeuristic,
232 enableInit);
233 break;
235 assert(field);
236 fields.push_back(field);
237 // Returns true to continue the iteration.
238 return true;
241 MutSparseTensorDescriptor desc(stt, fields);
243 // Initialize the storage scheme to an empty tensor. Initialized memSizes
244 // to all zeros, sets the dimSizes to known values and gives all position
245 // fields an initial zero entry, so that it is easier to maintain the
246 // "linear + 1" length property.
247 Value posZero = constantZero(builder, loc, stt.getPosType());
248 for (Level lvlRank = stt.getLvlRank(), l = 0; l < lvlRank; l++) {
249 // Fills dim sizes array.
250 // FIXME: `toOrigDim` is deprecated.
251 desc.setLvlSize(builder, loc, l, dimSizes[toOrigDim(stt, l)]);
252 // Pushes a leading zero to positions memref.
253 if (stt.isCompressedLvl(l))
254 createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, l,
255 posZero);
257 allocSchemeForRank(builder, loc, desc, /*rank=*/0);
260 /// Helper method that generates block specific to compressed case:
262 /// // given: parentPos = posCursor[lvl-1]
263 /// pstart = desc.positions[lvl][parentPos]
264 /// pstop = desc.positions[lvl][parentPos+1]
265 /// plast = pstop - 1
266 /// msz = desc.coordinates[lvl].size()
267 /// if (pstart < pstop) {
268 /// isPresent = (desc.coordinates[lvl][plast] == lvlCoords[lvl])
269 /// } else { // first insertion
270 /// isPresent = false
271 /// desc.positions[lvl][parentPos] = msz
272 /// }
273 /// if (isPresent) { // coordinate is already present
274 /// pnext = plast
275 /// } else {
276 /// desc.coordinates[lvl].push_back(lvlCoords[lvl])
277 /// desc.positions[lvl][parentPos+1] = msz+1
278 /// pnext = msz
279 /// <prepare level lvl+1>
280 /// }
281 /// posCursor[lvl] = pnext
282 static Value genCompressed(OpBuilder &builder, Location loc,
283 MutSparseTensorDescriptor desc, ValueRange lvlCoords,
284 Value /*unused*/, Value parentPos, Level lvl) {
285 const SparseTensorType stt(desc.getRankedTensorType());
286 const Level lvlRank = stt.getLvlRank();
287 assert(lvl < lvlRank && "Level is out of bounds");
288 assert(lvlCoords.size() == static_cast<size_t>(lvlRank) &&
289 "Level-rank mismatch");
290 SmallVector<Type> types;
291 Type indexType = builder.getIndexType();
292 Type boolType = builder.getIntegerType(1);
293 unsigned crdFidx;
294 unsigned crdStride;
295 std::tie(crdFidx, crdStride) = desc.getCrdMemRefIndexAndStride(lvl);
296 const Value one = constantIndex(builder, loc, 1);
297 const Value pp1 = builder.create<arith::AddIOp>(loc, parentPos, one);
298 const Value positionsAtLvl = desc.getPosMemRef(lvl);
299 const Value pstart = genLoad(builder, loc, positionsAtLvl, parentPos);
300 const Value pstop = genLoad(builder, loc, positionsAtLvl, pp1);
301 const Value crdMsz = desc.getCrdMemSize(builder, loc, lvl);
302 const Value crdStrideC =
303 crdStride > 1 ? constantIndex(builder, loc, crdStride) : Value();
304 const Value msz =
305 crdStrideC ? builder.create<arith::DivUIOp>(loc, crdMsz, crdStrideC)
306 : crdMsz;
307 const Value plast = builder.create<arith::SubIOp>(
308 loc, genCast(builder, loc, pstop, indexType), one);
309 // Conditional expression.
310 Value lt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
311 pstart, pstop);
312 types.push_back(boolType);
313 scf::IfOp ifOp1 = builder.create<scf::IfOp>(loc, types, lt, /*else*/ true);
314 types.pop_back();
315 builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
316 Value crd =
317 genLoad(builder, loc, desc.getMemRefField(crdFidx),
318 crdStrideC ? builder.create<arith::MulIOp>(loc, plast, crdStrideC)
319 : plast);
320 Value eq = builder.create<arith::CmpIOp>(
321 loc, arith::CmpIPredicate::eq, genCast(builder, loc, crd, indexType),
322 lvlCoords[lvl]);
323 builder.create<scf::YieldOp>(loc, eq);
324 builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
325 if (lvl > 0)
326 genStore(builder, loc, msz, positionsAtLvl, parentPos);
327 builder.create<scf::YieldOp>(loc, constantI1(builder, loc, false));
328 builder.setInsertionPointAfter(ifOp1);
329 // If present construct. Note that for a non-unique dimension level, we
330 // simply set the condition to false and rely on CSE/DCE to clean up the IR.
332 // TODO: generate less temporary IR?
334 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
335 types.push_back(desc.getField(i).getType());
336 types.push_back(indexType);
337 const Value p = stt.isUniqueLvl(lvl) ? ifOp1.getResult(0)
338 : constantI1(builder, loc, false);
339 scf::IfOp ifOp2 = builder.create<scf::IfOp>(loc, types, p, /*else*/ true);
340 // If present (fields unaffected, update pnext to plast).
341 builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
343 // FIXME: This does not looks like a clean way, but probably the most
344 // efficient way.
345 desc.getFields().push_back(plast);
346 builder.create<scf::YieldOp>(loc, desc.getFields());
347 desc.getFields().pop_back();
349 // If !present (changes fields, update pnext).
350 builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
351 Value mszp1 = builder.create<arith::AddIOp>(loc, msz, one);
352 genStore(builder, loc, mszp1, positionsAtLvl, pp1);
353 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, lvl,
354 lvlCoords[lvl]);
355 // Prepare the next level "as needed".
356 if ((lvl + 1) < lvlRank)
357 allocSchemeForRank(builder, loc, desc, lvl + 1);
359 desc.getFields().push_back(msz);
360 builder.create<scf::YieldOp>(loc, desc.getFields());
361 desc.getFields().pop_back();
363 // Update fields and return next pos.
364 builder.setInsertionPointAfter(ifOp2);
365 unsigned o = 0;
366 for (unsigned i = 0, e = desc.getNumFields(); i < e; i++)
367 desc.setField(i, ifOp2.getResult(o++));
368 return ifOp2.getResult(o);
371 /// Helper class to help lowering sparse_tensor.insert operation.
372 class SparseInsertGenerator
373 : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
374 public:
375 SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
376 bool genCall)
377 : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
379 /// Generates code along an insertion path without the need for a "cursor".
380 /// This current insertion strategy comes at the expense of some testing
381 /// overhead for each insertion. The strategy will be optimized later for
382 /// common insertion patterns. The current insertion strategy also assumes
383 /// insertions occur in "a reasonable order" that enables building the
384 /// storage scheme in an appending/inserting kind of fashion (i.e. no
385 /// in-between insertions that need data movement). The implementation
386 /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
388 /// TODO: better unord/not-unique; also generalize, optimize, specialize!
389 SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
390 OpBuilder &builder, Location loc) {
391 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
392 const Level lvlRank = stt.getLvlRank();
393 // Extract fields and coordinates from args.
394 SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
395 MutSparseTensorDescriptor desc(stt, fields);
396 const SmallVector<Value> coords =
397 llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
398 Value value = args.back();
399 Value parentPos = constantZero(builder, loc, builder.getIndexType());
400 // Generate code for every level.
401 for (Level l = 0; l < lvlRank; l++) {
402 const auto dlt = stt.getLvlType(l);
403 if (isCompressedDLT(dlt)) {
404 // Create:
405 // if (!present) {
406 // coordinates[l].push_back(coords[l])
407 // <update positions and prepare level l + 1>
408 // }
409 // positions[l] = coordinates.size() - 1
410 // <insert @ positions[l] at next level l + 1>
411 parentPos =
412 genCompressed(builder, loc, desc, coords, value, parentPos, l);
413 } else if (isSingletonDLT(dlt)) {
414 // Create:
415 // coordinates[l].push_back(coords[l])
416 // positions[l] = positions[l-1]
417 // <insert @ positions[l] at next level l + 1>
418 createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l,
419 coords[l]);
420 } else {
421 assert(isDenseDLT(dlt));
422 // Construct the new position as:
423 // positions[l] = size * positions[l-1] + coords[l]
424 // <insert @ positions[l] at next level l + 1>
425 Value size = desc.getLvlSize(builder, loc, l);
426 Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
427 parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
430 // Reached the actual value append/insert.
431 if (!stt.isDenseLvl(lvlRank - 1))
432 createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
433 std::nullopt, value);
434 else
435 genStore(builder, loc, value, desc.getValMemRef(), parentPos);
436 return fields;
439 std::string getMangledFuncName() {
440 // The mangled name of the function has this format:
441 // <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
442 constexpr const char kInsertFuncNamePrefix[] = "_insert_";
443 const SparseTensorType stt(llvm::cast<RankedTensorType>(rtp));
445 SmallString<32> nameBuffer;
446 llvm::raw_svector_ostream nameOstream(nameBuffer);
447 nameOstream << kInsertFuncNamePrefix;
448 const Level lvlRank = stt.getLvlRank();
449 for (Level l = 0; l < lvlRank; l++) {
450 std::string lvlType = toMLIRString(stt.getLvlType(l));
451 // Replace/remove punctuations in level properties.
452 std::replace_if(
453 lvlType.begin(), lvlType.end(),
454 [](char c) { return c == '(' || c == ','; }, '_');
455 llvm::erase_if(lvlType, [](char c) { return c == ')' || c == ' '; });
456 nameOstream << lvlType << "_";
458 // Static dim sizes are used in the generated code while dynamic sizes are
459 // loaded from the dimSizes buffer. This is the reason for adding the shape
460 // to the function name.
461 for (const auto sh : stt.getDimShape())
462 nameOstream << sh << "_";
463 // Permutation information is also used in generating insertion.
464 if (!stt.isIdentity())
465 nameOstream << stt.getDimToLvl() << "_";
466 nameOstream << stt.getElementType() << "_";
467 nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
468 return nameOstream.str().str();
471 private:
472 TensorType rtp;
475 /// Generations insertion finalization code.
476 static void genEndInsert(OpBuilder &builder, Location loc,
477 SparseTensorDescriptor desc) {
478 const SparseTensorType stt(desc.getRankedTensorType());
479 const Level lvlRank = stt.getLvlRank();
480 for (Level l = 0; l < lvlRank; l++) {
481 const auto dlt = stt.getLvlType(l);
482 if (isLooseCompressedDLT(dlt))
483 llvm_unreachable("TODO: Not yet implemented");
484 if (isCompressedDLT(dlt)) {
485 // Compressed dimensions need a position cleanup for all entries
486 // that were not visited during the insertion pass.
488 // TODO: avoid cleanup and keep compressed scheme consistent at all
489 // times?
491 if (l > 0) {
492 Type posType = stt.getPosType();
493 Value posMemRef = desc.getPosMemRef(l);
494 Value hi = desc.getPosMemSize(builder, loc, l);
495 Value zero = constantIndex(builder, loc, 0);
496 Value one = constantIndex(builder, loc, 1);
497 // Vector of only one, but needed by createFor's prototype.
498 SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)};
499 scf::ForOp loop = createFor(builder, loc, hi, inits, one);
500 Value i = loop.getInductionVar();
501 Value oldv = loop.getRegionIterArg(0);
502 Value newv = genLoad(builder, loc, posMemRef, i);
503 Value posZero = constantZero(builder, loc, posType);
504 Value cond = builder.create<arith::CmpIOp>(
505 loc, arith::CmpIPredicate::eq, newv, posZero);
506 scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType),
507 cond, /*else*/ true);
508 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
509 genStore(builder, loc, oldv, posMemRef, i);
510 builder.create<scf::YieldOp>(loc, oldv);
511 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
512 builder.create<scf::YieldOp>(loc, newv);
513 builder.setInsertionPointAfter(ifOp);
514 builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
515 builder.setInsertionPointAfter(loop);
517 } else {
518 assert(isDenseDLT(dlt) || isSingletonDLT(dlt));
523 static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
524 Value sz) {
525 auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType();
526 return builder
527 .create<memref::SubViewOp>(
528 loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem,
529 ValueRange{}, ValueRange{sz}, ValueRange{},
530 ArrayRef<int64_t>{0}, // static offset
531 ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size
532 ArrayRef<int64_t>{1}) // static stride
533 .getResult();
536 static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
537 ReassociationIndices reassociation;
538 for (int i = 0, e = srcTp.getRank(); i < e; i++)
539 reassociation.push_back(i);
540 return reassociation;
543 static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
544 Type dstTp) {
545 if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
546 // Scalars can only be converted to 0-ranked tensors.
547 if (rtp.getRank() != 0)
548 return nullptr;
549 elem = genCast(builder, loc, elem, rtp.getElementType());
550 return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
552 return genCast(builder, loc, elem, dstTp);
555 //===----------------------------------------------------------------------===//
556 // Codegen rules.
557 //===----------------------------------------------------------------------===//
559 /// Sparse tensor storage conversion rule for returns.
560 class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
561 public:
562 using OpConversionPattern::OpConversionPattern;
563 LogicalResult
564 matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
565 ConversionPatternRewriter &rewriter) const override {
566 SmallVector<Value> flattened;
567 flattenOperands(adaptor.getOperands(), flattened);
568 // Create a return with the flattened value extracted from sparse tensors.
569 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
570 return success();
574 /// Sparse tensor storage conversion rule for calls.
575 class SparseCallConverter : public OpConversionPattern<func::CallOp> {
576 public:
577 // The default CallOp converter can not handle 1:N type conversion.
578 using OpConversionPattern::OpConversionPattern;
579 LogicalResult
580 matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
581 ConversionPatternRewriter &rewriter) const override {
582 Location loc = op.getLoc();
583 // In case of:
584 // sparse_tensor, f, sparse_tensor = call @foo(...)
585 // ==>
586 // memref..., f, memref = call @foo(...) replace with
587 // cast(memref...)->sparse_tensor, f, cast(memref...)->sparse_tensor
588 SmallVector<Type> finalRetTy;
589 if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
590 return failure();
592 // (1) Generates new call with flattened return value.
593 SmallVector<Value> flattened;
594 flattenOperands(adaptor.getOperands(), flattened);
595 auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
596 finalRetTy, flattened);
597 // (2) Create cast operation for sparse tensor returns.
598 SmallVector<Value> castedRet;
599 // Tracks the offset of current return value (of the original call)
600 // relative to the new call (after sparse tensor flattening);
601 unsigned retOffset = 0;
602 // Temporal buffer to hold the flattened list of type for
603 // a sparse tensor.
604 SmallVector<Type> sparseFlat;
605 for (auto ret : op.getResults()) {
606 assert(retOffset < newCall.getNumResults());
607 auto retType = ret.getType();
608 if (failed(typeConverter->convertType(retType, sparseFlat)))
609 // This should never happen.
610 llvm_unreachable("Failed to convert type in sparse tensor codegen");
612 // Converted types can not be empty when the type conversion succeed.
613 assert(!sparseFlat.empty());
614 if (sparseFlat.size() > 1) {
615 auto flatSize = sparseFlat.size();
616 ValueRange fields(iterator_range<ResultRange::iterator>(
617 newCall.result_begin() + retOffset,
618 newCall.result_begin() + retOffset + flatSize));
619 castedRet.push_back(genTuple(rewriter, loc, retType, fields));
620 retOffset += flatSize;
621 } else {
622 // If this is an 1:1 conversion, no need for casting.
623 castedRet.push_back(newCall.getResult(retOffset));
624 retOffset++;
626 sparseFlat.clear();
629 assert(castedRet.size() == op.getNumResults());
630 rewriter.replaceOp(op, castedRet);
631 return success();
635 /// Sparse codegen rule for level accesses.
636 class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
637 public:
638 using OpConversionPattern::OpConversionPattern;
639 LogicalResult
640 matchAndRewrite(LvlOp op, OpAdaptor adaptor,
641 ConversionPatternRewriter &rewriter) const override {
642 std::optional<int64_t> lvl = op.getConstantLvlIndex();
643 if (!lvl || !getSparseTensorEncoding(adaptor.getSource().getType()))
644 return failure();
646 auto desc = getDescriptorFromTensorTuple(adaptor.getSource());
647 auto sz = desc.getLvlSize(rewriter, op.getLoc(), *lvl);
649 rewriter.replaceOp(op, sz);
650 return success();
654 // TODO: use a new SortCOO operation here instead of reusing convert op.
655 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
656 using OpConversionPattern::OpConversionPattern;
657 LogicalResult
658 matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
659 ConversionPatternRewriter &rewriter) const override {
660 Location loc = op.getLoc();
661 MLIRContext *ctx = op.getContext();
663 SparseTensorType srcStt = getSparseTensorType(op.getInputCoo());
664 SparseTensorType dstStt = getSparseTensorType(op.getResultCoo());
666 // Should have been verified.
667 assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
668 isUniqueCOOType(srcStt.getRankedTensorType()) &&
669 isUniqueCOOType(dstStt.getRankedTensorType()));
670 assert(dstStt.hasSameDimToLvl(srcStt));
672 // We don't need a mutable descriptor here as we perform sorting in-place.
673 auto nnz = genValMemSize(rewriter, op.getLoc(), adaptor.getInputCoo());
674 auto desc = getDescriptorFromTensorTuple(adaptor.getInputCoo());
675 auto crd = desc.getAOSMemRef();
676 auto val = desc.getValMemRef();
678 // Otherwise we need another data shuffle and a non-identity map.
679 assert(dstStt.hasSameDimToLvl(srcStt));
680 (void)dstStt; // to silence warning when assertion is disabled
682 auto id = AffineMap::getMultiDimIdentityMap(srcStt.getLvlRank(), ctx);
684 rewriter.create<SortOp>(loc, nnz, crd, ValueRange{val}, id,
685 rewriter.getIndexAttr(0), op.getAlgorithm());
687 // Since we do in-place sorting, the destinate tensor will have the same set
688 // of memrefs as the source tensor.
689 rewriter.replaceOp(op, adaptor.getInputCoo());
690 return success();
694 template <typename Op, StorageSpecifierKind kind>
695 class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
696 public:
697 using OpConversionPattern<Op>::OpConversionPattern;
698 LogicalResult
699 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
700 ConversionPatternRewriter &rewriter) const override {
701 // Simply lowers to specifer.get <field> operation.
702 auto desc = getDescriptorFromTensorTuple(adaptor.getSlice());
703 auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind,
704 op.getDim().getZExtValue());
706 rewriter.replaceOp(op, v);
707 return success();
711 /// Sparse codegen rule for trivial tensor casts.
712 class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
713 public:
714 using OpConversionPattern::OpConversionPattern;
715 LogicalResult
716 matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
717 ConversionPatternRewriter &rewriter) const override {
718 // Only rewrite identically annotated source/dest.
719 auto encDst = getSparseTensorEncoding(op.getType());
720 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
721 if (!encDst || encDst != encSrc)
722 return failure();
723 rewriter.replaceOp(op, adaptor.getOperands());
724 return success();
728 class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
729 public:
730 using OpConversionPattern::OpConversionPattern;
731 LogicalResult
732 matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
733 ConversionPatternRewriter &rewriter) const override {
734 // Simply fold the operation.
735 rewriter.replaceOp(op, adaptor.getSource());
736 return success();
740 /// Sparse codegen rule for the alloc operator.
741 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
742 class SparseTensorAllocConverter
743 : public OpConversionPattern<bufferization::AllocTensorOp> {
744 public:
745 using OpConversionPattern::OpConversionPattern;
746 SparseTensorAllocConverter(TypeConverter &typeConverter, MLIRContext *context,
747 bool enableInit)
748 : OpConversionPattern(typeConverter, context),
749 enableBufferInitialization(enableInit) {}
751 LogicalResult
752 matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
753 ConversionPatternRewriter &rewriter) const override {
754 const auto resType = getSparseTensorType(op);
755 if (!resType.hasEncoding())
756 return failure();
758 // Construct allocation for each field.
759 const Location loc = op.getLoc();
760 if (op.getCopy()) {
761 auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
762 SmallVector<Value> fields;
763 fields.reserve(desc.getNumFields());
764 // Memcpy on memref fields.
765 for (auto field : desc.getMemRefFields()) {
766 auto memrefTp = cast<MemRefType>(field.getType());
767 auto size = rewriter.create<memref::DimOp>(loc, field, 0);
768 auto copied =
769 rewriter.create<memref::AllocOp>(loc, memrefTp, ValueRange{size});
770 rewriter.create<memref::CopyOp>(loc, field, copied);
771 fields.push_back(copied);
773 // Reuses specifier.
774 fields.push_back(desc.getSpecifier());
775 assert(fields.size() == desc.getNumFields());
776 rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
777 return success();
780 const Value sizeHint = op.getSizeHint();
781 const ValueRange dynSizes = adaptor.getDynamicSizes();
782 const size_t found = dynSizes.size();
783 const int64_t expected = resType.getNumDynamicDims();
784 if (found != static_cast<size_t>(expected))
785 return rewriter.notifyMatchFailure(
786 op, llvm::formatv(
787 "Got wrong number of dynamic sizes: Found={0}, Expected={1}",
788 found, expected));
789 SmallVector<Value> fields;
790 createAllocFields(rewriter, loc, resType, dynSizes,
791 enableBufferInitialization, fields, sizeHint);
792 // Replace operation with resulting memrefs.
793 rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
794 return success();
797 private:
798 bool enableBufferInitialization;
801 /// Sparse codegen rule for the empty tensor operator.
802 /// TODO(springerm): remove when bufferization.alloc_tensor is gone
803 class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
804 public:
805 using OpConversionPattern::OpConversionPattern;
806 SparseTensorEmptyConverter(TypeConverter &typeConverter, MLIRContext *context,
807 bool enableInit)
808 : OpConversionPattern(typeConverter, context),
809 enableBufferInitialization(enableInit) {}
811 LogicalResult
812 matchAndRewrite(tensor::EmptyOp op, OpAdaptor adaptor,
813 ConversionPatternRewriter &rewriter) const override {
814 const auto resType = getSparseTensorType(op);
815 if (!resType.hasEncoding())
816 return failure();
818 // Construct allocation for each field.
819 const Location loc = op.getLoc();
820 const Value sizeHint; // none
821 const ValueRange dynSizes = adaptor.getDynamicSizes();
822 const size_t found = dynSizes.size();
823 const int64_t expected = resType.getNumDynamicDims();
824 if (found != static_cast<size_t>(expected))
825 return rewriter.notifyMatchFailure(
826 op, llvm::formatv(
827 "Got wrong number of dynamic sizes: Found={0}, Expected={1}",
828 found, expected));
829 SmallVector<Value> fields;
830 createAllocFields(rewriter, loc, resType, dynSizes,
831 enableBufferInitialization, fields, sizeHint);
832 // Replace operation with resulting memrefs.
833 rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
834 return success();
837 private:
838 bool enableBufferInitialization;
841 /// Sparse codegen rule for the dealloc operator.
842 class SparseTensorDeallocConverter
843 : public OpConversionPattern<bufferization::DeallocTensorOp> {
844 public:
845 using OpConversionPattern::OpConversionPattern;
846 SparseTensorDeallocConverter(TypeConverter &typeConverter,
847 MLIRContext *context, bool createDeallocs)
848 : OpConversionPattern(typeConverter, context),
849 createDeallocs(createDeallocs) {}
851 LogicalResult
852 matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
853 ConversionPatternRewriter &rewriter) const override {
854 auto enc = getSparseTensorEncoding(op.getTensor().getType());
855 if (!enc)
856 return failure();
858 // If user requests not to deallocate sparse tensors, simply erase the
859 // operation.
860 if (createDeallocs) {
861 // Replace the sparse tensor deallocation with field deallocations.
862 Location loc = op.getLoc();
863 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
864 for (auto input : desc.getMemRefFields())
865 // Deallocate every buffer used to store the sparse tensor handler.
866 rewriter.create<memref::DeallocOp>(loc, input);
868 rewriter.eraseOp(op);
869 return success();
872 private:
873 const bool createDeallocs;
876 /// Sparse codegen rule for tensor rematerialization.
877 class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
878 public:
879 using OpConversionPattern::OpConversionPattern;
880 LogicalResult
881 matchAndRewrite(LoadOp op, OpAdaptor adaptor,
882 ConversionPatternRewriter &rewriter) const override {
883 // Prepare descriptor.
884 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
885 // Generate optional insertion finalization code.
886 if (op.getHasInserts())
887 genEndInsert(rewriter, op.getLoc(), desc);
888 // Replace operation with resulting memrefs.
889 rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
890 return success();
894 /// Sparse codegen rule for the expand op.
895 class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
896 public:
897 using OpConversionPattern::OpConversionPattern;
898 LogicalResult
899 matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
900 ConversionPatternRewriter &rewriter) const override {
901 if (!getSparseTensorEncoding(op.getTensor().getType()))
902 return failure();
903 Location loc = op->getLoc();
904 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
905 const auto srcType = getSparseTensorType(op.getTensor());
906 Type eltType = srcType.getElementType();
907 Type boolType = rewriter.getIntegerType(1);
908 Type idxType = rewriter.getIndexType();
909 // All initialization should be done on entry of the loop nest.
910 rewriter.setInsertionPointAfter(op.getTensor().getDefiningOp());
912 // Determine the size for access expansion (always the innermost stored
913 // level size).
914 const auto sz = desc.getLvlSize(rewriter, loc, srcType.getLvlRank() - 1);
915 // Generate a memref for `sz` elements of type `t`.
916 const auto genAlloc = [&](Type t) {
917 const auto memTp = MemRefType::get({ShapedType::kDynamic}, t);
918 return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
920 // Allocate temporary buffers for values/filled-switch and added.
921 // We do not use stack buffers for this, since the expanded size may
922 // be rather large (as it envelops a single expanded dense dimension).
923 Value values = genAlloc(eltType);
924 Value filled = genAlloc(boolType);
925 Value added = genAlloc(idxType);
926 Value zero = constantZero(rewriter, loc, idxType);
927 // Reset the values/filled-switch to all-zero/false. Note that this
928 // introduces an O(N) operation into the computation, but this reset
929 // operation is amortized over the innermost loops for the access
930 // pattern expansion. As noted in the operation doc, we would like
931 // to amortize this setup cost even between kernels.
932 rewriter.create<linalg::FillOp>(
933 loc, ValueRange{constantZero(rewriter, loc, eltType)},
934 ValueRange{values});
935 rewriter.create<linalg::FillOp>(
936 loc, ValueRange{constantZero(rewriter, loc, boolType)},
937 ValueRange{filled});
938 // Replace expansion op with these buffers and initial coordinate.
939 assert(op.getNumResults() == 4);
940 rewriter.replaceOp(op, {values, filled, added, zero});
941 return success();
945 /// Sparse codegen rule for the compress operator.
946 class SparseCompressConverter : public OpConversionPattern<CompressOp> {
947 public:
948 using OpConversionPattern::OpConversionPattern;
949 LogicalResult
950 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
951 ConversionPatternRewriter &rewriter) const override {
952 Location loc = op->getLoc();
953 SmallVector<Value> fields;
954 auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
955 Value values = adaptor.getValues();
956 Value filled = adaptor.getFilled();
957 Value added = adaptor.getAdded();
958 Value count = adaptor.getCount();
959 const SparseTensorType dstType(desc.getRankedTensorType());
960 Type eltType = dstType.getElementType();
962 // If the innermost level is ordered, we need to sort the coordinates
963 // in the "added" array prior to applying the compression.
964 if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
965 rewriter.create<SortOp>(
966 loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
967 rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
968 // While performing the insertions, we also need to reset the elements
969 // of the values/filled-switch by only iterating over the set elements,
970 // to ensure that the runtime complexity remains proportional to the
971 // sparsity of the expanded access pattern.
973 // Generate
974 // out_memrefs = for (i = 0; i < count; i++)(in_memrefs) {
975 // crd = added[i];
976 // value = values[crd];
977 // insert({lvlCoords, crd}, value);
978 // new_memrefs = insert(in_memrefs, {lvlCoords, crd}, value);
979 // values[crd] = 0;
980 // filled[crd] = false;
981 // yield new_memrefs
982 // }
983 scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
984 Value i = loop.getInductionVar();
986 Value crd = genLoad(rewriter, loc, added, i);
987 Value value = genLoad(rewriter, loc, values, crd);
988 SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
989 SmallVector<Type> flatSpTensorTps = llvm::to_vector(
990 llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
991 params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
992 params.push_back(crd);
993 params.push_back(value);
994 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
995 params, /*genCall=*/true);
996 SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
997 genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
998 genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
999 rewriter.create<scf::YieldOp>(loc, insertRet);
1001 rewriter.setInsertionPointAfter(loop);
1002 Value result = genTuple(rewriter, loc, dstType, loop->getResults());
1003 // Deallocate the buffers on exit of the full loop nest.
1004 Operation *parent = getTop(op);
1005 rewriter.setInsertionPointAfter(parent);
1006 rewriter.create<memref::DeallocOp>(loc, values);
1007 rewriter.create<memref::DeallocOp>(loc, filled);
1008 rewriter.create<memref::DeallocOp>(loc, added);
1009 // Replace operation with resulting memrefs.
1010 rewriter.replaceOp(op, result);
1011 return success();
1015 /// Sparse codegen rule for the insert operator.
1016 class SparseInsertConverter : public OpConversionPattern<InsertOp> {
1017 public:
1018 using OpConversionPattern::OpConversionPattern;
1019 LogicalResult
1020 matchAndRewrite(InsertOp op, OpAdaptor adaptor,
1021 ConversionPatternRewriter &rewriter) const override {
1022 Location loc = op.getLoc();
1023 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1024 TypeRange flatSpTensorTps = desc.getFields().getTypes();
1025 SmallVector<Value> params = llvm::to_vector(desc.getFields());
1026 params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
1027 params.push_back(adaptor.getValue());
1028 SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1029 params, /*genCall=*/true);
1030 SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
1031 // Replace operation with resulting memrefs.
1032 rewriter.replaceOp(op,
1033 genTuple(rewriter, loc, op.getTensor().getType(), ret));
1034 return success();
1038 /// Sparse codegen rule for position accesses.
1039 class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
1040 public:
1041 using OpAdaptor = typename ToPositionsOp::Adaptor;
1042 using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
1043 LogicalResult
1044 matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
1045 ConversionPatternRewriter &rewriter) const override {
1046 // Replace the requested position access with corresponding field.
1047 // The cast_op is inserted by type converter to intermix 1:N type
1048 // conversion.
1049 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1050 rewriter.replaceOp(op, desc.getPosMemRef(op.getLevel()));
1051 return success();
1055 /// Sparse codegen rule for accessing the coordinates arrays.
1056 class SparseToCoordinatesConverter
1057 : public OpConversionPattern<ToCoordinatesOp> {
1058 public:
1059 using OpAdaptor = typename ToCoordinatesOp::Adaptor;
1060 using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
1061 LogicalResult
1062 matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
1063 ConversionPatternRewriter &rewriter) const override {
1064 // Replace the requested coordinates access with corresponding field.
1065 // The cast_op is inserted by type converter to intermix 1:N type
1066 // conversion.
1067 Location loc = op.getLoc();
1068 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1069 Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());
1071 // Insert a cast to bridge the actual type to the user expected type. If the
1072 // actual type and the user expected type aren't compatible, the compiler or
1073 // the runtime will issue an error.
1074 Type resType = op.getResult().getType();
1075 if (resType != field.getType())
1076 field = rewriter.create<memref::CastOp>(loc, resType, field);
1077 rewriter.replaceOp(op, field);
1079 return success();
1083 /// Sparse codegen rule for accessing the linear coordinates buffer.
1084 class SparseToCoordinatesBufferConverter
1085 : public OpConversionPattern<ToCoordinatesBufferOp> {
1086 public:
1087 using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
1088 using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
1089 LogicalResult
1090 matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
1091 ConversionPatternRewriter &rewriter) const override {
1092 // Replace the requested coordinates access with corresponding field.
1093 // The cast_op is inserted by type converter to intermix 1:N type
1094 // conversion.
1095 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1096 rewriter.replaceOp(op, desc.getAOSMemRef());
1098 return success();
1102 /// Sparse codegen rule for value accesses.
1103 class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
1104 public:
1105 using OpAdaptor = typename ToValuesOp::Adaptor;
1106 using OpConversionPattern<ToValuesOp>::OpConversionPattern;
1107 LogicalResult
1108 matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
1109 ConversionPatternRewriter &rewriter) const override {
1110 // Replace the requested values access with corresponding field.
1111 // The cast_op is inserted by type converter to intermix 1:N type
1112 // conversion.
1113 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1114 rewriter.replaceOp(op, desc.getValMemRef());
1115 return success();
1119 /// Sparse codegen rule for the convert operator.
1120 class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1121 public:
1122 using OpConversionPattern::OpConversionPattern;
1123 LogicalResult
1124 matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
1125 ConversionPatternRewriter &rewriter) const override {
1126 SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
1127 SparseTensorEncodingAttr encSrc =
1128 getSparseTensorEncoding(op.getSource().getType());
1129 // The output tensor can not be a slice and those cases should have been
1130 // rejected by ConvertOp::verify() already.
1131 assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices.");
1132 // Different encoding (except for different bitwidth) should be handled by
1133 // rewriting.
1134 // We need further rewrites if the input tensor is a slice too.
1135 if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() ||
1136 encSrc.isSlice()) {
1137 return failure();
1140 Type retElemTp = op.getResult().getType().getElementType();
1141 Type srcElemTp = op.getSource().getType().getElementType();
1142 // Fold the trivial cases.
1143 if (retElemTp == srcElemTp && encDst == encSrc) {
1144 rewriter.replaceOp(op, adaptor.getSource());
1145 return success();
1148 // Do element-wise type conversion without using InsertOp.
1150 // for each memref in srcTensor:
1151 // dst = memref.alloc
1152 // if srcMemRefType != dstMemRefType:
1153 // for every dst[i] = cast(src[i])
1154 // else:
1155 // dst = memref.copy(src)
1156 Location loc = op.getLoc();
1157 auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource());
1158 SmallVector<Value> fields;
1159 foreachFieldAndTypeInSparseTensor(
1160 SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
1161 [&rewriter, &fields, srcDesc,
1162 loc](Type fTp, FieldIndex fIdx, SparseTensorFieldKind fKind, Level lvl,
1163 DimLevelType /*dlt*/) -> bool {
1164 // Simply reuses the storage specifier as it is an SSA value.
1165 if (fKind == SparseTensorFieldKind::StorageSpec) {
1166 fields.push_back(srcDesc.getSpecifier());
1167 } else {
1168 // Allocates new memrefs
1169 Value srcMem = srcDesc.getMemRefField(fIdx);
1170 // TODO: We can instead use the actual memSize in specifier, that
1171 // would require a subViewOp to avoid overflow when copying
1172 // values.
1173 Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0);
1174 auto dstMem = rewriter.create<memref::AllocOp>(
1175 loc, cast<MemRefType>(fTp), sz);
1176 if (fTp != srcMem.getType()) {
1177 // Converts elements type.
1178 scf::buildLoopNest(
1179 rewriter, loc, constantIndex(rewriter, loc, 0), sz,
1180 constantIndex(rewriter, loc, 1),
1181 [srcMem, &dstMem](OpBuilder &builder, Location loc,
1182 ValueRange ivs) {
1183 Value v = builder.create<memref::LoadOp>(loc, srcMem, ivs);
1184 Value casted = genCast(builder, loc, v,
1185 dstMem.getType().getElementType());
1186 builder.create<memref::StoreOp>(loc, casted, dstMem, ivs);
1188 } else {
1189 // TODO: We can even reuse the same memref for the new tensor,
1190 // but that requires a `ref-counting` based memory management
1191 // for shared memrefs between multiple sparse tensors.
1192 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1194 fields.push_back(dstMem);
1196 return true;
1199 rewriter.replaceOp(
1200 op, genTuple(rewriter, loc, op.getResult().getType(), fields));
1201 return success();
1205 class SparseExtractSliceConverter
1206 : public OpConversionPattern<tensor::ExtractSliceOp> {
1207 public:
1208 using OpConversionPattern::OpConversionPattern;
1209 LogicalResult
1210 matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
1211 ConversionPatternRewriter &rewriter) const override {
1212 Location loc = op.getLoc();
1213 MLIRContext *ctx = op.getContext();
1214 auto srcEnc = getSparseTensorEncoding(op.getSourceType());
1215 auto dstEnc = getSparseTensorEncoding(op.getResult().getType());
1216 // TODO: We should check these in ExtractSliceOp::verify.
1217 if (!srcEnc || !dstEnc || !dstEnc.isSlice())
1218 return failure();
1219 assert(srcEnc.withoutDimSlices() == dstEnc.withoutDimSlices());
1221 SmallVector<Value> fields;
1222 auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields);
1224 auto newSpec = rewriter.create<StorageSpecifierInitOp>(
1225 loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier());
1226 desc.setSpecifier(newSpec);
1228 // Fills in slice information.
1229 for (auto [idx, offset, size, stride] : llvm::enumerate(
1230 op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
1231 Dimension dim = idx;
1233 Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
1234 Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1235 Value strideV = getValueOrCreateConstantIndexOp(rewriter, loc, stride);
1236 // TODO: We could probably only set dynamic value here. But it would
1237 // requires us to fill the hole when casting a static slice to dynamic
1238 // slice.
1239 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset,
1240 dim, offsetV);
1242 // FIXME: we need to distinguish level sizes and dimension size for slices
1243 // here. Maybe we should store slice level sizes in a different array
1244 // instead of reusing it.
1245 assert(srcEnc.isIdentity());
1246 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim,
1247 sizeV);
1248 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride,
1249 dim, strideV);
1252 // NOTE: we can not generate tuples directly from descriptor here, as the
1253 // descriptor is holding the original type, yet we want the slice type
1254 // here (they shared every memref but with an updated specifier).
1255 rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
1256 desc.getFields()));
1257 return success();
1261 /// Sparse codegen rule for number of entries operator.
1262 class SparseNumberOfEntriesConverter
1263 : public OpConversionPattern<NumberOfEntriesOp> {
1264 public:
1265 using OpConversionPattern::OpConversionPattern;
1266 LogicalResult
1267 matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
1268 ConversionPatternRewriter &rewriter) const override {
1269 // Query memSizes for the actually stored values.
1270 // FIXME: the nse value computed in this way might be wrong when there is
1271 // any "loose_compressed" level.
1272 rewriter.replaceOp(
1273 op, genValMemSize(rewriter, op.getLoc(), adaptor.getTensor()));
1274 return success();
1278 struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1279 using OpConversionPattern::OpConversionPattern;
1280 LogicalResult
1281 matchAndRewrite(AssembleOp op, OpAdaptor adaptor,
1282 ConversionPatternRewriter &rewriter) const override {
1283 Location loc = op.getLoc();
1284 const auto stt = getSparseTensorType(op.getResult());
1286 SmallVector<Value> fields;
1288 foreachFieldAndTypeInSparseTensor(
1289 stt,
1290 [&rewriter, &fields, &op, &stt,
1291 loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind,
1292 Level /*lvl*/, DimLevelType dlt) -> bool {
1293 assert(fields.size() == fIdx);
1294 if (fKind == SparseTensorFieldKind::StorageSpec) {
1295 fields.push_back(
1296 SparseTensorSpecifier::getInitValue(rewriter, loc, stt));
1297 } else {
1298 // Else simply takes the inputs.
1299 Value tensor = fKind == SparseTensorFieldKind::ValMemRef
1300 ? op.getValues()
1301 : op.getLevels()[fIdx];
1303 TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
1304 if (mem.getType().getRank() > 1) {
1305 // Flattens the buffer to rank 1.
1306 auto reassoc = getReassociationForFlattening(mem.getType());
1307 mem = rewriter.create<memref::CastOp>(
1308 loc, fType,
1309 rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
1310 } else {
1311 mem = rewriter.create<memref::CastOp>(loc, fType, mem);
1313 fields.push_back(mem);
1315 return true;
1318 MutSparseTensorDescriptor desc(stt, fields);
1319 Value c0 = constantIndex(rewriter, loc, 0);
1320 Value c1 = constantIndex(rewriter, loc, 1);
1321 Value c2 = constantIndex(rewriter, loc, 2);
1322 Value posBack = c0; // index to the last value in the position array
1323 Value memSize = c1; // memory size for current array
1325 Level trailCOOStart = getCOOStart(stt.getEncoding());
1326 Level trailCOORank = stt.getLvlRank() - trailCOOStart;
1327 // Sets up SparseTensorSpecifier.
1328 for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
1329 assert(!ShapedType::isDynamic(stt.getDimShape()[lvl]));
1331 // FIXME: dim/lvl confusion!
1332 // Sets up the level size.
1333 auto lvlSize = constantIndex(rewriter, loc, stt.getDimShape()[lvl]);
1334 desc.setLvlSize(rewriter, loc, lvl, lvlSize);
1335 // We use a single AOS array to store the trailing COO, so there is only
1336 // one memory size to set for the entire COO section.
1337 if (lvl > trailCOOStart)
1338 continue;
1340 // Sets up the memory size by reading the last value in position array.
1341 DimLevelType dlt = stt.getLvlType(lvl);
1342 // Simply forwards the position index when this is a dense level.
1343 if (isDenseDLT(dlt)) {
1344 memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
1345 posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1346 continue;
1349 if (isDLTWithPos(dlt)) {
1350 assert(isCompressedDLT(dlt) || isLooseCompressedDLT(dlt));
1351 if (isLooseCompressedDLT(dlt)) {
1352 memSize = rewriter.create<arith::MulIOp>(loc, memSize, c2);
1353 posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
1354 } else {
1355 assert(isCompressedDLT(dlt));
1356 posBack = memSize;
1357 memSize = rewriter.create<arith::AddIOp>(loc, memSize, c1);
1359 desc.setPosMemSize(rewriter, loc, lvl, memSize);
1360 // The last value in position array is the memory size for next level.
1361 memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
1362 posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
1364 assert(isDLTWithCrd(dlt) && lvl <= trailCOOStart);
1365 // FIXME: This seems to be unnecessarily complex, can we simplify it?
1366 if (lvl == trailCOOStart) {
1367 Value cooSz = rewriter.create<arith::MulIOp>(
1368 loc, memSize, constantIndex(rewriter, loc, trailCOORank));
1369 desc.setCrdMemSize(rewriter, loc, lvl, cooSz);
1370 } else {
1371 desc.setCrdMemSize(rewriter, loc, lvl, memSize);
1374 desc.setValMemSize(rewriter, loc, memSize);
1376 rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
1377 return success();
1381 struct SparseDisassembleOpConverter
1382 : public OpConversionPattern<DisassembleOp> {
1383 using OpConversionPattern::OpConversionPattern;
1384 SparseDisassembleOpConverter(TypeConverter &typeConverter,
1385 MLIRContext *context)
1386 : OpConversionPattern(typeConverter, context) {}
1388 LogicalResult
1389 matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
1390 ConversionPatternRewriter &rewriter) const override {
1391 auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1392 Location loc = op.getLoc();
1393 SmallVector<Value> retMem;
1394 SmallVector<Value> retLen;
1395 desc.getLayout().foreachField([desc, loc, &rewriter, &op, &retMem, &retLen](
1396 FieldIndex fid,
1397 SparseTensorFieldKind fKind, Level lvl,
1398 DimLevelType dlt) -> bool {
1399 if (fKind == SparseTensorFieldKind::StorageSpec)
1400 return true;
1401 SparseTensorType stt(desc.getRankedTensorType());
1402 Value sz, src;
1403 TypedValue<BaseMemRefType> dst;
1404 if (fKind == SparseTensorFieldKind::ValMemRef) {
1405 sz = desc.getValMemSize(rewriter, loc);
1406 src = desc.getValMemRef();
1407 dst = genToMemref(rewriter, loc, op.getOutValues());
1408 // Values is the last field in descriptor, but it is the first
1409 // operand in unpack operation.
1410 // TODO: maybe change unpack/pack operation instead to be
1411 // consistent.
1412 retMem.insert(retMem.begin(), dst);
1413 Type valLenTp = op.getValLen().getType();
1414 retLen.insert(retLen.begin(),
1415 genScalarToTensor(rewriter, loc, sz, valLenTp));
1416 } else {
1417 assert(fKind == SparseTensorFieldKind::PosMemRef ||
1418 fKind == SparseTensorFieldKind::CrdMemRef);
1420 sz = fKind == SparseTensorFieldKind::PosMemRef
1421 ? desc.getPosMemSize(rewriter, loc, lvl)
1422 : desc.getCrdMemSize(rewriter, loc, lvl);
1423 src = desc.getMemRefField(fid);
1424 dst = genToMemref(rewriter, loc, op.getOutLevels()[fid]);
1425 retMem.push_back(dst);
1426 // Retrieves the corresponding level length type.
1427 Type lvlLenTp = op.getLvlLens().getTypes()[retLen.size()];
1428 retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
1430 Value flatOut = dst;
1431 if (dst.getType().getRank() != 1) {
1432 auto reassoc = getReassociationForFlattening(dst.getType());
1433 flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
1435 Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
1436 Value srcMem = genSliceToSize(rewriter, loc, src, sz);
1437 rewriter.create<memref::CopyOp>(loc, srcMem, dstMem);
1438 return true;
1441 // Converts MemRefs back to Tensors.
1442 SmallVector<Value> retValues = llvm::to_vector(
1443 llvm::map_range(retMem, [&rewriter, loc](Value v) -> Value {
1444 return rewriter.create<bufferization::ToTensorOp>(loc, v);
1445 }));
1446 // Appends the actual memory length used in each buffer returned.
1447 retValues.append(retLen.begin(), retLen.end());
1448 rewriter.replaceOp(op, retValues);
1449 return success();
1453 struct SparseNewConverter : public OpConversionPattern<NewOp> {
1454 using OpConversionPattern::OpConversionPattern;
1455 LogicalResult
1456 matchAndRewrite(NewOp op, OpAdaptor adaptor,
1457 ConversionPatternRewriter &rewriter) const override {
1458 Location loc = op.getLoc();
1459 const auto dstTp = getSparseTensorType(op.getResult());
1460 // Creating COO with NewOp is handled by direct IR codegen. All other cases
1461 // are handled by rewriting.
1462 if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
1463 return failure();
1465 // Implement as follows:
1466 // %reader = @createCheckedSparseTensorReader(%filename)
1467 // %nse = @getSparseTensorNSE(%reader)
1468 // %coo = bufferization.alloc_tensor an ordered COO with
1469 // dst dim ordering, size_hint = %nse
1470 // %coordinates = sparse_tensor.coordinates_buffer(%coo)
1471 // %values = sparse_tensor.values(%coo)
1472 // %isSorted = @sparseTensorReaderReadToBuffers(%coordinates, %values)
1473 // if (! %isSorted) sparse_tensor.sort_coo(%nse, %coordinates, %values)
1474 // update storage specifier
1475 // @delSparseTensorReader(%reader)
1476 SmallVector<Value> dimShapesValues;
1477 Value dimSizesBuffer;
1478 Value reader = genReader(rewriter, loc, dstTp, adaptor.getOperands()[0],
1479 dimShapesValues, dimSizesBuffer);
1481 // Get the number of stored entries.
1482 const Type indexTp = rewriter.getIndexType();
1483 Value nse = createFuncCall(rewriter, loc, "getSparseTensorReaderNSE",
1484 {indexTp}, {reader}, EmitCInterface::Off)
1485 .getResult(0);
1487 // Construct allocation for each field.
1488 SmallVector<Value> dynSizes;
1489 if (dstTp.hasDynamicDimShape()) {
1490 for (const auto &d : llvm::enumerate(dstTp.getDimShape()))
1491 if (ShapedType::isDynamic(d.value()))
1492 dynSizes.push_back(rewriter.create<memref::LoadOp>(
1493 loc, dimSizesBuffer, constantIndex(rewriter, loc, d.index())));
1495 SmallVector<Value> fields;
1496 createAllocFields(rewriter, loc, dstTp, dynSizes, /*enableInit=*/false,
1497 fields, nse);
1498 MutSparseTensorDescriptor desc(dstTp, fields);
1500 // Now construct the dim2lvl and lvl2dim buffers.
1501 Value dim2lvlBuffer;
1502 Value lvl2dimBuffer;
1503 genMapBuffers(rewriter, loc, dstTp, dimShapesValues, dimSizesBuffer,
1504 dim2lvlBuffer, lvl2dimBuffer);
1506 // Read the COO tensor data.
1507 Value xs = desc.getAOSMemRef();
1508 Value ys = desc.getValMemRef();
1509 const Type boolTp = rewriter.getIntegerType(1);
1510 const Type elemTp = dstTp.getElementType();
1511 const Type crdTp = dstTp.getCrdType();
1512 SmallString<32> readToBuffersFuncName{"getSparseTensorReaderReadToBuffers",
1513 overheadTypeFunctionSuffix(crdTp),
1514 primaryTypeFunctionSuffix(elemTp)};
1515 Value isSorted =
1516 createFuncCall(rewriter, loc, readToBuffersFuncName, {boolTp},
1517 {reader, dim2lvlBuffer, lvl2dimBuffer, xs, ys},
1518 EmitCInterface::On)
1519 .getResult(0);
1521 // If the destination tensor is a sorted COO, we need to sort the COO tensor
1522 // data if the input elements aren't sorted yet.
1523 const Level lvlRank = dstTp.getLvlRank();
1524 if (dstTp.isOrderedLvl(lvlRank - 1)) {
1525 Value kFalse = constantI1(rewriter, loc, false);
1526 Value notSorted = rewriter.create<arith::CmpIOp>(
1527 loc, arith::CmpIPredicate::eq, isSorted, kFalse);
1528 scf::IfOp ifOp =
1529 rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
1530 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
1531 auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
1532 rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
1533 rewriter.getIndexAttr(0),
1534 SparseTensorSortKind::HybridQuickSort);
1535 rewriter.setInsertionPointAfter(ifOp);
1538 // Set PosMemRef0[1] = nse.
1539 const Value c1 = constantIndex(rewriter, loc, 1);
1540 const Value posMemref0 = desc.getPosMemRef(0);
1541 const Type posTp = dstTp.getPosType();
1542 const Value posNse = genCast(rewriter, loc, nse, posTp);
1543 rewriter.create<memref::StoreOp>(loc, posNse, posMemref0, c1);
1545 // Update storage specifier.
1546 Value coordinatesSize = rewriter.create<arith::MulIOp>(
1547 loc, nse, constantIndex(rewriter, loc, lvlRank));
1548 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::CrdMemSize, 0,
1549 coordinatesSize);
1550 desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::ValMemSize,
1551 std::nullopt, nse);
1553 // Release the sparse tensor reader.
1554 createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
1555 EmitCInterface::Off);
1557 // Replace operation with resulting memrefs.
1558 rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
1559 return success();
1563 } // namespace
1565 //===----------------------------------------------------------------------===//
1566 // Public method for populating conversion rules.
1567 //===----------------------------------------------------------------------===//
1569 /// Populates the given patterns list with conversion rules required for
1570 /// the sparsification of linear algebra operations.
1571 void mlir::populateSparseTensorCodegenPatterns(
1572 TypeConverter &typeConverter, RewritePatternSet &patterns,
1573 bool createSparseDeallocs, bool enableBufferInitialization) {
1574 patterns.add<SparseAssembleOpConverter, SparseDisassembleOpConverter,
1575 SparseReturnConverter, SparseCallConverter, SparseLvlOpConverter,
1576 SparseCastConverter, SparseExtractSliceConverter,
1577 SparseTensorLoadConverter, SparseExpandConverter,
1578 SparseCompressConverter, SparseInsertConverter,
1579 SparseReorderCOOConverter, SparseReMapConverter,
1580 SparseSliceGetterOpConverter<ToSliceOffsetOp,
1581 StorageSpecifierKind::DimOffset>,
1582 SparseSliceGetterOpConverter<ToSliceStrideOp,
1583 StorageSpecifierKind::DimStride>,
1584 SparseToPositionsConverter, SparseToCoordinatesConverter,
1585 SparseToCoordinatesBufferConverter, SparseToValuesConverter,
1586 SparseConvertConverter, SparseNewConverter,
1587 SparseNumberOfEntriesConverter>(typeConverter,
1588 patterns.getContext());
1589 patterns.add<SparseTensorDeallocConverter>(
1590 typeConverter, patterns.getContext(), createSparseDeallocs);
1591 patterns.add<SparseTensorAllocConverter, SparseTensorEmptyConverter>(
1592 typeConverter, patterns.getContext(), enableBufferInitialization);