1 //===- BalancedPartitioning.cpp -------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements BalancedPartitioning, a recursive balanced graph
10 // partitioning algorithm.
12 //===----------------------------------------------------------------------===//
14 #include "llvm/Support/BalancedPartitioning.h"
15 #include "llvm/Support/Debug.h"
16 #include "llvm/Support/Format.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/Support/ThreadPool.h"
21 #define DEBUG_TYPE "balanced-partitioning"
23 void BPFunctionNode::dump(raw_ostream
&OS
) const {
24 OS
<< formatv("{{ID={0} Utilities={{{1:$[,]}} Bucket={2}}", Id
,
25 make_range(UtilityNodes
.begin(), UtilityNodes
.end()), Bucket
);
28 template <typename Func
>
29 void BalancedPartitioning::BPThreadPool::async(Func
&&F
) {
30 #if LLVM_ENABLE_THREADS
31 // This new thread could spawn more threads, so mark it as active
33 TheThreadPool
.async([=]() {
37 // This thread will no longer spawn new threads, so mark it as inactive
38 if (--NumActiveThreads
== 0) {
39 // There are no more active threads, so mark as finished and notify
41 std::unique_lock
<std::mutex
> lock(mtx
);
42 assert(!IsFinishedSpawning
);
43 IsFinishedSpawning
= true;
49 llvm_unreachable("threads are disabled");
53 void BalancedPartitioning::BPThreadPool::wait() {
54 #if LLVM_ENABLE_THREADS
55 // TODO: We could remove the mutex and condition variable and use
56 // std::atomic::wait() instead, but that isn't available until C++20
58 std::unique_lock
<std::mutex
> lock(mtx
);
59 cv
.wait(lock
, [&]() { return IsFinishedSpawning
; });
60 assert(IsFinishedSpawning
&& NumActiveThreads
== 0);
62 // Now we can call ThreadPool::wait() since all tasks have been submitted
65 llvm_unreachable("threads are disabled");
69 BalancedPartitioning::BalancedPartitioning(
70 const BalancedPartitioningConfig
&Config
)
72 // Pre-computing log2 values
74 for (unsigned I
= 1; I
< LOG_CACHE_SIZE
; I
++)
75 Log2Cache
[I
] = std::log2(I
);
78 void BalancedPartitioning::run(std::vector
<BPFunctionNode
> &Nodes
) const {
81 "Partitioning %d nodes using depth %d and %d iterations per split\n",
82 Nodes
.size(), Config
.SplitDepth
, Config
.IterationsPerSplit
));
83 std::optional
<BPThreadPool
> TP
;
84 #if LLVM_ENABLE_THREADS
85 ThreadPool TheThreadPool
;
86 if (Config
.TaskSplitDepth
> 1)
87 TP
.emplace(TheThreadPool
);
90 // Record the input order
91 for (unsigned I
= 0; I
< Nodes
.size(); I
++)
92 Nodes
[I
].InputOrderIndex
= I
;
94 auto NodesRange
= llvm::make_range(Nodes
.begin(), Nodes
.end());
95 auto BisectTask
= [=, &TP
]() {
96 bisect(NodesRange
, /*RecDepth=*/0, /*RootBucket=*/1, /*Offset=*/0, TP
);
99 TP
->async(std::move(BisectTask
));
105 llvm::stable_sort(NodesRange
, [](const auto &L
, const auto &R
) {
106 return L
.Bucket
< R
.Bucket
;
109 LLVM_DEBUG(dbgs() << "Balanced partitioning completed\n");
112 void BalancedPartitioning::bisect(const FunctionNodeRange Nodes
,
113 unsigned RecDepth
, unsigned RootBucket
,
115 std::optional
<BPThreadPool
> &TP
) const {
116 unsigned NumNodes
= std::distance(Nodes
.begin(), Nodes
.end());
117 if (NumNodes
<= 1 || RecDepth
>= Config
.SplitDepth
) {
118 // We've reach the lowest level of the recursion tree. Fall back to the
119 // original order and assign to buckets.
120 llvm::sort(Nodes
, [](const auto &L
, const auto &R
) {
121 return L
.InputOrderIndex
< R
.InputOrderIndex
;
123 for (auto &N
: Nodes
)
128 LLVM_DEBUG(dbgs() << format("Bisect with %d nodes and root bucket %d\n",
129 NumNodes
, RootBucket
));
131 std::mt19937
RNG(RootBucket
);
133 unsigned LeftBucket
= 2 * RootBucket
;
134 unsigned RightBucket
= 2 * RootBucket
+ 1;
136 // Split into two and assign to the left and right buckets
137 split(Nodes
, LeftBucket
);
139 runIterations(Nodes
, RecDepth
, LeftBucket
, RightBucket
, RNG
);
141 // Split nodes wrt the resulting buckets
143 llvm::partition(Nodes
, [&](auto &N
) { return N
.Bucket
== LeftBucket
; });
144 unsigned MidOffset
= Offset
+ std::distance(Nodes
.begin(), NodesMid
);
146 auto LeftNodes
= llvm::make_range(Nodes
.begin(), NodesMid
);
147 auto RightNodes
= llvm::make_range(NodesMid
, Nodes
.end());
149 auto LeftRecTask
= [=, &TP
]() {
150 bisect(LeftNodes
, RecDepth
+ 1, LeftBucket
, Offset
, TP
);
152 auto RightRecTask
= [=, &TP
]() {
153 bisect(RightNodes
, RecDepth
+ 1, RightBucket
, MidOffset
, TP
);
156 if (TP
&& RecDepth
< Config
.TaskSplitDepth
&& NumNodes
>= 4) {
157 TP
->async(std::move(LeftRecTask
));
158 TP
->async(std::move(RightRecTask
));
165 void BalancedPartitioning::runIterations(const FunctionNodeRange Nodes
,
166 unsigned RecDepth
, unsigned LeftBucket
,
167 unsigned RightBucket
,
168 std::mt19937
&RNG
) const {
169 unsigned NumNodes
= std::distance(Nodes
.begin(), Nodes
.end());
170 DenseMap
<BPFunctionNode::UtilityNodeT
, unsigned> UtilityNodeIndex
;
171 for (auto &N
: Nodes
)
172 for (auto &UN
: N
.UtilityNodes
)
173 ++UtilityNodeIndex
[UN
];
174 // Remove utility nodes if they have just one edge or are connected to all
176 for (auto &N
: Nodes
)
177 llvm::erase_if(N
.UtilityNodes
, [&](auto &UN
) {
178 return UtilityNodeIndex
[UN
] == 1 || UtilityNodeIndex
[UN
] == NumNodes
;
181 // Renumber utility nodes so they can be used to index into Signatures
182 UtilityNodeIndex
.clear();
183 for (auto &N
: Nodes
)
184 for (auto &UN
: N
.UtilityNodes
)
185 UN
= UtilityNodeIndex
.insert({UN
, UtilityNodeIndex
.size()}).first
->second
;
187 // Initialize signatures
188 SignaturesT
Signatures(/*Size=*/UtilityNodeIndex
.size());
189 for (auto &N
: Nodes
) {
190 for (auto &UN
: N
.UtilityNodes
) {
191 assert(UN
< Signatures
.size());
192 if (N
.Bucket
== LeftBucket
) {
193 Signatures
[UN
].LeftCount
++;
195 Signatures
[UN
].RightCount
++;
200 for (unsigned I
= 0; I
< Config
.IterationsPerSplit
; I
++) {
201 unsigned NumMovedNodes
=
202 runIteration(Nodes
, LeftBucket
, RightBucket
, Signatures
, RNG
);
203 if (NumMovedNodes
== 0)
208 unsigned BalancedPartitioning::runIteration(const FunctionNodeRange Nodes
,
210 unsigned RightBucket
,
211 SignaturesT
&Signatures
,
212 std::mt19937
&RNG
) const {
213 // Init signature cost caches
214 for (auto &Signature
: Signatures
) {
215 if (Signature
.CachedGainIsValid
)
217 unsigned L
= Signature
.LeftCount
;
218 unsigned R
= Signature
.RightCount
;
219 assert((L
> 0 || R
> 0) && "incorrect signature");
220 float Cost
= logCost(L
, R
);
221 Signature
.CachedGainLR
= 0.f
;
222 Signature
.CachedGainRL
= 0.f
;
224 Signature
.CachedGainLR
= Cost
- logCost(L
- 1, R
+ 1);
226 Signature
.CachedGainRL
= Cost
- logCost(L
+ 1, R
- 1);
227 Signature
.CachedGainIsValid
= true;
230 // Compute move gains
231 typedef std::pair
<float, BPFunctionNode
*> GainPair
;
232 std::vector
<GainPair
> Gains
;
233 for (auto &N
: Nodes
) {
234 bool FromLeftToRight
= (N
.Bucket
== LeftBucket
);
235 float Gain
= moveGain(N
, FromLeftToRight
, Signatures
);
236 Gains
.push_back(std::make_pair(Gain
, &N
));
239 // Collect left and right gains
240 auto LeftEnd
= llvm::partition(
241 Gains
, [&](const auto &GP
) { return GP
.second
->Bucket
== LeftBucket
; });
242 auto LeftRange
= llvm::make_range(Gains
.begin(), LeftEnd
);
243 auto RightRange
= llvm::make_range(LeftEnd
, Gains
.end());
245 // Sort gains in descending order
246 auto LargerGain
= [](const auto &L
, const auto &R
) {
247 return L
.first
> R
.first
;
249 llvm::stable_sort(LeftRange
, LargerGain
);
250 llvm::stable_sort(RightRange
, LargerGain
);
252 unsigned NumMovedDataVertices
= 0;
253 for (auto [LeftPair
, RightPair
] : llvm::zip(LeftRange
, RightRange
)) {
254 auto &[LeftGain
, LeftNode
] = LeftPair
;
255 auto &[RightGain
, RightNode
] = RightPair
;
256 // Stop when the gain is no longer beneficial
257 if (LeftGain
+ RightGain
<= 0.f
)
259 // Try to exchange the nodes between buckets
260 if (moveFunctionNode(*LeftNode
, LeftBucket
, RightBucket
, Signatures
, RNG
))
261 ++NumMovedDataVertices
;
262 if (moveFunctionNode(*RightNode
, LeftBucket
, RightBucket
, Signatures
, RNG
))
263 ++NumMovedDataVertices
;
265 return NumMovedDataVertices
;
268 bool BalancedPartitioning::moveFunctionNode(BPFunctionNode
&N
,
270 unsigned RightBucket
,
271 SignaturesT
&Signatures
,
272 std::mt19937
&RNG
) const {
273 // Sometimes we skip the move. This helps to escape local optima
274 if (std::uniform_real_distribution
<float>(0.f
, 1.f
)(RNG
) <=
275 Config
.SkipProbability
)
278 bool FromLeftToRight
= (N
.Bucket
== LeftBucket
);
279 // Update the current bucket
280 N
.Bucket
= (FromLeftToRight
? RightBucket
: LeftBucket
);
282 // Update signatures and invalidate gain cache
283 if (FromLeftToRight
) {
284 for (auto &UN
: N
.UtilityNodes
) {
285 auto &Signature
= Signatures
[UN
];
286 Signature
.LeftCount
--;
287 Signature
.RightCount
++;
288 Signature
.CachedGainIsValid
= false;
291 for (auto &UN
: N
.UtilityNodes
) {
292 auto &Signature
= Signatures
[UN
];
293 Signature
.LeftCount
++;
294 Signature
.RightCount
--;
295 Signature
.CachedGainIsValid
= false;
301 void BalancedPartitioning::split(const FunctionNodeRange Nodes
,
302 unsigned StartBucket
) const {
303 unsigned NumNodes
= std::distance(Nodes
.begin(), Nodes
.end());
304 auto NodesMid
= Nodes
.begin() + (NumNodes
+ 1) / 2;
306 std::nth_element(Nodes
.begin(), NodesMid
, Nodes
.end(), [](auto &L
, auto &R
) {
307 return L
.InputOrderIndex
< R
.InputOrderIndex
;
310 for (auto &N
: llvm::make_range(Nodes
.begin(), NodesMid
))
311 N
.Bucket
= StartBucket
;
312 for (auto &N
: llvm::make_range(NodesMid
, Nodes
.end()))
313 N
.Bucket
= StartBucket
+ 1;
316 float BalancedPartitioning::moveGain(const BPFunctionNode
&N
,
317 bool FromLeftToRight
,
318 const SignaturesT
&Signatures
) {
320 for (auto &UN
: N
.UtilityNodes
)
321 Gain
+= (FromLeftToRight
? Signatures
[UN
].CachedGainLR
322 : Signatures
[UN
].CachedGainRL
);
326 float BalancedPartitioning::logCost(unsigned X
, unsigned Y
) const {
327 return -(X
* log2Cached(X
+ 1) + Y
* log2Cached(Y
+ 1));
330 float BalancedPartitioning::log2Cached(unsigned i
) const {
331 return (i
< LOG_CACHE_SIZE
) ? Log2Cache
[i
] : std::log2(i
);