[gn build] Port 69b8cf4f0621
[llvm-project.git] / mlir / lib / Dialect / SparseTensor / Transforms / SparseTensorRewriting.cpp
blob60db71d96547fe4bbec85d41d197adb82a8ff6c5
1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewriting rules that are specific to sparse tensors.
11 //===----------------------------------------------------------------------===//
13 #include "Utils/CodegenUtils.h"
14 #include "Utils/LoopEmitter.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Dialect/Vector/IR/VectorOps.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/Support/LLVM.h"
33 using namespace mlir;
34 using namespace mlir::bufferization;
35 using namespace mlir::linalg;
36 using namespace mlir::sparse_tensor;
38 //===---------------------------------------------------------------------===//
39 // Helper methods for the actual rewriting rules.
40 //===---------------------------------------------------------------------===//
42 // Helper method to match any typed zero.
43 static bool isZeroValue(Value val) {
44 return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
47 // Helper to detect a sparse tensor type operand.
48 static bool isSparseTensor(Value v) {
49 auto enc = getSparseTensorEncoding(v.getType());
50 return enc && !llvm::all_of(enc.getLvlTypes(),
51 [](auto lt) { return lt == LevelFormat::Dense; });
53 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
55 // Helper method to find zero/uninitialized tensor materialization.
56 static bool isMaterializing(OpOperand *op, bool isZero) {
57 Value val = op->get();
58 // Check allocation, with zero alloc when required.
59 if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
60 Value copy = alloc.getCopy();
61 if (isZero)
62 return copy && isZeroValue(copy);
63 return !copy;
65 // Check for empty tensor materialization.
66 if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
67 return !isZero;
68 // Last resort for zero alloc: the whole value is zero.
69 return isZero && isZeroValue(val);
72 // Helper to detect sampling operation.
73 static bool isSampling(GenericOp op) {
74 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
75 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
76 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
77 // Both scalar input arguments used exactly once.
78 Value s1 = op.getBlock()->getArgument(0);
79 Value s2 = op.getBlock()->getArgument(1);
80 return (def->getOperand(0) == s1 && def->getOperand(1) == s2) ||
81 (def->getOperand(1) == s1 && def->getOperand(0) == s2);
84 return false;
87 // Helper to detect chain of multiplications that do not involve x.
88 static bool isMulChain(Value val, Value x) {
89 if (auto arg = dyn_cast<BlockArgument>(val))
90 return arg != x;
91 if (auto *def = val.getDefiningOp()) {
92 if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def))
93 return isMulChain(def->getOperand(0), x) &&
94 isMulChain(def->getOperand(1), x);
96 return false;
99 // Helper to detect x = x + <multiplications>.
100 static bool isSumOfMul(GenericOp op) {
101 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
102 if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
103 if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
104 Value x = op.getBlock()->getArguments().back();
105 return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) ||
106 (def->getOperand(1) == x && isMulChain(def->getOperand(0), x));
109 return false;
112 // Helper to detect direct yield of a zero value.
113 static bool isZeroYield(GenericOp op) {
114 auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
115 if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
116 if (arg.getOwner()->getParentOp() == op) {
117 return isZeroValue(op->getOperand(arg.getArgNumber()));
120 return isZeroValue(yieldOp.getOperand(0));
123 /// Populates given sizes array from type (for static sizes) and from
124 /// the tensor (for dynamic sizes).
125 static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
126 Location loc, ShapedType stp, Value tensor) {
127 for (const auto &d : enumerate(stp.getShape())) {
128 Value dim;
129 if (d.value() == ShapedType::kDynamic)
130 dim = builder.create<tensor::DimOp>(loc, tensor, d.index());
131 else
132 dim = constantIndex(builder, loc, d.value());
133 sizes.push_back(dim);
137 static RankedTensorType getBufferType(const SparseTensorType &stt,
138 bool needTmpCOO) {
139 return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
140 : stt.getRankedTensorType();
143 /// Collects the dynamic dimension sizes for `tp` with the assumption that
144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145 /// sizes to dynSizes.
146 static void getDynamicSizes(RankedTensorType tp, ValueRange sizes,
147 SmallVectorImpl<Value> &dynSizes) {
148 for (const auto &d : enumerate(tp.getShape())) {
149 if (d.value() == ShapedType::kDynamic)
150 dynSizes.push_back(sizes[d.index()]);
154 static LogicalResult genForeachOnSparseConstant(ForeachOp op,
155 RewriterBase &rewriter,
156 SparseElementsAttr attr) {
157 auto loc = op.getLoc();
158 SmallVector<Value> reduc = op.getInitArgs();
160 // Foreach on constant.
161 foreachInSparseConstant(
162 rewriter, loc, attr, op.getOrder().value_or(AffineMap()),
163 [&reduc, &rewriter, op](ArrayRef<Value> cvs, Value v) mutable {
164 SmallVector<Value> args;
165 args.append(cvs.begin(), cvs.end());
166 args.push_back(v);
167 args.append(reduc);
168 // Clones the foreach op to get a copy of the loop body.
169 auto cloned = cast<ForeachOp>(rewriter.clone(*op.getOperation()));
170 assert(args.size() == cloned.getBody()->getNumArguments());
171 Operation *yield = cloned.getBody()->getTerminator();
172 rewriter.inlineBlockBefore(cloned.getBody(), op, args);
173 // clean up
174 rewriter.eraseOp(cloned);
175 reduc = yield->getOperands();
176 rewriter.eraseOp(yield);
179 rewriter.replaceOp(op, reduc);
180 return success();
183 /// Populates the given sizes array for concatenation from types (for static
184 /// sizes) and from the source tensors (for dynamic sizes).
185 static void concatSizesFromInputs(OpBuilder &builder,
186 SmallVectorImpl<Value> &sizes, Location loc,
187 ShapedType dstTp, ValueRange srcs,
188 unsigned dim) {
189 auto dstShape = dstTp.getShape();
190 sizesFromSrc(builder, sizes, loc, srcs[0]);
192 // Sum up on the `dim` if the dimension is dynamic.
193 if (dstShape[dim] != ShapedType::kDynamic) {
194 // Faithfully take the static size.
195 sizes[dim] = constantIndex(builder, loc, dstShape[dim]);
196 } else {
197 // Else, compute the shape dynamically.
198 for (const auto &src : srcs.drop_front()) {
199 Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim);
200 // Sum up all the sizes.
201 sizes[dim] = builder.create<arith::AddIOp>(loc, sizes[dim], srcSz);
206 //===---------------------------------------------------------------------===//
207 // The actual sparse tensor rewriting rules.
208 //===---------------------------------------------------------------------===//
210 namespace {
212 /// TODO: move it to tensor dialect instead.
214 /// Fold `tensor.concat` and `tensor.extract_slice`
216 /// %concat = tensor.concat dim(2) %t0, %t1
217 /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218 /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220 /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
223 /// Becomes
225 /// %extract0, %extract1 = %t0, %t1
226 struct FuseExtractSliceWithConcat
227 : public OpRewritePattern<tensor::ExtractSliceOp> {
228 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
230 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
231 PatternRewriter &rewriter) const override {
232 auto concatOp = extractOp.getSource().getDefiningOp<tensor::ConcatOp>();
233 if (!concatOp)
234 return failure();
236 Location loc = extractOp.getLoc();
237 int64_t dim = concatOp.getDim();
238 int64_t rank = extractOp.getResultType().getRank();
240 SmallVector<OpFoldResult> srcStrides(rank, rewriter.getIndexAttr(1));
241 SmallVector<OpFoldResult> srcOffsets(rank, rewriter.getIndexAttr(0));
243 // Compute the partial sums for the slice offsets.
244 AffineExpr sum = rewriter.getAffineDimExpr(0);
245 SmallVector<AffineExpr> partialSums = {sum};
246 SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
247 for (auto [idx, input] :
248 llvm::enumerate(concatOp.getInputs().drop_back())) {
249 sum = sum + rewriter.getAffineDimExpr(idx + 1);
250 partialSums.push_back(sum);
251 offsetStrides.push_back(
252 rewriter.createOrFold<tensor::DimOp>(loc, input, dim));
254 auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
255 partialSums, rewriter.getContext());
256 SmallVector<OpFoldResult> dimOffsets =
257 affine::makeComposedFoldedMultiResultAffineApply(
258 rewriter, loc, partialSumMap, offsetStrides);
260 auto allEqual = [](ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
261 for (auto [l, r] : llvm::zip(lhs, rhs)) {
262 std::optional<int64_t> staticVal = getConstantIntValue(l);
263 if (!staticVal.has_value() || staticVal != getConstantIntValue(r))
264 return false;
266 return lhs.size() == rhs.size();
269 for (auto [i, input, offset] :
270 llvm::enumerate(concatOp.getInputs(), dimOffsets)) {
271 SmallVector<OpFoldResult> srcSizes =
272 tensor::getMixedSizes(rewriter, loc, input);
273 srcOffsets[dim] = offset;
275 SmallVector<OpFoldResult> dstSizes = extractOp.getMixedSizes();
276 SmallVector<OpFoldResult> dstOffsets = extractOp.getMixedOffsets();
277 SmallVector<OpFoldResult> dstStrides = extractOp.getMixedStrides();
279 if (allEqual(srcSizes, dstSizes) && allEqual(srcOffsets, dstOffsets) &&
280 allEqual(srcStrides, dstStrides)) {
281 Value operand = concatOp.getOperand(i);
282 if (operand.getType() == extractOp.getResultType())
283 rewriter.replaceOp(extractOp, operand);
284 break;
288 return success();
292 /// Rewriting rule that fuses sparse_tensor.convert into producer.
293 struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
294 public:
295 using OpRewritePattern::OpRewritePattern;
297 LogicalResult matchAndRewrite(ConvertOp op,
298 PatternRewriter &rewriter) const override {
299 auto producer = op.getSource().getDefiningOp<GenericOp>();
300 if (!producer || producer.getDpsInits().size() != 1 ||
301 !isMaterializing(producer.getDpsInitOperand(0), false) ||
302 !producer.getResult(0).hasOneUse()) {
303 return failure();
305 // Clone the materialization operation, but update the result to sparse.
306 rewriter.setInsertionPoint(producer);
307 Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
308 Operation *cloned = rewriter.clone(*init);
309 cloned->getResult(0).setType(op.getResult().getType());
311 rewriter.modifyOpInPlace(producer, [&]() {
312 producer.getDpsInitsMutable().assign(cloned->getResults());
313 producer.getResult(0).setType(op.getResult().getType());
316 rewriter.replaceAllOpUsesWith(op, producer);
317 op->erase();
319 return success();
323 /// Rewriting rule that converts direct yield of zero with initial allocation.
324 struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
325 public:
326 using OpRewritePattern<GenericOp>::OpRewritePattern;
328 LogicalResult matchAndRewrite(GenericOp op,
329 PatternRewriter &rewriter) const override {
330 if (!op.hasPureTensorSemantics() || op.getNumResults() != 1 ||
331 !isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
332 !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
333 return failure();
334 auto outputType = getRankedTensorType(op.getResult(0));
335 // Yielding zero on newly materialized sparse tensor can be
336 // optimized directly (regardless of dynamic or static size).
337 if (getSparseTensorEncoding(outputType)) {
338 rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
339 return success();
341 // Use static zero value directly instead of materialization.
342 if (!outputType.hasStaticShape())
343 return failure();
344 Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
345 rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
346 rewriter.eraseOp(def);
347 return success();
351 /// Rewriting rule that converts two kernels:
353 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
354 /// X(i,j) = S(i,j) * T(i,j)
356 /// into a single kernel, using distributive law:
358 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
360 /// This kind of fusion (merging two ops into one but using arithmetic
361 /// equalities that may not hold for floating-point computations) would
362 /// be undesirable in the dense case, since we distribute the multiplication
363 /// into the reduction loop. However, for sparse sampling tensor S, such
364 /// a fusion may actually reduce the asymptotic complexity of the kernel,
365 /// since intermediate results may be nullified.
366 struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
367 public:
368 using OpRewritePattern<GenericOp>::OpRewritePattern;
370 LogicalResult matchAndRewrite(GenericOp op,
371 PatternRewriter &rewriter) const override {
372 // Check consumer.
373 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 2 ||
374 op.getNumResults() != 1 ||
375 op.getNumParallelLoops() != op.getNumLoops() ||
376 !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() ||
377 !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() ||
378 !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity())
379 return failure();
380 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
381 // operand can be sparse or dense, since the point of this rewriting rule
382 // is detecting a situation in which *more* sparsity is introduced into
383 // a computation, be it already sparse or still dense.
384 unsigned other = 0;
385 if (isSparseTensor(op.getDpsInputOperand(0)))
386 other = 1;
387 else if (!isSparseTensor(op.getDpsInputOperand(1)))
388 return failure();
389 // Check producer.
390 auto prod = dyn_cast_or_null<GenericOp>(
391 op.getDpsInputOperand(other)->get().getDefiningOp());
392 if (!prod || !prod.hasPureTensorSemantics() || prod.getNumResults() != 1 ||
393 !prod.getResult(0).hasOneUse())
394 return failure();
395 // Sampling consumer and sum of multiplication chain producer.
396 if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
397 !isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
398 !isSampling(op) || !isSumOfMul(prod))
399 return failure();
400 // Modify operand structure of producer and consumer.
401 Location loc = prod.getLoc();
402 SmallVector<Value> inputOps = prod.getInputs();
403 SmallVector<Value> outputOps = op.getOutputs();
404 SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
405 inputOps.push_back(op.getDpsInputOperand(1 - other)->get());
406 fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
407 // Fuse producer and consumer into a new generic op.
408 auto fusedOp = rewriter.create<GenericOp>(
409 loc, op.getResult(0).getType(), inputOps, outputOps,
410 rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.getIteratorTypes(),
411 /*doc=*/nullptr, /*library_call=*/nullptr);
412 Block &prodBlock = prod.getRegion().front();
413 Block &consBlock = op.getRegion().front();
414 IRMapping mapper;
415 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
416 unsigned num = prodBlock.getNumArguments();
417 for (unsigned i = 0; i < num - 1; i++)
418 addArg(mapper, fusedBlock, prodBlock.getArgument(i));
419 addArg(mapper, fusedBlock, consBlock.getArgument(1 - other));
420 addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1));
421 // Clone bodies of the producer and consumer in new evaluation order.
422 auto *acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp();
423 auto *sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp();
424 Value last;
425 for (auto &op : prodBlock.without_terminator())
426 if (&op != acc) {
427 last = op.getResult(0);
428 rewriter.clone(op, mapper);
430 mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0));
431 mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
432 last = rewriter.clone(*acc, mapper)->getResult(0);
433 rewriter.create<linalg::YieldOp>(loc, last);
434 // Force initial value on merged allocation for dense outputs.
435 // TODO: deal with non alloc tensor here one day
436 if (!getSparseTensorEncoding(op.getResult(0).getType())) {
437 Value init = prod.getDpsInitOperand(0)
438 ->get()
439 .getDefiningOp<AllocTensorOp>()
440 .getCopy();
441 AllocTensorOp a =
442 op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
443 rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
445 // Replace consumer with fused operation. Old producer
446 // and consumer ops will be removed by DCE.
447 rewriter.replaceOp(op, fusedOp->getResults());
448 return success();
451 private:
452 // Helper to add argument and record the mapping.
453 static void addArg(IRMapping &mapper, Block *b, BlockArgument a) {
454 mapper.map(a, b->addArgument(a.getType(), a.getLoc()));
458 // Fuse a tensor cast into producing operation. Note that a tensor.cast
459 // should really not be used to convert between sparse encodings. Since
460 // the pattern currently appears as a result of some prior rewriting
461 // we make an attempt to repair very obvious cases.
462 // TODO: audit the pure tensor dialect rewriting rules
463 struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
464 public:
465 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
467 LogicalResult matchAndRewrite(tensor::CastOp op,
468 PatternRewriter &rewriter) const override {
469 Type srcType = op.getSource().getType();
470 Type dstType = op.getDest().getType();
471 // A nop cast simply folds away.
472 if (srcType == dstType) {
473 rewriter.replaceOp(op, op->getResults());
474 return success();
476 // See if a sparsity changing cast can be fused into producer.
477 if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
478 if (Operation *def = op.getSource().getDefiningOp()) {
479 if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
480 rewriter.modifyOpInPlace(def, [&]() {
481 def->getResult(0).setType(op->getResultTypes()[0]);
483 rewriter.replaceOp(op, def->getResult(0));
484 return success();
488 // Repair tensor casts with at least one sparse operand into the
489 // the properly supported sparse_tensor.convert.
490 if (getSparseTensorEncoding(srcType) || getSparseTensorEncoding(dstType)) {
491 rewriter.replaceOpWithNewOp<ConvertOp>(op, dstType, op.getSource());
492 return success();
494 // Fail otherwise.
495 return failure();
499 /// Rewrites a sequence of operations for sparse tensor selections in to
500 /// semi-ring operations such that they can be compiled correctly by the
501 /// sparsifier. E.g., transforming the following sequence
503 /// %sel = arith.select %cond, %sp1, %sp2
505 /// to
507 /// %sel = binary %sp1, %sp2:
508 /// both (%l, %r) {yield select %cond, %l, %r}
509 /// left (%l) {yield select %cond, %l, 0}
510 /// right (%r) {yield select %cond, 0, %r}
512 /// TODO: We require that the tensor used for extracting conditions to be dense
513 /// to sparsify the code. To support a sparse condition tensor, we need a
514 /// tri-nary operation.
515 struct GenSemiRingSelect : public OpRewritePattern<GenericOp> {
516 public:
517 using OpRewritePattern<GenericOp>::OpRewritePattern;
518 LogicalResult matchAndRewrite(GenericOp op,
519 PatternRewriter &rewriter) const override {
520 // Rejects non sparse kernels.
521 if (!op.hasPureTensorSemantics() || !hasAnySparseOperand(op))
522 return failure();
524 Location loc = op.getLoc();
525 SmallVector<std::pair<Operation *, sparse_tensor::BinaryOp>> semiRings;
526 for (Operation &inst : *op.getBody()) {
527 // Matches pattern.
528 auto matched = isRewritablePattern(op, &inst);
529 if (!matched.has_value())
530 continue;
532 rewriter.setInsertionPoint(&inst);
533 auto [c, t, f] = matched.value();
534 assert(t.getType() == f.getType());
535 auto selTp = t.getType();
536 auto c0 = constantZero(rewriter, loc, selTp);
537 auto binOp = rewriter.create<sparse_tensor::BinaryOp>(loc, selTp, t, f);
538 // Initializes all the blocks.
539 rewriter.createBlock(&binOp.getOverlapRegion(), {}, {selTp, selTp},
540 {t.getLoc(), f.getLoc()});
541 rewriter.createBlock(&binOp.getRightRegion(), {}, selTp, f.getLoc());
542 rewriter.createBlock(&binOp.getLeftRegion(), {}, selTp, t.getLoc());
544 for (auto *r : binOp.getRegions()) {
545 Block *b = &r->front();
546 rewriter.setInsertionPointToStart(b);
548 IRMapping irMap;
549 // Clones the cmp operations into the region to make the binary op
550 // admissible.
551 Value newC = c;
552 if (auto *def = c.getDefiningOp())
553 newC = rewriter.clone(*def, irMap)->getResult(0);
555 irMap.map(c, newC);
556 if (r == &binOp.getLeftRegion()) {
557 irMap.map(t, b->getArgument(0));
558 irMap.map(f, c0);
559 } else if (r == &binOp.getRightRegion()) {
560 irMap.map(t, c0);
561 irMap.map(f, b->getArgument(0));
562 } else {
563 irMap.map(t, b->getArgument(0));
564 irMap.map(f, b->getArgument(1));
566 auto y = rewriter.clone(inst, irMap)->getResult(0);
567 rewriter.create<sparse_tensor::YieldOp>(loc, y);
570 // We successfully rewrited a operation. We can not do replacement here
571 // becuase it invalidate the iterator for the current loop to traverse
572 // the instructions.
573 semiRings.emplace_back(&inst, binOp);
576 // Finalizes the replacement.
577 for (auto [sel, semi] : semiRings)
578 rewriter.replaceOp(sel, semi->getResults());
580 return success(!semiRings.empty());
583 private:
584 static std::optional<std::tuple<Value, BlockArgument, BlockArgument>>
585 isRewritablePattern(GenericOp op, Operation *v) {
586 auto sel = dyn_cast<arith::SelectOp>(v);
587 if (!sel)
588 return std::nullopt;
590 auto tVal = dyn_cast<BlockArgument>(sel.getTrueValue());
591 auto fVal = dyn_cast<BlockArgument>(sel.getFalseValue());
592 // TODO: For simplicity, we only handle cases where both true/false value
593 // are directly loaded the input tensor. We can probably admit more cases
594 // in theory.
595 if (!tVal || !fVal)
596 return std::nullopt;
598 // Helper lambda to determine whether the value is loaded from a dense input
599 // or is a loop invariant.
600 auto isValFromDenseInputOrInvariant = [&op](Value v) -> bool {
601 if (auto bArg = dyn_cast<BlockArgument>(v);
602 bArg && !isSparseTensor(op.getDpsInputOperand(bArg.getArgNumber())))
603 return true;
604 // If the value is defined outside the loop, it is a loop invariant.
605 return v.getDefiningOp() && v.getDefiningOp()->getBlock() != op.getBody();
608 // If the condition value is load directly from a dense tensor or
609 // loop-invariants, we can sparsify the kernel.
610 auto cond = sel.getCondition();
611 if (isValFromDenseInputOrInvariant(cond))
612 return std::make_tuple(cond, tVal, fVal);
614 Value cmpL, cmpR;
615 if (matchPattern(cond, m_Op<arith::CmpIOp>(matchers::m_Any(&cmpL),
616 matchers::m_Any(&cmpR))) ||
617 matchPattern(cond, m_Op<arith::CmpFOp>(matchers::m_Any(&cmpL),
618 matchers::m_Any(&cmpR)))) {
619 // TODO: we can do it recursively to check whether all the leaf values are
620 // loaded from dense tensors or are loop invariants.
621 if (isValFromDenseInputOrInvariant(cmpL) ||
622 isValFromDenseInputOrInvariant(cmpR))
623 return std::make_tuple(cond, tVal, fVal);
626 return std::nullopt;
630 /// Rewrites a sparse reduction that would not sparsify directly since
631 /// doing so would only iterate over the stored elements, ignoring the
632 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
633 /// (note that reductions like add/sub/or/xor can directly be sparsified
634 /// since the implicit zeros do not contribute to the final result).
635 /// Note that prod/and are still included since, even though they often
636 /// are nullified in sparse data, they may still occur for special
637 /// situations in which e.g. some rows in a sparse matrix are fully
638 /// dense. For min/max, including the implicit zeros is a much more
639 /// common situation.
641 /// TODO: this essentially "densifies" the operation; we want to implement
642 /// this much more efficiently by performing the reduction over the
643 /// stored values, and feed in the zero once if there were *any*
644 /// implicit zeros as well; but for now, at least we provide
645 /// the functionality
647 struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
648 public:
649 using OpRewritePattern<GenericOp>::OpRewritePattern;
651 LogicalResult matchAndRewrite(GenericOp op,
652 PatternRewriter &rewriter) const override {
653 // Reject non-reductions.
654 if (!op.hasPureTensorSemantics() || op.getNumDpsInputs() != 1 ||
655 op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
656 return failure();
657 auto *inp = op.getDpsInputOperand(0);
658 auto *init = op.getDpsInitOperand(0);
659 if (!isSparseTensor(inp))
660 return failure();
661 // Look for direct x = x OP y for semi-ring ready reductions.
662 auto *red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
663 .getOperand(0)
664 .getDefiningOp();
665 if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinimumFOp,
666 arith::MinSIOp, arith::MinUIOp, arith::MaximumFOp, arith::MaxSIOp,
667 arith::MaxUIOp>(red))
668 return failure();
669 Value s0 = op.getBlock()->getArgument(0);
670 Value s1 = op.getBlock()->getArgument(1);
671 if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
672 (red->getOperand(0) != s1 || red->getOperand(1) != s0))
673 return failure();
674 // Identity.
675 Location loc = op.getLoc();
676 Value identity =
677 rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
678 // Unary {
679 // present -> value
680 // absent -> zero.
681 // }
682 Type rtp = s0.getType();
683 rewriter.setInsertionPointToStart(&op.getRegion().front());
684 auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
685 Block *present =
686 rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
687 rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
688 rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
689 rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
690 rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
691 auto zero =
692 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
693 rewriter.create<sparse_tensor::YieldOp>(loc, zero);
694 rewriter.setInsertionPointAfter(semiring);
695 // CustomReduce {
696 // x = x REDUC y, identity
697 // }
698 auto custom = rewriter.create<sparse_tensor::ReduceOp>(
699 loc, rtp, semiring.getResult(), s1, identity);
700 Block *region =
701 rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
702 rewriter.setInsertionPointToStart(&custom.getRegion().front());
703 IRMapping irMap;
704 irMap.map(red->getOperand(0), region->getArgument(0));
705 irMap.map(red->getOperand(1), region->getArgument(1));
706 auto *cloned = rewriter.clone(*red, irMap);
707 rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
708 rewriter.setInsertionPointAfter(custom);
709 rewriter.replaceOp(red, custom.getResult());
710 return success();
714 /// Sparse rewriting rule for the print operator. This operation is mainly used
715 /// for debugging and testing. As such, it lowers to the vector.print operation
716 /// which only require very light-weight runtime support.
717 struct PrintRewriter : public OpRewritePattern<PrintOp> {
718 public:
719 using OpRewritePattern::OpRewritePattern;
720 LogicalResult matchAndRewrite(PrintOp op,
721 PatternRewriter &rewriter) const override {
722 Location loc = op.getLoc();
723 auto tensor = op.getTensor();
724 auto stt = getSparseTensorType(tensor);
725 // Header with NSE.
726 auto nse = rewriter.create<NumberOfEntriesOp>(loc, tensor);
727 rewriter.create<vector::PrintOp>(
728 loc, rewriter.getStringAttr("---- Sparse Tensor ----\nnse = "));
729 rewriter.create<vector::PrintOp>(loc, nse);
730 // Print run-time contents for dim/lvl sizes.
731 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("dim = "));
732 printSizes(rewriter, loc, tensor, stt.getDimRank(), /*isDim=*/true);
733 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("lvl = "));
734 printSizes(rewriter, loc, tensor, stt.getLvlRank(), /*isDim=*/false);
735 // Use the "codegen" foreach loop construct to iterate over
736 // all typical sparse tensor components for printing.
737 foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
738 &stt](Type, FieldIndex,
739 SparseTensorFieldKind kind,
740 Level l, LevelType) {
741 switch (kind) {
742 case SparseTensorFieldKind::StorageSpec: {
743 break;
745 case SparseTensorFieldKind::PosMemRef: {
746 auto lvl = constantIndex(rewriter, loc, l);
747 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("pos["));
748 rewriter.create<vector::PrintOp>(
749 loc, lvl, vector::PrintPunctuation::NoPunctuation);
750 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
751 auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
752 printContents(rewriter, loc, pos);
753 break;
755 case SparseTensorFieldKind::CrdMemRef: {
756 auto lvl = constantIndex(rewriter, loc, l);
757 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("crd["));
758 rewriter.create<vector::PrintOp>(
759 loc, lvl, vector::PrintPunctuation::NoPunctuation);
760 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
761 Value crd = nullptr;
762 // For COO AoS storage, we want to print a single, linear view of
763 // the full coordinate storage at this level. For any other storage,
764 // we show the coordinate storage for every indivual level.
765 if (stt.getAoSCOOStart() == l)
766 crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
767 else
768 crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
769 printContents(rewriter, loc, crd);
770 break;
772 case SparseTensorFieldKind::ValMemRef: {
773 rewriter.create<vector::PrintOp>(loc,
774 rewriter.getStringAttr("values : "));
775 auto val = rewriter.create<ToValuesOp>(loc, tensor);
776 printContents(rewriter, loc, val);
777 break;
780 return true;
782 rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("----\n"));
783 rewriter.eraseOp(op);
784 return success();
787 private:
788 // Helper to print contents of a single memref. For "push_back" vectors,
789 // we assume that the previous getters for pos/crd/val have added a
790 // slice-to-size view to make sure we just print the size and not the
791 // full capacity.
793 // Generates code to print (1-dim or higher):
794 // ( a0, a1, ... )
795 static void printContents(PatternRewriter &rewriter, Location loc,
796 Value vec) {
797 auto shape = cast<ShapedType>(vec.getType()).getShape();
798 SmallVector<Value> idxs;
799 printContentsLevel(rewriter, loc, vec, 0, shape, idxs);
800 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
803 // Helper to the helper.
804 static void printContentsLevel(PatternRewriter &rewriter, Location loc,
805 Value vec, unsigned i, ArrayRef<int64_t> shape,
806 SmallVectorImpl<Value> &idxs) {
807 // Open bracket.
808 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
809 // Generate for loop.
810 auto zero = constantIndex(rewriter, loc, 0);
811 auto index = constantIndex(rewriter, loc, i);
812 auto size = rewriter.create<memref::DimOp>(loc, vec, index);
813 auto step = constantIndex(rewriter, loc, 1);
814 auto forOp = rewriter.create<scf::ForOp>(loc, zero, size, step);
815 idxs.push_back(forOp.getInductionVar());
816 rewriter.setInsertionPointToStart(forOp.getBody());
817 if (i < shape.size() - 1) {
818 // Enter deeper loop nest.
819 printContentsLevel(rewriter, loc, vec, i + 1, shape, idxs);
820 } else {
821 // Actual contents printing.
822 auto val = rewriter.create<memref::LoadOp>(loc, vec, idxs);
823 if (llvm::isa<ComplexType>(val.getType())) {
824 // Since the vector dialect does not support complex types in any op,
825 // we split those into (real, imag) pairs here.
826 Value real = rewriter.create<complex::ReOp>(loc, val);
827 Value imag = rewriter.create<complex::ImOp>(loc, val);
828 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
829 rewriter.create<vector::PrintOp>(loc, real,
830 vector::PrintPunctuation::Comma);
831 rewriter.create<vector::PrintOp>(loc, imag,
832 vector::PrintPunctuation::Close);
833 } else {
834 rewriter.create<vector::PrintOp>(
835 loc, val, vector::PrintPunctuation::NoPunctuation);
837 // Terminating comma (except at end).
838 auto bound = rewriter.create<arith::AddIOp>(loc, idxs.back(), step);
839 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
840 bound, size);
841 scf::IfOp ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ false);
842 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
843 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Comma);
845 idxs.pop_back();
846 rewriter.setInsertionPointAfter(forOp);
847 // Close bracket.
848 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
851 // Helper method to print run-time lvl/dim sizes.
852 static void printSizes(PatternRewriter &rewriter, Location loc, Value tensor,
853 unsigned size, bool isDim) {
854 // Open bracket.
855 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
856 // Print unrolled contents (dimop requires constant value).
857 for (unsigned i = 0; i < size; i++) {
858 auto idx = constantIndex(rewriter, loc, i);
859 Value val;
860 if (isDim)
861 val = rewriter.create<tensor::DimOp>(loc, tensor, idx);
862 else
863 val = rewriter.create<LvlOp>(loc, tensor, idx);
864 rewriter.create<vector::PrintOp>(
865 loc, val,
866 i != size - 1 ? vector::PrintPunctuation::Comma
867 : vector::PrintPunctuation::NoPunctuation);
869 // Close bracket and end of line.
870 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Close);
871 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::NewLine);
875 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
876 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
877 public:
878 using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
880 LogicalResult matchAndRewrite(tensor::ReshapeOp op,
881 PatternRewriter &rewriter) const override {
882 Location loc = op.getLoc();
883 Value srcTensor = op.getSource();
884 const auto srcTp = tryGetSparseTensorType(srcTensor);
885 const auto dstTp = tryGetSparseTensorType(op.getResult());
886 if (!srcTp || !dstTp)
887 return failure();
889 if (!srcTp->hasEncoding() || !dstTp->hasEncoding() ||
890 !dstTp->hasStaticDimShape())
891 return failure();
893 SmallVector<Value> srcSizes;
894 sizesForTensor(rewriter, srcSizes, loc, *srcTp, srcTensor);
895 SmallVector<Value> dstSizes;
896 for (Dimension d : dstTp->getDimShape())
897 dstSizes.push_back(constantIndex(rewriter, loc, d));
899 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
900 // Only need an unordered COO buffer if input and output are not sorted
901 // in the same way.
902 Type bufferTp = getBufferType(
903 dstTp->withoutDimToLvl(),
904 !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity());
905 SmallVector<Value> dynSizes;
906 Value buffer = rewriter
907 .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
908 nnz, Attribute())
909 .getResult();
911 // Convert src coordinates to dst coordinates by first collapsing it to 1D
912 // and then expand it to the match the rank of the destination tensor.
913 // Implemented as follows:
914 // foreach srcCoords %srcTensor
915 // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
916 // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
917 // insert expandedCoords, %buffer
919 // followed by an optional
920 // %t = sparse_tensor.cast %tmp
921 // depending on whether the input/output are sorted in the same way.
922 const auto encSrc = srcTp->getEncoding();
923 ForeachOp foreachOp = rewriter.create<ForeachOp>(
924 loc, srcTensor, buffer,
925 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
926 ValueRange reduc) {
927 const Dimension srcRank = srcTp->getDimRank();
928 SmallVector<Value> srcDcvs;
929 srcDcvs.reserve(srcRank);
930 for (Dimension d = 0; d < srcRank; d++) {
931 Level lvl = toLvl(encSrc, d);
932 srcDcvs.push_back(srcLcvs[lvl]);
935 Value collapseSize = constantIndex(builder, loc, 1);
936 for (Dimension d = 0; d < srcRank; d++)
937 collapseSize =
938 builder.create<arith::MulIOp>(loc, collapseSize, srcSizes[d]);
939 SmallVector<Value, 1> collapsedSizes = {collapseSize};
941 ReassociationIndices collapseIdx;
942 for (Dimension i = 0; i < srcRank; i++)
943 collapseIdx.push_back(i);
944 SmallVector<ReassociationIndices, 1> collapseReass = {collapseIdx};
945 SmallVector<Value, 1> collapsedDcvs;
946 reshapeCvs(builder, loc, collapseReass, srcSizes, srcDcvs,
947 collapsedSizes, collapsedDcvs);
949 ReassociationIndices expandIdx;
950 for (Dimension i = 0; i < dstTp->getDimRank(); i++)
951 expandIdx.push_back(i);
952 SmallVector<ReassociationIndices, 1> expandReass = {expandIdx};
953 SmallVector<Value> dstDcvs;
954 reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
955 dstSizes, dstDcvs);
957 auto t =
958 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
959 builder.create<sparse_tensor::YieldOp>(loc, t);
962 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
963 if (bufferTp != *dstTp) {
964 auto dstRTT = dstTp->getRankedTensorType();
965 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
966 rewriter.create<DeallocTensorOp>(loc, t);
967 t = converted;
969 rewriter.replaceOp(op, t);
970 return success();
974 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
975 template <typename ReshapeOp>
976 struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
977 public:
978 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
980 LogicalResult matchAndRewrite(ReshapeOp op,
981 PatternRewriter &rewriter) const override {
982 Location loc = op.getLoc();
983 Value srcTensor = op.getSrc();
984 const auto srcTp = getSparseTensorType(srcTensor);
985 const auto dstTp = getSparseTensorType(op.getResult());
986 if (!srcTp.hasEncoding() || !dstTp.hasEncoding())
987 return failure();
989 // Generate code to represent the static dimension constants or compute
990 // the dynamic dimension values.
991 SmallVector<Value> srcSizes;
992 sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
993 SmallVector<Value> dstSizes;
994 SmallVector<Value> dstDynSizes;
995 if (dstTp.hasStaticDimShape()) {
996 for (Dimension d : dstTp.getDimShape())
997 dstSizes.push_back(constantIndex(rewriter, loc, d));
998 } else {
999 ArrayRef<Size> dstShape = dstTp.getDimShape();
1000 genReshapeDstShape(rewriter, loc, dstSizes, srcSizes, dstShape,
1001 op.getReassociationIndices());
1002 for (auto [idx, shape] : llvm::enumerate(dstShape)) {
1003 if (shape == ShapedType::kDynamic)
1004 dstDynSizes.push_back(dstSizes[idx]);
1007 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
1008 // Only need a unordered COO buffer if input and output are not sorted
1009 // in the same way.
1010 Type bufferTp = getBufferType(
1011 dstTp.withoutDimToLvl(),
1012 !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity());
1014 Value buffer =
1015 rewriter
1016 .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(),
1017 /*sizeHint=*/nnz, Attribute())
1018 .getResult();
1020 // Implement the sparse2sparse reshape as follows:
1021 // foreach srcCoords %srcTensor
1022 // insert reshapeCvs(srcCoords), %buffer
1024 // followed by an optional
1025 // %t = sparse_tensor.cast %tmp
1026 // depending on whether the input/output are sorted in the same way.
1027 const auto encSrc = srcTp.getEncoding();
1028 ForeachOp foreachOp = rewriter.create<ForeachOp>(
1029 loc, srcTensor, buffer,
1030 [&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
1031 ValueRange reduc) {
1032 const Dimension dimRank = srcTp.getDimRank();
1033 SmallVector<Value> srcDcvs;
1034 srcDcvs.reserve(dimRank);
1035 for (Dimension d = 0; d < dimRank; d++) {
1036 Level lvl = toLvl(encSrc, d);
1037 srcDcvs.push_back(srcLcvs[lvl]);
1039 SmallVector<Value> dstDcvs;
1040 reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
1041 srcDcvs, dstSizes, dstDcvs);
1042 auto t =
1043 builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
1044 builder.create<sparse_tensor::YieldOp>(loc, t);
1047 Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
1048 if (bufferTp != dstTp) {
1049 auto dstRTT = dstTp.getRankedTensorType();
1050 Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
1051 rewriter.create<DeallocTensorOp>(loc, t);
1052 t = converted;
1054 rewriter.replaceOp(op, t);
1055 return success();
1059 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1060 /// operator.
1061 template <typename ReshapeOp>
1062 struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
1063 public:
1064 using OpRewritePattern<ReshapeOp>::OpRewritePattern;
1066 LogicalResult matchAndRewrite(ReshapeOp op,
1067 PatternRewriter &rewriter) const override {
1068 Location loc = op->getLoc();
1069 auto encDst = getSparseTensorEncoding(op.getResult().getType());
1070 auto encSrc = getSparseTensorEncoding(op.getSrc().getType());
1071 // Since a pure dense expansion is very cheap (change of view), for
1072 // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1073 // conversion from the reshape operation itself.
1074 // All other cases are handled elsewhere.
1075 if (encDst && encSrc) {
1076 return failure();
1078 if (encSrc) {
1079 auto rtp = getRankedTensorType(op.getSrc());
1080 auto denseTp =
1081 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1082 auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
1083 rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
1084 return success();
1086 if (encDst) {
1087 auto rtp = getRankedTensorType(op.getResult());
1088 auto denseTp =
1089 RankedTensorType::get(rtp.getShape(), rtp.getElementType());
1090 ReshapeOp reshape;
1091 if constexpr (std::is_same<ReshapeOp, tensor::ExpandShapeOp>::value) {
1092 reshape = rewriter.create<ReshapeOp>(
1093 loc, denseTp, op.getSrc(), op.getReassociation(),
1094 op.getOutputShape(), op.getStaticOutputShape());
1095 } else {
1096 reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
1097 op.getReassociation());
1099 Value convert = rewriter.create<ConvertOp>(loc, rtp, reshape);
1100 rewriter.replaceOp(op, convert);
1101 return success();
1103 return failure();
1107 // A trivial wrapper to help generate different operations for dense/sparse
1108 // tensors.
1109 struct TensorLike {
1110 TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
1111 ValueRange sizes) {
1112 SmallVector<Value> dynSzs;
1113 getDynamicSizes(rtt, sizes, dynSzs);
1115 val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
1116 if (!isSparse()) {
1117 Value c0 = constantZero(builder, loc, rtt.getElementType());
1118 val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
1122 void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
1123 val = builder.create<tensor::InsertOp>(loc, v, val, crds);
1126 Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
1127 if (isSparse())
1128 return builder.create<LoadOp>(loc, val, true);
1129 return val;
1132 bool isSparse() const {
1133 return getSparseTensorEncoding(val.getType()) != nullptr;
1136 Value val;
1139 struct SparseTensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
1140 using OpRewritePattern::OpRewritePattern;
1141 LogicalResult matchAndRewrite(tensor::DimOp op,
1142 PatternRewriter &rewriter) const override {
1143 std::optional<int64_t> dim = op.getConstantIndex();
1144 auto stt = tryGetSparseTensorType(op.getSource());
1145 if (!dim || !stt || !stt->hasEncoding())
1146 return failure();
1148 if (stt->isPermutation()) {
1149 rewriter.replaceOpWithNewOp<LvlOp>(op, op.getSource(),
1150 toLvl(stt->getEncoding(), *dim));
1151 return success();
1154 // Non-permutation dim2lvl/lvl2dim maps.
1155 // Compute as follows:
1156 // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1157 // Note that it is not the most efficient way (but a more general one) for
1158 // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1159 // computed simply by lvl_size * block_size.
1160 Location loc = op.getLoc();
1161 SmallVector<Value> maxLvlCrds;
1162 for (Level l = 0; l < stt->getLvlRank(); l++) {
1163 Value lvlSz = rewriter.create<LvlOp>(loc, op.getSource(), l);
1164 Value maxLvlCrd = rewriter.create<arith::SubIOp>(
1165 loc, lvlSz, constantOne(rewriter, loc, rewriter.getIndexType()));
1166 maxLvlCrds.push_back(maxLvlCrd);
1169 AffineExpr lvl2DimExp = stt->getLvlToDim().getResult(*dim);
1170 Value maxDimCrd = rewriter.create<affine::AffineApplyOp>(
1171 op.getLoc(), AffineMap::get(stt->getLvlRank(), 0, lvl2DimExp),
1172 maxLvlCrds);
1174 Value dimSz = rewriter.create<arith::AddIOp>(
1175 loc, maxDimCrd, constantOne(rewriter, loc, rewriter.getIndexType()));
1176 rewriter.replaceOp(op, dimSz);
1177 return success();
1181 struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
1182 using OpRewritePattern::OpRewritePattern;
1183 LogicalResult matchAndRewrite(ConcatenateOp op,
1184 PatternRewriter &rewriter) const override {
1185 if (op.needsExtraSort())
1186 op.emitError("ConcatenateOp not staged");
1188 const Location loc = op.getLoc();
1189 const auto dstTp = getSparseTensorType(op);
1190 const Dimension conDim = op.getDimension();
1191 SmallVector<Value> sizes;
1192 concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
1194 // %t = concatenate %s1, %s2, %s3 {dim = 1}
1195 // ==>
1196 // if (isSparseDst)
1197 // if (allDense)
1198 // %tmp = bufferization.alloc_tensor dstTp
1199 // else
1200 // %tmp = bufferization.alloc_tensor : unordered COO
1201 // else
1202 // %tmp = memref.alloc : dense tensor
1203 // foreach in %s1 : insert d0, d1, %tmp
1204 // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1205 // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1207 TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
1208 Value offset = constantIndex(rewriter, loc, 0);
1209 Value iterArg = dstBuf.val;
1211 ForeachOp foreachOp;
1212 for (Value input : op.getInputs()) {
1213 // Builds a for op for each input tensor to append new values into the
1214 // output tensor.
1215 foreachOp = rewriter.create<ForeachOp>(
1216 loc, input, iterArg,
1217 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1218 ValueRange reduc) {
1219 SmallVector<Value> offDimCrd(dcvs);
1220 offDimCrd[conDim] =
1221 builder.create<arith::AddIOp>(loc, offDimCrd[conDim], offset);
1223 // Enters foreach, updates the SSA chain.
1224 dstBuf.val = reduc.front();
1225 if (!dstTp.isAllDense()) {
1226 Value cond = genIsNonzero(builder, loc, v);
1227 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1228 /*else*/ true);
1229 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1230 builder.create<scf::YieldOp>(loc, dstBuf.val);
1232 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1233 dstBuf.insert(builder, loc, v, offDimCrd);
1234 builder.create<scf::YieldOp>(loc, dstBuf.val);
1236 // Exits the ifOp, update the sparse tensor SSA value.
1237 builder.setInsertionPointAfter(ifOp);
1238 dstBuf.val = ifOp.getResult(0);
1239 } else {
1240 dstBuf.insert(builder, loc, v, offDimCrd);
1242 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1244 // Accumulates the offset. Note that only static-shaped inputs are allowed
1245 // by concatenate op verifier, which saves us from computing the offset
1246 // dynamically.
1247 const Size sz = getSparseTensorType(input).getDynamicDimSize(conDim);
1248 assert(!ShapedType::isDynamic(sz));
1249 offset = rewriter.create<arith::AddIOp>(loc, offset,
1250 constantIndex(rewriter, loc, sz));
1251 iterArg = foreachOp.getResult(0);
1252 dstBuf.val = iterArg;
1255 dstBuf.val = iterArg;
1256 Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
1257 rewriter.replaceOp(op, ret);
1258 return success();
1262 struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1263 using OpRewritePattern::OpRewritePattern;
1264 LogicalResult matchAndRewrite(ConvertOp op,
1265 PatternRewriter &rewriter) const override {
1266 if (op.needsExtraSort())
1267 return op.emitError("ConvertOp not staged.");
1269 // TODO: Maybe we want a different operation for this too.
1270 auto encDst = getSparseTensorEncoding(op.getType());
1271 auto encSrc = getSparseTensorEncoding(op.getSource().getType());
1272 if (encDst && encSrc && !encSrc.isSlice() &&
1273 encSrc.withoutBitWidths() == encDst.withoutBitWidths()) {
1274 // Trivial tensor conversion and simple element type conversion is handled
1275 // in codegen.
1276 return failure();
1279 Location loc = op.getLoc();
1280 Value src = op.getSource();
1282 SparseTensorType srcStt = getSparseTensorType(op.getSource());
1283 SparseTensorType dstStt = getSparseTensorType(op.getDest());
1285 bool fromSparseConst = false;
1286 if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>())
1287 if (dyn_cast<SparseElementsAttr>(constOp.getValue()))
1288 fromSparseConst = true;
1290 const AffineMapAttr foreachOrder =
1291 (!dstStt.isIdentity() && fromSparseConst)
1292 ? AffineMapAttr::get(dstStt.getExpandedDimToLvl())
1293 : nullptr;
1295 bool skipZeroCheck = srcStt.hasEncoding() || fromSparseConst;
1297 SmallVector<Value> sizes;
1298 sizesFromSrc(rewriter, sizes, loc, src);
1299 ValueRange vs;
1300 TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1302 auto foreachOp = rewriter.create<ForeachOp>(
1303 loc, src, dstBuf.val, foreachOrder,
1304 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1305 ValueRange reduc) {
1306 // Enters the loop, update the SSA value for insertion chain.
1307 dstBuf.val = reduc.front();
1308 if (!skipZeroCheck) {
1309 Value cond = genIsNonzero(builder, loc, v);
1310 auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
1311 /*else*/ true);
1312 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1313 builder.create<scf::YieldOp>(loc, dstBuf.val);
1315 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1316 dstBuf.insert(builder, loc, v, dcvs);
1317 builder.create<scf::YieldOp>(loc, dstBuf.val);
1319 // Exits the ifOp, update the sparse tensor SSA value.
1320 builder.setInsertionPointAfter(ifOp);
1321 dstBuf.val = ifOp.getResult(0);
1322 } else {
1323 dstBuf.insert(builder, loc, v, dcvs);
1325 builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
1328 rewriter.setInsertionPointAfter(foreachOp);
1330 // Exits the for loop, links the SSA chain.
1331 dstBuf.val = foreachOp.getResult(0);
1333 Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
1334 rewriter.replaceOp(op, ret);
1335 return success();
1339 struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
1340 using OpRewritePattern::OpRewritePattern;
1341 LogicalResult matchAndRewrite(CrdTranslateOp op,
1342 PatternRewriter &rewriter) const override {
1343 AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
1344 ? op.getEncoder().getDimToLvl()
1345 : op.getEncoder().getLvlToDim();
1347 SmallVector<Value> outCrds;
1348 for (AffineExpr result : map.getResults()) {
1349 // TODO: we should probably expand the affine map to IR using our own
1350 // rules, since affine.apply assume signed value, while the cooridinates
1351 // we provided must always be signless.
1352 Value trans = rewriter.create<affine::AffineApplyOp>(
1353 op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
1354 op.getInCrds());
1355 outCrds.push_back(trans);
1357 rewriter.replaceOp(op, outCrds);
1358 return success();
1362 /// Sparse rewriting rule for the foreach operator.
1363 struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
1364 public:
1365 using OpRewritePattern::OpRewritePattern;
1367 LogicalResult matchAndRewrite(ForeachOp op,
1368 PatternRewriter &rewriter) const override {
1370 auto loc = op.getLoc();
1371 Value input = op.getTensor();
1372 SmallVector<Value> reduc = op.getInitArgs();
1373 const auto stt = getSparseTensorType(input);
1374 const Level lvlRank = stt.getLvlRank();
1376 // Special-case: for each over a sparse constant uses its own rewriting
1377 // rule.
1378 if (auto constOp = input.getDefiningOp<arith::ConstantOp>()) {
1379 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue())) {
1380 return genForeachOnSparseConstant(op, rewriter, attr);
1384 // Otherwise, use loop emitter to generate loops.
1385 const auto enc = stt.getEncoding();
1387 // 1. Generates loop for the sparse input.
1388 LoopEmitter loopEmitter(
1389 ValueRange{input},
1390 StringAttr::get(getContext(), ForeachOp::getOperationName()));
1391 loopEmitter.initializeLoopEmit(rewriter, loc);
1392 for (Level l = 0; l < lvlRank; l++) {
1393 // TODO: provide utility function for loop sequences that only contains
1394 // one for loop?
1395 const SmallVector<TensorLevel, 1> tidLvls{
1396 loopEmitter.makeTensorLevel(0, l)};
1397 loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
1398 // Note that reduc will be taken care of by loop emitter and get updated
1399 // in place.
1400 loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
1401 reduc);
1404 SmallVector<Value> lcvs = loopEmitter.getLoopIVs();
1405 if (op.getOrder()) {
1406 // TODO: Support it so that we can do direct conversion from CSR->BSR.
1407 llvm_unreachable(
1408 "Level order not yet implemented on non-constant input tensors.");
1411 Value vals = loopEmitter.getValBuffer()[0];
1412 SmallVector<Value> pos = loopEmitter.getValPosits(0);
1413 // Loads the value from sparse tensor using position-index;
1414 // loads the value from dense tensor using coords.
1415 Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
1416 : rewriter.create<memref::LoadOp>(loc, vals, lcvs);
1418 // 2. Inline the block in the foreach operator.
1419 Block *srcBlock = op.getBody();
1421 // Remap coordinates.
1422 SmallVector<Value> args =
1423 enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1425 // Remap value.
1426 args.push_back(val);
1427 // Remap reduction variables.
1428 args.append(reduc);
1430 // Remove sparse_tensor.yield.
1431 SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
1432 rewriter.eraseOp(srcBlock->getTerminator());
1434 Operation &last = rewriter.getBlock()->back();
1435 if (llvm::isa<scf::YieldOp>(last)) {
1436 // Because `scf.for` inserts an implicit yield op when there is no
1437 // reduction variable upon creation, we reset the insertion point such
1438 // that the block is inlined before *before* the yield op.
1439 rewriter.setInsertionPoint(&last);
1442 rewriter.inlineBlockBefore(srcBlock, rewriter.getBlock(),
1443 rewriter.getInsertionPoint(), args);
1444 rewriter.setInsertionPointToEnd(rewriter.getBlock());
1445 for (Level l = 0; l < lvlRank; l++) {
1446 // Link the reduction chain. Note that loop emitter update the reducValue
1447 // in place.
1448 loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
1449 loopEmitter.exitCurrentLoopSeq(rewriter, loc);
1452 // Replace the foreach operator with the value returned by the outtermost
1453 // for loop.
1454 rewriter.replaceOp(op, reducValue);
1455 return success();
1459 /// Sparse rewriting rule for the new operator.
1460 struct NewRewriter : public OpRewritePattern<NewOp> {
1461 using OpRewritePattern::OpRewritePattern;
1462 LogicalResult matchAndRewrite(NewOp op,
1463 PatternRewriter &rewriter) const override {
1464 Location loc = op.getLoc();
1465 auto stt = getSparseTensorType(op.getResult());
1466 if (!stt.hasEncoding() || stt.getAoSCOOStart() == 0)
1467 return failure();
1469 // Implement the NewOp as follows:
1470 // %orderedCoo = sparse_tensor.new %filename
1471 // %t = sparse_tensor.convert %orderedCoo
1472 // with enveloping reinterpreted_map ops for non-permutations.
1473 RankedTensorType dstTp = stt.getRankedTensorType();
1474 RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
1475 Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
1476 Value convert = cooTensor;
1477 auto enc = stt.getEncoding();
1478 if (!stt.isPermutation()) { // demap coo, demap dstTp
1479 auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
1480 convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
1481 dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
1483 convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
1484 if (!stt.isPermutation()) // remap to original enc
1485 convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
1486 rewriter.replaceOp(op, convert);
1488 // Release the temporary ordered COO tensor.
1489 rewriter.setInsertionPointAfterValue(convert);
1490 rewriter.create<DeallocTensorOp>(loc, cooTensor);
1492 return success();
1496 /// Sparse rewriting rule for the out operator.
1497 struct OutRewriter : public OpRewritePattern<OutOp> {
1498 using OpRewritePattern::OpRewritePattern;
1499 LogicalResult matchAndRewrite(OutOp op,
1500 PatternRewriter &rewriter) const override {
1501 Location loc = op.getLoc();
1502 // Calculate NNZ.
1503 Value src = op.getTensor();
1504 Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
1506 // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1507 const auto srcTp = getSparseTensorType(src);
1508 const Dimension dimRank = srcTp.getDimRank();
1509 Type indexTp = rewriter.getIndexType();
1510 Value dimSizes = genAlloca(rewriter, loc, dimRank, indexTp);
1512 // Generate code to calculate dimension size values and store the values to
1513 // the buffer.
1514 SmallVector<Value> dims;
1515 sizesForTensor(rewriter, dims, loc, srcTp, src);
1516 for (Dimension d = 0; d < dimRank; d++) {
1517 rewriter.create<memref::StoreOp>(loc, dims[d], dimSizes,
1518 constantIndex(rewriter, loc, d));
1521 // Create a sparse tensor writer and output meta data.
1522 Type opaqueTp = getOpaquePointerType(rewriter);
1523 Value writer =
1524 createFuncCall(rewriter, loc, "createSparseTensorWriter", {opaqueTp},
1525 {op.getDest()}, EmitCInterface::Off)
1526 .getResult(0);
1527 Value rankValue = constantIndex(rewriter, loc, dimRank);
1528 createFuncCall(rewriter, loc, "outSparseTensorWriterMetaData", {},
1529 {writer, rankValue, nnz, dimSizes}, EmitCInterface::On);
1531 Value dimCoords = dimSizes; // Reuse the dimSizes buffer for dimCoords.
1532 Type eltTp = srcTp.getElementType();
1533 SmallString<29> outNextFuncName{"outSparseTensorWriterNext",
1534 primaryTypeFunctionSuffix(eltTp)};
1535 Value value = genAllocaScalar(rewriter, loc, eltTp);
1536 ModuleOp module = op->getParentOfType<ModuleOp>();
1538 // For each element in the source tensor, output the element.
1539 rewriter.create<ForeachOp>(
1540 loc, src, std::nullopt,
1541 [&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1542 ValueRange reduc) {
1543 for (Dimension d = 0; d < dimRank; d++) {
1544 rewriter.create<memref::StoreOp>(loc, dcvs[d], dimCoords,
1545 constantIndex(builder, loc, d));
1547 rewriter.create<memref::StoreOp>(loc, v, value);
1548 SmallVector<Value> operands{writer, rankValue, dimCoords, value};
1549 FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands,
1550 EmitCInterface::On);
1551 builder.create<func::CallOp>(loc, TypeRange(), fn, operands);
1552 builder.create<sparse_tensor::YieldOp>(loc);
1555 // Release the writer.
1556 createFuncCall(rewriter, loc, "delSparseTensorWriter", {}, {writer},
1557 EmitCInterface::Off);
1559 rewriter.eraseOp(op);
1560 return success();
1564 } // namespace
1566 //===---------------------------------------------------------------------===//
1567 // Methods that add patterns described in this file to a pattern list.
1568 //===---------------------------------------------------------------------===//
1570 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
1571 patterns.add<FuseExtractSliceWithConcat, FoldConvertIntoProducer,
1572 FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
1573 GenSemiRingReduction, GenSemiRingSelect, PrintRewriter>(
1574 patterns.getContext());
1577 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
1578 bool enableRT,
1579 bool enableConvert) {
1580 patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1581 ReshapeRewriter<tensor::CollapseShapeOp>,
1582 Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
1583 Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
1584 SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
1585 patterns.getContext());
1587 if (enableConvert)
1588 patterns.add<DirectConvertRewriter>(patterns.getContext());
1589 if (!enableRT)
1590 patterns.add<NewRewriter>(patterns.getContext());
1593 void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
1594 // Run CrdTranslateRewriter later in the pipeline so that operation can be
1595 // folded before lowering to affine.apply
1596 patterns.add<CrdTranslateRewriter, ForeachRewriter>(patterns.getContext());