1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2 #include "gmock/gmock.h"
3 #include "gtest/gtest.h"
7 using namespace mlir::sparse_tensor
;
11 /// Simple recursive data structure used to match expressions in Mergers.
15 /// Expressions representing tensors simply have a tensor number.
18 /// Tensor operations point to their children.
19 std::shared_ptr
<Pattern
> e0
;
20 std::shared_ptr
<Pattern
> e1
;
23 /// Rather than using these, please use the readable helper constructor
24 /// functions below to make tests more readable.
25 Pattern(unsigned tensorNum
) : kind(Kind::kTensor
), tensorNum(tensorNum
) {}
26 Pattern(Kind kind
, const std::shared_ptr
<Pattern
> &e0
,
27 const std::shared_ptr
<Pattern
> &e1
)
28 : kind(kind
), e0(e0
), e1(e1
) {
29 assert(kind
>= Kind::kMulF
);
35 /// Readable Pattern builder functions.
36 /// These should be preferred over the actual constructors.
39 static std::shared_ptr
<Pattern
> tensorPattern(unsigned tensorNum
) {
40 return std::make_shared
<Pattern
>(tensorNum
);
43 static std::shared_ptr
<Pattern
>
44 addfPattern(const std::shared_ptr
<Pattern
> &e0
,
45 const std::shared_ptr
<Pattern
> &e1
) {
46 return std::make_shared
<Pattern
>(Kind::kAddF
, e0
, e1
);
49 static std::shared_ptr
<Pattern
>
50 mulfPattern(const std::shared_ptr
<Pattern
> &e0
,
51 const std::shared_ptr
<Pattern
> &e1
) {
52 return std::make_shared
<Pattern
>(Kind::kMulF
, e0
, e1
);
55 class MergerTestBase
: public ::testing::Test
{
57 MergerTestBase(unsigned numTensors
, unsigned numLoops
)
58 : numTensors(numTensors
), numLoops(numLoops
),
59 merger(numTensors
, numLoops
) {}
62 /// Expression construction helpers.
65 unsigned tensor(unsigned tensor
) {
66 return merger
.addExp(Kind::kTensor
, tensor
);
69 unsigned addf(unsigned e0
, unsigned e1
) {
70 return merger
.addExp(Kind::kAddF
, e0
, e1
);
73 unsigned mulf(unsigned e0
, unsigned e1
) {
74 return merger
.addExp(Kind::kMulF
, e0
, e1
);
78 /// Comparison helpers.
81 /// For readability of tests.
82 unsigned lat(unsigned lat
) { return lat
; }
84 /// Returns true if a lattice point with an expression matching the given
85 /// pattern and bits matching the given bits is present in lattice points
86 /// [p, p+n) of lattice set s. This is useful for testing partial ordering
87 /// constraints between lattice points. We generally know how contiguous
88 /// groups of lattice points should be ordered with respect to other groups,
89 /// but there is no required ordering within groups.
90 bool latPointWithinRange(unsigned s
, unsigned p
, unsigned n
,
91 const std::shared_ptr
<Pattern
> &pattern
,
92 const BitVector
&bits
) {
93 for (unsigned i
= p
; i
< p
+ n
; ++i
) {
94 if (compareExpression(merger
.lat(merger
.set(s
)[i
]).exp
, pattern
) &&
95 compareBits(s
, i
, bits
))
101 /// Wrapper over latPointWithinRange for readability of tests.
102 void expectLatPointWithinRange(unsigned s
, unsigned p
, unsigned n
,
103 const std::shared_ptr
<Pattern
> &pattern
,
104 const BitVector
&bits
) {
105 EXPECT_TRUE(latPointWithinRange(s
, p
, n
, pattern
, bits
));
108 /// Wrapper over expectLatPointWithinRange for a single lat point.
109 void expectLatPoint(unsigned s
, unsigned p
,
110 const std::shared_ptr
<Pattern
> &pattern
,
111 const BitVector
&bits
) {
112 EXPECT_TRUE(latPointWithinRange(s
, p
, 1, pattern
, bits
));
115 /// Converts a vector of (loop, tensor) pairs to a bitvector with the
116 /// corresponding bits set.
118 loopsToBits(const std::vector
<std::pair
<unsigned, unsigned>> &loops
) {
119 BitVector testBits
= BitVector(numTensors
+ 1, false);
120 for (auto l
: loops
) {
121 auto loop
= std::get
<0>(l
);
122 auto tensor
= std::get
<1>(l
);
123 testBits
.set(numTensors
* loop
+ tensor
);
128 /// Returns true if the bits of lattice point p in set s match the given bits.
129 bool compareBits(unsigned s
, unsigned p
, const BitVector
&bits
) {
130 return merger
.lat(merger
.set(s
)[p
]).bits
== bits
;
133 /// Check that there are n lattice points in set s.
134 void expectNumLatPoints(unsigned s
, unsigned n
) {
135 EXPECT_THAT(merger
.set(s
).size(), n
);
138 /// Compares expressions for equality. Equality is defined recursively as:
139 /// - Two expressions can only be equal if they have the same Kind.
140 /// - Two binary expressions are equal if they have the same Kind and their
141 /// children are equal.
142 /// - Expressions with Kind invariant or tensor are equal if they have the
143 /// same expression id.
144 bool compareExpression(unsigned e
, const std::shared_ptr
<Pattern
> &pattern
) {
145 auto tensorExp
= merger
.exp(e
);
146 if (tensorExp
.kind
!= pattern
->kind
)
148 assert(tensorExp
.kind
!= Kind::kInvariant
&&
149 "Invariant comparison not yet supported");
150 switch (tensorExp
.kind
) {
152 return tensorExp
.tensor
== pattern
->tensorNum
;
158 return compareExpression(tensorExp
.children
.e0
, pattern
->e0
);
171 return compareExpression(tensorExp
.children
.e0
, pattern
->e0
) &&
172 compareExpression(tensorExp
.children
.e1
, pattern
->e1
);
174 llvm_unreachable("Unhandled Kind");
183 class MergerTest3T1L
: public MergerTestBase
{
185 // Our three tensors (two inputs, one output).
186 const unsigned t0
= 0, t1
= 1, t2
= 2;
189 const unsigned l0
= 0;
191 MergerTest3T1L() : MergerTestBase(3, 1) {
192 // Tensor 0: sparse input vector.
193 merger
.addExp(Kind::kTensor
, t0
, -1u);
194 merger
.setDim(t0
, l0
, Dim::kSparse
);
196 // Tensor 1: sparse input vector.
197 merger
.addExp(Kind::kTensor
, t1
, -1u);
198 merger
.setDim(t1
, l0
, Dim::kSparse
);
200 // Tensor 2: dense output vector.
201 merger
.addExp(Kind::kTensor
, t2
, -1u);
202 merger
.setDim(t2
, l0
, Dim::kDense
);
208 /// Vector addition of 2 vectors, i.e.:
209 /// a(i) = b(i) + c(i)
210 /// which should form the 3 lattice points
212 /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
213 /// lat( i_00 / tensor_0 )
214 /// lat( i_01 / tensor_1 )
216 /// and after optimization, will reduce to the 2 lattice points
218 /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
219 /// lat( i_00 / tensor_0 )
221 TEST_F(MergerTest3T1L
, VectorAdd2
) {
222 // Construct expression.
223 auto e
= addf(tensor(t0
), tensor(t1
));
225 // Build lattices and check.
226 auto s
= merger
.buildLattices(e
, l0
);
227 expectNumLatPoints(s
, 3);
228 expectLatPoint(s
, lat(0), addfPattern(tensorPattern(t0
), tensorPattern(t1
)),
229 loopsToBits({{l0
, t0
}, {l0
, t1
}}));
230 expectLatPointWithinRange(s
, lat(1), 2, tensorPattern(t0
),
231 loopsToBits({{l0
, t0
}}));
232 expectLatPointWithinRange(s
, lat(1), 2, tensorPattern(t1
),
233 loopsToBits({{l0
, t1
}}));
235 // Optimize lattices and check.
236 s
= merger
.optimizeSet(s
);
237 expectNumLatPoints(s
, 3);
238 expectLatPoint(s
, lat(0), addfPattern(tensorPattern(t0
), tensorPattern(t1
)),
239 loopsToBits({{l0
, t0
}, {l0
, t1
}}));
240 expectLatPointWithinRange(s
, lat(1), 2, tensorPattern(t0
),
241 loopsToBits({{l0
, t0
}}));
242 expectLatPointWithinRange(s
, lat(1), 2, tensorPattern(t1
),
243 loopsToBits({{l0
, t1
}}));
246 /// Vector multiplication of 2 vectors, i.e.:
247 /// a(i) = b(i) * c(i)
248 /// which should form the single lattice point
250 /// lat( i_00 i_01 / (tensor_0 * tensor_1) )
252 TEST_F(MergerTest3T1L
, VectorMul2
) {
253 // Construct expression.
254 auto e
= mulf(t0
, t1
);
256 // Build lattices and check.
257 auto s
= merger
.buildLattices(e
, l0
);
258 expectNumLatPoints(s
, 1);
259 expectLatPoint(s
, lat(0), mulfPattern(tensorPattern(t0
), tensorPattern(t1
)),
260 loopsToBits({{l0
, t0
}, {l0
, t1
}}));
262 // Optimize lattices and check.
263 s
= merger
.optimizeSet(s
);
264 expectNumLatPoints(s
, 1);
265 expectLatPoint(s
, lat(0), mulfPattern(tensorPattern(t0
), tensorPattern(t1
)),
266 loopsToBits({{l0
, t0
}, {l0
, t1
}}));