1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "Utils/CodegenUtils.h"
10 #include "Utils/IterationGraphSorter.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/AffineExprVisitor.h"
21 #include "mlir/IR/AffineMap.h"
24 using namespace mlir::sparse_tensor
;
28 //===----------------------------------------------------------------------===//
29 // File Local Helper classes.
30 //===----------------------------------------------------------------------===//
32 // CRTP to help implementing a rewriter that demaps all its inputs.
33 template <typename SubClass
, typename SourceOp
>
34 struct DemapInsRewriter
: public OpRewritePattern
<SourceOp
> {
35 using OpRewritePattern
<SourceOp
>::OpRewritePattern
;
36 using OpAdaptor
= typename
SourceOp::Adaptor
;
38 LogicalResult
matchAndRewrite(SourceOp op
,
39 PatternRewriter
&rewriter
) const override
{
40 Location loc
= op
.getLoc();
42 // Demaps non-trivial inputs.
44 SmallVector
<Value
> deMappedIns(op
->getOperands());
45 for (Value
&in
: deMappedIns
) {
46 if (auto stt
= tryGetSparseTensorType(in
); stt
&& !stt
->isIdentity()) {
47 in
= rewriter
.create
<ReinterpretMapOp
>(loc
, stt
->getDemappedType(), in
);
53 OpAdaptor
adaptor(deMappedIns
, op
);
54 LogicalResult status
=
55 static_cast<const SubClass
*>(this)->rewriteOp(op
, adaptor
, rewriter
);
56 return changed
? success() : status
;
60 // Flattens an affine expression into a list of AffineDimExprs.
61 struct AffineDimCollector
: public AffineExprVisitor
<AffineDimCollector
> {
62 explicit AffineDimCollector(unsigned dimNum
) : dims(dimNum
){};
63 void visitDimExpr(AffineDimExpr expr
) { dims
.set(expr
.getPosition()); }
67 // Flattens an affine expression into a list of AffineDimExprs.
68 struct AffineExprAdmissibleVisitor
69 : public AffineExprVisitor
<AffineExprAdmissibleVisitor
> {
70 explicit AffineExprAdmissibleVisitor(bool isOutput
)
71 : admissible(true), isOutput(isOutput
){};
73 // We only allow AffineDimExpr on output.
74 void visitAddExpr(AffineBinaryOpExpr expr
) {
78 void visitMulExpr(AffineBinaryOpExpr expr
) {
83 // We disallow mod, floor div and ceil div on inputs.
84 void visitModExpr(AffineBinaryOpExpr expr
) { admissible
= false; }
85 void visitFloorDivExpr(AffineBinaryOpExpr expr
) { admissible
= false; }
86 void visitCeilDivExpr(AffineBinaryOpExpr expr
) { admissible
= false; }
87 operator bool() { return admissible
; }
94 // The first BitVector stores levels where inadmissible exprs are used.
95 // The second BitVector stores the AffineDimExp that are used by the
96 // inadmissible expressions.
97 using InadmissInfo
= std::pair
<BitVector
, BitVector
>;
101 //===----------------------------------------------------------------------===//
102 // File Local Helper methods.
103 //===----------------------------------------------------------------------===//
105 // Collects the inadmissible affine expression imposed on levels.
106 static InadmissInfo
collectInadmissInfo(AffineMap map
, bool isOutput
) {
107 auto ret
= std::make_pair(BitVector(map
.getNumResults()),
108 BitVector(map
.getNumDims()));
109 AffineDimCollector
collector(map
.getNumDims());
110 for (unsigned lvl
= 0, e
= map
.getNumResults(); lvl
< e
; lvl
++) {
111 AffineExprAdmissibleVisitor
admissible(isOutput
);
112 admissible
.walkPostOrder(map
.getResult(lvl
));
114 // Record the inadmissible level.
116 // Record the AffineDimExpr that is used in the inadmissible expr.
117 collector
.walkPostOrder(map
.getResult(lvl
));
120 ret
.second
= collector
.dims
;
124 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
125 // inadmissible affine expressions can be eliminated.
126 // For example, we can rewrite
127 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
129 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
130 // by composing inverse(idxMap), that is
131 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
132 // -> ((l0 * 2 + l2) floordiv 2,
133 // (l1 * 3 + l3) floordiv 3,
134 // (l0 * 2 + l2) mod 2,
135 // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
137 // This function builds the inverse(idxMap) that replace every dimensions used
138 // in `info` to levels, and updates the iterator type array `itTps` for the new
139 // index variable introduced.
141 // Note that the returned affine map does not retain the order of the input
142 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
143 // replaced levels, and remaining ones for unused dimensions.
144 // For example, to handle
145 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
146 // which is a typical map for block_2to4. The function returns:
147 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
148 // in which, (l0, l1) together replaces `d1`, yet they appear
149 // before `d0` in the resulting affine map.
150 // The index (loop) order can later be canonicalized by a topo sort.
152 genReplaceDimToLvlMap(const InadmissInfo
&info
, AffineMap idxMap
,
153 SmallVector
<utils::IteratorType
> &itTps
) {
154 MLIRContext
*ctx
= idxMap
.getContext();
155 auto [inAdLvls
, usedDims
] = info
;
156 // Note that idxMap does not equal to dim2Lvl map, it is computed by
157 // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
159 // TODO: we might fail here, in those case we should really return
160 // failure instead of assertion error.
161 auto lvl2Idx
= inferLvlToDim(idxMap
, ctx
);
163 assert(lvl2Idx
.getNumResults() <= idxMap
.getNumDims());
164 if (lvl2Idx
.getNumResults() != idxMap
.getNumDims()) {
165 // This could happen when some dimensions are projected.
166 // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
167 // ==> lvl2Idx = (j, k) -> (j, k)
168 // In this case, we append the unused dimesion at the end.
169 // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170 SmallVector
<AffineExpr
> results
;
171 AffineDimCollector
usedInLvl(idxMap
.getNumDims());
172 for (auto e
: idxMap
.getResults())
173 usedInLvl
.walkPostOrder(e
);
175 unsigned curUsedDimID
= 0;
176 unsigned curUnusedDimID
= lvl2Idx
.getNumDims();
178 BitVector unused
= usedInLvl
.dims
.flip();
179 for (unsigned i
= 0; i
< idxMap
.getNumDims(); i
++) {
181 results
.push_back(getAffineDimExpr(curUnusedDimID
++, ctx
));
183 results
.push_back(lvl2Idx
.getResult(curUsedDimID
++));
186 AffineMap::get(lvl2Idx
.getNumDims() + unused
.count(), 0, results
, ctx
);
188 assert(lvl2Idx
.getNumResults() == idxMap
.getNumDims());
190 // We do not need to replace the DimExpr that is not used in inadmissible
191 // level expressions. We use the first inAdLvl.count() dim to represent the
192 // replaced level, the remainings are reserved for unchanged ones.
193 // Note that results from the inverse map computed previously does not follow
194 // the convention we used, and we need to fix the mismatch below.
195 unsigned curRepID
= 0;
196 unsigned curOriID
= inAdLvls
.count();
197 SmallVector
<AffineExpr
> results
;
198 SmallVector
<AffineExpr
> dimRep(idxMap
.getNumResults(), AffineExpr());
199 SmallVector
<utils::IteratorType
> transItTps
;
201 for (unsigned l
: inAdLvls
.set_bits()) {
202 // By our convention, the inadmissible level `l` always appears in the
203 // leading part (accumulated by curRepID) of the affine map's parameter
204 // list. Record the mapping so that we can replace all the uses of `l` to
205 // the correct position after the translation.
206 dimRep
[l
] = getAffineDimExpr(curRepID
++, ctx
);
207 // A new index variable is introduced for the inadmissible level, inherit
208 // the iterator type. E.g., if l0 = d0 floordiv 2, the
209 // iterator type of l0 equals to the iterator type of d0.
210 AffineExpr lvlExp
= idxMap
.getResult(l
);
211 AffineDimCollector
collector(idxMap
.getNumDims());
212 collector
.walkPostOrder(lvlExp
);
213 // We assumes a level can only be derived from one dimension.
214 assert(collector
.dims
.count() == 1);
215 transItTps
.push_back(itTps
[collector
.dims
.find_first()]);
218 for (unsigned d
= 0, e
= idxMap
.getNumDims(); d
< e
; d
++) {
219 if (usedDims
.test(d
)) {
220 // The dimension is used in some of the inadmissible levels, and it need
221 // to be inversed. Get the inversion from the inverse map, and fix the
222 // mismatch captured by the above loop.
223 results
.push_back(lvl2Idx
.getResult(d
).replaceDims(dimRep
));
225 // The dimension is not used in any of the inadmissible levels, and it
226 // does not need to be inversed. Fix the mismatch by mapping it to the
227 // trailing part of the affine map (accumulated by curOriID).
228 results
.push_back(getAffineDimExpr(curOriID
++, ctx
));
229 transItTps
.push_back(itTps
[d
]);
232 unsigned numDim
= idxMap
.getNumDims() - usedDims
.count() + inAdLvls
.count();
233 // Update iterator type.
234 itTps
.assign(transItTps
.begin(), transItTps
.end());
235 return AffineMap::get(numDim
, 0, results
, ctx
);
238 // Translates the index map in the linalg::GenericOp from idx->dim map to
239 // idx->lvl map. Returns failure if the index map can not be translated to an
241 // Returns the translated index map array and the iterator type array.
242 static std::optional
<std::pair
<ArrayAttr
, ArrayAttr
>>
243 translateMap(linalg::GenericOp op
, PatternRewriter
&rewriter
) {
244 // idxMap is a idx2dim map before reinterpretation.
245 MLIRContext
*ctx
= op
.getContext();
246 SmallVector
<AffineMap
> idxMapArray
= op
.getIndexingMapsArray();
247 SmallVector
<utils::IteratorType
> itTps
= op
.getIteratorTypesArray();
248 for (unsigned i
= 0, e
= idxMapArray
.size(); i
< e
; i
++) {
249 Value tensor
= op
->getOpOperand(i
).get();
250 auto stt
= tryGetSparseTensorType(tensor
);
251 if (stt
&& !stt
->isIdentity()) {
252 AffineMap dim2Lvl
= stt
->getDimToLvl();
253 // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
254 idxMapArray
[i
] = dim2Lvl
.compose(idxMapArray
[i
]);
258 // A naive way to handle common constant expressions that arise during dim2lvl
260 auto populateCstMapping
= [ctx
](DenseMap
<AffineExpr
, AffineExpr
> &cstMapping
,
261 unsigned pos
, int64_t lvlSz
) {
262 if (!ShapedType::isDynamic(lvlSz
)) {
263 auto c0
= getAffineConstantExpr(0, ctx
);
264 auto lvlExp
= getAffineDimExpr(pos
, ctx
);
265 auto szExp
= getAffineConstantExpr(lvlSz
, ctx
);
267 // lvl floordiv lvlSz = 0
269 getAffineBinaryOpExpr(AffineExprKind::FloorDiv
, lvlExp
, szExp
);
270 cstMapping
.try_emplace(divExp
, c0
);
272 // lvl mod lvlSz = lvl
273 auto modExp
= getAffineBinaryOpExpr(AffineExprKind::Mod
, lvlExp
, szExp
);
274 cstMapping
.try_emplace(modExp
, lvlExp
);
278 unsigned boundedNum
= 0;
279 // A fixed-point algorithm.
283 for (OpOperand
&operand
: op
->getOpOperands()) {
284 auto stt
= tryGetSparseTensorType(operand
.get());
285 // Skip on dense operands.
286 if (!stt
|| !stt
->getEncoding())
289 unsigned tid
= operand
.getOperandNumber();
290 bool isOutput
= &operand
== op
.getDpsInitOperand(0);
291 AffineMap idxMap
= idxMapArray
[tid
];
292 InadmissInfo inAdInfo
= collectInadmissInfo(idxMap
, isOutput
);
293 auto [inAdLvls
, dimExprs
] = inAdInfo
;
294 for (unsigned d
: dimExprs
.set_bits()) {
295 // The first `boundedNum` used in the AffineMap is introduced to
296 // resolve previous inadmissible expressions. We can not replace them
297 // as it might bring back the inadmissible expressions.
302 if (inAdLvls
.count() != 0) {
303 // Naive constant progagation, should be sufficient to handle block
304 // sparsity in our cases.
305 SmallVector
<int64_t> lvlShape
= stt
->getLvlShape();
306 DenseMap
<AffineExpr
, AffineExpr
> cstMapping
;
307 unsigned position
= 0;
308 for (unsigned lvl
: inAdLvls
.set_bits()) {
309 int64_t lvlSz
= lvlShape
[lvl
];
310 populateCstMapping(cstMapping
, position
, lvlSz
);
314 AffineMap lvl2Idx
= genReplaceDimToLvlMap(inAdInfo
, idxMap
, itTps
);
315 // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
316 // inadmissible expressions.
317 for (unsigned tid
= 0, e
= idxMapArray
.size(); tid
< e
; tid
++) {
318 AffineMap transMap
= idxMapArray
[tid
].compose(lvl2Idx
);
319 idxMapArray
[tid
] = transMap
.replace(
320 cstMapping
, /*numResultDims=*/transMap
.getNumDims(),
321 /*numResultSyms=*/0);
324 boundedNum
+= inAdLvls
.count();
329 SmallVector
<Attribute
> iterAttr
=
330 llvm::map_to_vector(itTps
, [ctx
](auto itTp
) -> Attribute
{
331 return linalg::IteratorTypeAttr::get(ctx
, itTp
);
334 return std::make_pair(rewriter
.getAffineMapArrayAttr(idxMapArray
),
335 rewriter
.getArrayAttr(iterAttr
));
338 // Generates a "de"mapping reinterpretation of the map.
339 static Value
genDemap(OpBuilder
&builder
, SparseTensorEncodingAttr enc
,
341 return builder
.create
<ReinterpretMapOp
>(val
.getLoc(), enc
.withoutDimToLvl(),
345 // Generates a "re"mapping reinterpretation of the map.
346 static Value
genRemap(OpBuilder
&builder
, SparseTensorEncodingAttr enc
,
348 return builder
.create
<ReinterpretMapOp
>(val
.getLoc(), enc
, val
);
351 static SmallVector
<Value
> remapValueRange(OpBuilder
&rewriter
, TypeRange types
,
353 SmallVector
<Value
> ret(outs
);
354 assert(outs
.size() == types
.size());
355 for (auto [r
, t
] : llvm::zip(ret
, types
))
356 if (r
.getType() != t
)
357 r
= rewriter
.create
<ReinterpretMapOp
>(r
.getLoc(), t
, r
);
363 //===----------------------------------------------------------------------===//
364 // Rewriting rules for linalg generic ops.
365 //===----------------------------------------------------------------------===//
367 /// Sparse rewriting rule for the generic `linalg` operation.
368 struct GenericOpReinterpretMap
369 : public DemapInsRewriter
<GenericOpReinterpretMap
, linalg::GenericOp
> {
371 using DemapInsRewriter::DemapInsRewriter
;
372 LogicalResult
rewriteOp(linalg::GenericOp linalgOp
, OpAdaptor adaptor
,
373 PatternRewriter
&rewriter
) const {
374 // Only rewrite single output operations with pure (sparse) tensor
376 if (linalgOp
.getNumDpsInits() != 1 || !linalgOp
.hasPureTensorSemantics() ||
377 !hasAnySparseOperandOrResult(linalgOp
) ||
378 !hasAnyNonIdentityOperandsOrResults(linalgOp
))
381 // Try translating the index map.
382 auto transMap
= translateMap(linalgOp
, rewriter
);
384 return rewriter
.notifyMatchFailure(
385 linalgOp
, "the sparse kernel can not be sparsified.");
387 // On success, replace update the linalg operands and maps in place.
388 Value res
= linalgOp
.getResult(0);
389 auto stt
= tryGetSparseTensorType(res
);
390 auto [idxMap
, itTp
] = *transMap
;
392 rewriter
.startOpModification(linalgOp
);
393 linalgOp
.setIndexingMapsAttr(idxMap
);
394 linalgOp
.setIteratorTypesAttr(itTp
);
395 // Use demapped arguments.
396 linalgOp
.getInputsMutable().assign(adaptor
.getInputs());
397 linalgOp
.getDpsInitsMutable().assign(adaptor
.getOutputs());
398 res
.setType(adaptor
.getOutputs()[0].getType());
399 rewriter
.finalizeOpModification(linalgOp
);
401 rewriter
.setInsertionPointAfter(linalgOp
);
402 if (stt
&& stt
->hasEncoding()) {
403 Value t
= genRemap(rewriter
, stt
->getEncoding(), res
);
404 rewriter
.replaceAllUsesExcept(res
, t
, t
.getDefiningOp());
410 struct GenericOpScheduler
: public OpRewritePattern
<linalg::GenericOp
> {
411 using OpRewritePattern::OpRewritePattern
;
412 LogicalResult
matchAndRewrite(linalg::GenericOp linalgOp
,
413 PatternRewriter
&rewriter
) const override
{
414 if (linalgOp
.getNumDpsInits() != 1 || !linalgOp
.hasPureTensorSemantics() ||
415 hasAnyNonIdentityOperandsOrResults(linalgOp
) || // need demap first
416 !hasAnySparseOperandOrResult(linalgOp
)) {
420 const StringRef sorted
= "sorted";
421 if (linalgOp
->hasAttr(sorted
))
424 auto scheduler
= IterationGraphSorter::fromGenericOp(linalgOp
);
425 bool isAdmissible
= false;
427 // A const list of all masks that we used for iteration graph
428 // computation. Must be ordered from more strict to less strict.
429 // Ideally (though might not be guaranteed), the earlier a constraint mask
430 // can be satisfied, the faster the generated kernel will be.
431 const auto allMasks
= {SortMask::kIncludeAll
, SortMask::kIncludeDense
,
432 SortMask::kIncludeDenseInput
,
433 SortMask::kIncludeDenseOutput
,
434 SortMask::kSparseOnly
};
435 for (const SortMask mask
: allMasks
) {
436 order
= scheduler
.sort(mask
);
438 if (isAdmissibleOrder(linalgOp
, order
)) {
442 // else try a set of less strict constraints.
448 if (failed(resolveCycle(scheduler
, linalgOp
, rewriter
))) {
449 return rewriter
.notifyMatchFailure(
450 linalgOp
, "the sparse kernel can not be scheduled: loop detected.");
456 return rewriter
.notifyMatchFailure(
457 linalgOp
, "the sparse kernel can not be scheduled.");
460 // Marks the GenericOp to avoid recursive matching.
461 rewriter
.modifyOpInPlace(linalgOp
, [&]() {
462 linalgOp
->setAttr(sorted
, rewriter
.getBoolAttr(true));
466 if (order
.isIdentity())
469 assert(order
.isPermutation());
470 // `order` is orignial loop -> sorted loop map
471 ArrayAttr preItTypes
= linalgOp
.getIteratorTypesAttr();
472 SmallVector
<Attribute
> curItTypes
;
473 curItTypes
.reserve(preItTypes
.size());
474 for (AffineExpr expr
: order
.getResults()) {
475 unsigned loopID
= llvm::cast
<AffineDimExpr
>(expr
).getPosition();
476 curItTypes
.push_back(preItTypes
[loopID
]);
479 // Inverse `order` to get sorted loop -> original loop map
480 order
= inversePermutation(order
);
481 SmallVector
<AffineMap
> idxMaps
= linalgOp
.getIndexingMapsArray();
482 for (AffineMap
&idxMap
: idxMaps
)
483 idxMap
= idxMap
.compose(order
); // sorted loop -> lvl map
485 rewriter
.startOpModification(linalgOp
);
486 linalgOp
.setIndexingMapsAttr(rewriter
.getAffineMapArrayAttr(idxMaps
));
487 linalgOp
.setIteratorTypesAttr(rewriter
.getArrayAttr(curItTypes
));
488 rewriter
.finalizeOpModification(linalgOp
);
494 /// Whether the loop order is admissible by sparsification.
495 static bool isAdmissibleOrder(linalg::GenericOp linalgOp
, AffineMap order
) {
496 if (!hasAnySparseResult(linalgOp
))
499 OpOperand
*lhs
= linalgOp
.getDpsInitOperand(0);
501 const auto iteratorTypes
= linalgOp
.getIteratorTypesArray();
502 for (const AffineExpr l
: order
.getResults()) {
503 unsigned loopId
= llvm::cast
<AffineDimExpr
>(l
).getPosition();
505 cast
<linalg::IteratorTypeAttr
>(linalgOp
.getIteratorTypes()[loopId
]);
506 if (linalg::isReductionIterator(itTp
.getValue()))
507 break; // terminate at first reduction
510 // Determine admissible dynamic insertion situations:
511 // (1) fully injective, since there are no reductions,
512 // (2) admissible 1-d expansion in innermost dimension.
513 return static_cast<int64_t>(nest
) >= linalgOp
.getRank(lhs
) - 1;
516 // Last resort cycle resolution.
517 static LogicalResult
resolveCycle(IterationGraphSorter
&scheduler
,
518 linalg::LinalgOp linalgOp
,
519 PatternRewriter
&rewriter
) {
520 // Compute topological sort while leaving out every sparse input tensor in
521 // succession until an acylic iteration graph results.
522 for (OpOperand
*t
: linalgOp
.getDpsInputOperands()) {
523 Value tval
= t
->get();
524 auto srcEnc
= getSparseTensorEncoding(tval
.getType());
525 // The constraints introduced by compound index expression are
526 // complicated. Skip them.
527 AffineMap idxMap
= linalgOp
.getMatchingIndexingMap(t
);
528 bool hasCompExpr
= llvm::any_of(idxMap
.getResults(), [](AffineExpr exp
) {
529 return !llvm::isa
<AffineDimExpr
>(exp
);
531 if (!srcEnc
|| hasCompExpr
)
534 // Try scheduling loop without constraints from `tval`.
535 AffineMap order
= scheduler
.sort(SortMask::kSparseOnly
, tval
);
536 if (!order
) // still cyclic
539 // Found an input tensor that resolves the cycle by inserting a
540 // conversion into a sparse tensor that adheres to the iteration
542 auto stt
= getSparseTensorType(tval
);
543 assert(stt
.isIdentity());
544 order
= inversePermutation(order
);
545 // sorted loop -> lvl map.
546 idxMap
= idxMap
.compose(order
);
548 // Found a permutation such that the results in `idxMap` is sorted.
550 // (d0, d1, d2, d3) -> (d2, d1, d0)
551 // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
552 // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
553 // transposed tensor's levels are visited in the same order as the loop
555 SmallVector
<std::pair
<unsigned, unsigned>> lvlSeq
;
556 for (AffineExpr expr
: idxMap
.getResults()) {
557 unsigned lvl
= llvm::cast
<AffineDimExpr
>(expr
).getPosition();
558 lvlSeq
.push_back(std::make_pair(lvl
, lvlSeq
.size()));
560 llvm::sort(lvlSeq
, llvm::less_first());
561 SmallVector
<unsigned> perm
=
562 llvm::to_vector(llvm::make_second_range(lvlSeq
));
563 auto dimToLvl
= AffineMap::getPermutationMap(perm
, linalgOp
.getContext());
564 // The result of the idxMap must be unsorted.
565 assert(!dimToLvl
.isIdentity());
567 // Inserting the transpose
568 rewriter
.setInsertionPoint(linalgOp
);
569 RankedTensorType dstTp
= stt
.withDimToLvl(dimToLvl
).getRankedTensorType();
570 Value dst
= rewriter
.create
<ConvertOp
>(tval
.getLoc(), dstTp
, tval
);
571 rewriter
.modifyOpInPlace(linalgOp
, [&]() {
572 linalgOp
->setOperand(t
->getOperandNumber(), dst
);
575 // Release the transposed form afterwards.
576 // TODO: CSE when used in more than one following op?
577 rewriter
.setInsertionPointAfter(linalgOp
);
578 rewriter
.create
<bufferization::DeallocTensorOp
>(dst
.getLoc(), dst
);
582 // Cannot be resolved with a single conversion.
583 // TODO: convert more than one?
588 //===----------------------------------------------------------------------===//
589 // Reinterpret Map Rewriters for operations other than linalg.generics
590 //===----------------------------------------------------------------------===//
592 template <typename AllocOp
>
593 struct TensorAllocDemapper
: public OpRewritePattern
<AllocOp
> {
594 using OpRewritePattern
<AllocOp
>::OpRewritePattern
;
595 LogicalResult
matchAndRewrite(AllocOp op
,
596 PatternRewriter
&rewriter
) const override
{
597 if (!hasAnyNonIdentityOperandsOrResults(op
))
600 Location loc
= op
.getLoc();
601 auto stt
= getSparseTensorType(op
.getResult());
603 SmallVector
<Value
> maxDimCrds
;
604 maxDimCrds
.reserve(stt
.getDimRank());
605 ValueRange dynSz
= op
.getDynamicSizes();
606 for (int64_t dimSz
: stt
.getDimShape()) {
607 if (ShapedType::isDynamic(dimSz
)) {
608 Value maxCrd
= rewriter
.create
<arith::SubIOp
>(
609 loc
, dynSz
.front(), constantIndex(rewriter
, loc
, 1));
610 maxDimCrds
.push_back(maxCrd
);
611 dynSz
= dynSz
.drop_front();
613 maxDimCrds
.push_back(constantIndex(rewriter
, loc
, dimSz
- 1));
617 ValueRange maxLvlCrds
= stt
.translateCrds(rewriter
, loc
, maxDimCrds
,
618 CrdTransDirectionKind::dim2lvl
);
619 auto lvlShape
= stt
.getLvlShape();
620 SmallVector
<Value
> dynLvlSzs
;
621 for (unsigned i
= 0, e
= lvlShape
.size(); i
< e
; i
++) {
622 if (ShapedType::isDynamic(lvlShape
[i
])) {
623 Value sz
= rewriter
.create
<arith::AddIOp
>(
624 loc
, maxLvlCrds
[i
], constantIndex(rewriter
, loc
, 1));
625 dynLvlSzs
.push_back(sz
);
629 assert(dynSz
.empty()); // should have consumed all.
630 rewriter
.startOpModification(op
);
631 op
->setOperands(dynLvlSzs
);
632 op
.getResult().setType(stt
.getDemappedType());
633 rewriter
.finalizeOpModification(op
);
634 rewriter
.setInsertionPointAfter(op
);
636 Value t
= genRemap(rewriter
, stt
.getEncoding(), op
.getResult());
637 rewriter
.replaceAllUsesExcept(op
.getResult(), t
, t
.getDefiningOp());
642 struct TensorInsertDemapper
643 : public DemapInsRewriter
<TensorInsertDemapper
, tensor::InsertOp
> {
644 using DemapInsRewriter::DemapInsRewriter
;
645 LogicalResult
rewriteOp(tensor::InsertOp op
, OpAdaptor adaptor
,
646 PatternRewriter
&rewriter
) const {
647 if (!hasAnySparseResult(op
) || !hasAnyNonIdentityOperandsOrResults(op
))
650 Location loc
= op
.getLoc();
651 auto stt
= getSparseTensorType(op
.getResult());
652 ValueRange lvlCrd
= stt
.translateCrds(rewriter
, loc
, op
.getIndices(),
653 CrdTransDirectionKind::dim2lvl
);
654 auto insertOp
= rewriter
.create
<tensor::InsertOp
>(
655 loc
, op
.getScalar(), adaptor
.getDest(), lvlCrd
);
657 Value out
= genRemap(rewriter
, stt
.getEncoding(), insertOp
.getResult());
658 rewriter
.replaceOp(op
, out
);
663 struct SparseAssembleDemapper
: public OpRewritePattern
<AssembleOp
> {
664 using OpRewritePattern::OpRewritePattern
;
665 LogicalResult
matchAndRewrite(AssembleOp op
,
666 PatternRewriter
&rewriter
) const override
{
667 if (!hasAnyNonIdentityOperandsOrResults(op
))
670 assert(hasAnySparseResult(op
));
671 auto stt
= getSparseTensorType(op
.getResult());
672 rewriter
.modifyOpInPlace(
673 op
, [&op
, &stt
]() { op
.getResult().setType(stt
.getDemappedType()); });
674 rewriter
.setInsertionPointAfter(op
);
675 Value out
= genRemap(rewriter
, stt
.getEncoding(), op
.getResult());
676 rewriter
.replaceAllUsesExcept(op
, out
, out
.getDefiningOp());
681 struct SparseDisassembleDemapper
682 : public DemapInsRewriter
<SparseDisassembleDemapper
, DisassembleOp
> {
683 using DemapInsRewriter::DemapInsRewriter
;
684 LogicalResult
rewriteOp(DisassembleOp op
, OpAdaptor adaptor
,
685 PatternRewriter
&rewriter
) const {
686 if (!hasAnyNonIdentityOperandsOrResults(op
))
689 assert(hasAnySparseOperandOrResult(op
));
690 rewriter
.modifyOpInPlace(op
, [&op
, &adaptor
]() {
691 op
.getTensorMutable().assign(adaptor
.getTensor());
697 struct ForeachOpDemapper
698 : public DemapInsRewriter
<ForeachOpDemapper
, ForeachOp
> {
699 using DemapInsRewriter::DemapInsRewriter
;
700 LogicalResult
rewriteOp(ForeachOp op
, OpAdaptor adaptor
,
701 PatternRewriter
&rewriter
) const {
702 // Only handle operations with sparse input/output with non-identity dim2lvl
704 if (!hasAnyNonIdentityOperandsOrResults(op
))
707 // TODO: demap constant as well.
708 if (auto constOp
= op
.getTensor().getDefiningOp
<arith::ConstantOp
>())
709 if (auto attr
= dyn_cast
<SparseElementsAttr
>(constOp
.getValue()))
712 Location loc
= op
.getLoc();
713 // Cache the type information since we update the foreach op in-place.
714 auto srcStt
= getSparseTensorType(op
.getTensor());
715 SmallVector
<Type
> prevRetTps(op
.getResultTypes());
717 rewriter
.startOpModification(op
);
718 op
.getTensorMutable().assign(adaptor
.getTensor());
719 op
.getInitArgsMutable().assign(adaptor
.getInitArgs());
720 // Update results' types.
721 for (auto r
: op
.getResults())
722 if (auto stt
= tryGetSparseTensorType(r
); stt
&& !stt
->isIdentity())
723 r
.setType(stt
->getDemappedType());
725 Level lvlRank
= getSparseTensorType(adaptor
.getTensor()).getLvlRank();
726 // Update the foreach body.
727 SmallVector
<Type
> blockArgTps(lvlRank
, rewriter
.getIndexType());
728 blockArgTps
.push_back(srcStt
.getElementType());
729 blockArgTps
.append(adaptor
.getInitArgs().getTypes().begin(),
730 adaptor
.getInitArgs().getTypes().end());
731 Block
*body
= op
.getBody();
732 // Block Args: [dimCrd, val, initArgs]
733 unsigned preArgNum
= body
->getNumArguments();
734 for (Type t
: blockArgTps
)
735 body
->addArgument(t
, loc
);
737 // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
738 rewriter
.setInsertionPointToStart(body
);
739 ValueRange lvlCrds
= body
->getArguments().slice(preArgNum
, lvlRank
);
741 ValueRange dimCrds
= srcStt
.translateCrds(rewriter
, loc
, lvlCrds
,
742 CrdTransDirectionKind::lvl2dim
);
743 rewriter
.replaceAllUsesWith(
744 body
->getArguments().take_front(srcStt
.getDimRank()), dimCrds
);
745 body
->eraseArguments(0, srcStt
.getDimRank());
746 // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
747 unsigned numInitArgs
= op
.getInitArgs().size();
748 rewriter
.replaceAllUsesWith(body
->getArgument(0),
749 body
->getArgument(lvlRank
+ numInitArgs
+ 1));
750 body
->eraseArgument(0);
751 // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
752 ValueRange srcArgs
= body
->getArguments().take_front(numInitArgs
);
753 ValueRange dstArgs
= body
->getArguments().take_back(numInitArgs
);
754 // Remap back before replacement.
755 SmallVector
<Value
> reMappedArgs
=
756 remapValueRange(rewriter
, srcArgs
.getTypes(), dstArgs
);
757 rewriter
.replaceAllUsesWith(srcArgs
, reMappedArgs
);
758 body
->eraseArguments(0, numInitArgs
);
759 // Block Args: [lvlCrds, DemappedArgs] and we are done.
761 // Update yield operations.
762 if (numInitArgs
!= 0) {
763 rewriter
.setInsertionPointToEnd(body
);
764 auto yield
= llvm::cast
<YieldOp
>(body
->getTerminator());
765 if (auto stt
= tryGetSparseTensorType(yield
.getSingleResult());
766 stt
&& !stt
->isIdentity()) {
768 genDemap(rewriter
, stt
->getEncoding(), yield
.getSingleResult());
769 rewriter
.create
<YieldOp
>(loc
, y
);
770 rewriter
.eraseOp(yield
);
773 rewriter
.finalizeOpModification(op
);
775 rewriter
.setInsertionPointAfter(op
);
776 SmallVector
<Value
> outs
=
777 remapValueRange(rewriter
, prevRetTps
, op
.getResults());
779 // Replace all the uses of the foreach results, expect the use in
780 // reinterpret_map used to remap the output.
781 for (auto [from
, to
] : llvm::zip(op
.getResults(), outs
))
782 rewriter
.replaceAllUsesExcept(from
, to
, to
.getDefiningOp());
790 void mlir::populateSparseReinterpretMap(RewritePatternSet
&patterns
,
791 ReinterpretMapScope scope
) {
792 if (scope
== ReinterpretMapScope::kAll
||
793 scope
== ReinterpretMapScope::kGenericOnly
) {
794 patterns
.add
<GenericOpReinterpretMap
, GenericOpScheduler
>(
795 patterns
.getContext());
797 if (scope
== ReinterpretMapScope::kAll
||
798 scope
== ReinterpretMapScope::kExceptGeneric
) {
799 patterns
.add
<TensorAllocDemapper
<bufferization::AllocTensorOp
>,
800 TensorAllocDemapper
<tensor::EmptyOp
>, SparseAssembleDemapper
,
801 SparseDisassembleDemapper
, TensorInsertDemapper
,
802 ForeachOpDemapper
>(patterns
.getContext());