From 067bebb50f6ec3f30ca8c34117c2c964729aca58 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Tue, 5 Dec 2023 09:31:17 -0800 Subject: [PATCH] [mlir][sparse] minor refactoring of sparsification file (#74403) Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter --- .../SparseTensor/Transforms/Sparsification.cpp | 57 ++++++++-------------- 1 file changed, 19 insertions(+), 38 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index e0d3ce241e45..d171087f56ab 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -34,6 +34,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/TensorEncoding.h" #include "llvm/ADT/SmallBitVector.h" + #include using namespace mlir; @@ -43,11 +44,6 @@ using namespace mlir::sparse_tensor; // Sparsifier analysis methods. //===----------------------------------------------------------------------===// -// TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory, -// and those letters are too easy to confuse visually. We should switch -// to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop" -// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention). - /// Determines if affine expression is invariant. static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx, bool &isAtLoop) { @@ -56,11 +52,9 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx, const LoopId i = cast(a).getPosition(); if (i == ldx) { isAtLoop = true; - // Must be invariant if we are at the given loop. - return true; + return true; // invariant at given loop } - // The DimExpr is invariant the loop has already been generated. - return i < loopDepth; + return i < loopDepth; // invariant when already generated } case AffineExprKind::Add: case AffineExprKind::Mul: { @@ -85,7 +79,6 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a, const LoopId idx = merger.makeLoopId(cast(a).getPosition()); if (!isUndefLT(merger.getLvlType(tid, idx))) return false; // used more than once - if (setLvlFormat) merger.setLevelAndType(tid, idx, lvl, lt); return true; @@ -195,7 +188,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl, } } -/// Get the total number of compound affine expressions in the +/// Gets the total number of compound affine expressions in the /// `getMatchingIndexingMap` for the given tensor. For the following inputs: /// /// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed) @@ -225,7 +218,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map, return num; } -/// Get the total number of sparse levels with compound affine +/// Gets the total number of sparse levels with compound affine /// expressions, summed over all operands of the `GenericOp`. static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { unsigned num = 0; @@ -235,6 +228,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) { return num; } +// Returns true iff output has nontrivial affine indices. static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) { OpOperand *out = op.getDpsInitOperand(0); if (getSparseTensorType(out->get()).isAllDense()) @@ -260,11 +254,9 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) { const auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; - const Level lvlRank = map.getNumResults(); assert(!enc || lvlRank == enc.getLvlRank()); assert(static_cast(env.op().getRank(&t)) == lvlRank); - // We only need to do index reduction if there is at least one non-trivial // index expression on sparse levels. // If all non-trivial index expression is on dense levels, we can @@ -343,9 +335,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) { } /// Generates index for load/store on sparse tensor. -// FIXME: It's not entirely clear what "index" means here (i.e., is it -// a "coordinate", or "Ldx", or what). So the function should be renamed -// and/or the documentation expanded in order to clarify. static Value genIndex(CodegenEnv &env, OpOperand *t) { const auto map = env.op().getMatchingIndexingMap(t); const auto stt = getSparseTensorType(t->get()); @@ -495,7 +484,6 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) { Value val = env.exp(exp).val; if (val) return val; - // Load during insertion. linalg::GenericOp op = env.op(); OpOperand *t = &op->getOpOperand(env.exp(exp).tensor); @@ -574,7 +562,7 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) { /// exception of index computations, which need to be relinked to actual /// inlined cloned code. static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, - Value e, LoopId ldx) { + Value e) { if (auto arg = dyn_cast(e)) { // Direct arguments of the original linalg op must be converted // into dense tensor loads. Note that we should not encounter @@ -598,7 +586,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) { rewriter.updateRootInPlace(def, [&]() { def->setOperand( - i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx)); + i, relinkBranch(env, rewriter, block, def->getOperand(i))); }); } } @@ -607,8 +595,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, } /// Recursively generates tensor expression. -static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e, - LoopId ldx) { +static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) { if (e == ::mlir::sparse_tensor::detail::kInvalidId) return Value(); @@ -631,15 +618,15 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e, // based on the type of the other operand. if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId && env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) { - v1 = genExp(env, rewriter, exp.children.e1, ldx); + v1 = genExp(env, rewriter, exp.children.e1); v0 = constantZero(rewriter, loc, v1.getType()); } else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId && env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) { - v0 = genExp(env, rewriter, exp.children.e0, ldx); + v0 = genExp(env, rewriter, exp.children.e0); v1 = constantZero(rewriter, loc, v0.getType()); } else { - v0 = genExp(env, rewriter, exp.children.e0, ldx); - v1 = genExp(env, rewriter, exp.children.e1, ldx); + v0 = genExp(env, rewriter, exp.children.e0); + v1 = genExp(env, rewriter, exp.children.e1); } Value ee; @@ -653,7 +640,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e, kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect)) { OpBuilder::InsertionGuard guard(rewriter); - ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx); + ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee); } } @@ -806,7 +793,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter, const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx); return isCompressedLT(lt) || isSingletonLT(lt); }); - return isParallelFor(env, isOuter, isSparse); } @@ -1112,11 +1098,6 @@ static bool translateBitsToTidLvlPairs( // level. We need to generate the address according to the // affine expression. This is also the best place we can do it // to avoid putting it inside inner loops. - // NOTE: It assumes that the levels of the input tensor are - // initialized in order (and it is also currently guaranteed by - // computeIterationGraph), another more admissible approach - // might be accepting out-of-order access between consecutive - // dense levels. affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp); } } @@ -1221,7 +1202,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, LoopOrd at) { // At each leaf, assign remaining tensor (sub)expression to output tensor. if (at == env.getLoopNum()) { - Value rhs = genExp(env, rewriter, exp, at - 1); + Value rhs = genExp(env, rewriter, exp); genTensorStore(env, rewriter, exp, rhs); return; } @@ -1235,8 +1216,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts); // Emit a loop for every lattice point L0 >= Li in this loop sequence. - // - // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))` + // We cannot change this to `for (const LatPointId li : env.set(lts))` // because the loop body causes data-movement which invalidates // the iterator. const unsigned lsize = env.set(lts).size(); @@ -1251,7 +1231,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp, Value cntInput = env.getExpandCount(); Value insInput = env.getInsertionChain(); Value validIns = env.getValidLexInsert(); - // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))` + // We cannot change this to `for (const LatPointId lj : env.set(lts))` // because the loop body causes data-movement which invalidates the // iterator. for (unsigned j = 0; j < lsize; j++) { @@ -1323,6 +1303,7 @@ public: if (hasNonTrivialAffineOnSparseOut(op)) return failure(); + // Only accept scheduled loops. if (!op->hasAttr("sorted")) { return rewriter.notifyMatchFailure( op, "Loops not yet scheduled, try run --sparse-reinterpret-map " @@ -1348,9 +1329,9 @@ public: } } - CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank); // Detects sparse annotations and translates the per-level sparsity // information for all tensors to loop indices in the kernel. + CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank); if (!findSparseAnnotations(env, needIdxRed)) return failure(); -- 2.11.4.GIT