1 //===- BalancedPartitioningTest.cpp - BalancedPartitioning tests ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Support/BalancedPartitioning.h"
10 #include "llvm/Testing/Support/SupportHelpers.h"
11 #include "gmock/gmock.h"
12 #include "gtest/gtest.h"
17 using testing::UnorderedElementsAre
;
18 using testing::UnorderedElementsAreArray
;
22 void PrintTo(const BPFunctionNode
&Node
, std::ostream
*OS
) {
23 raw_os_ostream
ROS(*OS
);
27 class BalancedPartitioningTest
: public ::testing::Test
{
29 BalancedPartitioningConfig Config
;
30 BalancedPartitioning Bp
;
31 BalancedPartitioningTest() : Bp(Config
) {}
33 static std::vector
<BPFunctionNode::IDT
>
34 getIds(std::vector
<BPFunctionNode
> Nodes
) {
35 std::vector
<BPFunctionNode::IDT
> Ids
;
42 TEST_F(BalancedPartitioningTest
, Basic
) {
43 std::vector
<BPFunctionNode
> Nodes
= {
44 BPFunctionNode(0, {1, 2}), BPFunctionNode(2, {3, 4}),
45 BPFunctionNode(1, {1, 2}), BPFunctionNode(3, {3, 4}),
46 BPFunctionNode(4, {4}),
51 auto NodeIs
= [](BPFunctionNode::IDT Id
, std::optional
<uint32_t> Bucket
) {
52 return AllOf(Field("Id", &BPFunctionNode::Id
, Id
),
53 Field("Bucket", &BPFunctionNode::Bucket
, Bucket
));
57 UnorderedElementsAre(NodeIs(0, 0), NodeIs(1, 1), NodeIs(2, 2),
58 NodeIs(3, 3), NodeIs(4, 4)));
61 TEST_F(BalancedPartitioningTest
, Large
) {
62 const int ProblemSize
= 1000;
63 std::vector
<BPFunctionNode::UtilityNodeT
> AllUNs
;
64 for (int i
= 0; i
< ProblemSize
; i
++)
65 AllUNs
.emplace_back(i
);
68 std::vector
<BPFunctionNode
> Nodes
;
69 for (int i
= 0; i
< ProblemSize
; i
++) {
70 std::vector
<BPFunctionNode::UtilityNodeT
> UNs
;
72 std::uniform_int_distribution
<int>(0, AllUNs
.size() - 1)(RNG
);
73 std::sample(AllUNs
.begin(), AllUNs
.end(), std::back_inserter(UNs
),
75 Nodes
.emplace_back(i
, UNs
);
78 auto OrigIds
= getIds(Nodes
);
83 Nodes
, Each(Not(Field("Bucket", &BPFunctionNode::Bucket
, std::nullopt
))));
84 EXPECT_THAT(getIds(Nodes
), UnorderedElementsAreArray(OrigIds
));
87 TEST_F(BalancedPartitioningTest
, MoveGain
) {
88 BalancedPartitioning::SignaturesT Signatures
= {
89 {10, 10, 10.f
, 0.f
, true}, // 0
90 {10, 10, 0.f
, 10.f
, true}, // 1
91 {10, 10, 0.f
, 20.f
, true}, // 2
93 EXPECT_FLOAT_EQ(Bp
.moveGain(BPFunctionNode(0, {}), true, Signatures
), 0.f
);
94 EXPECT_FLOAT_EQ(Bp
.moveGain(BPFunctionNode(0, {0, 1}), true, Signatures
),
96 EXPECT_FLOAT_EQ(Bp
.moveGain(BPFunctionNode(0, {1, 2}), false, Signatures
),
100 } // end namespace llvm