1 //===- CodegenEnv.cpp - Code generation environment class ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "CodegenEnv.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/Linalg/Utils/Utils.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 using namespace mlir::sparse_tensor
;
21 //===----------------------------------------------------------------------===//
22 // Code generation environment helper functions
23 //===----------------------------------------------------------------------===//
25 /// Returns true if tensor materializes uninitialized into the computation.
26 static bool isMaterializing(Value val
) {
27 return val
.getDefiningOp
<tensor::EmptyOp
>() ||
28 val
.getDefiningOp
<bufferization::AllocTensorOp
>();
31 /// Sorts the dependent loops such that it is ordered in the same sequence in
32 /// which loops will be generated.
33 static void sortDependentLoops(std::vector
<LoopCoeffPair
> &target
) {
34 std::sort(target
.begin(), target
.end(),
35 [](const LoopCoeffPair
&l
, const LoopCoeffPair
&r
) {
36 assert(std::addressof(l
) == std::addressof(r
) || l
!= r
);
37 return l
.first
< r
.first
;
40 //===----------------------------------------------------------------------===//
41 // Code generation environment constructor and general methods
42 //===----------------------------------------------------------------------===//
44 CodegenEnv::CodegenEnv(linalg::GenericOp linop
, SparsificationOptions opts
,
45 unsigned numTensors
, unsigned numLoops
, unsigned maxRank
)
46 : linalgOp(linop
), sparseOptions(opts
),
47 latticeMerger(numTensors
, numLoops
, maxRank
), loopEmitter(),
48 sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
49 expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId
),
50 redCustom(detail::kInvalidId
), redValidLexInsert() {}
52 LogicalResult
CodegenEnv::initTensorExp() {
53 // Builds the tensor expression for the Linalg operation in SSA form.
54 std::optional
<ExprId
> optExp
= latticeMerger
.buildTensorExpFromLinalg(op());
55 if (!optExp
|| !isAdmissibleTensorExp(*optExp
))
62 void CodegenEnv::startEmit() {
63 assert(insChain
== nullptr && "must only start emitting once");
65 insChain
= sparseOut
->get();
66 latticeMerger
.setHasSparseOut(true);
69 // Sort the related loop array such that they are in the same order as they
70 // appears on the topoOrder.
71 // TODO: since we only handle affine addition for slice based codegen, and
72 // addition is assoicative, the order how we evaluate the expression does
73 // not matter. However, to support multiplication, the order of the loop
74 // index should match the evaluation order to the affine expression AST.
76 // Initialize loop emitter.
77 SmallVector
<Value
> tensors
; // input tensors passed to loop emitter
78 for (OpOperand
&t
: linalgOp
->getOpOperands()) {
79 tensors
.push_back(t
.get());
80 const TensorId tid
= makeTensorId(t
.getOperandNumber());
81 const Level lvlRank
= linalgOp
.getMatchingIndexingMap(&t
).getNumResults();
82 const auto enc
= getSparseTensorEncoding(t
.get().getType());
84 assert(!enc
|| lvlRank
== enc
.getLvlRank());
85 for (Level lvl
= 0; lvl
< lvlRank
; lvl
++)
86 sortDependentLoops(latticeMerger
.getDependentLoops(tid
, lvl
));
89 loopEmitter
.initialize(
91 StringAttr::get(linalgOp
.getContext(),
92 linalg::GenericOp::getOperationName()),
94 /*isSparseOut=*/sparseOut
!= nullptr, /*numLoops=*/getLoopNum(),
95 // TODO: compute the map and pass it to loop emitter directly instead of
96 // passing in a callback.
97 /*dependentLvlGetter=*/
99 Level lvl
) -> std::vector
<std::pair
<TensorLevel
, unsigned>> {
100 // Translates from a list of loop indices to a list of [tid, lvl] pair.
101 std::vector
<LoopCoeffPair
> &rLoops
= merger().getDependentLoops(t
, lvl
);
102 std::vector
<std::pair
<TensorLevel
, unsigned>> ret
;
103 ret
.reserve(rLoops
.size());
104 for (auto [loop
, coeff
] : rLoops
) {
105 TensorLevel tl
= makeTensorLevel(merger().getLoopDefiningLvl(loop
));
106 ret
.emplace_back(tl
, coeff
);
112 std::optional
<Operation
*> CodegenEnv::genLoopBoundary(
113 function_ref
<std::optional
<Operation
*>(MutableArrayRef
<Value
> parameters
)>
115 SmallVector
<Value
> params
;
117 params
.push_back(redVal
);
118 if (isValidLexInsert())
119 params
.push_back(redValidLexInsert
);
121 assert(!isValidLexInsert());
124 params
.push_back(expCount
);
125 if (insChain
!= nullptr)
126 params
.push_back(insChain
);
127 auto r
= callback(params
); // may update parameters
130 updateReduc(params
[i
++]);
131 if (isValidLexInsert())
132 updateValidLexInsert(params
[i
++]);
135 updateExpandCount(params
[i
++]);
136 if (insChain
!= nullptr)
137 updateInsertionChain(params
[i
]);
141 //===----------------------------------------------------------------------===//
142 // Code generation environment verify functions.
143 //===----------------------------------------------------------------------===//
145 bool CodegenEnv::isAdmissibleTensorExp(ExprId exp
) {
146 // We reject any expression that makes a reduction from `-outTensor`, as those
147 // expressions create a dependency between the current iteration (i) and the
148 // previous iteration (i-1). It would require iterating over the whole
149 // coordinate space, which prevent exploiting sparsity for faster code.
150 for (utils::IteratorType it
: linalgOp
.getIteratorTypesArray()) {
151 if (it
== utils::IteratorType::reduction
) {
152 if (latticeMerger
.hasNegateOnOut(exp
))
158 OpOperand
*lhs
= linalgOp
.getDpsInitOperand(0);
159 const TensorId tensor
= makeTensorId(lhs
->getOperandNumber());
160 // An non-annotated output tensor is assumed dense, and becomes a random
161 // access n-dim memref. Admissible since insertions cannot occur.
162 if (getSparseTensorType(lhs
->get()).isAllDense())
165 // A tensor expression with a sparse output tensor that changes its values
166 // but not its nonzero structure, an operation called "simply dynamic" in
167 // [Bik96,Ch9], is also admissible without special env.
168 if (latticeMerger
.isSingleCondition(tensor
, exp
))
171 // Accept "truly dynamic" if the output tensor materializes uninitialized
172 // into the computation and insertions occur in lexicographic index order.
175 // Find the outermost parallel nest to determine whether compress/expand is
178 const auto iteratorTypes
= linalgOp
.getIteratorTypesArray();
179 for (unsigned i
= 0, e
= getLoopNum(); i
< e
; i
++) {
180 if (linalg::isReductionIterator(iteratorTypes
[i
]))
181 break; // terminate at first reduction
185 // Inadmissible kernel should have already been rejected by the previous
186 // path during loop scheduling.
187 assert(static_cast<int64_t>(outerParNest
) >=
188 linalgOp
.getRank(linalgOp
.getDpsInitOperand(0)) - 1);
189 return isMaterializing(lhs
->get());
192 //===----------------------------------------------------------------------===//
193 // Code generation environment topological sort methods
194 //===----------------------------------------------------------------------===//
196 Value
CodegenEnv::getLoopVar(LoopId i
) const {
197 return loopEmitter
.getLoopIV(i
);
200 //===----------------------------------------------------------------------===//
201 // Code generation environment sparse tensor output and expansion methods
202 //===----------------------------------------------------------------------===//
204 void CodegenEnv::updateInsertionChain(Value chain
) {
205 assert(sparseOut
!= nullptr && insChain
!= nullptr);
209 bool CodegenEnv::atExpandLevel(OpOperand
*o
, unsigned rank
, LoopId n
) const {
210 return sparseOut
== o
&& outerParNest
== static_cast<LoopId
>(rank
- 1) &&
214 void CodegenEnv::startExpand(Value values
, Value filled
, Value added
,
216 assert(sparseOut
!= nullptr && expValues
== nullptr);
223 void CodegenEnv::updateExpandCount(Value count
) {
224 assert(sparseOut
!= nullptr && expValues
!= nullptr);
228 void CodegenEnv::endExpand() {
229 assert(sparseOut
!= nullptr && expValues
!= nullptr);
230 expValues
= expFilled
= expAdded
= expCount
= Value();
233 //===----------------------------------------------------------------------===//
234 // Code generation environment reduction methods
235 //===----------------------------------------------------------------------===//
237 void CodegenEnv::startReduc(ExprId exp
, Value val
) {
238 assert(!isReduc() && exp
!= detail::kInvalidId
&& val
);
241 latticeMerger
.setExprValue(exp
, val
);
244 void CodegenEnv::updateReduc(Value val
) {
245 assert(isReduc() && val
);
247 latticeMerger
.clearExprValue(redExp
);
248 latticeMerger
.setExprValue(redExp
, val
);
251 Value
CodegenEnv::endReduc() {
255 latticeMerger
.clearExprValue(redExp
);
256 redExp
= detail::kInvalidId
;
260 void CodegenEnv::startValidLexInsert(Value val
) {
261 assert(!isValidLexInsert() && isReduc() && val
);
262 redValidLexInsert
= val
;
265 void CodegenEnv::updateValidLexInsert(Value val
) {
266 assert(redValidLexInsert
&& isReduc() && val
);
267 redValidLexInsert
= val
;
270 void CodegenEnv::endValidLexInsert() {
271 assert(isValidLexInsert() && !isReduc());
272 redValidLexInsert
= Value();
275 void CodegenEnv::startCustomReduc(ExprId exp
) {
276 assert(!isCustomReduc() && exp
!= detail::kInvalidId
);
280 Value
CodegenEnv::getCustomRedId() const {
281 assert(isCustomReduc());
282 return dyn_cast
<sparse_tensor::ReduceOp
>(exp(redCustom
).op
).getIdentity();
285 void CodegenEnv::endCustomReduc() {
286 assert(isCustomReduc());
287 redCustom
= detail::kInvalidId
;