[mlir][sparse] cleanup of CodegenEnv reduction API (#75243)
[llvm-project.git] / mlir / lib / Dialect / SparseTensor / Transforms / CodegenEnv.cpp
blob4bd3af2d3f2f6a34b965bba4566f6ef7c2e7077b
1 //===- CodegenEnv.cpp - Code generation environment class ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "CodegenEnv.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Linalg/Utils/Utils.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
16 #include <optional>
18 using namespace mlir;
19 using namespace mlir::sparse_tensor;
21 //===----------------------------------------------------------------------===//
22 // Code generation environment helper functions
23 //===----------------------------------------------------------------------===//
25 /// Returns true if tensor materializes uninitialized into the computation.
26 static bool isMaterializing(Value val) {
27 return val.getDefiningOp<tensor::EmptyOp>() ||
28 val.getDefiningOp<bufferization::AllocTensorOp>();
31 /// Sorts the dependent loops such that it is ordered in the same sequence in
32 /// which loops will be generated.
33 static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
34 std::sort(target.begin(), target.end(),
35 [](const LoopCoeffPair &l, const LoopCoeffPair &r) {
36 assert(std::addressof(l) == std::addressof(r) || l != r);
37 return l.first < r.first;
38 });
40 //===----------------------------------------------------------------------===//
41 // Code generation environment constructor and general methods
42 //===----------------------------------------------------------------------===//
44 CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
45 unsigned numTensors, unsigned numLoops, unsigned maxRank)
46 : linalgOp(linop), sparseOptions(opts),
47 latticeMerger(numTensors, numLoops, maxRank), loopEmitter(),
48 sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
49 expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId),
50 redCustom(detail::kInvalidId), redValidLexInsert() {}
52 LogicalResult CodegenEnv::initTensorExp() {
53 // Builds the tensor expression for the Linalg operation in SSA form.
54 std::optional<ExprId> optExp = latticeMerger.buildTensorExpFromLinalg(op());
55 if (!optExp || !isAdmissibleTensorExp(*optExp))
56 return failure();
58 tensorExp = *optExp;
59 return success();
62 void CodegenEnv::startEmit() {
63 assert(insChain == nullptr && "must only start emitting once");
64 if (sparseOut) {
65 insChain = sparseOut->get();
66 latticeMerger.setHasSparseOut(true);
69 // Sort the related loop array such that they are in the same order as they
70 // appears on the topoOrder.
71 // TODO: since we only handle affine addition for slice based codegen, and
72 // addition is assoicative, the order how we evaluate the expression does
73 // not matter. However, to support multiplication, the order of the loop
74 // index should match the evaluation order to the affine expression AST.
76 // Initialize loop emitter.
77 SmallVector<Value> tensors; // input tensors passed to loop emitter
78 for (OpOperand &t : linalgOp->getOpOperands()) {
79 tensors.push_back(t.get());
80 const TensorId tid = makeTensorId(t.getOperandNumber());
81 const Level lvlRank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
82 const auto enc = getSparseTensorEncoding(t.get().getType());
83 (void)enc;
84 assert(!enc || lvlRank == enc.getLvlRank());
85 for (Level lvl = 0; lvl < lvlRank; lvl++)
86 sortDependentLoops(latticeMerger.getDependentLoops(tid, lvl));
89 loopEmitter.initialize(
90 tensors,
91 StringAttr::get(linalgOp.getContext(),
92 linalg::GenericOp::getOperationName()),
93 /*hasOutput=*/true,
94 /*isSparseOut=*/sparseOut != nullptr, /*numLoops=*/getLoopNum(),
95 // TODO: compute the map and pass it to loop emitter directly instead of
96 // passing in a callback.
97 /*dependentLvlGetter=*/
98 [this](TensorId t,
99 Level lvl) -> std::vector<std::pair<TensorLevel, unsigned>> {
100 // Translates from a list of loop indices to a list of [tid, lvl] pair.
101 std::vector<LoopCoeffPair> &rLoops = merger().getDependentLoops(t, lvl);
102 std::vector<std::pair<TensorLevel, unsigned>> ret;
103 ret.reserve(rLoops.size());
104 for (auto [loop, coeff] : rLoops) {
105 TensorLevel tl = makeTensorLevel(merger().getLoopDefiningLvl(loop));
106 ret.emplace_back(tl, coeff);
108 return ret;
112 std::optional<Operation *> CodegenEnv::genLoopBoundary(
113 function_ref<std::optional<Operation *>(MutableArrayRef<Value> parameters)>
114 callback) {
115 SmallVector<Value> params;
116 if (isReduc()) {
117 params.push_back(redVal);
118 if (isValidLexInsert())
119 params.push_back(redValidLexInsert);
120 } else {
121 assert(!isValidLexInsert());
123 if (isExpand())
124 params.push_back(expCount);
125 if (insChain != nullptr)
126 params.push_back(insChain);
127 auto r = callback(params); // may update parameters
128 unsigned i = 0;
129 if (isReduc()) {
130 updateReduc(params[i++]);
131 if (isValidLexInsert())
132 updateValidLexInsert(params[i++]);
134 if (isExpand())
135 updateExpandCount(params[i++]);
136 if (insChain != nullptr)
137 updateInsertionChain(params[i]);
138 return r;
141 //===----------------------------------------------------------------------===//
142 // Code generation environment verify functions.
143 //===----------------------------------------------------------------------===//
145 bool CodegenEnv::isAdmissibleTensorExp(ExprId exp) {
146 // We reject any expression that makes a reduction from `-outTensor`, as those
147 // expressions create a dependency between the current iteration (i) and the
148 // previous iteration (i-1). It would require iterating over the whole
149 // coordinate space, which prevent exploiting sparsity for faster code.
150 for (utils::IteratorType it : linalgOp.getIteratorTypesArray()) {
151 if (it == utils::IteratorType::reduction) {
152 if (latticeMerger.hasNegateOnOut(exp))
153 return false;
154 break;
158 OpOperand *lhs = linalgOp.getDpsInitOperand(0);
159 const TensorId tensor = makeTensorId(lhs->getOperandNumber());
160 // An non-annotated output tensor is assumed dense, and becomes a random
161 // access n-dim memref. Admissible since insertions cannot occur.
162 if (getSparseTensorType(lhs->get()).isAllDense())
163 return true;
165 // A tensor expression with a sparse output tensor that changes its values
166 // but not its nonzero structure, an operation called "simply dynamic" in
167 // [Bik96,Ch9], is also admissible without special env.
168 if (latticeMerger.isSingleCondition(tensor, exp))
169 return true;
171 // Accept "truly dynamic" if the output tensor materializes uninitialized
172 // into the computation and insertions occur in lexicographic index order.
173 sparseOut = lhs;
175 // Find the outermost parallel nest to determine whether compress/expand is
176 // needed.
177 outerParNest = 0;
178 const auto iteratorTypes = linalgOp.getIteratorTypesArray();
179 for (unsigned i = 0, e = getLoopNum(); i < e; i++) {
180 if (linalg::isReductionIterator(iteratorTypes[i]))
181 break; // terminate at first reduction
182 outerParNest++;
185 // Inadmissible kernel should have already been rejected by the previous
186 // path during loop scheduling.
187 assert(static_cast<int64_t>(outerParNest) >=
188 linalgOp.getRank(linalgOp.getDpsInitOperand(0)) - 1);
189 return isMaterializing(lhs->get());
192 //===----------------------------------------------------------------------===//
193 // Code generation environment topological sort methods
194 //===----------------------------------------------------------------------===//
196 Value CodegenEnv::getLoopVar(LoopId i) const {
197 return loopEmitter.getLoopIV(i);
200 //===----------------------------------------------------------------------===//
201 // Code generation environment sparse tensor output and expansion methods
202 //===----------------------------------------------------------------------===//
204 void CodegenEnv::updateInsertionChain(Value chain) {
205 assert(sparseOut != nullptr && insChain != nullptr);
206 insChain = chain;
209 bool CodegenEnv::atExpandLevel(OpOperand *o, unsigned rank, LoopId n) const {
210 return sparseOut == o && outerParNest == static_cast<LoopId>(rank - 1) &&
211 outerParNest == n;
214 void CodegenEnv::startExpand(Value values, Value filled, Value added,
215 Value count) {
216 assert(sparseOut != nullptr && expValues == nullptr);
217 expValues = values;
218 expFilled = filled;
219 expAdded = added;
220 expCount = count;
223 void CodegenEnv::updateExpandCount(Value count) {
224 assert(sparseOut != nullptr && expValues != nullptr);
225 expCount = count;
228 void CodegenEnv::endExpand() {
229 assert(sparseOut != nullptr && expValues != nullptr);
230 expValues = expFilled = expAdded = expCount = Value();
233 //===----------------------------------------------------------------------===//
234 // Code generation environment reduction methods
235 //===----------------------------------------------------------------------===//
237 void CodegenEnv::startReduc(ExprId exp, Value val) {
238 assert(!isReduc() && exp != detail::kInvalidId && val);
239 redExp = exp;
240 redVal = val;
241 latticeMerger.setExprValue(exp, val);
244 void CodegenEnv::updateReduc(Value val) {
245 assert(isReduc() && val);
246 redVal = val;
247 latticeMerger.clearExprValue(redExp);
248 latticeMerger.setExprValue(redExp, val);
251 Value CodegenEnv::endReduc() {
252 assert(isReduc());
253 Value val = redVal;
254 redVal = val;
255 latticeMerger.clearExprValue(redExp);
256 redExp = detail::kInvalidId;
257 return val;
260 void CodegenEnv::startValidLexInsert(Value val) {
261 assert(!isValidLexInsert() && isReduc() && val);
262 redValidLexInsert = val;
265 void CodegenEnv::updateValidLexInsert(Value val) {
266 assert(redValidLexInsert && isReduc() && val);
267 redValidLexInsert = val;
270 void CodegenEnv::endValidLexInsert() {
271 assert(isValidLexInsert() && !isReduc());
272 redValidLexInsert = Value();
275 void CodegenEnv::startCustomReduc(ExprId exp) {
276 assert(!isCustomReduc() && exp != detail::kInvalidId);
277 redCustom = exp;
280 Value CodegenEnv::getCustomRedId() const {
281 assert(isCustomReduc());
282 return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
285 void CodegenEnv::endCustomReduc() {
286 assert(isCustomReduc());
287 redCustom = detail::kInvalidId;