1 //===- RootOrderingTest.cpp - unit tests for optimal branching ------------===//
3 // Part of the LLVM Project, under the Apache License v[1].0 with LLVM
4 // Exceptions. See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h"
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/MLIRContext.h"
13 #include "gtest/gtest.h"
16 using namespace mlir::arith
;
17 using namespace mlir::pdl_to_pdl_interp
;
21 //===----------------------------------------------------------------------===//
23 //===----------------------------------------------------------------------===//
25 /// The test fixture for constructing root ordering tests and verifying results.
26 /// This fixture constructs the test values v. The test populates the graph
27 /// with the desired costs and then calls check(), passing the expected optimal
28 /// cost and the list of edges in the preorder traversal of the optimal
30 class RootOrderingTest
: public ::testing::Test
{
33 context
.loadDialect
<ArithDialect
>();
37 /// Creates the test values. These values simply act as vertices / vertex IDs
38 /// in the cost graph, rather than being a part of an IR.
40 OpBuilder
builder(&context
);
41 builder
.setInsertionPointToStart(&block
);
42 for (int i
= 0; i
< 4; ++i
)
43 // Ops will be deleted when `block` is destroyed.
44 v
[i
] = builder
.create
<ConstantIntOp
>(builder
.getUnknownLoc(), i
, 32);
47 /// Checks that optimal branching on graph has the given cost and
48 /// its preorder traversal results in the specified edges.
49 void check(unsigned cost
, const OptimalBranching::EdgeList
&edges
) {
50 OptimalBranching
opt(graph
, v
[0]);
51 EXPECT_EQ(opt
.solve(), cost
);
52 EXPECT_EQ(opt
.preOrderTraversal({v
, v
+ edges
.size()}), edges
);
53 for (std::pair
<Value
, Value
> edge
: edges
)
54 EXPECT_EQ(opt
.getRootOrderingParents().lookup(edge
.first
), edge
.second
);
58 /// The context for creating the values.
61 /// Block holding all the operations.
64 /// Values used in the graph definition. We always use leading `n` values.
67 /// The graph being tested on.
68 RootOrderingGraph graph
;
71 //===----------------------------------------------------------------------===//
72 // Simple 3-node graphs
73 //===----------------------------------------------------------------------===//
75 TEST_F(RootOrderingTest
, simpleA
) {
76 graph
[v
[1]][v
[0]].cost
= {1, 10};
77 graph
[v
[2]][v
[0]].cost
= {1, 11};
78 graph
[v
[1]][v
[2]].cost
= {2, 12};
79 graph
[v
[2]][v
[1]].cost
= {2, 13};
80 check(2, {{v
[0], {}}, {v
[1], v
[0]}, {v
[2], v
[0]}});
83 TEST_F(RootOrderingTest
, simpleB
) {
84 graph
[v
[1]][v
[0]].cost
= {1, 10};
85 graph
[v
[2]][v
[0]].cost
= {2, 11};
86 graph
[v
[1]][v
[2]].cost
= {1, 12};
87 graph
[v
[2]][v
[1]].cost
= {1, 13};
88 check(2, {{v
[0], {}}, {v
[1], v
[0]}, {v
[2], v
[1]}});
91 TEST_F(RootOrderingTest
, simpleC
) {
92 graph
[v
[1]][v
[0]].cost
= {2, 10};
93 graph
[v
[2]][v
[0]].cost
= {2, 11};
94 graph
[v
[1]][v
[2]].cost
= {1, 12};
95 graph
[v
[2]][v
[1]].cost
= {1, 13};
96 check(3, {{v
[0], {}}, {v
[1], v
[0]}, {v
[2], v
[1]}});
99 //===----------------------------------------------------------------------===//
100 // Graph for testing contraction
101 //===----------------------------------------------------------------------===//
103 TEST_F(RootOrderingTest
, contraction
) {
104 graph
[v
[1]][v
[0]].cost
= {10, 0};
105 graph
[v
[2]][v
[0]].cost
= {5, 0};
106 graph
[v
[2]][v
[1]].cost
= {1, 0};
107 graph
[v
[3]][v
[2]].cost
= {2, 0};
108 graph
[v
[1]][v
[3]].cost
= {3, 0};
109 check(10, {{v
[0], {}}, {v
[2], v
[0]}, {v
[3], v
[2]}, {v
[1], v
[3]}});