[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / unittests / Support / CyclicReplacerCacheTest.cpp
blob26f0709f7d83100924bf14e8409dc4f1bb148cb9
1 //===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Support/CyclicReplacerCache.h"
10 #include "mlir/Support/LLVM.h"
11 #include "llvm/ADT/SetVector.h"
12 #include "gmock/gmock.h"
13 #include <map>
14 #include <set>
16 using namespace mlir;
18 TEST(CachedCyclicReplacerTest, testNoRecursion) {
19 CachedCyclicReplacer<int, bool> replacer(
20 /*replacer=*/[](int n) { return static_cast<bool>(n); },
21 /*cycleBreaker=*/[](int n) { return std::nullopt; });
23 EXPECT_EQ(replacer(3), true);
24 EXPECT_EQ(replacer(0), false);
27 TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
28 // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
29 std::optional<CachedCyclicReplacer<int, int>> replacer;
30 replacer.emplace(
31 /*replacer=*/[&](int n) { return (*replacer)((n + 1) % 3); },
32 /*cycleBreaker=*/[&](int n) { return -1; });
34 // Starting at 0.
35 EXPECT_EQ((*replacer)(0), -1);
36 // Starting at 2.
37 EXPECT_EQ((*replacer)(2), -1);
40 //===----------------------------------------------------------------------===//
41 // CachedCyclicReplacer: ChainRecursion
42 //===----------------------------------------------------------------------===//
44 /// This set of tests uses a replacer function that maps ints into vectors of
45 /// ints.
46 ///
47 /// The replacement result for input `n` is the replacement result of `(n+1)%3`
48 /// appended with an element `42`. Theoretically, this will produce an
49 /// infinitely long vector. The cycle-breaker function prunes this infinite
50 /// recursion in the replacer logic by returning an empty vector upon the first
51 /// re-occurrence of an input value.
52 namespace {
53 class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
54 public:
55 // N ==> (N+1) % 3
56 // This will create a chain of infinite length without recursion pruning.
57 CachedCyclicReplacerChainRecursionPruningTest()
58 : replacer(
59 [&](int n) {
60 ++invokeCount;
61 std::vector<int> result = replacer((n + 1) % 3);
62 result.push_back(42);
63 return result;
65 [&](int n) -> std::optional<std::vector<int>> {
66 return baseCase.value_or(n) == n
67 ? std::make_optional(std::vector<int>{})
68 : std::nullopt;
69 }) {}
71 std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); };
73 CachedCyclicReplacer<int, std::vector<int>> replacer;
74 int invokeCount = 0;
75 std::optional<int> baseCase = std::nullopt;
77 } // namespace
79 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
80 // Starting at 0. Cycle length is 3.
81 EXPECT_EQ(replacer(0), getChain(3));
82 EXPECT_EQ(invokeCount, 3);
84 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
85 invokeCount = 0;
86 EXPECT_EQ(replacer(1), getChain(5));
87 EXPECT_EQ(invokeCount, 2);
89 // Starting at 2. Cycle length is 4. Entire result is cached.
90 invokeCount = 0;
91 EXPECT_EQ(replacer(2), getChain(4));
92 EXPECT_EQ(invokeCount, 0);
95 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) {
96 // Starting at 1. Cycle length is 3.
97 EXPECT_EQ(replacer(1), getChain(3));
98 EXPECT_EQ(invokeCount, 3);
101 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) {
102 baseCase = 0;
104 // Starting at 0. Cycle length is 3.
105 EXPECT_EQ(replacer(0), getChain(3));
106 EXPECT_EQ(invokeCount, 3);
109 TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
110 baseCase = 0;
112 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
113 EXPECT_EQ(replacer(1), getChain(5));
114 EXPECT_EQ(invokeCount, 5);
116 // Starting at 0. Cycle length is 3. Entire result is cached.
117 invokeCount = 0;
118 EXPECT_EQ(replacer(0), getChain(3));
119 EXPECT_EQ(invokeCount, 0);
122 //===----------------------------------------------------------------------===//
123 // CachedCyclicReplacer: GraphReplacement
124 //===----------------------------------------------------------------------===//
126 /// This set of tests uses a replacer function that maps from cyclic graphs to
127 /// trees, pruning out cycles in the process.
129 /// It consists of two helper classes:
130 /// - Graph
131 /// - A directed graph where nodes are non-negative integers.
132 /// - PrunedGraph
133 /// - A Graph where edges that used to cause cycles are now represented with
134 /// an indirection (a recursionId).
135 namespace {
136 class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
137 public:
138 /// A directed graph where nodes are non-negative integers.
139 struct Graph {
140 using Node = int64_t;
142 /// Use ordered containers for deterministic output.
143 /// Nodes without outgoing edges are considered nonexistent.
144 std::map<Node, std::set<Node>> edges;
146 void addEdge(Node src, Node sink) { edges[src].insert(sink); }
148 bool isCyclic() const {
149 DenseSet<Node> visited;
150 for (Node root : llvm::make_first_range(edges)) {
151 if (visited.contains(root))
152 continue;
154 SetVector<Node> path;
155 SmallVector<Node> workstack;
156 workstack.push_back(root);
157 while (!workstack.empty()) {
158 Node curr = workstack.back();
159 workstack.pop_back();
161 if (curr < 0) {
162 // A negative node signals the end of processing all of this node's
163 // children. Remove self from path.
164 assert(path.back() == -curr && "internal inconsistency");
165 path.pop_back();
166 continue;
169 if (path.contains(curr))
170 return true;
172 visited.insert(curr);
173 auto edgesIter = edges.find(curr);
174 if (edgesIter == edges.end() || edgesIter->second.empty())
175 continue;
177 path.insert(curr);
178 // Push negative node to signify recursion return.
179 workstack.push_back(-curr);
180 workstack.insert(workstack.end(), edgesIter->second.begin(),
181 edgesIter->second.end());
184 return false;
187 /// Deterministic output for testing.
188 std::string serialize() const {
189 std::ostringstream oss;
190 for (const auto &[src, neighbors] : edges) {
191 oss << src << ":";
192 for (Graph::Node neighbor : neighbors)
193 oss << " " << neighbor;
194 oss << "\n";
196 return oss.str();
200 /// A Graph where edges that used to cause cycles (back-edges) are now
201 /// represented with an indirection (a recursionId).
203 /// In addition to each node having an integer ID, each node also tracks the
204 /// original integer ID it had in the original graph. This way for every
205 /// back-edge, we can represent it as pointing to a new instance of the
206 /// original node. Then we mark the original node and the new instance with
207 /// a new unique recursionId to indicate that they're supposed to be the same
208 /// node.
209 struct PrunedGraph {
210 using Node = Graph::Node;
211 struct NodeInfo {
212 Graph::Node originalId;
213 /// A negative recursive index means not recursive. Otherwise nodes with
214 /// the same originalId & recursionId are the same node in the original
215 /// graph.
216 int64_t recursionId;
219 /// Add a regular non-recursive-self node.
220 Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) {
221 Node id = nextConnectionId++;
222 info[id] = {originalId, recursionIndex};
223 return id;
225 /// Add a recursive-self-node, i.e. a duplicate of the original node that is
226 /// meant to represent an indirection to it.
227 std::pair<Node, int64_t> addRecursiveSelfNode(Graph::Node originalId) {
228 auto node = addNode(originalId, nextRecursionId);
229 return {node, nextRecursionId++};
231 void addEdge(Node src, Node sink) { connections.addEdge(src, sink); }
233 /// Deterministic output for testing.
234 std::string serialize() const {
235 std::ostringstream oss;
236 oss << "nodes\n";
237 for (const auto &[nodeId, nodeInfo] : info) {
238 oss << nodeId << ": n" << nodeInfo.originalId;
239 if (nodeInfo.recursionId >= 0)
240 oss << '<' << nodeInfo.recursionId << '>';
241 oss << "\n";
243 oss << "edges\n";
244 oss << connections.serialize();
245 return oss.str();
248 bool isCyclic() const { return connections.isCyclic(); }
250 private:
251 Graph connections;
252 int64_t nextRecursionId = 0;
253 int64_t nextConnectionId = 0;
254 /// Use ordered map for deterministic output.
255 std::map<Graph::Node, NodeInfo> info;
258 PrunedGraph breakCycles(const Graph &input) {
259 assert(input.isCyclic() && "input graph is not cyclic");
261 PrunedGraph output;
263 DenseMap<Graph::Node, int64_t> recMap;
264 auto cycleBreaker = [&](Graph::Node inNode) -> std::optional<Graph::Node> {
265 auto [node, recId] = output.addRecursiveSelfNode(inNode);
266 recMap[inNode] = recId;
267 return node;
270 CyclicReplacerCache<Graph::Node, Graph::Node> cache(cycleBreaker);
272 std::function<Graph::Node(Graph::Node)> replaceNode =
273 [&](Graph::Node inNode) {
274 auto cacheEntry = cache.lookupOrInit(inNode);
275 if (std::optional<Graph::Node> result = cacheEntry.get())
276 return *result;
278 // Recursively replace its neighbors.
279 SmallVector<Graph::Node> neighbors;
280 if (auto it = input.edges.find(inNode); it != input.edges.end())
281 neighbors = SmallVector<Graph::Node>(
282 llvm::map_range(it->second, replaceNode));
284 // Create a new node in the output graph.
285 int64_t recursionIndex =
286 cacheEntry.wasRepeated() ? recMap.lookup(inNode) : -1;
287 Graph::Node result = output.addNode(inNode, recursionIndex);
289 for (Graph::Node neighbor : neighbors)
290 output.addEdge(result, neighbor);
292 cacheEntry.resolve(result);
293 return result;
296 /// Translate starting from each node.
297 for (Graph::Node root : llvm::make_first_range(input.edges))
298 replaceNode(root);
300 return output;
303 /// Helper for serialization tests that allow putting comments in the
304 /// serialized format. Every line that begins with a `;` is considered a
305 /// comment. The entire line, incl. the terminating `\n` is removed.
306 std::string trimComments(StringRef input) {
307 std::ostringstream oss;
308 bool isNewLine = false;
309 bool isComment = false;
310 for (char c : input) {
311 // Lines beginning with ';' are comments.
312 if (isNewLine && c == ';')
313 isComment = true;
315 if (!isComment)
316 oss << c;
318 if (c == '\n') {
319 isNewLine = true;
320 isComment = false;
323 return oss.str();
326 } // namespace
328 TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
329 // 0 -> 1 -> 2
330 // ^ |
331 // +---------+
332 Graph input = {{{0, {1}}, {1, {2}}, {2, {0}}}};
333 PrunedGraph output = breakCycles(input);
334 ASSERT_FALSE(output.isCyclic()) << output.serialize();
335 EXPECT_EQ(output.serialize(), trimComments(R"(nodes
336 ; root 0
337 0: n0<0>
338 1: n2
339 2: n1
340 3: n0<0>
341 ; root 1
342 4: n2
343 ; root 2
344 5: n1
345 edges
346 1: 0
347 2: 1
348 3: 2
349 4: 3
350 5: 4
351 )"));
354 TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) {
355 // +----> 1 -----+
356 // | v
357 // 0 <---------- 3
358 // | ^
359 // +----> 2 -----+
361 // Two loops:
362 // 0 -> 1 -> 3 -> 0
363 // 0 -> 2 -> 3 -> 0
364 Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}};
365 PrunedGraph output = breakCycles(input);
366 ASSERT_FALSE(output.isCyclic()) << output.serialize();
367 EXPECT_EQ(output.serialize(), trimComments(R"(nodes
368 ; root 0
369 0: n0<0>
370 1: n3
371 2: n1
372 3: n2
373 4: n0<0>
374 ; root 1
375 5: n3
376 6: n1
377 ; root 2
378 7: n2
379 edges
380 1: 0
381 2: 1
382 3: 1
383 4: 2 3
384 5: 4
385 6: 5
386 7: 5
387 )"));
390 TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) {
391 // +----> 1 -----+
392 // | ^ v
393 // 0 <----+----- 2
395 // Two nested loops:
396 // 0 -> 1 -> 2 -> 0
397 // 1 -> 2 -> 1
398 Graph input = {{{0, {1}}, {1, {2}}, {2, {0, 1}}}};
399 PrunedGraph output = breakCycles(input);
400 ASSERT_FALSE(output.isCyclic()) << output.serialize();
401 EXPECT_EQ(output.serialize(), trimComments(R"(nodes
402 ; root 0
403 0: n0<0>
404 1: n1<1>
405 2: n2
406 3: n1<1>
407 4: n0<0>
408 ; root 1
409 5: n1<2>
410 6: n2
411 7: n1<2>
412 ; root 2
413 8: n2
414 edges
415 2: 0 1
416 3: 2
417 4: 3
418 6: 4 5
419 7: 6
420 8: 4 7
421 )"));
424 TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) {
425 // +----> 1 -----+
426 // | ^ v
427 // 0 <----+----- 3
428 // | v ^
429 // +----> 2 -----+
431 // Two sets of nested loops:
432 // 0 -> 1 -> 3 -> 0
433 // 1 -> 3 -> 1
434 // 0 -> 2 -> 3 -> 0
435 // 2 -> 3 -> 2
436 Graph input = {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}};
437 PrunedGraph output = breakCycles(input);
438 ASSERT_FALSE(output.isCyclic()) << output.serialize();
439 EXPECT_EQ(output.serialize(), trimComments(R"(nodes
440 ; root 0
441 0: n0<0>
442 1: n1<1>
443 2: n3<2>
444 3: n2
445 4: n3<2>
446 5: n1<1>
447 6: n2<3>
448 7: n3
449 8: n2<3>
450 9: n0<0>
451 ; root 1
452 10: n1<4>
453 11: n3<5>
454 12: n2
455 13: n3<5>
456 14: n1<4>
457 ; root 2
458 15: n2<6>
459 16: n3
460 17: n2<6>
461 ; root 3
462 18: n3
463 edges
464 ; root 0
465 3: 2
466 4: 0 1 3
467 5: 4
468 7: 0 5 6
469 8: 7
470 9: 5 8
471 ; root 1
472 12: 11
473 13: 9 10 12
474 14: 13
475 ; root 2
476 16: 9 14 15
477 17: 16
478 ; root 3
479 18: 9 14 17
480 )"));