1 //===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Support/CyclicReplacerCache.h"
10 #include "mlir/Support/LLVM.h"
11 #include "llvm/ADT/SetVector.h"
12 #include "gmock/gmock.h"
18 TEST(CachedCyclicReplacerTest
, testNoRecursion
) {
19 CachedCyclicReplacer
<int, bool> replacer(
20 /*replacer=*/[](int n
) { return static_cast<bool>(n
); },
21 /*cycleBreaker=*/[](int n
) { return std::nullopt
; });
23 EXPECT_EQ(replacer(3), true);
24 EXPECT_EQ(replacer(0), false);
27 TEST(CachedCyclicReplacerTest
, testInPlaceRecursionPruneAnywhere
) {
28 // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
29 std::optional
<CachedCyclicReplacer
<int, int>> replacer
;
31 /*replacer=*/[&](int n
) { return (*replacer
)((n
+ 1) % 3); },
32 /*cycleBreaker=*/[&](int n
) { return -1; });
35 EXPECT_EQ((*replacer
)(0), -1);
37 EXPECT_EQ((*replacer
)(2), -1);
40 //===----------------------------------------------------------------------===//
41 // CachedCyclicReplacer: ChainRecursion
42 //===----------------------------------------------------------------------===//
44 /// This set of tests uses a replacer function that maps ints into vectors of
47 /// The replacement result for input `n` is the replacement result of `(n+1)%3`
48 /// appended with an element `42`. Theoretically, this will produce an
49 /// infinitely long vector. The cycle-breaker function prunes this infinite
50 /// recursion in the replacer logic by returning an empty vector upon the first
51 /// re-occurrence of an input value.
53 class CachedCyclicReplacerChainRecursionPruningTest
: public ::testing::Test
{
56 // This will create a chain of infinite length without recursion pruning.
57 CachedCyclicReplacerChainRecursionPruningTest()
61 std::vector
<int> result
= replacer((n
+ 1) % 3);
65 [&](int n
) -> std::optional
<std::vector
<int>> {
66 return baseCase
.value_or(n
) == n
67 ? std::make_optional(std::vector
<int>{})
71 std::vector
<int> getChain(unsigned N
) { return std::vector
<int>(N
, 42); };
73 CachedCyclicReplacer
<int, std::vector
<int>> replacer
;
75 std::optional
<int> baseCase
= std::nullopt
;
79 TEST_F(CachedCyclicReplacerChainRecursionPruningTest
, testPruneAnywhere0
) {
80 // Starting at 0. Cycle length is 3.
81 EXPECT_EQ(replacer(0), getChain(3));
82 EXPECT_EQ(invokeCount
, 3);
84 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
86 EXPECT_EQ(replacer(1), getChain(5));
87 EXPECT_EQ(invokeCount
, 2);
89 // Starting at 2. Cycle length is 4. Entire result is cached.
91 EXPECT_EQ(replacer(2), getChain(4));
92 EXPECT_EQ(invokeCount
, 0);
95 TEST_F(CachedCyclicReplacerChainRecursionPruningTest
, testPruneAnywhere1
) {
96 // Starting at 1. Cycle length is 3.
97 EXPECT_EQ(replacer(1), getChain(3));
98 EXPECT_EQ(invokeCount
, 3);
101 TEST_F(CachedCyclicReplacerChainRecursionPruningTest
, testPruneSpecific0
) {
104 // Starting at 0. Cycle length is 3.
105 EXPECT_EQ(replacer(0), getChain(3));
106 EXPECT_EQ(invokeCount
, 3);
109 TEST_F(CachedCyclicReplacerChainRecursionPruningTest
, testPruneSpecific1
) {
112 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
113 EXPECT_EQ(replacer(1), getChain(5));
114 EXPECT_EQ(invokeCount
, 5);
116 // Starting at 0. Cycle length is 3. Entire result is cached.
118 EXPECT_EQ(replacer(0), getChain(3));
119 EXPECT_EQ(invokeCount
, 0);
122 //===----------------------------------------------------------------------===//
123 // CachedCyclicReplacer: GraphReplacement
124 //===----------------------------------------------------------------------===//
126 /// This set of tests uses a replacer function that maps from cyclic graphs to
127 /// trees, pruning out cycles in the process.
129 /// It consists of two helper classes:
131 /// - A directed graph where nodes are non-negative integers.
133 /// - A Graph where edges that used to cause cycles are now represented with
134 /// an indirection (a recursionId).
136 class CachedCyclicReplacerGraphReplacement
: public ::testing::Test
{
138 /// A directed graph where nodes are non-negative integers.
140 using Node
= int64_t;
142 /// Use ordered containers for deterministic output.
143 /// Nodes without outgoing edges are considered nonexistent.
144 std::map
<Node
, std::set
<Node
>> edges
;
146 void addEdge(Node src
, Node sink
) { edges
[src
].insert(sink
); }
148 bool isCyclic() const {
149 DenseSet
<Node
> visited
;
150 for (Node root
: llvm::make_first_range(edges
)) {
151 if (visited
.contains(root
))
154 SetVector
<Node
> path
;
155 SmallVector
<Node
> workstack
;
156 workstack
.push_back(root
);
157 while (!workstack
.empty()) {
158 Node curr
= workstack
.back();
159 workstack
.pop_back();
162 // A negative node signals the end of processing all of this node's
163 // children. Remove self from path.
164 assert(path
.back() == -curr
&& "internal inconsistency");
169 if (path
.contains(curr
))
172 visited
.insert(curr
);
173 auto edgesIter
= edges
.find(curr
);
174 if (edgesIter
== edges
.end() || edgesIter
->second
.empty())
178 // Push negative node to signify recursion return.
179 workstack
.push_back(-curr
);
180 workstack
.insert(workstack
.end(), edgesIter
->second
.begin(),
181 edgesIter
->second
.end());
187 /// Deterministic output for testing.
188 std::string
serialize() const {
189 std::ostringstream oss
;
190 for (const auto &[src
, neighbors
] : edges
) {
192 for (Graph::Node neighbor
: neighbors
)
193 oss
<< " " << neighbor
;
200 /// A Graph where edges that used to cause cycles (back-edges) are now
201 /// represented with an indirection (a recursionId).
203 /// In addition to each node having an integer ID, each node also tracks the
204 /// original integer ID it had in the original graph. This way for every
205 /// back-edge, we can represent it as pointing to a new instance of the
206 /// original node. Then we mark the original node and the new instance with
207 /// a new unique recursionId to indicate that they're supposed to be the same
210 using Node
= Graph::Node
;
212 Graph::Node originalId
;
213 /// A negative recursive index means not recursive. Otherwise nodes with
214 /// the same originalId & recursionId are the same node in the original
219 /// Add a regular non-recursive-self node.
220 Node
addNode(Graph::Node originalId
, int64_t recursionIndex
= -1) {
221 Node id
= nextConnectionId
++;
222 info
[id
] = {originalId
, recursionIndex
};
225 /// Add a recursive-self-node, i.e. a duplicate of the original node that is
226 /// meant to represent an indirection to it.
227 std::pair
<Node
, int64_t> addRecursiveSelfNode(Graph::Node originalId
) {
228 auto node
= addNode(originalId
, nextRecursionId
);
229 return {node
, nextRecursionId
++};
231 void addEdge(Node src
, Node sink
) { connections
.addEdge(src
, sink
); }
233 /// Deterministic output for testing.
234 std::string
serialize() const {
235 std::ostringstream oss
;
237 for (const auto &[nodeId
, nodeInfo
] : info
) {
238 oss
<< nodeId
<< ": n" << nodeInfo
.originalId
;
239 if (nodeInfo
.recursionId
>= 0)
240 oss
<< '<' << nodeInfo
.recursionId
<< '>';
244 oss
<< connections
.serialize();
248 bool isCyclic() const { return connections
.isCyclic(); }
252 int64_t nextRecursionId
= 0;
253 int64_t nextConnectionId
= 0;
254 /// Use ordered map for deterministic output.
255 std::map
<Graph::Node
, NodeInfo
> info
;
258 PrunedGraph
breakCycles(const Graph
&input
) {
259 assert(input
.isCyclic() && "input graph is not cyclic");
263 DenseMap
<Graph::Node
, int64_t> recMap
;
264 auto cycleBreaker
= [&](Graph::Node inNode
) -> std::optional
<Graph::Node
> {
265 auto [node
, recId
] = output
.addRecursiveSelfNode(inNode
);
266 recMap
[inNode
] = recId
;
270 CyclicReplacerCache
<Graph::Node
, Graph::Node
> cache(cycleBreaker
);
272 std::function
<Graph::Node(Graph::Node
)> replaceNode
=
273 [&](Graph::Node inNode
) {
274 auto cacheEntry
= cache
.lookupOrInit(inNode
);
275 if (std::optional
<Graph::Node
> result
= cacheEntry
.get())
278 // Recursively replace its neighbors.
279 SmallVector
<Graph::Node
> neighbors
;
280 if (auto it
= input
.edges
.find(inNode
); it
!= input
.edges
.end())
281 neighbors
= SmallVector
<Graph::Node
>(
282 llvm::map_range(it
->second
, replaceNode
));
284 // Create a new node in the output graph.
285 int64_t recursionIndex
=
286 cacheEntry
.wasRepeated() ? recMap
.lookup(inNode
) : -1;
287 Graph::Node result
= output
.addNode(inNode
, recursionIndex
);
289 for (Graph::Node neighbor
: neighbors
)
290 output
.addEdge(result
, neighbor
);
292 cacheEntry
.resolve(result
);
296 /// Translate starting from each node.
297 for (Graph::Node root
: llvm::make_first_range(input
.edges
))
303 /// Helper for serialization tests that allow putting comments in the
304 /// serialized format. Every line that begins with a `;` is considered a
305 /// comment. The entire line, incl. the terminating `\n` is removed.
306 std::string
trimComments(StringRef input
) {
307 std::ostringstream oss
;
308 bool isNewLine
= false;
309 bool isComment
= false;
310 for (char c
: input
) {
311 // Lines beginning with ';' are comments.
312 if (isNewLine
&& c
== ';')
328 TEST_F(CachedCyclicReplacerGraphReplacement
, testSingleLoop
) {
332 Graph input
= {{{0, {1}}, {1, {2}}, {2, {0}}}};
333 PrunedGraph output
= breakCycles(input
);
334 ASSERT_FALSE(output
.isCyclic()) << output
.serialize();
335 EXPECT_EQ(output
.serialize(), trimComments(R
"(nodes
354 TEST_F(CachedCyclicReplacerGraphReplacement
, testDualLoop
) {
364 Graph input
= {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}};
365 PrunedGraph output
= breakCycles(input
);
366 ASSERT_FALSE(output
.isCyclic()) << output
.serialize();
367 EXPECT_EQ(output
.serialize(), trimComments(R
"(nodes
390 TEST_F(CachedCyclicReplacerGraphReplacement
, testNestedLoops
) {
398 Graph input
= {{{0, {1}}, {1, {2}}, {2, {0, 1}}}};
399 PrunedGraph output
= breakCycles(input
);
400 ASSERT_FALSE(output
.isCyclic()) << output
.serialize();
401 EXPECT_EQ(output
.serialize(), trimComments(R
"(nodes
424 TEST_F(CachedCyclicReplacerGraphReplacement
, testDualNestedLoops
) {
431 // Two sets of nested loops:
436 Graph input
= {{{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}};
437 PrunedGraph output
= breakCycles(input
);
438 ASSERT_FALSE(output
.isCyclic()) << output
.serialize();
439 EXPECT_EQ(output
.serialize(), trimComments(R
"(nodes