1 //===- SparseTensorRewriting.cpp - Sparse tensor rewriting rules ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements rewriting rules that are specific to sparse tensors.
11 //===----------------------------------------------------------------------===//
13 #include "Utils/CodegenUtils.h"
14 #include "Utils/LoopEmitter.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24 #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
26 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/Dialect/Vector/IR/VectorOps.h"
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/Support/LLVM.h"
34 using namespace mlir::bufferization
;
35 using namespace mlir::linalg
;
36 using namespace mlir::sparse_tensor
;
38 //===---------------------------------------------------------------------===//
39 // Helper methods for the actual rewriting rules.
40 //===---------------------------------------------------------------------===//
42 // Helper method to match any typed zero.
43 static bool isZeroValue(Value val
) {
44 return matchPattern(val
, m_Zero()) || matchPattern(val
, m_AnyZeroFloat());
47 // Helper to detect a sparse tensor type operand.
48 static bool isSparseTensor(Value v
) {
49 auto enc
= getSparseTensorEncoding(v
.getType());
50 return enc
&& !llvm::all_of(enc
.getLvlTypes(),
51 [](auto lt
) { return lt
== LevelFormat::Dense
; });
53 static bool isSparseTensor(OpOperand
*op
) { return isSparseTensor(op
->get()); }
55 // Helper method to find zero/uninitialized tensor materialization.
56 static bool isMaterializing(OpOperand
*op
, bool isZero
) {
57 Value val
= op
->get();
58 // Check allocation, with zero alloc when required.
59 if (auto alloc
= val
.getDefiningOp
<AllocTensorOp
>()) {
60 Value copy
= alloc
.getCopy();
62 return copy
&& isZeroValue(copy
);
65 // Check for empty tensor materialization.
66 if (auto empty
= val
.getDefiningOp
<tensor::EmptyOp
>())
68 // Last resort for zero alloc: the whole value is zero.
69 return isZero
&& isZeroValue(val
);
72 // Helper to detect sampling operation.
73 static bool isSampling(GenericOp op
) {
74 auto yieldOp
= cast
<linalg::YieldOp
>(op
.getRegion().front().getTerminator());
75 if (auto *def
= yieldOp
.getOperand(0).getDefiningOp()) {
76 if (isa
<arith::MulFOp
>(def
) || isa
<arith::MulIOp
>(def
)) {
77 // Both scalar input arguments used exactly once.
78 Value s1
= op
.getBlock()->getArgument(0);
79 Value s2
= op
.getBlock()->getArgument(1);
80 return (def
->getOperand(0) == s1
&& def
->getOperand(1) == s2
) ||
81 (def
->getOperand(1) == s1
&& def
->getOperand(0) == s2
);
87 // Helper to detect chain of multiplications that do not involve x.
88 static bool isMulChain(Value val
, Value x
) {
89 if (auto arg
= dyn_cast
<BlockArgument
>(val
))
91 if (auto *def
= val
.getDefiningOp()) {
92 if (isa
<arith::MulFOp
>(def
) || isa
<arith::MulIOp
>(def
))
93 return isMulChain(def
->getOperand(0), x
) &&
94 isMulChain(def
->getOperand(1), x
);
99 // Helper to detect x = x + <multiplications>.
100 static bool isSumOfMul(GenericOp op
) {
101 auto yieldOp
= cast
<linalg::YieldOp
>(op
.getRegion().front().getTerminator());
102 if (auto *def
= yieldOp
.getOperand(0).getDefiningOp()) {
103 if (isa
<arith::AddFOp
>(def
) || isa
<arith::AddIOp
>(def
)) {
104 Value x
= op
.getBlock()->getArguments().back();
105 return (def
->getOperand(0) == x
&& isMulChain(def
->getOperand(1), x
)) ||
106 (def
->getOperand(1) == x
&& isMulChain(def
->getOperand(0), x
));
112 // Helper to detect direct yield of a zero value.
113 static bool isZeroYield(GenericOp op
) {
114 auto yieldOp
= cast
<linalg::YieldOp
>(op
.getRegion().front().getTerminator());
115 if (auto arg
= dyn_cast
<BlockArgument
>(yieldOp
.getOperand(0))) {
116 if (arg
.getOwner()->getParentOp() == op
) {
117 return isZeroValue(op
->getOperand(arg
.getArgNumber()));
120 return isZeroValue(yieldOp
.getOperand(0));
123 /// Populates given sizes array from type (for static sizes) and from
124 /// the tensor (for dynamic sizes).
125 static void sizesForTensor(OpBuilder
&builder
, SmallVectorImpl
<Value
> &sizes
,
126 Location loc
, ShapedType stp
, Value tensor
) {
127 for (const auto &d
: enumerate(stp
.getShape())) {
129 if (d
.value() == ShapedType::kDynamic
)
130 dim
= builder
.create
<tensor::DimOp
>(loc
, tensor
, d
.index());
132 dim
= constantIndex(builder
, loc
, d
.value());
133 sizes
.push_back(dim
);
137 static RankedTensorType
getBufferType(const SparseTensorType
&stt
,
139 return needTmpCOO
? stt
.getCOOType(/*ordered=*/false)
140 : stt
.getRankedTensorType();
143 /// Collects the dynamic dimension sizes for `tp` with the assumption that
144 /// `sizes` are the dimension sizes for the type. Stores the dynamic dimension
145 /// sizes to dynSizes.
146 static void getDynamicSizes(RankedTensorType tp
, ValueRange sizes
,
147 SmallVectorImpl
<Value
> &dynSizes
) {
148 for (const auto &d
: enumerate(tp
.getShape())) {
149 if (d
.value() == ShapedType::kDynamic
)
150 dynSizes
.push_back(sizes
[d
.index()]);
154 static LogicalResult
genForeachOnSparseConstant(ForeachOp op
,
155 RewriterBase
&rewriter
,
156 SparseElementsAttr attr
) {
157 auto loc
= op
.getLoc();
158 SmallVector
<Value
> reduc
= op
.getInitArgs();
160 // Foreach on constant.
161 foreachInSparseConstant(
162 rewriter
, loc
, attr
, op
.getOrder().value_or(AffineMap()),
163 [&reduc
, &rewriter
, op
](ArrayRef
<Value
> cvs
, Value v
) mutable {
164 SmallVector
<Value
> args
;
165 args
.append(cvs
.begin(), cvs
.end());
168 // Clones the foreach op to get a copy of the loop body.
169 auto cloned
= cast
<ForeachOp
>(rewriter
.clone(*op
.getOperation()));
170 assert(args
.size() == cloned
.getBody()->getNumArguments());
171 Operation
*yield
= cloned
.getBody()->getTerminator();
172 rewriter
.inlineBlockBefore(cloned
.getBody(), op
, args
);
174 rewriter
.eraseOp(cloned
);
175 reduc
= yield
->getOperands();
176 rewriter
.eraseOp(yield
);
179 rewriter
.replaceOp(op
, reduc
);
183 /// Populates the given sizes array for concatenation from types (for static
184 /// sizes) and from the source tensors (for dynamic sizes).
185 static void concatSizesFromInputs(OpBuilder
&builder
,
186 SmallVectorImpl
<Value
> &sizes
, Location loc
,
187 ShapedType dstTp
, ValueRange srcs
,
189 auto dstShape
= dstTp
.getShape();
190 sizesFromSrc(builder
, sizes
, loc
, srcs
[0]);
192 // Sum up on the `dim` if the dimension is dynamic.
193 if (dstShape
[dim
] != ShapedType::kDynamic
) {
194 // Faithfully take the static size.
195 sizes
[dim
] = constantIndex(builder
, loc
, dstShape
[dim
]);
197 // Else, compute the shape dynamically.
198 for (const auto &src
: srcs
.drop_front()) {
199 Value srcSz
= linalg::createOrFoldDimOp(builder
, loc
, src
, dim
);
200 // Sum up all the sizes.
201 sizes
[dim
] = builder
.create
<arith::AddIOp
>(loc
, sizes
[dim
], srcSz
);
206 //===---------------------------------------------------------------------===//
207 // The actual sparse tensor rewriting rules.
208 //===---------------------------------------------------------------------===//
212 /// TODO: move it to tensor dialect instead.
214 /// Fold `tensor.concat` and `tensor.extract_slice`
216 /// %concat = tensor.concat dim(2) %t0, %t1
217 /// : (tensor<1x64x1xf32>, tensor<1x64x1xf32>) -> tensor<1x64x2xf32>
218 /// %extracted0 = tensor.extract_slice %concat[0, 0, 0][1, 64, 1][1, 1, 1]
219 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
220 /// %extracted1 = tensor.extract_slice %concat[0, 0, 1][1, 64, 1][1, 1, 1]
221 /// : tensor<1x64x2xf32> to tensor<1x64x1xf32>
225 /// %extract0, %extract1 = %t0, %t1
226 struct FuseExtractSliceWithConcat
227 : public OpRewritePattern
<tensor::ExtractSliceOp
> {
228 using OpRewritePattern
<tensor::ExtractSliceOp
>::OpRewritePattern
;
230 LogicalResult
matchAndRewrite(tensor::ExtractSliceOp extractOp
,
231 PatternRewriter
&rewriter
) const override
{
232 auto concatOp
= extractOp
.getSource().getDefiningOp
<tensor::ConcatOp
>();
236 Location loc
= extractOp
.getLoc();
237 int64_t dim
= concatOp
.getDim();
238 int64_t rank
= extractOp
.getResultType().getRank();
240 SmallVector
<OpFoldResult
> srcStrides(rank
, rewriter
.getIndexAttr(1));
241 SmallVector
<OpFoldResult
> srcOffsets(rank
, rewriter
.getIndexAttr(0));
243 // Compute the partial sums for the slice offsets.
244 AffineExpr sum
= rewriter
.getAffineDimExpr(0);
245 SmallVector
<AffineExpr
> partialSums
= {sum
};
246 SmallVector
<OpFoldResult
> offsetStrides
= {rewriter
.getIndexAttr(0)};
247 for (auto [idx
, input
] :
248 llvm::enumerate(concatOp
.getInputs().drop_back())) {
249 sum
= sum
+ rewriter
.getAffineDimExpr(idx
+ 1);
250 partialSums
.push_back(sum
);
251 offsetStrides
.push_back(
252 rewriter
.createOrFold
<tensor::DimOp
>(loc
, input
, dim
));
254 auto partialSumMap
= AffineMap::get(concatOp
.getInputs().size(), 0,
255 partialSums
, rewriter
.getContext());
256 SmallVector
<OpFoldResult
> dimOffsets
=
257 affine::makeComposedFoldedMultiResultAffineApply(
258 rewriter
, loc
, partialSumMap
, offsetStrides
);
260 auto allEqual
= [](ArrayRef
<OpFoldResult
> lhs
, ArrayRef
<OpFoldResult
> rhs
) {
261 for (auto [l
, r
] : llvm::zip(lhs
, rhs
)) {
262 std::optional
<int64_t> staticVal
= getConstantIntValue(l
);
263 if (!staticVal
.has_value() || staticVal
!= getConstantIntValue(r
))
266 return lhs
.size() == rhs
.size();
269 for (auto [i
, input
, offset
] :
270 llvm::enumerate(concatOp
.getInputs(), dimOffsets
)) {
271 SmallVector
<OpFoldResult
> srcSizes
=
272 tensor::getMixedSizes(rewriter
, loc
, input
);
273 srcOffsets
[dim
] = offset
;
275 SmallVector
<OpFoldResult
> dstSizes
= extractOp
.getMixedSizes();
276 SmallVector
<OpFoldResult
> dstOffsets
= extractOp
.getMixedOffsets();
277 SmallVector
<OpFoldResult
> dstStrides
= extractOp
.getMixedStrides();
279 if (allEqual(srcSizes
, dstSizes
) && allEqual(srcOffsets
, dstOffsets
) &&
280 allEqual(srcStrides
, dstStrides
)) {
281 Value operand
= concatOp
.getOperand(i
);
282 if (operand
.getType() == extractOp
.getResultType())
283 rewriter
.replaceOp(extractOp
, operand
);
292 /// Rewriting rule that fuses sparse_tensor.convert into producer.
293 struct FoldConvertIntoProducer
: public OpRewritePattern
<ConvertOp
> {
295 using OpRewritePattern::OpRewritePattern
;
297 LogicalResult
matchAndRewrite(ConvertOp op
,
298 PatternRewriter
&rewriter
) const override
{
299 auto producer
= op
.getSource().getDefiningOp
<GenericOp
>();
300 if (!producer
|| producer
.getDpsInits().size() != 1 ||
301 !isMaterializing(producer
.getDpsInitOperand(0), false) ||
302 !producer
.getResult(0).hasOneUse()) {
305 // Clone the materialization operation, but update the result to sparse.
306 rewriter
.setInsertionPoint(producer
);
307 Operation
*init
= producer
.getDpsInitOperand(0)->get().getDefiningOp();
308 Operation
*cloned
= rewriter
.clone(*init
);
309 cloned
->getResult(0).setType(op
.getResult().getType());
311 rewriter
.modifyOpInPlace(producer
, [&]() {
312 producer
.getDpsInitsMutable().assign(cloned
->getResults());
313 producer
.getResult(0).setType(op
.getResult().getType());
316 rewriter
.replaceAllOpUsesWith(op
, producer
);
323 /// Rewriting rule that converts direct yield of zero with initial allocation.
324 struct FoldInvariantYield
: public OpRewritePattern
<GenericOp
> {
326 using OpRewritePattern
<GenericOp
>::OpRewritePattern
;
328 LogicalResult
matchAndRewrite(GenericOp op
,
329 PatternRewriter
&rewriter
) const override
{
330 if (!op
.hasPureTensorSemantics() || op
.getNumResults() != 1 ||
331 !isMaterializing(op
.getDpsInitOperand(0), /*isZero=*/false) ||
332 !isZeroYield(op
) || !op
.getDpsInitOperand(0)->get().hasOneUse())
334 auto outputType
= getRankedTensorType(op
.getResult(0));
335 // Yielding zero on newly materialized sparse tensor can be
336 // optimized directly (regardless of dynamic or static size).
337 if (getSparseTensorEncoding(outputType
)) {
338 rewriter
.replaceOp(op
, op
.getDpsInitOperand(0)->get());
341 // Use static zero value directly instead of materialization.
342 if (!outputType
.hasStaticShape())
344 Operation
*def
= op
.getDpsInitOperand(0)->get().getDefiningOp();
345 rewriter
.replaceOp(op
, constantZero(rewriter
, op
.getLoc(), outputType
));
346 rewriter
.eraseOp(def
);
351 /// Rewriting rule that converts two kernels:
353 /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... )
354 /// X(i,j) = S(i,j) * T(i,j)
356 /// into a single kernel, using distributive law:
358 /// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... )
360 /// This kind of fusion (merging two ops into one but using arithmetic
361 /// equalities that may not hold for floating-point computations) would
362 /// be undesirable in the dense case, since we distribute the multiplication
363 /// into the reduction loop. However, for sparse sampling tensor S, such
364 /// a fusion may actually reduce the asymptotic complexity of the kernel,
365 /// since intermediate results may be nullified.
366 struct FuseSparseMultiplyOverAdd
: public OpRewritePattern
<GenericOp
> {
368 using OpRewritePattern
<GenericOp
>::OpRewritePattern
;
370 LogicalResult
matchAndRewrite(GenericOp op
,
371 PatternRewriter
&rewriter
) const override
{
373 if (!op
.hasPureTensorSemantics() || op
.getNumDpsInputs() != 2 ||
374 op
.getNumResults() != 1 ||
375 op
.getNumParallelLoops() != op
.getNumLoops() ||
376 !op
.getMatchingIndexingMap(op
.getDpsInitOperand(0)).isIdentity() ||
377 !op
.getMatchingIndexingMap(op
.getDpsInputOperand(0)).isIdentity() ||
378 !op
.getMatchingIndexingMap(op
.getDpsInputOperand(1)).isIdentity())
380 // Find consuming OP2(sparse, other) or OP2(other, sparse). The other
381 // operand can be sparse or dense, since the point of this rewriting rule
382 // is detecting a situation in which *more* sparsity is introduced into
383 // a computation, be it already sparse or still dense.
385 if (isSparseTensor(op
.getDpsInputOperand(0)))
387 else if (!isSparseTensor(op
.getDpsInputOperand(1)))
390 auto prod
= dyn_cast_or_null
<GenericOp
>(
391 op
.getDpsInputOperand(other
)->get().getDefiningOp());
392 if (!prod
|| !prod
.hasPureTensorSemantics() || prod
.getNumResults() != 1 ||
393 !prod
.getResult(0).hasOneUse())
395 // Sampling consumer and sum of multiplication chain producer.
396 if (!isMaterializing(op
.getDpsInitOperand(0), /*isZero=*/false) ||
397 !isMaterializing(prod
.getDpsInitOperand(0), /*isZero=*/true) ||
398 !isSampling(op
) || !isSumOfMul(prod
))
400 // Modify operand structure of producer and consumer.
401 Location loc
= prod
.getLoc();
402 SmallVector
<Value
> inputOps
= prod
.getInputs();
403 SmallVector
<Value
> outputOps
= op
.getOutputs();
404 SmallVector
<AffineMap
> fusedIndexMaps
= prod
.getIndexingMapsArray();
405 inputOps
.push_back(op
.getDpsInputOperand(1 - other
)->get());
406 fusedIndexMaps
.push_back(fusedIndexMaps
.back()); // mimic other
407 // Fuse producer and consumer into a new generic op.
408 auto fusedOp
= rewriter
.create
<GenericOp
>(
409 loc
, op
.getResult(0).getType(), inputOps
, outputOps
,
410 rewriter
.getAffineMapArrayAttr(fusedIndexMaps
), prod
.getIteratorTypes(),
411 /*doc=*/nullptr, /*library_call=*/nullptr);
412 Block
&prodBlock
= prod
.getRegion().front();
413 Block
&consBlock
= op
.getRegion().front();
415 Block
*fusedBlock
= rewriter
.createBlock(&fusedOp
.getRegion());
416 unsigned num
= prodBlock
.getNumArguments();
417 for (unsigned i
= 0; i
< num
- 1; i
++)
418 addArg(mapper
, fusedBlock
, prodBlock
.getArgument(i
));
419 addArg(mapper
, fusedBlock
, consBlock
.getArgument(1 - other
));
420 addArg(mapper
, fusedBlock
, prodBlock
.getArgument(num
- 1));
421 // Clone bodies of the producer and consumer in new evaluation order.
422 auto *acc
= prodBlock
.getTerminator()->getOperand(0).getDefiningOp();
423 auto *sampler
= consBlock
.getTerminator()->getOperand(0).getDefiningOp();
425 for (auto &op
: prodBlock
.without_terminator())
427 last
= op
.getResult(0);
428 rewriter
.clone(op
, mapper
);
430 mapper
.map(consBlock
.getArgument(other
), fusedBlock
->back().getResult(0));
431 mapper
.map(last
, rewriter
.clone(*sampler
, mapper
)->getResult(0));
432 last
= rewriter
.clone(*acc
, mapper
)->getResult(0);
433 rewriter
.create
<linalg::YieldOp
>(loc
, last
);
434 // Force initial value on merged allocation for dense outputs.
435 // TODO: deal with non alloc tensor here one day
436 if (!getSparseTensorEncoding(op
.getResult(0).getType())) {
437 Value init
= prod
.getDpsInitOperand(0)
439 .getDefiningOp
<AllocTensorOp
>()
442 op
.getDpsInitOperand(0)->get().getDefiningOp
<AllocTensorOp
>();
443 rewriter
.modifyOpInPlace(a
, [&]() { a
.getCopyMutable().assign(init
); });
445 // Replace consumer with fused operation. Old producer
446 // and consumer ops will be removed by DCE.
447 rewriter
.replaceOp(op
, fusedOp
->getResults());
452 // Helper to add argument and record the mapping.
453 static void addArg(IRMapping
&mapper
, Block
*b
, BlockArgument a
) {
454 mapper
.map(a
, b
->addArgument(a
.getType(), a
.getLoc()));
458 // Fuse a tensor cast into producing operation. Note that a tensor.cast
459 // should really not be used to convert between sparse encodings. Since
460 // the pattern currently appears as a result of some prior rewriting
461 // we make an attempt to repair very obvious cases.
462 // TODO: audit the pure tensor dialect rewriting rules
463 struct FuseTensorCast
: public OpRewritePattern
<tensor::CastOp
> {
465 using OpRewritePattern
<tensor::CastOp
>::OpRewritePattern
;
467 LogicalResult
matchAndRewrite(tensor::CastOp op
,
468 PatternRewriter
&rewriter
) const override
{
469 Type srcType
= op
.getSource().getType();
470 Type dstType
= op
.getDest().getType();
471 // A nop cast simply folds away.
472 if (srcType
== dstType
) {
473 rewriter
.replaceOp(op
, op
->getResults());
476 // See if a sparsity changing cast can be fused into producer.
477 if (tensor::isSameTypeWithoutEncoding(srcType
, dstType
)) {
478 if (Operation
*def
= op
.getSource().getDefiningOp()) {
479 if (def
->hasOneUse() && isa
<tensor::ExtractSliceOp
>(def
)) {
480 rewriter
.modifyOpInPlace(def
, [&]() {
481 def
->getResult(0).setType(op
->getResultTypes()[0]);
483 rewriter
.replaceOp(op
, def
->getResult(0));
488 // Repair tensor casts with at least one sparse operand into the
489 // the properly supported sparse_tensor.convert.
490 if (getSparseTensorEncoding(srcType
) || getSparseTensorEncoding(dstType
)) {
491 rewriter
.replaceOpWithNewOp
<ConvertOp
>(op
, dstType
, op
.getSource());
499 /// Rewrites a sequence of operations for sparse tensor selections in to
500 /// semi-ring operations such that they can be compiled correctly by the
501 /// sparsifier. E.g., transforming the following sequence
503 /// %sel = arith.select %cond, %sp1, %sp2
507 /// %sel = binary %sp1, %sp2:
508 /// both (%l, %r) {yield select %cond, %l, %r}
509 /// left (%l) {yield select %cond, %l, 0}
510 /// right (%r) {yield select %cond, 0, %r}
512 /// TODO: We require that the tensor used for extracting conditions to be dense
513 /// to sparsify the code. To support a sparse condition tensor, we need a
514 /// tri-nary operation.
515 struct GenSemiRingSelect
: public OpRewritePattern
<GenericOp
> {
517 using OpRewritePattern
<GenericOp
>::OpRewritePattern
;
518 LogicalResult
matchAndRewrite(GenericOp op
,
519 PatternRewriter
&rewriter
) const override
{
520 // Rejects non sparse kernels.
521 if (!op
.hasPureTensorSemantics() || !hasAnySparseOperand(op
))
524 Location loc
= op
.getLoc();
525 SmallVector
<std::pair
<Operation
*, sparse_tensor::BinaryOp
>> semiRings
;
526 for (Operation
&inst
: *op
.getBody()) {
528 auto matched
= isRewritablePattern(op
, &inst
);
529 if (!matched
.has_value())
532 rewriter
.setInsertionPoint(&inst
);
533 auto [c
, t
, f
] = matched
.value();
534 assert(t
.getType() == f
.getType());
535 auto selTp
= t
.getType();
536 auto c0
= constantZero(rewriter
, loc
, selTp
);
537 auto binOp
= rewriter
.create
<sparse_tensor::BinaryOp
>(loc
, selTp
, t
, f
);
538 // Initializes all the blocks.
539 rewriter
.createBlock(&binOp
.getOverlapRegion(), {}, {selTp
, selTp
},
540 {t
.getLoc(), f
.getLoc()});
541 rewriter
.createBlock(&binOp
.getRightRegion(), {}, selTp
, f
.getLoc());
542 rewriter
.createBlock(&binOp
.getLeftRegion(), {}, selTp
, t
.getLoc());
544 for (auto *r
: binOp
.getRegions()) {
545 Block
*b
= &r
->front();
546 rewriter
.setInsertionPointToStart(b
);
549 // Clones the cmp operations into the region to make the binary op
552 if (auto *def
= c
.getDefiningOp())
553 newC
= rewriter
.clone(*def
, irMap
)->getResult(0);
556 if (r
== &binOp
.getLeftRegion()) {
557 irMap
.map(t
, b
->getArgument(0));
559 } else if (r
== &binOp
.getRightRegion()) {
561 irMap
.map(f
, b
->getArgument(0));
563 irMap
.map(t
, b
->getArgument(0));
564 irMap
.map(f
, b
->getArgument(1));
566 auto y
= rewriter
.clone(inst
, irMap
)->getResult(0);
567 rewriter
.create
<sparse_tensor::YieldOp
>(loc
, y
);
570 // We successfully rewrited a operation. We can not do replacement here
571 // becuase it invalidate the iterator for the current loop to traverse
573 semiRings
.emplace_back(&inst
, binOp
);
576 // Finalizes the replacement.
577 for (auto [sel
, semi
] : semiRings
)
578 rewriter
.replaceOp(sel
, semi
->getResults());
580 return success(!semiRings
.empty());
584 static std::optional
<std::tuple
<Value
, BlockArgument
, BlockArgument
>>
585 isRewritablePattern(GenericOp op
, Operation
*v
) {
586 auto sel
= dyn_cast
<arith::SelectOp
>(v
);
590 auto tVal
= dyn_cast
<BlockArgument
>(sel
.getTrueValue());
591 auto fVal
= dyn_cast
<BlockArgument
>(sel
.getFalseValue());
592 // TODO: For simplicity, we only handle cases where both true/false value
593 // are directly loaded the input tensor. We can probably admit more cases
598 // Helper lambda to determine whether the value is loaded from a dense input
599 // or is a loop invariant.
600 auto isValFromDenseInputOrInvariant
= [&op
](Value v
) -> bool {
601 if (auto bArg
= dyn_cast
<BlockArgument
>(v
);
602 bArg
&& !isSparseTensor(op
.getDpsInputOperand(bArg
.getArgNumber())))
604 // If the value is defined outside the loop, it is a loop invariant.
605 return v
.getDefiningOp() && v
.getDefiningOp()->getBlock() != op
.getBody();
608 // If the condition value is load directly from a dense tensor or
609 // loop-invariants, we can sparsify the kernel.
610 auto cond
= sel
.getCondition();
611 if (isValFromDenseInputOrInvariant(cond
))
612 return std::make_tuple(cond
, tVal
, fVal
);
615 if (matchPattern(cond
, m_Op
<arith::CmpIOp
>(matchers::m_Any(&cmpL
),
616 matchers::m_Any(&cmpR
))) ||
617 matchPattern(cond
, m_Op
<arith::CmpFOp
>(matchers::m_Any(&cmpL
),
618 matchers::m_Any(&cmpR
)))) {
619 // TODO: we can do it recursively to check whether all the leaf values are
620 // loaded from dense tensors or are loop invariants.
621 if (isValFromDenseInputOrInvariant(cmpL
) ||
622 isValFromDenseInputOrInvariant(cmpR
))
623 return std::make_tuple(cond
, tVal
, fVal
);
630 /// Rewrites a sparse reduction that would not sparsify directly since
631 /// doing so would only iterate over the stored elements, ignoring the
632 /// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
633 /// (note that reductions like add/sub/or/xor can directly be sparsified
634 /// since the implicit zeros do not contribute to the final result).
635 /// Note that prod/and are still included since, even though they often
636 /// are nullified in sparse data, they may still occur for special
637 /// situations in which e.g. some rows in a sparse matrix are fully
638 /// dense. For min/max, including the implicit zeros is a much more
639 /// common situation.
641 /// TODO: this essentially "densifies" the operation; we want to implement
642 /// this much more efficiently by performing the reduction over the
643 /// stored values, and feed in the zero once if there were *any*
644 /// implicit zeros as well; but for now, at least we provide
645 /// the functionality
647 struct GenSemiRingReduction
: public OpRewritePattern
<GenericOp
> {
649 using OpRewritePattern
<GenericOp
>::OpRewritePattern
;
651 LogicalResult
matchAndRewrite(GenericOp op
,
652 PatternRewriter
&rewriter
) const override
{
653 // Reject non-reductions.
654 if (!op
.hasPureTensorSemantics() || op
.getNumDpsInputs() != 1 ||
655 op
.getNumReductionLoops() == 0 || op
.getNumResults() != 1)
657 auto *inp
= op
.getDpsInputOperand(0);
658 auto *init
= op
.getDpsInitOperand(0);
659 if (!isSparseTensor(inp
))
661 // Look for direct x = x OP y for semi-ring ready reductions.
662 auto *red
= cast
<linalg::YieldOp
>(op
.getRegion().front().getTerminator())
665 if (!isa
<arith::AndIOp
, arith::MulIOp
, arith::MulFOp
, arith::MinimumFOp
,
666 arith::MinSIOp
, arith::MinUIOp
, arith::MaximumFOp
, arith::MaxSIOp
,
667 arith::MaxUIOp
>(red
))
669 Value s0
= op
.getBlock()->getArgument(0);
670 Value s1
= op
.getBlock()->getArgument(1);
671 if ((red
->getOperand(0) != s0
|| red
->getOperand(1) != s1
) &&
672 (red
->getOperand(0) != s1
|| red
->getOperand(1) != s0
))
675 Location loc
= op
.getLoc();
677 rewriter
.create
<tensor::ExtractOp
>(loc
, init
->get(), ValueRange());
682 Type rtp
= s0
.getType();
683 rewriter
.setInsertionPointToStart(&op
.getRegion().front());
684 auto semiring
= rewriter
.create
<sparse_tensor::UnaryOp
>(loc
, rtp
, s0
);
686 rewriter
.createBlock(&semiring
.getPresentRegion(), {}, rtp
, loc
);
687 rewriter
.setInsertionPointToStart(&semiring
.getPresentRegion().front());
688 rewriter
.create
<sparse_tensor::YieldOp
>(loc
, present
->getArgument(0));
689 rewriter
.createBlock(&semiring
.getAbsentRegion(), {}, {}, {});
690 rewriter
.setInsertionPointToStart(&semiring
.getAbsentRegion().front());
692 rewriter
.create
<arith::ConstantOp
>(loc
, rewriter
.getZeroAttr(rtp
));
693 rewriter
.create
<sparse_tensor::YieldOp
>(loc
, zero
);
694 rewriter
.setInsertionPointAfter(semiring
);
696 // x = x REDUC y, identity
698 auto custom
= rewriter
.create
<sparse_tensor::ReduceOp
>(
699 loc
, rtp
, semiring
.getResult(), s1
, identity
);
701 rewriter
.createBlock(&custom
.getRegion(), {}, {rtp
, rtp
}, {loc
, loc
});
702 rewriter
.setInsertionPointToStart(&custom
.getRegion().front());
704 irMap
.map(red
->getOperand(0), region
->getArgument(0));
705 irMap
.map(red
->getOperand(1), region
->getArgument(1));
706 auto *cloned
= rewriter
.clone(*red
, irMap
);
707 rewriter
.create
<sparse_tensor::YieldOp
>(loc
, cloned
->getResult(0));
708 rewriter
.setInsertionPointAfter(custom
);
709 rewriter
.replaceOp(red
, custom
.getResult());
714 /// Sparse rewriting rule for the print operator. This operation is mainly used
715 /// for debugging and testing. As such, it lowers to the vector.print operation
716 /// which only require very light-weight runtime support.
717 struct PrintRewriter
: public OpRewritePattern
<PrintOp
> {
719 using OpRewritePattern::OpRewritePattern
;
720 LogicalResult
matchAndRewrite(PrintOp op
,
721 PatternRewriter
&rewriter
) const override
{
722 Location loc
= op
.getLoc();
723 auto tensor
= op
.getTensor();
724 auto stt
= getSparseTensorType(tensor
);
726 auto nse
= rewriter
.create
<NumberOfEntriesOp
>(loc
, tensor
);
727 rewriter
.create
<vector::PrintOp
>(
728 loc
, rewriter
.getStringAttr("---- Sparse Tensor ----\nnse = "));
729 rewriter
.create
<vector::PrintOp
>(loc
, nse
);
730 // Print run-time contents for dim/lvl sizes.
731 rewriter
.create
<vector::PrintOp
>(loc
, rewriter
.getStringAttr("dim = "));
732 printSizes(rewriter
, loc
, tensor
, stt
.getDimRank(), /*isDim=*/true);
733 rewriter
.create
<vector::PrintOp
>(loc
, rewriter
.getStringAttr("lvl = "));
734 printSizes(rewriter
, loc
, tensor
, stt
.getLvlRank(), /*isDim=*/false);
735 // Use the "codegen" foreach loop construct to iterate over
736 // all typical sparse tensor components for printing.
737 foreachFieldAndTypeInSparseTensor(stt
, [&rewriter
, &loc
, &tensor
,
738 &stt
](Type
, FieldIndex
,
739 SparseTensorFieldKind kind
,
740 Level l
, LevelType
) {
742 case SparseTensorFieldKind::StorageSpec
: {
745 case SparseTensorFieldKind::PosMemRef
: {
746 auto lvl
= constantIndex(rewriter
, loc
, l
);
747 rewriter
.create
<vector::PrintOp
>(loc
, rewriter
.getStringAttr("pos["));
748 rewriter
.create
<vector::PrintOp
>(
749 loc
, lvl
, vector::PrintPunctuation::NoPunctuation
);
750 rewriter
.create
<vector::PrintOp
>(loc
, rewriter
.getStringAttr("] : "));
751 auto pos
= rewriter
.create
<ToPositionsOp
>(loc
, tensor
, l
);
752 printContents(rewriter
, loc
, pos
);
755 case SparseTensorFieldKind::CrdMemRef
: {
756 auto lvl
= constantIndex(rewriter
, loc
, l
);
757 rewriter
.create
<vector::PrintOp
>(loc
, rewriter
.getStringAttr("crd["));
758 rewriter
.create
<vector::PrintOp
>(
759 loc
, lvl
, vector::PrintPunctuation::NoPunctuation
);
760 rewriter
.create
<vector::PrintOp
>(loc
, rewriter
.getStringAttr("] : "));
762 // For COO AoS storage, we want to print a single, linear view of
763 // the full coordinate storage at this level. For any other storage,
764 // we show the coordinate storage for every indivual level.
765 if (stt
.getAoSCOOStart() == l
)
766 crd
= rewriter
.create
<ToCoordinatesBufferOp
>(loc
, tensor
);
768 crd
= rewriter
.create
<ToCoordinatesOp
>(loc
, tensor
, l
);
769 printContents(rewriter
, loc
, crd
);
772 case SparseTensorFieldKind::ValMemRef
: {
773 rewriter
.create
<vector::PrintOp
>(loc
,
774 rewriter
.getStringAttr("values : "));
775 auto val
= rewriter
.create
<ToValuesOp
>(loc
, tensor
);
776 printContents(rewriter
, loc
, val
);
782 rewriter
.create
<vector::PrintOp
>(loc
, rewriter
.getStringAttr("----\n"));
783 rewriter
.eraseOp(op
);
788 // Helper to print contents of a single memref. For "push_back" vectors,
789 // we assume that the previous getters for pos/crd/val have added a
790 // slice-to-size view to make sure we just print the size and not the
793 // Generates code to print (1-dim or higher):
795 static void printContents(PatternRewriter
&rewriter
, Location loc
,
797 auto shape
= cast
<ShapedType
>(vec
.getType()).getShape();
798 SmallVector
<Value
> idxs
;
799 printContentsLevel(rewriter
, loc
, vec
, 0, shape
, idxs
);
800 rewriter
.create
<vector::PrintOp
>(loc
, vector::PrintPunctuation::NewLine
);
803 // Helper to the helper.
804 static void printContentsLevel(PatternRewriter
&rewriter
, Location loc
,
805 Value vec
, unsigned i
, ArrayRef
<int64_t> shape
,
806 SmallVectorImpl
<Value
> &idxs
) {
808 rewriter
.create
<vector::PrintOp
>(loc
, vector::PrintPunctuation::Open
);
809 // Generate for loop.
810 auto zero
= constantIndex(rewriter
, loc
, 0);
811 auto index
= constantIndex(rewriter
, loc
, i
);
812 auto size
= rewriter
.create
<memref::DimOp
>(loc
, vec
, index
);
813 auto step
= constantIndex(rewriter
, loc
, 1);
814 auto forOp
= rewriter
.create
<scf::ForOp
>(loc
, zero
, size
, step
);
815 idxs
.push_back(forOp
.getInductionVar());
816 rewriter
.setInsertionPointToStart(forOp
.getBody());
817 if (i
< shape
.size() - 1) {
818 // Enter deeper loop nest.
819 printContentsLevel(rewriter
, loc
, vec
, i
+ 1, shape
, idxs
);
821 // Actual contents printing.
822 auto val
= rewriter
.create
<memref::LoadOp
>(loc
, vec
, idxs
);
823 if (llvm::isa
<ComplexType
>(val
.getType())) {
824 // Since the vector dialect does not support complex types in any op,
825 // we split those into (real, imag) pairs here.
826 Value real
= rewriter
.create
<complex::ReOp
>(loc
, val
);
827 Value imag
= rewriter
.create
<complex::ImOp
>(loc
, val
);
828 rewriter
.create
<vector::PrintOp
>(loc
, vector::PrintPunctuation::Open
);
829 rewriter
.create
<vector::PrintOp
>(loc
, real
,
830 vector::PrintPunctuation::Comma
);
831 rewriter
.create
<vector::PrintOp
>(loc
, imag
,
832 vector::PrintPunctuation::Close
);
834 rewriter
.create
<vector::PrintOp
>(
835 loc
, val
, vector::PrintPunctuation::NoPunctuation
);
837 // Terminating comma (except at end).
838 auto bound
= rewriter
.create
<arith::AddIOp
>(loc
, idxs
.back(), step
);
839 Value cond
= rewriter
.create
<arith::CmpIOp
>(loc
, arith::CmpIPredicate::ne
,
841 scf::IfOp ifOp
= rewriter
.create
<scf::IfOp
>(loc
, cond
, /*else*/ false);
842 rewriter
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
843 rewriter
.create
<vector::PrintOp
>(loc
, vector::PrintPunctuation::Comma
);
846 rewriter
.setInsertionPointAfter(forOp
);
848 rewriter
.create
<vector::PrintOp
>(loc
, vector::PrintPunctuation::Close
);
851 // Helper method to print run-time lvl/dim sizes.
852 static void printSizes(PatternRewriter
&rewriter
, Location loc
, Value tensor
,
853 unsigned size
, bool isDim
) {
855 rewriter
.create
<vector::PrintOp
>(loc
, vector::PrintPunctuation::Open
);
856 // Print unrolled contents (dimop requires constant value).
857 for (unsigned i
= 0; i
< size
; i
++) {
858 auto idx
= constantIndex(rewriter
, loc
, i
);
861 val
= rewriter
.create
<tensor::DimOp
>(loc
, tensor
, idx
);
863 val
= rewriter
.create
<LvlOp
>(loc
, tensor
, idx
);
864 rewriter
.create
<vector::PrintOp
>(
866 i
!= size
- 1 ? vector::PrintPunctuation::Comma
867 : vector::PrintPunctuation::NoPunctuation
);
869 // Close bracket and end of line.
870 rewriter
.create
<vector::PrintOp
>(loc
, vector::PrintPunctuation::Close
);
871 rewriter
.create
<vector::PrintOp
>(loc
, vector::PrintPunctuation::NewLine
);
875 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
876 struct TensorReshapeRewriter
: public OpRewritePattern
<tensor::ReshapeOp
> {
878 using OpRewritePattern
<tensor::ReshapeOp
>::OpRewritePattern
;
880 LogicalResult
matchAndRewrite(tensor::ReshapeOp op
,
881 PatternRewriter
&rewriter
) const override
{
882 Location loc
= op
.getLoc();
883 Value srcTensor
= op
.getSource();
884 const auto srcTp
= tryGetSparseTensorType(srcTensor
);
885 const auto dstTp
= tryGetSparseTensorType(op
.getResult());
886 if (!srcTp
|| !dstTp
)
889 if (!srcTp
->hasEncoding() || !dstTp
->hasEncoding() ||
890 !dstTp
->hasStaticDimShape())
893 SmallVector
<Value
> srcSizes
;
894 sizesForTensor(rewriter
, srcSizes
, loc
, *srcTp
, srcTensor
);
895 SmallVector
<Value
> dstSizes
;
896 for (Dimension d
: dstTp
->getDimShape())
897 dstSizes
.push_back(constantIndex(rewriter
, loc
, d
));
899 Value nnz
= rewriter
.create
<NumberOfEntriesOp
>(loc
, srcTensor
);
900 // Only need an unordered COO buffer if input and output are not sorted
902 Type bufferTp
= getBufferType(
903 dstTp
->withoutDimToLvl(),
904 !srcTp
->isAllOrdered() || !srcTp
->isIdentity() || !dstTp
->isIdentity());
905 SmallVector
<Value
> dynSizes
;
906 Value buffer
= rewriter
907 .create
<AllocTensorOp
>(loc
, bufferTp
, dynSizes
, Value(),
911 // Convert src coordinates to dst coordinates by first collapsing it to 1D
912 // and then expand it to the match the rank of the destination tensor.
913 // Implemented as follows:
914 // foreach srcCoords %srcTensor
915 // collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
916 // expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
917 // insert expandedCoords, %buffer
919 // followed by an optional
920 // %t = sparse_tensor.cast %tmp
921 // depending on whether the input/output are sorted in the same way.
922 const auto encSrc
= srcTp
->getEncoding();
923 ForeachOp foreachOp
= rewriter
.create
<ForeachOp
>(
924 loc
, srcTensor
, buffer
,
925 [&](OpBuilder
&builder
, Location loc
, ValueRange srcLcvs
, Value v
,
927 const Dimension srcRank
= srcTp
->getDimRank();
928 SmallVector
<Value
> srcDcvs
;
929 srcDcvs
.reserve(srcRank
);
930 for (Dimension d
= 0; d
< srcRank
; d
++) {
931 Level lvl
= toLvl(encSrc
, d
);
932 srcDcvs
.push_back(srcLcvs
[lvl
]);
935 Value collapseSize
= constantIndex(builder
, loc
, 1);
936 for (Dimension d
= 0; d
< srcRank
; d
++)
938 builder
.create
<arith::MulIOp
>(loc
, collapseSize
, srcSizes
[d
]);
939 SmallVector
<Value
, 1> collapsedSizes
= {collapseSize
};
941 ReassociationIndices collapseIdx
;
942 for (Dimension i
= 0; i
< srcRank
; i
++)
943 collapseIdx
.push_back(i
);
944 SmallVector
<ReassociationIndices
, 1> collapseReass
= {collapseIdx
};
945 SmallVector
<Value
, 1> collapsedDcvs
;
946 reshapeCvs(builder
, loc
, collapseReass
, srcSizes
, srcDcvs
,
947 collapsedSizes
, collapsedDcvs
);
949 ReassociationIndices expandIdx
;
950 for (Dimension i
= 0; i
< dstTp
->getDimRank(); i
++)
951 expandIdx
.push_back(i
);
952 SmallVector
<ReassociationIndices
, 1> expandReass
= {expandIdx
};
953 SmallVector
<Value
> dstDcvs
;
954 reshapeCvs(builder
, loc
, expandReass
, collapsedSizes
, collapsedDcvs
,
958 builder
.create
<tensor::InsertOp
>(loc
, v
, reduc
.front(), dstDcvs
);
959 builder
.create
<sparse_tensor::YieldOp
>(loc
, t
);
962 Value t
= rewriter
.create
<LoadOp
>(loc
, foreachOp
.getResult(0), true);
963 if (bufferTp
!= *dstTp
) {
964 auto dstRTT
= dstTp
->getRankedTensorType();
965 Value converted
= rewriter
.create
<ConvertOp
>(loc
, dstRTT
, t
).getResult();
966 rewriter
.create
<DeallocTensorOp
>(loc
, t
);
969 rewriter
.replaceOp(op
, t
);
974 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
975 template <typename ReshapeOp
>
976 struct Sparse2SparseReshapeRewriter
: public OpRewritePattern
<ReshapeOp
> {
978 using OpRewritePattern
<ReshapeOp
>::OpRewritePattern
;
980 LogicalResult
matchAndRewrite(ReshapeOp op
,
981 PatternRewriter
&rewriter
) const override
{
982 Location loc
= op
.getLoc();
983 Value srcTensor
= op
.getSrc();
984 const auto srcTp
= getSparseTensorType(srcTensor
);
985 const auto dstTp
= getSparseTensorType(op
.getResult());
986 if (!srcTp
.hasEncoding() || !dstTp
.hasEncoding())
989 // Generate code to represent the static dimension constants or compute
990 // the dynamic dimension values.
991 SmallVector
<Value
> srcSizes
;
992 sizesForTensor(rewriter
, srcSizes
, loc
, srcTp
, srcTensor
);
993 SmallVector
<Value
> dstSizes
;
994 SmallVector
<Value
> dstDynSizes
;
995 if (dstTp
.hasStaticDimShape()) {
996 for (Dimension d
: dstTp
.getDimShape())
997 dstSizes
.push_back(constantIndex(rewriter
, loc
, d
));
999 ArrayRef
<Size
> dstShape
= dstTp
.getDimShape();
1000 genReshapeDstShape(rewriter
, loc
, dstSizes
, srcSizes
, dstShape
,
1001 op
.getReassociationIndices());
1002 for (auto [idx
, shape
] : llvm::enumerate(dstShape
)) {
1003 if (shape
== ShapedType::kDynamic
)
1004 dstDynSizes
.push_back(dstSizes
[idx
]);
1007 Value nnz
= rewriter
.create
<NumberOfEntriesOp
>(loc
, srcTensor
);
1008 // Only need a unordered COO buffer if input and output are not sorted
1010 Type bufferTp
= getBufferType(
1011 dstTp
.withoutDimToLvl(),
1012 !srcTp
.isAllOrdered() || !srcTp
.isIdentity() || !dstTp
.isIdentity());
1016 .create
<AllocTensorOp
>(loc
, bufferTp
, dstDynSizes
, Value(),
1017 /*sizeHint=*/nnz
, Attribute())
1020 // Implement the sparse2sparse reshape as follows:
1021 // foreach srcCoords %srcTensor
1022 // insert reshapeCvs(srcCoords), %buffer
1024 // followed by an optional
1025 // %t = sparse_tensor.cast %tmp
1026 // depending on whether the input/output are sorted in the same way.
1027 const auto encSrc
= srcTp
.getEncoding();
1028 ForeachOp foreachOp
= rewriter
.create
<ForeachOp
>(
1029 loc
, srcTensor
, buffer
,
1030 [&](OpBuilder
&builder
, Location loc
, ValueRange srcLcvs
, Value v
,
1032 const Dimension dimRank
= srcTp
.getDimRank();
1033 SmallVector
<Value
> srcDcvs
;
1034 srcDcvs
.reserve(dimRank
);
1035 for (Dimension d
= 0; d
< dimRank
; d
++) {
1036 Level lvl
= toLvl(encSrc
, d
);
1037 srcDcvs
.push_back(srcLcvs
[lvl
]);
1039 SmallVector
<Value
> dstDcvs
;
1040 reshapeCvs(builder
, loc
, op
.getReassociationIndices(), srcSizes
,
1041 srcDcvs
, dstSizes
, dstDcvs
);
1043 builder
.create
<tensor::InsertOp
>(loc
, v
, reduc
.front(), dstDcvs
);
1044 builder
.create
<sparse_tensor::YieldOp
>(loc
, t
);
1047 Value t
= rewriter
.create
<LoadOp
>(loc
, foreachOp
.getResult(0), true);
1048 if (bufferTp
!= dstTp
) {
1049 auto dstRTT
= dstTp
.getRankedTensorType();
1050 Value converted
= rewriter
.create
<ConvertOp
>(loc
, dstRTT
, t
).getResult();
1051 rewriter
.create
<DeallocTensorOp
>(loc
, t
);
1054 rewriter
.replaceOp(op
, t
);
1059 /// Sparse rewriting rule for sparse-to-dense and dense-to-sparse reshape
1061 template <typename ReshapeOp
>
1062 struct ReshapeRewriter
: public OpRewritePattern
<ReshapeOp
> {
1064 using OpRewritePattern
<ReshapeOp
>::OpRewritePattern
;
1066 LogicalResult
matchAndRewrite(ReshapeOp op
,
1067 PatternRewriter
&rewriter
) const override
{
1068 Location loc
= op
->getLoc();
1069 auto encDst
= getSparseTensorEncoding(op
.getResult().getType());
1070 auto encSrc
= getSparseTensorEncoding(op
.getSrc().getType());
1071 // Since a pure dense expansion is very cheap (change of view), for
1072 // a sparse2dense or dense2sparse, we can simply unfuse a sparse
1073 // conversion from the reshape operation itself.
1074 // All other cases are handled elsewhere.
1075 if (encDst
&& encSrc
) {
1079 auto rtp
= getRankedTensorType(op
.getSrc());
1081 RankedTensorType::get(rtp
.getShape(), rtp
.getElementType());
1082 auto convert
= rewriter
.create
<ConvertOp
>(loc
, denseTp
, op
.getSrc());
1083 rewriter
.modifyOpInPlace(op
, [&]() { op
->setOperand(0, convert
); });
1087 auto rtp
= getRankedTensorType(op
.getResult());
1089 RankedTensorType::get(rtp
.getShape(), rtp
.getElementType());
1091 if constexpr (std::is_same
<ReshapeOp
, tensor::ExpandShapeOp
>::value
) {
1092 reshape
= rewriter
.create
<ReshapeOp
>(
1093 loc
, denseTp
, op
.getSrc(), op
.getReassociation(),
1094 op
.getOutputShape(), op
.getStaticOutputShape());
1096 reshape
= rewriter
.create
<ReshapeOp
>(loc
, denseTp
, op
.getSrc(),
1097 op
.getReassociation());
1099 Value convert
= rewriter
.create
<ConvertOp
>(loc
, rtp
, reshape
);
1100 rewriter
.replaceOp(op
, convert
);
1107 // A trivial wrapper to help generate different operations for dense/sparse
1110 TensorLike(OpBuilder
&builder
, Location loc
, RankedTensorType rtt
,
1112 SmallVector
<Value
> dynSzs
;
1113 getDynamicSizes(rtt
, sizes
, dynSzs
);
1115 val
= builder
.create
<AllocTensorOp
>(loc
, rtt
, dynSzs
);
1117 Value c0
= constantZero(builder
, loc
, rtt
.getElementType());
1118 val
= builder
.create
<linalg::FillOp
>(loc
, c0
, val
).getResult(0);
1122 void insert(OpBuilder
&builder
, Location loc
, Value v
, ValueRange crds
) {
1123 val
= builder
.create
<tensor::InsertOp
>(loc
, v
, val
, crds
);
1126 Value
finalize(OpBuilder
&builder
, Location loc
, RankedTensorType rtp
) const {
1128 return builder
.create
<LoadOp
>(loc
, val
, true);
1132 bool isSparse() const {
1133 return getSparseTensorEncoding(val
.getType()) != nullptr;
1139 struct SparseTensorDimOpRewriter
: public OpRewritePattern
<tensor::DimOp
> {
1140 using OpRewritePattern::OpRewritePattern
;
1141 LogicalResult
matchAndRewrite(tensor::DimOp op
,
1142 PatternRewriter
&rewriter
) const override
{
1143 std::optional
<int64_t> dim
= op
.getConstantIndex();
1144 auto stt
= tryGetSparseTensorType(op
.getSource());
1145 if (!dim
|| !stt
|| !stt
->hasEncoding())
1148 if (stt
->isPermutation()) {
1149 rewriter
.replaceOpWithNewOp
<LvlOp
>(op
, op
.getSource(),
1150 toLvl(stt
->getEncoding(), *dim
));
1154 // Non-permutation dim2lvl/lvl2dim maps.
1155 // Compute as follows:
1156 // affine.apply #map (l0 - 1, l1 - 1, ...) + 1
1157 // Note that it is not the most efficient way (but a more general one) for
1158 // the lvl to dim translation, e.g., for BSR, the dimension size for can be
1159 // computed simply by lvl_size * block_size.
1160 Location loc
= op
.getLoc();
1161 SmallVector
<Value
> maxLvlCrds
;
1162 for (Level l
= 0; l
< stt
->getLvlRank(); l
++) {
1163 Value lvlSz
= rewriter
.create
<LvlOp
>(loc
, op
.getSource(), l
);
1164 Value maxLvlCrd
= rewriter
.create
<arith::SubIOp
>(
1165 loc
, lvlSz
, constantOne(rewriter
, loc
, rewriter
.getIndexType()));
1166 maxLvlCrds
.push_back(maxLvlCrd
);
1169 AffineExpr lvl2DimExp
= stt
->getLvlToDim().getResult(*dim
);
1170 Value maxDimCrd
= rewriter
.create
<affine::AffineApplyOp
>(
1171 op
.getLoc(), AffineMap::get(stt
->getLvlRank(), 0, lvl2DimExp
),
1174 Value dimSz
= rewriter
.create
<arith::AddIOp
>(
1175 loc
, maxDimCrd
, constantOne(rewriter
, loc
, rewriter
.getIndexType()));
1176 rewriter
.replaceOp(op
, dimSz
);
1181 struct ConcatenateRewriter
: public OpRewritePattern
<ConcatenateOp
> {
1182 using OpRewritePattern::OpRewritePattern
;
1183 LogicalResult
matchAndRewrite(ConcatenateOp op
,
1184 PatternRewriter
&rewriter
) const override
{
1185 if (op
.needsExtraSort())
1186 op
.emitError("ConcatenateOp not staged");
1188 const Location loc
= op
.getLoc();
1189 const auto dstTp
= getSparseTensorType(op
);
1190 const Dimension conDim
= op
.getDimension();
1191 SmallVector
<Value
> sizes
;
1192 concatSizesFromInputs(rewriter
, sizes
, loc
, dstTp
, op
.getInputs(), conDim
);
1194 // %t = concatenate %s1, %s2, %s3 {dim = 1}
1198 // %tmp = bufferization.alloc_tensor dstTp
1200 // %tmp = bufferization.alloc_tensor : unordered COO
1202 // %tmp = memref.alloc : dense tensor
1203 // foreach in %s1 : insert d0, d1, %tmp
1204 // foreach in %s2 : insert d0, d1 + size(s1), %tmp
1205 // foreach in %s3 : insert d0, d1 + size(s1) + size(s2), %tmp
1207 TensorLike
dstBuf(rewriter
, loc
, dstTp
.getRankedTensorType(), sizes
);
1208 Value offset
= constantIndex(rewriter
, loc
, 0);
1209 Value iterArg
= dstBuf
.val
;
1211 ForeachOp foreachOp
;
1212 for (Value input
: op
.getInputs()) {
1213 // Builds a for op for each input tensor to append new values into the
1215 foreachOp
= rewriter
.create
<ForeachOp
>(
1216 loc
, input
, iterArg
,
1217 [&](OpBuilder
&builder
, Location loc
, ValueRange dcvs
, Value v
,
1219 SmallVector
<Value
> offDimCrd(dcvs
);
1221 builder
.create
<arith::AddIOp
>(loc
, offDimCrd
[conDim
], offset
);
1223 // Enters foreach, updates the SSA chain.
1224 dstBuf
.val
= reduc
.front();
1225 if (!dstTp
.isAllDense()) {
1226 Value cond
= genIsNonzero(builder
, loc
, v
);
1227 auto ifOp
= builder
.create
<scf::IfOp
>(loc
, reduc
.getTypes(), cond
,
1229 builder
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
1230 builder
.create
<scf::YieldOp
>(loc
, dstBuf
.val
);
1232 builder
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
1233 dstBuf
.insert(builder
, loc
, v
, offDimCrd
);
1234 builder
.create
<scf::YieldOp
>(loc
, dstBuf
.val
);
1236 // Exits the ifOp, update the sparse tensor SSA value.
1237 builder
.setInsertionPointAfter(ifOp
);
1238 dstBuf
.val
= ifOp
.getResult(0);
1240 dstBuf
.insert(builder
, loc
, v
, offDimCrd
);
1242 builder
.create
<sparse_tensor::YieldOp
>(loc
, dstBuf
.val
);
1244 // Accumulates the offset. Note that only static-shaped inputs are allowed
1245 // by concatenate op verifier, which saves us from computing the offset
1247 const Size sz
= getSparseTensorType(input
).getDynamicDimSize(conDim
);
1248 assert(!ShapedType::isDynamic(sz
));
1249 offset
= rewriter
.create
<arith::AddIOp
>(loc
, offset
,
1250 constantIndex(rewriter
, loc
, sz
));
1251 iterArg
= foreachOp
.getResult(0);
1252 dstBuf
.val
= iterArg
;
1255 dstBuf
.val
= iterArg
;
1256 Value ret
= dstBuf
.finalize(rewriter
, loc
, dstTp
.getRankedTensorType());
1257 rewriter
.replaceOp(op
, ret
);
1262 struct DirectConvertRewriter
: public OpRewritePattern
<ConvertOp
> {
1263 using OpRewritePattern::OpRewritePattern
;
1264 LogicalResult
matchAndRewrite(ConvertOp op
,
1265 PatternRewriter
&rewriter
) const override
{
1266 if (op
.needsExtraSort())
1267 return op
.emitError("ConvertOp not staged.");
1269 // TODO: Maybe we want a different operation for this too.
1270 auto encDst
= getSparseTensorEncoding(op
.getType());
1271 auto encSrc
= getSparseTensorEncoding(op
.getSource().getType());
1272 if (encDst
&& encSrc
&& !encSrc
.isSlice() &&
1273 encSrc
.withoutBitWidths() == encDst
.withoutBitWidths()) {
1274 // Trivial tensor conversion and simple element type conversion is handled
1279 Location loc
= op
.getLoc();
1280 Value src
= op
.getSource();
1282 SparseTensorType srcStt
= getSparseTensorType(op
.getSource());
1283 SparseTensorType dstStt
= getSparseTensorType(op
.getDest());
1285 bool fromSparseConst
= false;
1286 if (auto constOp
= op
.getSource().getDefiningOp
<arith::ConstantOp
>())
1287 if (dyn_cast
<SparseElementsAttr
>(constOp
.getValue()))
1288 fromSparseConst
= true;
1290 const AffineMapAttr foreachOrder
=
1291 (!dstStt
.isIdentity() && fromSparseConst
)
1292 ? AffineMapAttr::get(dstStt
.getExpandedDimToLvl())
1295 bool skipZeroCheck
= srcStt
.hasEncoding() || fromSparseConst
;
1297 SmallVector
<Value
> sizes
;
1298 sizesFromSrc(rewriter
, sizes
, loc
, src
);
1300 TensorLike
dstBuf(rewriter
, loc
, dstStt
.getRankedTensorType(), sizes
);
1302 auto foreachOp
= rewriter
.create
<ForeachOp
>(
1303 loc
, src
, dstBuf
.val
, foreachOrder
,
1304 [&](OpBuilder
&builder
, Location loc
, ValueRange dcvs
, Value v
,
1306 // Enters the loop, update the SSA value for insertion chain.
1307 dstBuf
.val
= reduc
.front();
1308 if (!skipZeroCheck
) {
1309 Value cond
= genIsNonzero(builder
, loc
, v
);
1310 auto ifOp
= builder
.create
<scf::IfOp
>(loc
, reduc
.getTypes(), cond
,
1312 builder
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
1313 builder
.create
<scf::YieldOp
>(loc
, dstBuf
.val
);
1315 builder
.setInsertionPointToStart(&ifOp
.getThenRegion().front());
1316 dstBuf
.insert(builder
, loc
, v
, dcvs
);
1317 builder
.create
<scf::YieldOp
>(loc
, dstBuf
.val
);
1319 // Exits the ifOp, update the sparse tensor SSA value.
1320 builder
.setInsertionPointAfter(ifOp
);
1321 dstBuf
.val
= ifOp
.getResult(0);
1323 dstBuf
.insert(builder
, loc
, v
, dcvs
);
1325 builder
.create
<sparse_tensor::YieldOp
>(loc
, dstBuf
.val
);
1328 rewriter
.setInsertionPointAfter(foreachOp
);
1330 // Exits the for loop, links the SSA chain.
1331 dstBuf
.val
= foreachOp
.getResult(0);
1333 Value ret
= dstBuf
.finalize(rewriter
, loc
, dstStt
.getRankedTensorType());
1334 rewriter
.replaceOp(op
, ret
);
1339 struct CrdTranslateRewriter
: public OpRewritePattern
<CrdTranslateOp
> {
1340 using OpRewritePattern::OpRewritePattern
;
1341 LogicalResult
matchAndRewrite(CrdTranslateOp op
,
1342 PatternRewriter
&rewriter
) const override
{
1343 AffineMap map
= op
.getDirection() == CrdTransDirectionKind::dim2lvl
1344 ? op
.getEncoder().getDimToLvl()
1345 : op
.getEncoder().getLvlToDim();
1347 SmallVector
<Value
> outCrds
;
1348 for (AffineExpr result
: map
.getResults()) {
1349 // TODO: we should probably expand the affine map to IR using our own
1350 // rules, since affine.apply assume signed value, while the cooridinates
1351 // we provided must always be signless.
1352 Value trans
= rewriter
.create
<affine::AffineApplyOp
>(
1353 op
.getLoc(), AffineMap::get(map
.getNumDims(), 0, result
),
1355 outCrds
.push_back(trans
);
1357 rewriter
.replaceOp(op
, outCrds
);
1362 /// Sparse rewriting rule for the foreach operator.
1363 struct ForeachRewriter
: public OpRewritePattern
<ForeachOp
> {
1365 using OpRewritePattern::OpRewritePattern
;
1367 LogicalResult
matchAndRewrite(ForeachOp op
,
1368 PatternRewriter
&rewriter
) const override
{
1370 auto loc
= op
.getLoc();
1371 Value input
= op
.getTensor();
1372 SmallVector
<Value
> reduc
= op
.getInitArgs();
1373 const auto stt
= getSparseTensorType(input
);
1374 const Level lvlRank
= stt
.getLvlRank();
1376 // Special-case: for each over a sparse constant uses its own rewriting
1378 if (auto constOp
= input
.getDefiningOp
<arith::ConstantOp
>()) {
1379 if (auto attr
= dyn_cast
<SparseElementsAttr
>(constOp
.getValue())) {
1380 return genForeachOnSparseConstant(op
, rewriter
, attr
);
1384 // Otherwise, use loop emitter to generate loops.
1385 const auto enc
= stt
.getEncoding();
1387 // 1. Generates loop for the sparse input.
1388 LoopEmitter
loopEmitter(
1390 StringAttr::get(getContext(), ForeachOp::getOperationName()));
1391 loopEmitter
.initializeLoopEmit(rewriter
, loc
);
1392 for (Level l
= 0; l
< lvlRank
; l
++) {
1393 // TODO: provide utility function for loop sequences that only contains
1395 const SmallVector
<TensorLevel
, 1> tidLvls
{
1396 loopEmitter
.makeTensorLevel(0, l
)};
1397 loopEmitter
.enterNewLoopSeq(rewriter
, loc
, tidLvls
);
1398 // Note that reduc will be taken care of by loop emitter and get updated
1400 loopEmitter
.enterCoIterationOverTensorsAtLvls(rewriter
, loc
, tidLvls
, 1,
1404 SmallVector
<Value
> lcvs
= loopEmitter
.getLoopIVs();
1405 if (op
.getOrder()) {
1406 // TODO: Support it so that we can do direct conversion from CSR->BSR.
1408 "Level order not yet implemented on non-constant input tensors.");
1411 Value vals
= loopEmitter
.getValBuffer()[0];
1412 SmallVector
<Value
> pos
= loopEmitter
.getValPosits(0);
1413 // Loads the value from sparse tensor using position-index;
1414 // loads the value from dense tensor using coords.
1415 Value val
= enc
? rewriter
.create
<memref::LoadOp
>(loc
, vals
, pos
)
1416 : rewriter
.create
<memref::LoadOp
>(loc
, vals
, lcvs
);
1418 // 2. Inline the block in the foreach operator.
1419 Block
*srcBlock
= op
.getBody();
1421 // Remap coordinates.
1422 SmallVector
<Value
> args
=
1423 enc
.translateCrds(rewriter
, loc
, lcvs
, CrdTransDirectionKind::lvl2dim
);
1426 args
.push_back(val
);
1427 // Remap reduction variables.
1430 // Remove sparse_tensor.yield.
1431 SmallVector
<Value
> reducValue
= srcBlock
->getTerminator()->getOperands();
1432 rewriter
.eraseOp(srcBlock
->getTerminator());
1434 Operation
&last
= rewriter
.getBlock()->back();
1435 if (llvm::isa
<scf::YieldOp
>(last
)) {
1436 // Because `scf.for` inserts an implicit yield op when there is no
1437 // reduction variable upon creation, we reset the insertion point such
1438 // that the block is inlined before *before* the yield op.
1439 rewriter
.setInsertionPoint(&last
);
1442 rewriter
.inlineBlockBefore(srcBlock
, rewriter
.getBlock(),
1443 rewriter
.getInsertionPoint(), args
);
1444 rewriter
.setInsertionPointToEnd(rewriter
.getBlock());
1445 for (Level l
= 0; l
< lvlRank
; l
++) {
1446 // Link the reduction chain. Note that loop emitter update the reducValue
1448 loopEmitter
.exitCurrentLoop(rewriter
, loc
, reducValue
);
1449 loopEmitter
.exitCurrentLoopSeq(rewriter
, loc
);
1452 // Replace the foreach operator with the value returned by the outtermost
1454 rewriter
.replaceOp(op
, reducValue
);
1459 /// Sparse rewriting rule for the new operator.
1460 struct NewRewriter
: public OpRewritePattern
<NewOp
> {
1461 using OpRewritePattern::OpRewritePattern
;
1462 LogicalResult
matchAndRewrite(NewOp op
,
1463 PatternRewriter
&rewriter
) const override
{
1464 Location loc
= op
.getLoc();
1465 auto stt
= getSparseTensorType(op
.getResult());
1466 if (!stt
.hasEncoding() || stt
.getAoSCOOStart() == 0)
1469 // Implement the NewOp as follows:
1470 // %orderedCoo = sparse_tensor.new %filename
1471 // %t = sparse_tensor.convert %orderedCoo
1472 // with enveloping reinterpreted_map ops for non-permutations.
1473 RankedTensorType dstTp
= stt
.getRankedTensorType();
1474 RankedTensorType cooTp
= stt
.getCOOType(/*ordered=*/true);
1475 Value cooTensor
= rewriter
.create
<NewOp
>(loc
, cooTp
, op
.getSource());
1476 Value convert
= cooTensor
;
1477 auto enc
= stt
.getEncoding();
1478 if (!stt
.isPermutation()) { // demap coo, demap dstTp
1479 auto coo
= getSparseTensorType(cooTensor
).getEncoding().withoutDimToLvl();
1480 convert
= rewriter
.create
<ReinterpretMapOp
>(loc
, coo
, convert
);
1481 dstTp
= getSparseTensorType(convert
).withEncoding(enc
.withoutDimToLvl());
1483 convert
= rewriter
.create
<ConvertOp
>(loc
, dstTp
, convert
);
1484 if (!stt
.isPermutation()) // remap to original enc
1485 convert
= rewriter
.create
<ReinterpretMapOp
>(loc
, enc
, convert
);
1486 rewriter
.replaceOp(op
, convert
);
1488 // Release the temporary ordered COO tensor.
1489 rewriter
.setInsertionPointAfterValue(convert
);
1490 rewriter
.create
<DeallocTensorOp
>(loc
, cooTensor
);
1496 /// Sparse rewriting rule for the out operator.
1497 struct OutRewriter
: public OpRewritePattern
<OutOp
> {
1498 using OpRewritePattern::OpRewritePattern
;
1499 LogicalResult
matchAndRewrite(OutOp op
,
1500 PatternRewriter
&rewriter
) const override
{
1501 Location loc
= op
.getLoc();
1503 Value src
= op
.getTensor();
1504 Value nnz
= rewriter
.create
<NumberOfEntriesOp
>(loc
, src
);
1506 // Allocate a temporary buffer for storing dimension-sizes/coordinates.
1507 const auto srcTp
= getSparseTensorType(src
);
1508 const Dimension dimRank
= srcTp
.getDimRank();
1509 Type indexTp
= rewriter
.getIndexType();
1510 Value dimSizes
= genAlloca(rewriter
, loc
, dimRank
, indexTp
);
1512 // Generate code to calculate dimension size values and store the values to
1514 SmallVector
<Value
> dims
;
1515 sizesForTensor(rewriter
, dims
, loc
, srcTp
, src
);
1516 for (Dimension d
= 0; d
< dimRank
; d
++) {
1517 rewriter
.create
<memref::StoreOp
>(loc
, dims
[d
], dimSizes
,
1518 constantIndex(rewriter
, loc
, d
));
1521 // Create a sparse tensor writer and output meta data.
1522 Type opaqueTp
= getOpaquePointerType(rewriter
);
1524 createFuncCall(rewriter
, loc
, "createSparseTensorWriter", {opaqueTp
},
1525 {op
.getDest()}, EmitCInterface::Off
)
1527 Value rankValue
= constantIndex(rewriter
, loc
, dimRank
);
1528 createFuncCall(rewriter
, loc
, "outSparseTensorWriterMetaData", {},
1529 {writer
, rankValue
, nnz
, dimSizes
}, EmitCInterface::On
);
1531 Value dimCoords
= dimSizes
; // Reuse the dimSizes buffer for dimCoords.
1532 Type eltTp
= srcTp
.getElementType();
1533 SmallString
<29> outNextFuncName
{"outSparseTensorWriterNext",
1534 primaryTypeFunctionSuffix(eltTp
)};
1535 Value value
= genAllocaScalar(rewriter
, loc
, eltTp
);
1536 ModuleOp module
= op
->getParentOfType
<ModuleOp
>();
1538 // For each element in the source tensor, output the element.
1539 rewriter
.create
<ForeachOp
>(
1540 loc
, src
, std::nullopt
,
1541 [&](OpBuilder
&builder
, Location loc
, ValueRange dcvs
, Value v
,
1543 for (Dimension d
= 0; d
< dimRank
; d
++) {
1544 rewriter
.create
<memref::StoreOp
>(loc
, dcvs
[d
], dimCoords
,
1545 constantIndex(builder
, loc
, d
));
1547 rewriter
.create
<memref::StoreOp
>(loc
, v
, value
);
1548 SmallVector
<Value
> operands
{writer
, rankValue
, dimCoords
, value
};
1549 FlatSymbolRefAttr fn
= getFunc(module
, outNextFuncName
, {}, operands
,
1550 EmitCInterface::On
);
1551 builder
.create
<func::CallOp
>(loc
, TypeRange(), fn
, operands
);
1552 builder
.create
<sparse_tensor::YieldOp
>(loc
);
1555 // Release the writer.
1556 createFuncCall(rewriter
, loc
, "delSparseTensorWriter", {}, {writer
},
1557 EmitCInterface::Off
);
1559 rewriter
.eraseOp(op
);
1566 //===---------------------------------------------------------------------===//
1567 // Methods that add patterns described in this file to a pattern list.
1568 //===---------------------------------------------------------------------===//
1570 void mlir::populatePreSparsificationRewriting(RewritePatternSet
&patterns
) {
1571 patterns
.add
<FuseExtractSliceWithConcat
, FoldConvertIntoProducer
,
1572 FoldInvariantYield
, FuseSparseMultiplyOverAdd
, FuseTensorCast
,
1573 GenSemiRingReduction
, GenSemiRingSelect
, PrintRewriter
>(
1574 patterns
.getContext());
1577 void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet
&patterns
,
1579 bool enableConvert
) {
1580 patterns
.add
<ConcatenateRewriter
, ReshapeRewriter
<tensor::ExpandShapeOp
>,
1581 ReshapeRewriter
<tensor::CollapseShapeOp
>,
1582 Sparse2SparseReshapeRewriter
<tensor::ExpandShapeOp
>,
1583 Sparse2SparseReshapeRewriter
<tensor::CollapseShapeOp
>,
1584 SparseTensorDimOpRewriter
, TensorReshapeRewriter
, OutRewriter
>(
1585 patterns
.getContext());
1588 patterns
.add
<DirectConvertRewriter
>(patterns
.getContext());
1590 patterns
.add
<NewRewriter
>(patterns
.getContext());
1593 void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet
&patterns
) {
1594 // Run CrdTranslateRewriter later in the pipeline so that operation can be
1595 // folded before lowering to affine.apply
1596 patterns
.add
<CrdTranslateRewriter
, ForeachRewriter
>(patterns
.getContext());