[clang][extract-api] Emit "navigator" property of "name" in SymbolGraph
[llvm-project.git] / mlir / unittests / Dialect / SparseTensor / MergerTest.cpp
blob4bdfa71d8bc49f94efa68bf5162d9acbbfc58ae1
1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2 #include "gmock/gmock.h"
3 #include "gtest/gtest.h"
4 #include <memory>
6 using namespace mlir;
7 using namespace mlir::sparse_tensor;
9 namespace {
11 /// Simple recursive data structure used to match expressions in Mergers.
12 struct Pattern {
13 Kind kind;
15 /// Expressions representing tensors simply have a tensor number.
16 unsigned tensorNum;
18 /// Tensor operations point to their children.
19 std::shared_ptr<Pattern> e0;
20 std::shared_ptr<Pattern> e1;
22 /// Constructors.
23 /// Rather than using these, please use the readable helper constructor
24 /// functions below to make tests more readable.
25 Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
26 Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
27 const std::shared_ptr<Pattern> &e1)
28 : kind(kind), e0(e0), e1(e1) {
29 assert(kind >= Kind::kMulF);
30 assert(e0 && e1);
34 ///
35 /// Readable Pattern builder functions.
36 /// These should be preferred over the actual constructors.
37 ///
39 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
40 return std::make_shared<Pattern>(tensorNum);
43 static std::shared_ptr<Pattern>
44 addfPattern(const std::shared_ptr<Pattern> &e0,
45 const std::shared_ptr<Pattern> &e1) {
46 return std::make_shared<Pattern>(Kind::kAddF, e0, e1);
49 static std::shared_ptr<Pattern>
50 mulfPattern(const std::shared_ptr<Pattern> &e0,
51 const std::shared_ptr<Pattern> &e1) {
52 return std::make_shared<Pattern>(Kind::kMulF, e0, e1);
55 class MergerTestBase : public ::testing::Test {
56 protected:
57 MergerTestBase(unsigned numTensors, unsigned numLoops)
58 : numTensors(numTensors), numLoops(numLoops),
59 merger(numTensors, numLoops) {}
61 ///
62 /// Expression construction helpers.
63 ///
65 unsigned tensor(unsigned tensor) {
66 return merger.addExp(Kind::kTensor, tensor);
69 unsigned addf(unsigned e0, unsigned e1) {
70 return merger.addExp(Kind::kAddF, e0, e1);
73 unsigned mulf(unsigned e0, unsigned e1) {
74 return merger.addExp(Kind::kMulF, e0, e1);
77 ///
78 /// Comparison helpers.
79 ///
81 /// For readability of tests.
82 unsigned lat(unsigned lat) { return lat; }
84 /// Returns true if a lattice point with an expression matching the given
85 /// pattern and bits matching the given bits is present in lattice points
86 /// [p, p+n) of lattice set s. This is useful for testing partial ordering
87 /// constraints between lattice points. We generally know how contiguous
88 /// groups of lattice points should be ordered with respect to other groups,
89 /// but there is no required ordering within groups.
90 bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
91 const std::shared_ptr<Pattern> &pattern,
92 const BitVector &bits) {
93 for (unsigned i = p; i < p + n; ++i) {
94 if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
95 compareBits(s, i, bits))
96 return true;
98 return false;
101 /// Wrapper over latPointWithinRange for readability of tests.
102 void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
103 const std::shared_ptr<Pattern> &pattern,
104 const BitVector &bits) {
105 EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
108 /// Wrapper over expectLatPointWithinRange for a single lat point.
109 void expectLatPoint(unsigned s, unsigned p,
110 const std::shared_ptr<Pattern> &pattern,
111 const BitVector &bits) {
112 EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
115 /// Converts a vector of (loop, tensor) pairs to a bitvector with the
116 /// corresponding bits set.
117 BitVector
118 loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
119 BitVector testBits = BitVector(numTensors + 1, false);
120 for (auto l : loops) {
121 auto loop = std::get<0>(l);
122 auto tensor = std::get<1>(l);
123 testBits.set(numTensors * loop + tensor);
125 return testBits;
128 /// Returns true if the bits of lattice point p in set s match the given bits.
129 bool compareBits(unsigned s, unsigned p, const BitVector &bits) {
130 return merger.lat(merger.set(s)[p]).bits == bits;
133 /// Check that there are n lattice points in set s.
134 void expectNumLatPoints(unsigned s, unsigned n) {
135 EXPECT_THAT(merger.set(s).size(), n);
138 /// Compares expressions for equality. Equality is defined recursively as:
139 /// - Two expressions can only be equal if they have the same Kind.
140 /// - Two binary expressions are equal if they have the same Kind and their
141 /// children are equal.
142 /// - Expressions with Kind invariant or tensor are equal if they have the
143 /// same expression id.
144 bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
145 auto tensorExp = merger.exp(e);
146 if (tensorExp.kind != pattern->kind)
147 return false;
148 assert(tensorExp.kind != Kind::kInvariant &&
149 "Invariant comparison not yet supported");
150 switch (tensorExp.kind) {
151 case Kind::kTensor:
152 return tensorExp.tensor == pattern->tensorNum;
153 case Kind::kAbsF:
154 case Kind::kCeilF:
155 case Kind::kFloorF:
156 case Kind::kNegF:
157 case Kind::kNegI:
158 return compareExpression(tensorExp.children.e0, pattern->e0);
159 case Kind::kMulF:
160 case Kind::kMulI:
161 case Kind::kDivF:
162 case Kind::kDivS:
163 case Kind::kDivU:
164 case Kind::kAddF:
165 case Kind::kAddI:
166 case Kind::kSubF:
167 case Kind::kSubI:
168 case Kind::kAndI:
169 case Kind::kOrI:
170 case Kind::kXorI:
171 return compareExpression(tensorExp.children.e0, pattern->e0) &&
172 compareExpression(tensorExp.children.e1, pattern->e1);
173 default:
174 llvm_unreachable("Unhandled Kind");
178 unsigned numTensors;
179 unsigned numLoops;
180 Merger merger;
183 class MergerTest3T1L : public MergerTestBase {
184 protected:
185 // Our three tensors (two inputs, one output).
186 const unsigned t0 = 0, t1 = 1, t2 = 2;
188 // Our single loop.
189 const unsigned l0 = 0;
191 MergerTest3T1L() : MergerTestBase(3, 1) {
192 // Tensor 0: sparse input vector.
193 merger.addExp(Kind::kTensor, t0, -1u);
194 merger.setDim(t0, l0, Dim::kSparse);
196 // Tensor 1: sparse input vector.
197 merger.addExp(Kind::kTensor, t1, -1u);
198 merger.setDim(t1, l0, Dim::kSparse);
200 // Tensor 2: dense output vector.
201 merger.addExp(Kind::kTensor, t2, -1u);
202 merger.setDim(t2, l0, Dim::kDense);
206 } // namespace
208 /// Vector addition of 2 vectors, i.e.:
209 /// a(i) = b(i) + c(i)
210 /// which should form the 3 lattice points
211 /// {
212 /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
213 /// lat( i_00 / tensor_0 )
214 /// lat( i_01 / tensor_1 )
215 /// }
216 /// and after optimization, will reduce to the 2 lattice points
217 /// {
218 /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
219 /// lat( i_00 / tensor_0 )
220 /// }
221 TEST_F(MergerTest3T1L, VectorAdd2) {
222 // Construct expression.
223 auto e = addf(tensor(t0), tensor(t1));
225 // Build lattices and check.
226 auto s = merger.buildLattices(e, l0);
227 expectNumLatPoints(s, 3);
228 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
229 loopsToBits({{l0, t0}, {l0, t1}}));
230 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
231 loopsToBits({{l0, t0}}));
232 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
233 loopsToBits({{l0, t1}}));
235 // Optimize lattices and check.
236 s = merger.optimizeSet(s);
237 expectNumLatPoints(s, 3);
238 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
239 loopsToBits({{l0, t0}, {l0, t1}}));
240 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
241 loopsToBits({{l0, t0}}));
242 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
243 loopsToBits({{l0, t1}}));
246 /// Vector multiplication of 2 vectors, i.e.:
247 /// a(i) = b(i) * c(i)
248 /// which should form the single lattice point
249 /// {
250 /// lat( i_00 i_01 / (tensor_0 * tensor_1) )
251 /// }
252 TEST_F(MergerTest3T1L, VectorMul2) {
253 // Construct expression.
254 auto e = mulf(t0, t1);
256 // Build lattices and check.
257 auto s = merger.buildLattices(e, l0);
258 expectNumLatPoints(s, 1);
259 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
260 loopsToBits({{l0, t0}, {l0, t1}}));
262 // Optimize lattices and check.
263 s = merger.optimizeSet(s);
264 expectNumLatPoints(s, 1);
265 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
266 loopsToBits({{l0, t0}, {l0, t1}}));