1 //===- llvm/Analysis/MaximumSpanningTree.h - Interface ----------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This module provides means for calculating a maximum spanning tree for a
10 // given set of weighted edges. The type parameter T is the type of a node.
12 //===----------------------------------------------------------------------===//
14 #ifndef LLVM_LIB_TRANSFORMS_INSTRUMENTATION_MAXIMUMSPANNINGTREE_H
15 #define LLVM_LIB_TRANSFORMS_INSTRUMENTATION_MAXIMUMSPANNINGTREE_H
17 #include "llvm/ADT/EquivalenceClasses.h"
18 #include "llvm/IR/BasicBlock.h"
24 /// MaximumSpanningTree - A MST implementation.
25 /// The type parameter T determines the type of the nodes of the graph.
27 class MaximumSpanningTree
{
29 typedef std::pair
<const T
*, const T
*> Edge
;
30 typedef std::pair
<Edge
, double> EdgeWeight
;
31 typedef std::vector
<EdgeWeight
> EdgeWeights
;
33 typedef std::vector
<Edge
> MaxSpanTree
;
38 // A comparing class for comparing weighted edges.
39 struct EdgeWeightCompare
{
40 static bool getBlockSize(const T
*X
) {
41 const BasicBlock
*BB
= dyn_cast_or_null
<BasicBlock
>(X
);
42 return BB
? BB
->size() : 0;
45 bool operator()(EdgeWeight X
, EdgeWeight Y
) const {
46 if (X
.second
> Y
.second
) return true;
47 if (X
.second
< Y
.second
) return false;
49 // Equal edge weights: break ties by comparing block sizes.
50 size_t XSizeA
= getBlockSize(X
.first
.first
);
51 size_t YSizeA
= getBlockSize(Y
.first
.first
);
52 if (XSizeA
> YSizeA
) return true;
53 if (XSizeA
< YSizeA
) return false;
55 size_t XSizeB
= getBlockSize(X
.first
.second
);
56 size_t YSizeB
= getBlockSize(Y
.first
.second
);
57 if (XSizeB
> YSizeB
) return true;
58 if (XSizeB
< YSizeB
) return false;
65 static char ID
; // Class identification, replacement for typeinfo
67 /// MaximumSpanningTree() - Takes a vector of weighted edges and returns a
69 MaximumSpanningTree(EdgeWeights
&EdgeVector
) {
71 std::stable_sort(EdgeVector
.begin(), EdgeVector
.end(), EdgeWeightCompare());
73 // Create spanning tree, Forest contains a special data structure
74 // that makes checking if two nodes are already in a common (sub-)tree
76 EquivalenceClasses
<const T
*> Forest
;
77 for (typename
EdgeWeights::iterator EWi
= EdgeVector
.begin(),
78 EWe
= EdgeVector
.end(); EWi
!= EWe
; ++EWi
) {
79 Edge e
= (*EWi
).first
;
81 Forest
.insert(e
.first
);
82 Forest
.insert(e
.second
);
85 // Iterate over the sorted edges, biggest first.
86 for (typename
EdgeWeights::iterator EWi
= EdgeVector
.begin(),
87 EWe
= EdgeVector
.end(); EWi
!= EWe
; ++EWi
) {
88 Edge e
= (*EWi
).first
;
90 if (Forest
.findLeader(e
.first
) != Forest
.findLeader(e
.second
)) {
91 Forest
.unionSets(e
.first
, e
.second
);
92 // So we know now that the edge is not already in a subtree, so we push
93 // the edge to the MST.
99 typename
MaxSpanTree::iterator
begin() {
103 typename
MaxSpanTree::iterator
end() {
108 } // End llvm namespace
110 #endif // LLVM_LIB_TRANSFORMS_INSTRUMENTATION_MAXIMUMSPANNINGTREE_H