Make “sublibrary” standard terminology in docs
[cabal.git] / Cabal / src / Distribution / Utils / UnionFind.hs
blobb22f07c0e43d2b4944fff4e10da44ff32baf91a6
1 {-# LANGUAGE NondecreasingIndentation #-}
3 -- | A simple mutable union-find data structure.
4 --
5 -- It is used in a unification algorithm for backpack mix-in linking.
6 --
7 -- This implementation is based off of the one in \"The Essence of ML Type
8 -- Inference\". (N.B. the union-find package is also based off of this.)
9 module Distribution.Utils.UnionFind
10 ( Point
11 , fresh
12 , find
13 , union
14 , equivalent
15 ) where
17 import Control.Monad
18 import Control.Monad.ST
19 import Data.STRef
21 -- | A variable which can be unified; alternately, this can be thought
22 -- of as an equivalence class with a distinguished representative.
23 newtype Point s a = Point (STRef s (Link s a))
24 deriving (Eq)
26 -- | Mutable write to a 'Point'
27 writePoint :: Point s a -> Link s a -> ST s ()
28 writePoint (Point v) = writeSTRef v
30 -- | Read the current value of 'Point'.
31 readPoint :: Point s a -> ST s (Link s a)
32 readPoint (Point v) = readSTRef v
34 -- | The internal data structure for a 'Point', which either records
35 -- the representative element of an equivalence class, or a link to
36 -- the 'Point' that actually stores the representative type.
37 data Link s a
38 = -- NB: it is too bad we can't say STRef Int#; the weights remain boxed
39 Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a)
40 | Link {-# UNPACK #-} !(Point s a)
42 -- | Create a fresh equivalence class with one element.
43 fresh :: a -> ST s (Point s a)
44 fresh desc = do
45 weight <- newSTRef 1
46 descriptor <- newSTRef desc
47 Point `fmap` newSTRef (Info weight descriptor)
49 -- | Flatten any chains of links, returning a 'Point'
50 -- which points directly to the canonical representation.
51 repr :: Point s a -> ST s (Point s a)
52 repr point =
53 readPoint point >>= \r ->
54 case r of
55 Link point' -> do
56 point'' <- repr point'
57 when (point'' /= point') $ do
58 writePoint point =<< readPoint point'
59 return point''
60 Info _ _ -> return point
62 -- | Return the canonical element of an equivalence
63 -- class 'Point'.
64 find :: Point s a -> ST s a
65 find point =
66 -- Optimize length 0 and 1 case at expense of
67 -- general case
68 readPoint point >>= \r ->
69 case r of
70 Info _ d_ref -> readSTRef d_ref
71 Link point' ->
72 readPoint point' >>= \r' ->
73 case r' of
74 Info _ d_ref -> readSTRef d_ref
75 Link _ -> repr point >>= find
77 -- | Unify two equivalence classes, so that they share
78 -- a canonical element. Keeps the descriptor of point2.
79 union :: Point s a -> Point s a -> ST s ()
80 union refpoint1 refpoint2 = do
81 point1 <- repr refpoint1
82 point2 <- repr refpoint2
83 when (point1 /= point2) $ do
84 l1 <- readPoint point1
85 l2 <- readPoint point2
86 case (l1, l2) of
87 (Info wref1 dref1, Info wref2 dref2) -> do
88 weight1 <- readSTRef wref1
89 weight2 <- readSTRef wref2
90 -- Should be able to optimize the == case separately
91 if weight1 >= weight2
92 then do
93 writePoint point2 (Link point1)
94 -- The weight calculation here seems a bit dodgy
95 writeSTRef wref1 (weight1 + weight2)
96 writeSTRef dref1 =<< readSTRef dref2
97 else do
98 writePoint point1 (Link point2)
99 writeSTRef wref2 (weight1 + weight2)
100 _ -> error "UnionFind.union: repr invariant broken"
102 -- | Test if two points are in the same equivalence class.
103 equivalent :: Point s a -> Point s a -> ST s Bool
104 equivalent point1 point2 = liftM2 (==) (repr point1) (repr point2)