1 {-# LANGUAGE NondecreasingIndentation #-}
3 -- | A simple mutable union-find data structure.
5 -- It is used in a unification algorithm for backpack mix-in linking.
7 -- This implementation is based off of the one in \"The Essence of ML Type
8 -- Inference\". (N.B. the union-find package is also based off of this.)
9 module Distribution
.Utils
.UnionFind
18 import Control
.Monad
.ST
21 -- | A variable which can be unified; alternately, this can be thought
22 -- of as an equivalence class with a distinguished representative.
23 newtype Point s a
= Point
(STRef s
(Link s a
))
26 -- | Mutable write to a 'Point'
27 writePoint
:: Point s a
-> Link s a
-> ST s
()
28 writePoint
(Point v
) = writeSTRef v
30 -- | Read the current value of 'Point'.
31 readPoint
:: Point s a
-> ST s
(Link s a
)
32 readPoint
(Point v
) = readSTRef v
34 -- | The internal data structure for a 'Point', which either records
35 -- the representative element of an equivalence class, or a link to
36 -- the 'Point' that actually stores the representative type.
38 = -- NB: it is too bad we can't say STRef Int#; the weights remain boxed
39 Info
{-# UNPACK #-} !(STRef s
Int) {-# UNPACK #-} !(STRef s a
)
40 | Link
{-# UNPACK #-} !(Point s a
)
42 -- | Create a fresh equivalence class with one element.
43 fresh
:: a
-> ST s
(Point s a
)
46 descriptor
<- newSTRef desc
47 Point `
fmap` newSTRef
(Info weight descriptor
)
49 -- | Flatten any chains of links, returning a 'Point'
50 -- which points directly to the canonical representation.
51 repr
:: Point s a
-> ST s
(Point s a
)
53 readPoint point
>>= \r ->
56 point
'' <- repr point
'
57 when (point
'' /= point
') $ do
58 writePoint point
=<< readPoint point
'
60 Info _ _
-> return point
62 -- | Return the canonical element of an equivalence
64 find :: Point s a
-> ST s a
66 -- Optimize length 0 and 1 case at expense of
68 readPoint point
>>= \r ->
70 Info _ d_ref
-> readSTRef d_ref
72 readPoint point
' >>= \r' ->
74 Info _ d_ref
-> readSTRef d_ref
75 Link _
-> repr point
>>= find
77 -- | Unify two equivalence classes, so that they share
78 -- a canonical element. Keeps the descriptor of point2.
79 union :: Point s a
-> Point s a
-> ST s
()
80 union refpoint1 refpoint2
= do
81 point1
<- repr refpoint1
82 point2
<- repr refpoint2
83 when (point1
/= point2
) $ do
84 l1
<- readPoint point1
85 l2
<- readPoint point2
87 (Info wref1 dref1
, Info wref2 dref2
) -> do
88 weight1
<- readSTRef wref1
89 weight2
<- readSTRef wref2
90 -- Should be able to optimize the == case separately
93 writePoint point2
(Link point1
)
94 -- The weight calculation here seems a bit dodgy
95 writeSTRef wref1
(weight1
+ weight2
)
96 writeSTRef dref1
=<< readSTRef dref2
98 writePoint point1
(Link point2
)
99 writeSTRef wref2
(weight1
+ weight2
)
100 _
-> error "UnionFind.union: repr invariant broken"
102 -- | Test if two points are in the same equivalence class.
103 equivalent
:: Point s a
-> Point s a
-> ST s
Bool
104 equivalent point1 point2
= liftM2 (==) (repr point1
) (repr point2
)