[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Dialect / SparseTensor / Transforms / SparseReinterpretMap.cpp
blobce7f6b2865375fd0f539c26151dac3fdbb45f65d
1 //===- SparseReinterpretMap.cpp - reinterpret sparse tensor maps ----------===/
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "Utils/CodegenUtils.h"
10 #include "Utils/IterationGraphSorter.h"
12 #include "mlir/Dialect/Affine/IR/AffineOps.h"
13 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
14 #include "mlir/Dialect/Linalg/IR/Linalg.h"
15 #include "mlir/Dialect/Linalg/Utils/Utils.h"
16 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
18 #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/AffineExprVisitor.h"
21 #include "mlir/IR/AffineMap.h"
23 using namespace mlir;
24 using namespace mlir::sparse_tensor;
26 namespace {
28 //===----------------------------------------------------------------------===//
29 // File Local Helper classes.
30 //===----------------------------------------------------------------------===//
32 // CRTP to help implementing a rewriter that demaps all its inputs.
33 template <typename SubClass, typename SourceOp>
34 struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
35 using OpRewritePattern<SourceOp>::OpRewritePattern;
36 using OpAdaptor = typename SourceOp::Adaptor;
38 LogicalResult matchAndRewrite(SourceOp op,
39 PatternRewriter &rewriter) const override {
40 Location loc = op.getLoc();
42 // Demaps non-trivial inputs.
43 bool changed = false;
44 SmallVector<Value> deMappedIns(op->getOperands());
45 for (Value &in : deMappedIns) {
46 if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) {
47 in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
48 changed = true;
52 // CRTP call.
53 OpAdaptor adaptor(deMappedIns, op);
54 LogicalResult status =
55 static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
56 return changed ? success() : status;
60 // Flattens an affine expression into a list of AffineDimExprs.
61 struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
62 explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
63 void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
64 BitVector dims;
67 // Flattens an affine expression into a list of AffineDimExprs.
68 struct AffineExprAdmissibleVisitor
69 : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
70 explicit AffineExprAdmissibleVisitor(bool isOutput)
71 : admissible(true), isOutput(isOutput){};
73 // We only allow AffineDimExpr on output.
74 void visitAddExpr(AffineBinaryOpExpr expr) {
75 if (isOutput)
76 admissible = false;
78 void visitMulExpr(AffineBinaryOpExpr expr) {
79 if (isOutput)
80 admissible = false;
83 // We disallow mod, floor div and ceil div on inputs.
84 void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
85 void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
86 void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
87 operator bool() { return admissible; }
89 private:
90 bool admissible;
91 bool isOutput;
94 // The first BitVector stores levels where inadmissible exprs are used.
95 // The second BitVector stores the AffineDimExp that are used by the
96 // inadmissible expressions.
97 using InadmissInfo = std::pair<BitVector, BitVector>;
99 } // namespace
101 //===----------------------------------------------------------------------===//
102 // File Local Helper methods.
103 //===----------------------------------------------------------------------===//
105 // Collects the inadmissible affine expression imposed on levels.
106 static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
107 auto ret = std::make_pair(BitVector(map.getNumResults()),
108 BitVector(map.getNumDims()));
109 AffineDimCollector collector(map.getNumDims());
110 for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
111 AffineExprAdmissibleVisitor admissible(isOutput);
112 admissible.walkPostOrder(map.getResult(lvl));
113 if (!admissible) {
114 // Record the inadmissible level.
115 ret.first.set(lvl);
116 // Record the AffineDimExpr that is used in the inadmissible expr.
117 collector.walkPostOrder(map.getResult(lvl));
120 ret.second = collector.dims;
121 return ret;
124 // Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
125 // inadmissible affine expressions can be eliminated.
126 // For example, we can rewrite
127 // idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
128 // to
129 // idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
130 // by composing inverse(idxMap), that is
131 // inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
132 // -> ((l0 * 2 + l2) floordiv 2,
133 // (l1 * 3 + l3) floordiv 3,
134 // (l0 * 2 + l2) mod 2,
135 // (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
137 // This function builds the inverse(idxMap) that replace every dimensions used
138 // in `info` to levels, and updates the iterator type array `itTps` for the new
139 // index variable introduced.
141 // Note that the returned affine map does not retain the order of the input
142 // affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
143 // replaced levels, and remaining ones for unused dimensions.
144 // For example, to handle
145 // idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
146 // which is a typical map for block_2to4. The function returns:
147 // inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
148 // in which, (l0, l1) together replaces `d1`, yet they appear
149 // before `d0` in the resulting affine map.
150 // The index (loop) order can later be canonicalized by a topo sort.
151 static AffineMap
152 genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
153 SmallVector<utils::IteratorType> &itTps) {
154 MLIRContext *ctx = idxMap.getContext();
155 auto [inAdLvls, usedDims] = info;
156 // Note that idxMap does not equal to dim2Lvl map, it is computed by
157 // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
158 // ID map.
159 // TODO: we might fail here, in those case we should really return
160 // failure instead of assertion error.
161 auto lvl2Idx = inferLvlToDim(idxMap, ctx);
163 assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
164 if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
165 // This could happen when some dimensions are projected.
166 // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
167 // ==> lvl2Idx = (j, k) -> (j, k)
168 // In this case, we append the unused dimesion at the end.
169 // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
170 SmallVector<AffineExpr> results;
171 AffineDimCollector usedInLvl(idxMap.getNumDims());
172 for (auto e : idxMap.getResults())
173 usedInLvl.walkPostOrder(e);
175 unsigned curUsedDimID = 0;
176 unsigned curUnusedDimID = lvl2Idx.getNumDims();
178 BitVector unused = usedInLvl.dims.flip();
179 for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
180 if (unused.test(i))
181 results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
182 else
183 results.push_back(lvl2Idx.getResult(curUsedDimID++));
185 lvl2Idx =
186 AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
188 assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
190 // We do not need to replace the DimExpr that is not used in inadmissible
191 // level expressions. We use the first inAdLvl.count() dim to represent the
192 // replaced level, the remainings are reserved for unchanged ones.
193 // Note that results from the inverse map computed previously does not follow
194 // the convention we used, and we need to fix the mismatch below.
195 unsigned curRepID = 0;
196 unsigned curOriID = inAdLvls.count();
197 SmallVector<AffineExpr> results;
198 SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
199 SmallVector<utils::IteratorType> transItTps;
201 for (unsigned l : inAdLvls.set_bits()) {
202 // By our convention, the inadmissible level `l` always appears in the
203 // leading part (accumulated by curRepID) of the affine map's parameter
204 // list. Record the mapping so that we can replace all the uses of `l` to
205 // the correct position after the translation.
206 dimRep[l] = getAffineDimExpr(curRepID++, ctx);
207 // A new index variable is introduced for the inadmissible level, inherit
208 // the iterator type. E.g., if l0 = d0 floordiv 2, the
209 // iterator type of l0 equals to the iterator type of d0.
210 AffineExpr lvlExp = idxMap.getResult(l);
211 AffineDimCollector collector(idxMap.getNumDims());
212 collector.walkPostOrder(lvlExp);
213 // We assumes a level can only be derived from one dimension.
214 assert(collector.dims.count() == 1);
215 transItTps.push_back(itTps[collector.dims.find_first()]);
218 for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
219 if (usedDims.test(d)) {
220 // The dimension is used in some of the inadmissible levels, and it need
221 // to be inversed. Get the inversion from the inverse map, and fix the
222 // mismatch captured by the above loop.
223 results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
224 } else {
225 // The dimension is not used in any of the inadmissible levels, and it
226 // does not need to be inversed. Fix the mismatch by mapping it to the
227 // trailing part of the affine map (accumulated by curOriID).
228 results.push_back(getAffineDimExpr(curOriID++, ctx));
229 transItTps.push_back(itTps[d]);
232 unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
233 // Update iterator type.
234 itTps.assign(transItTps.begin(), transItTps.end());
235 return AffineMap::get(numDim, 0, results, ctx);
238 // Translates the index map in the linalg::GenericOp from idx->dim map to
239 // idx->lvl map. Returns failure if the index map can not be translated to an
240 // admissible form.
241 // Returns the translated index map array and the iterator type array.
242 static std::optional<std::pair<ArrayAttr, ArrayAttr>>
243 translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
244 // idxMap is a idx2dim map before reinterpretation.
245 MLIRContext *ctx = op.getContext();
246 SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
247 SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
248 for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
249 Value tensor = op->getOpOperand(i).get();
250 auto stt = tryGetSparseTensorType(tensor);
251 if (stt && !stt->isIdentity()) {
252 AffineMap dim2Lvl = stt->getDimToLvl();
253 // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
254 idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
258 // A naive way to handle common constant expressions that arise during dim2lvl
259 // translation.
260 auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
261 unsigned pos, int64_t lvlSz) {
262 if (!ShapedType::isDynamic(lvlSz)) {
263 auto c0 = getAffineConstantExpr(0, ctx);
264 auto lvlExp = getAffineDimExpr(pos, ctx);
265 auto szExp = getAffineConstantExpr(lvlSz, ctx);
267 // lvl floordiv lvlSz = 0
268 auto divExp =
269 getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
270 cstMapping.try_emplace(divExp, c0);
272 // lvl mod lvlSz = lvl
273 auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
274 cstMapping.try_emplace(modExp, lvlExp);
278 unsigned boundedNum = 0;
279 // A fixed-point algorithm.
280 bool changed = true;
281 while (changed) {
282 changed = false;
283 for (OpOperand &operand : op->getOpOperands()) {
284 auto stt = tryGetSparseTensorType(operand.get());
285 // Skip on dense operands.
286 if (!stt || !stt->getEncoding())
287 continue;
289 unsigned tid = operand.getOperandNumber();
290 bool isOutput = &operand == op.getDpsInitOperand(0);
291 AffineMap idxMap = idxMapArray[tid];
292 InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
293 auto [inAdLvls, dimExprs] = inAdInfo;
294 for (unsigned d : dimExprs.set_bits()) {
295 // The first `boundedNum` used in the AffineMap is introduced to
296 // resolve previous inadmissible expressions. We can not replace them
297 // as it might bring back the inadmissible expressions.
298 if (d < boundedNum)
299 return std::nullopt;
302 if (inAdLvls.count() != 0) {
303 // Naive constant progagation, should be sufficient to handle block
304 // sparsity in our cases.
305 SmallVector<int64_t> lvlShape = stt->getLvlShape();
306 DenseMap<AffineExpr, AffineExpr> cstMapping;
307 unsigned position = 0;
308 for (unsigned lvl : inAdLvls.set_bits()) {
309 int64_t lvlSz = lvlShape[lvl];
310 populateCstMapping(cstMapping, position, lvlSz);
311 position++;
314 AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
315 // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
316 // inadmissible expressions.
317 for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
318 AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
319 idxMapArray[tid] = transMap.replace(
320 cstMapping, /*numResultDims=*/transMap.getNumDims(),
321 /*numResultSyms=*/0);
323 changed = true;
324 boundedNum += inAdLvls.count();
329 SmallVector<Attribute> iterAttr =
330 llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
331 return linalg::IteratorTypeAttr::get(ctx, itTp);
334 return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
335 rewriter.getArrayAttr(iterAttr));
338 // Generates a "de"mapping reinterpretation of the map.
339 static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
340 Value val) {
341 return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
342 val);
345 // Generates a "re"mapping reinterpretation of the map.
346 static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
347 Value val) {
348 return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
351 static SmallVector<Value> remapValueRange(OpBuilder &rewriter, TypeRange types,
352 ValueRange outs) {
353 SmallVector<Value> ret(outs);
354 assert(outs.size() == types.size());
355 for (auto [r, t] : llvm::zip(ret, types))
356 if (r.getType() != t)
357 r = rewriter.create<ReinterpretMapOp>(r.getLoc(), t, r);
358 return ret;
361 namespace {
363 //===----------------------------------------------------------------------===//
364 // Rewriting rules for linalg generic ops.
365 //===----------------------------------------------------------------------===//
367 /// Sparse rewriting rule for the generic `linalg` operation.
368 struct GenericOpReinterpretMap
369 : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
370 public:
371 using DemapInsRewriter::DemapInsRewriter;
372 LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
373 PatternRewriter &rewriter) const {
374 // Only rewrite single output operations with pure (sparse) tensor
375 // semantics.
376 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
377 !hasAnySparseOperandOrResult(linalgOp) ||
378 !hasAnyNonIdentityOperandsOrResults(linalgOp))
379 return failure();
381 // Try translating the index map.
382 auto transMap = translateMap(linalgOp, rewriter);
383 if (!transMap)
384 return rewriter.notifyMatchFailure(
385 linalgOp, "the sparse kernel can not be sparsified.");
387 // On success, replace update the linalg operands and maps in place.
388 Value res = linalgOp.getResult(0);
389 auto stt = tryGetSparseTensorType(res);
390 auto [idxMap, itTp] = *transMap;
392 rewriter.startOpModification(linalgOp);
393 linalgOp.setIndexingMapsAttr(idxMap);
394 linalgOp.setIteratorTypesAttr(itTp);
395 // Use demapped arguments.
396 linalgOp.getInputsMutable().assign(adaptor.getInputs());
397 linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
398 res.setType(adaptor.getOutputs()[0].getType());
399 rewriter.finalizeOpModification(linalgOp);
401 rewriter.setInsertionPointAfter(linalgOp);
402 if (stt && stt->hasEncoding()) {
403 Value t = genRemap(rewriter, stt->getEncoding(), res);
404 rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
406 return success();
410 struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
411 using OpRewritePattern::OpRewritePattern;
412 LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
413 PatternRewriter &rewriter) const override {
414 if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
415 hasAnyNonIdentityOperandsOrResults(linalgOp) || // need demap first
416 !hasAnySparseOperandOrResult(linalgOp)) {
417 return failure();
420 const StringRef sorted = "sorted";
421 if (linalgOp->hasAttr(sorted))
422 return failure();
424 auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
425 bool isAdmissible = false;
426 AffineMap order;
427 // A const list of all masks that we used for iteration graph
428 // computation. Must be ordered from more strict to less strict.
429 // Ideally (though might not be guaranteed), the earlier a constraint mask
430 // can be satisfied, the faster the generated kernel will be.
431 const auto allMasks = {SortMask::kIncludeAll, SortMask::kIncludeDense,
432 SortMask::kIncludeDenseInput,
433 SortMask::kIncludeDenseOutput,
434 SortMask::kSparseOnly};
435 for (const SortMask mask : allMasks) {
436 order = scheduler.sort(mask);
437 if (order) {
438 if (isAdmissibleOrder(linalgOp, order)) {
439 isAdmissible = true;
440 break;
442 // else try a set of less strict constraints.
446 if (!order) {
447 // Cycles detected.
448 if (failed(resolveCycle(scheduler, linalgOp, rewriter))) {
449 return rewriter.notifyMatchFailure(
450 linalgOp, "the sparse kernel can not be scheduled: loop detected.");
452 return success();
455 if (!isAdmissible) {
456 return rewriter.notifyMatchFailure(
457 linalgOp, "the sparse kernel can not be scheduled.");
460 // Marks the GenericOp to avoid recursive matching.
461 rewriter.modifyOpInPlace(linalgOp, [&]() {
462 linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
465 // Already sorted.
466 if (order.isIdentity())
467 return success();
469 assert(order.isPermutation());
470 // `order` is orignial loop -> sorted loop map
471 ArrayAttr preItTypes = linalgOp.getIteratorTypesAttr();
472 SmallVector<Attribute> curItTypes;
473 curItTypes.reserve(preItTypes.size());
474 for (AffineExpr expr : order.getResults()) {
475 unsigned loopID = llvm::cast<AffineDimExpr>(expr).getPosition();
476 curItTypes.push_back(preItTypes[loopID]);
479 // Inverse `order` to get sorted loop -> original loop map
480 order = inversePermutation(order);
481 SmallVector<AffineMap> idxMaps = linalgOp.getIndexingMapsArray();
482 for (AffineMap &idxMap : idxMaps)
483 idxMap = idxMap.compose(order); // sorted loop -> lvl map
485 rewriter.startOpModification(linalgOp);
486 linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
487 linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
488 rewriter.finalizeOpModification(linalgOp);
490 return success();
493 private:
494 /// Whether the loop order is admissible by sparsification.
495 static bool isAdmissibleOrder(linalg::GenericOp linalgOp, AffineMap order) {
496 if (!hasAnySparseResult(linalgOp))
497 return true;
499 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
500 unsigned nest = 0;
501 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
502 for (const AffineExpr l : order.getResults()) {
503 unsigned loopId = llvm::cast<AffineDimExpr>(l).getPosition();
504 auto itTp =
505 cast<linalg::IteratorTypeAttr>(linalgOp.getIteratorTypes()[loopId]);
506 if (linalg::isReductionIterator(itTp.getValue()))
507 break; // terminate at first reduction
508 nest++;
510 // Determine admissible dynamic insertion situations:
511 // (1) fully injective, since there are no reductions,
512 // (2) admissible 1-d expansion in innermost dimension.
513 return static_cast<int64_t>(nest) >= linalgOp.getRank(lhs) - 1;
516 // Last resort cycle resolution.
517 static LogicalResult resolveCycle(IterationGraphSorter &scheduler,
518 linalg::LinalgOp linalgOp,
519 PatternRewriter &rewriter) {
520 // Compute topological sort while leaving out every sparse input tensor in
521 // succession until an acylic iteration graph results.
522 for (OpOperand *t : linalgOp.getDpsInputOperands()) {
523 Value tval = t->get();
524 auto srcEnc = getSparseTensorEncoding(tval.getType());
525 // The constraints introduced by compound index expression are
526 // complicated. Skip them.
527 AffineMap idxMap = linalgOp.getMatchingIndexingMap(t);
528 bool hasCompExpr = llvm::any_of(idxMap.getResults(), [](AffineExpr exp) {
529 return !llvm::isa<AffineDimExpr>(exp);
531 if (!srcEnc || hasCompExpr)
532 continue;
534 // Try scheduling loop without constraints from `tval`.
535 AffineMap order = scheduler.sort(SortMask::kSparseOnly, tval);
536 if (!order) // still cyclic
537 continue;
539 // Found an input tensor that resolves the cycle by inserting a
540 // conversion into a sparse tensor that adheres to the iteration
541 // graph order.
542 auto stt = getSparseTensorType(tval);
543 assert(stt.isIdentity());
544 order = inversePermutation(order);
545 // sorted loop -> lvl map.
546 idxMap = idxMap.compose(order);
548 // Found a permutation such that the results in `idxMap` is sorted.
549 // For example,
550 // (d0, d1, d2, d3) -> (d2, d1, d0)
551 // loops are scheduled in order of d0->d1->d2->d3, to resolve the cycle,
552 // we find a permutation, perm(d2, d1, d0) -> (d0, d1, d2), such that the
553 // transposed tensor's levels are visited in the same order as the loop
554 // scheduling order.
555 SmallVector<std::pair<unsigned, unsigned>> lvlSeq;
556 for (AffineExpr expr : idxMap.getResults()) {
557 unsigned lvl = llvm::cast<AffineDimExpr>(expr).getPosition();
558 lvlSeq.push_back(std::make_pair(lvl, lvlSeq.size()));
560 llvm::sort(lvlSeq, llvm::less_first());
561 SmallVector<unsigned> perm =
562 llvm::to_vector(llvm::make_second_range(lvlSeq));
563 auto dimToLvl = AffineMap::getPermutationMap(perm, linalgOp.getContext());
564 // The result of the idxMap must be unsorted.
565 assert(!dimToLvl.isIdentity());
567 // Inserting the transpose
568 rewriter.setInsertionPoint(linalgOp);
569 RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
570 Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
571 rewriter.modifyOpInPlace(linalgOp, [&]() {
572 linalgOp->setOperand(t->getOperandNumber(), dst);
575 // Release the transposed form afterwards.
576 // TODO: CSE when used in more than one following op?
577 rewriter.setInsertionPointAfter(linalgOp);
578 rewriter.create<bufferization::DeallocTensorOp>(dst.getLoc(), dst);
580 return success();
582 // Cannot be resolved with a single conversion.
583 // TODO: convert more than one?
584 return failure();
588 //===----------------------------------------------------------------------===//
589 // Reinterpret Map Rewriters for operations other than linalg.generics
590 //===----------------------------------------------------------------------===//
592 template <typename AllocOp>
593 struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
594 using OpRewritePattern<AllocOp>::OpRewritePattern;
595 LogicalResult matchAndRewrite(AllocOp op,
596 PatternRewriter &rewriter) const override {
597 if (!hasAnyNonIdentityOperandsOrResults(op))
598 return failure();
600 Location loc = op.getLoc();
601 auto stt = getSparseTensorType(op.getResult());
603 SmallVector<Value> maxDimCrds;
604 maxDimCrds.reserve(stt.getDimRank());
605 ValueRange dynSz = op.getDynamicSizes();
606 for (int64_t dimSz : stt.getDimShape()) {
607 if (ShapedType::isDynamic(dimSz)) {
608 Value maxCrd = rewriter.create<arith::SubIOp>(
609 loc, dynSz.front(), constantIndex(rewriter, loc, 1));
610 maxDimCrds.push_back(maxCrd);
611 dynSz = dynSz.drop_front();
612 } else {
613 maxDimCrds.push_back(constantIndex(rewriter, loc, dimSz - 1));
617 ValueRange maxLvlCrds = stt.translateCrds(rewriter, loc, maxDimCrds,
618 CrdTransDirectionKind::dim2lvl);
619 auto lvlShape = stt.getLvlShape();
620 SmallVector<Value> dynLvlSzs;
621 for (unsigned i = 0, e = lvlShape.size(); i < e; i++) {
622 if (ShapedType::isDynamic(lvlShape[i])) {
623 Value sz = rewriter.create<arith::AddIOp>(
624 loc, maxLvlCrds[i], constantIndex(rewriter, loc, 1));
625 dynLvlSzs.push_back(sz);
629 assert(dynSz.empty()); // should have consumed all.
630 rewriter.startOpModification(op);
631 op->setOperands(dynLvlSzs);
632 op.getResult().setType(stt.getDemappedType());
633 rewriter.finalizeOpModification(op);
634 rewriter.setInsertionPointAfter(op);
636 Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
637 rewriter.replaceAllUsesExcept(op.getResult(), t, t.getDefiningOp());
638 return success();
642 struct TensorInsertDemapper
643 : public DemapInsRewriter<TensorInsertDemapper, tensor::InsertOp> {
644 using DemapInsRewriter::DemapInsRewriter;
645 LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
646 PatternRewriter &rewriter) const {
647 if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
648 return failure();
650 Location loc = op.getLoc();
651 auto stt = getSparseTensorType(op.getResult());
652 ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
653 CrdTransDirectionKind::dim2lvl);
654 auto insertOp = rewriter.create<tensor::InsertOp>(
655 loc, op.getScalar(), adaptor.getDest(), lvlCrd);
657 Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
658 rewriter.replaceOp(op, out);
659 return success();
663 struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
664 using OpRewritePattern::OpRewritePattern;
665 LogicalResult matchAndRewrite(AssembleOp op,
666 PatternRewriter &rewriter) const override {
667 if (!hasAnyNonIdentityOperandsOrResults(op))
668 return failure();
670 assert(hasAnySparseResult(op));
671 auto stt = getSparseTensorType(op.getResult());
672 rewriter.modifyOpInPlace(
673 op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
674 rewriter.setInsertionPointAfter(op);
675 Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
676 rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
677 return success();
681 struct SparseDisassembleDemapper
682 : public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
683 using DemapInsRewriter::DemapInsRewriter;
684 LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
685 PatternRewriter &rewriter) const {
686 if (!hasAnyNonIdentityOperandsOrResults(op))
687 return failure();
689 assert(hasAnySparseOperandOrResult(op));
690 rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
691 op.getTensorMutable().assign(adaptor.getTensor());
693 return success();
697 struct ForeachOpDemapper
698 : public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
699 using DemapInsRewriter::DemapInsRewriter;
700 LogicalResult rewriteOp(ForeachOp op, OpAdaptor adaptor,
701 PatternRewriter &rewriter) const {
702 // Only handle operations with sparse input/output with non-identity dim2lvl
703 // maps.
704 if (!hasAnyNonIdentityOperandsOrResults(op))
705 return failure();
707 // TODO: demap constant as well.
708 if (auto constOp = op.getTensor().getDefiningOp<arith::ConstantOp>())
709 if (auto attr = dyn_cast<SparseElementsAttr>(constOp.getValue()))
710 return failure();
712 Location loc = op.getLoc();
713 // Cache the type information since we update the foreach op in-place.
714 auto srcStt = getSparseTensorType(op.getTensor());
715 SmallVector<Type> prevRetTps(op.getResultTypes());
717 rewriter.startOpModification(op);
718 op.getTensorMutable().assign(adaptor.getTensor());
719 op.getInitArgsMutable().assign(adaptor.getInitArgs());
720 // Update results' types.
721 for (auto r : op.getResults())
722 if (auto stt = tryGetSparseTensorType(r); stt && !stt->isIdentity())
723 r.setType(stt->getDemappedType());
725 Level lvlRank = getSparseTensorType(adaptor.getTensor()).getLvlRank();
726 // Update the foreach body.
727 SmallVector<Type> blockArgTps(lvlRank, rewriter.getIndexType());
728 blockArgTps.push_back(srcStt.getElementType());
729 blockArgTps.append(adaptor.getInitArgs().getTypes().begin(),
730 adaptor.getInitArgs().getTypes().end());
731 Block *body = op.getBody();
732 // Block Args: [dimCrd, val, initArgs]
733 unsigned preArgNum = body->getNumArguments();
734 for (Type t : blockArgTps)
735 body->addArgument(t, loc);
737 // Block Args: [dimCrd, val, initArgs, lvlCrds, val, DemappedArgs]
738 rewriter.setInsertionPointToStart(body);
739 ValueRange lvlCrds = body->getArguments().slice(preArgNum, lvlRank);
741 ValueRange dimCrds = srcStt.translateCrds(rewriter, loc, lvlCrds,
742 CrdTransDirectionKind::lvl2dim);
743 rewriter.replaceAllUsesWith(
744 body->getArguments().take_front(srcStt.getDimRank()), dimCrds);
745 body->eraseArguments(0, srcStt.getDimRank());
746 // Block Args: [val, initArgs, lvlCrds, val, DemappedArgs]
747 unsigned numInitArgs = op.getInitArgs().size();
748 rewriter.replaceAllUsesWith(body->getArgument(0),
749 body->getArgument(lvlRank + numInitArgs + 1));
750 body->eraseArgument(0);
751 // Block Args: [initArgs, lvlCrds, val, DemappedArgs]
752 ValueRange srcArgs = body->getArguments().take_front(numInitArgs);
753 ValueRange dstArgs = body->getArguments().take_back(numInitArgs);
754 // Remap back before replacement.
755 SmallVector<Value> reMappedArgs =
756 remapValueRange(rewriter, srcArgs.getTypes(), dstArgs);
757 rewriter.replaceAllUsesWith(srcArgs, reMappedArgs);
758 body->eraseArguments(0, numInitArgs);
759 // Block Args: [lvlCrds, DemappedArgs] and we are done.
761 // Update yield operations.
762 if (numInitArgs != 0) {
763 rewriter.setInsertionPointToEnd(body);
764 auto yield = llvm::cast<YieldOp>(body->getTerminator());
765 if (auto stt = tryGetSparseTensorType(yield.getSingleResult());
766 stt && !stt->isIdentity()) {
767 Value y =
768 genDemap(rewriter, stt->getEncoding(), yield.getSingleResult());
769 rewriter.create<YieldOp>(loc, y);
770 rewriter.eraseOp(yield);
773 rewriter.finalizeOpModification(op);
775 rewriter.setInsertionPointAfter(op);
776 SmallVector<Value> outs =
777 remapValueRange(rewriter, prevRetTps, op.getResults());
779 // Replace all the uses of the foreach results, expect the use in
780 // reinterpret_map used to remap the output.
781 for (auto [from, to] : llvm::zip(op.getResults(), outs))
782 rewriter.replaceAllUsesExcept(from, to, to.getDefiningOp());
784 return success();
788 } // namespace
790 void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
791 ReinterpretMapScope scope) {
792 if (scope == ReinterpretMapScope::kAll ||
793 scope == ReinterpretMapScope::kGenericOnly) {
794 patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
795 patterns.getContext());
797 if (scope == ReinterpretMapScope::kAll ||
798 scope == ReinterpretMapScope::kExceptGeneric) {
799 patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
800 TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
801 SparseDisassembleDemapper, TensorInsertDemapper,
802 ForeachOpDemapper>(patterns.getContext());