fix sessions and CE oracles
[why3.git] / examples_in_progress / union_find.mlw
bloba5a7fccf584103dcc7b6b012b0f1e10df05e91ed
1 module Intf
3   use int.Int
5   type t = private {
6                   size: int;        (* elements are 0,1,...,size-1 *)
7     ghost mutable repr: int -> int;
8   } invariant {
9     0 <= size /\
10     (forall i. 0 <= i < size -> 0 <= repr i < size) /\
11     (forall i. 0 <= i < size -> repr (repr i) = repr i)
12   } by { size = 0; repr = fun i -> i }
14   val create (n: int) : t
15     requires { 0 <= n }
16     ensures  { result.size = n }
17     ensures  { forall i. 0 <= i < result.size -> result.repr i = i }
19   val find (uf: t) (x: int) : int
20     requires { 0 <= x < uf.size }
21     writes   { uf }
22     ensures  { result = uf.repr x }
23     ensures  { uf.repr = old uf.repr }
25   predicate same (repr: int -> int) (x y: int) =
26     repr x = repr y
28   val union (uf: t) (x y: int) : unit
29     requires { 0 <= x < uf.size }
30     requires { 0 <= y < uf.size }
31     writes   { uf.repr }
32     ensures  { same uf.repr x y }
33     ensures  { forall i j. 0 <= i < uf.size -> 0 <= j < uf.size ->
34                same uf.repr i j <->
35                  same (old uf.repr) i j \/
36                  same (old uf.repr) i x /\ same (old uf.repr) y j \/
37                  same (old uf.repr) i y /\ same (old uf.repr) x j }
39 end
41 module Impl
43   use int.Int
44   use array.Array
46   (* there is a path from x to y of length at most d *)
47   inductive path (size: int) (link: array int) (x d y: int) =
48   | path_zero: forall size link x d.
49                0 <= x < size -> link[x] = x -> 0 <= d ->
50                path size link x d x
51   | path_succ: forall size link.
52                forall x d z. 0 <= x < size -> link[x] <> x ->
53                path size link link[x] d z ->
54                path size link x (d + 1) z
56   lemma path_dist_nonneg:
57     forall size link x d y. path size link x d y -> 0 <= d
59   lemma path_src:
60     forall size link x d y. path size link x d y -> 0 <= x < size
62   lemma path_dst:
63     forall size link x d y. path size link x d y -> 0 <= y < size
65   let rec lemma path_unique (size: int) (link: array int) (x d1 d2 y1 y2: int)
66     requires { length link = size }
67     requires { path size link x d1 y1 }
68     requires { path size link x d2 y2 }
69     variant  { d1 + d2 }
70     ensures  { y1 = y2 }
71   = if x <> link[x] then begin
72       path_unique size link link[x] (d1-1) (d2-1) y1 y2;
73     end
75   let rec lemma path_last (size: int) (link: array int) (x d y: int)
76     requires { length link = size }
77     requires { path size link x d y }
78     variant  { d }
79     ensures  { link[y] = y }
80   = if x <> link[x] then path_last size link link[x] (d-1) y
82   type t = {
83                   size: int;        (* elements are 0,1,...,size-1 *)
84                   link: array int;
85                   rank: array int;
86     ghost mutable repr: int -> int;
87     ghost mutable dist: int -> int;
88   } invariant {
89     0 <= size = length link = length rank /\
90     (forall i. 0 <= i < size -> 0 <= repr i < size) /\
91     (forall i. 0 <= i < size -> repr (repr i) = repr i) /\
92     (forall i. 0 <= i < size -> link[i] <> i -> dist link[i] < dist i) /\
93     (forall i. 0 <= i < size -> path size link i (dist i) (repr i))
94   } by {
95     size = 0; link = Array.make 0 0; rank = Array.make 0 0;
96     repr = (fun i -> i); dist = (fun _i -> 0)
97   }
99   let create (n: int) : t
100     requires { 0 <= n }
101     ensures  { result.size = n }
102     ensures  { forall i. 0 <= i < result.size -> result.repr i = i }
103   = let link = Array.make n 0 in
104     for i = 0 to n - 1 do
105       invariant { forall j. 0 <= j < i -> link[j] = j }
106       link[i] <- i
107     done;
108     let rank = Array.make n 0 in
109     { size = n; link = link; rank = rank;
110       repr = (fun i -> i); dist = (fun _i -> 0) }
112   let rec lemma path_dist (size: int) (link: array int) (dist: int -> int)
113                           (x d y: int)
114     requires { length link = size }
115     requires { path size link x d y }
116     requires { forall i. 0 <= i < size -> link[i] <> i -> dist link[i] < dist i}
117     requires { x <> y }
118     variant  { d }
119     ensures  { dist y < dist x }
120   = if x <> link[x] && link[x] <> y then
121       path_dist size link dist link[x] (d-1) y
123   let rec lemma path_compression
124     (size: int) (link: array int) (x dx rx: int) (i di ri: int)
125     requires { length link = size }
126     requires { path size link x dx rx }
127     requires { x <> rx }
128     requires { path size link i di ri }
129     variant  { di }
130     ensures  { path size link[x <- rx] i di ri }
131   = if i = x then ()
132     else if link[i] = i then ()
133     else path_compression size link x dx rx link[i] (di-1) ri
135   let rec find (uf: t) (x: int) : int
136     requires { 0 <= x < uf.size }
137     writes   { uf.link(* , uf.dist *) }
138     variant  { uf.dist x }
139     ensures  { result = uf.repr x }
140     ensures  { path uf.size uf.link x (uf.dist x) result }
141   = let y = uf.link[x] in
142     if y <> x then begin
143       assert { path uf.size uf.link y (uf.dist x - 1) (uf.repr x) };
144       let r = find uf y in
145       assert { x <> r };
146       uf.link[x] <- r; (* path compression *)
147       r
148     end else
149       x
151   predicate same (repr: int -> int) (x y: int) =
152     repr x = repr y
154   let union (uf: t) (x y: int) : unit
155     requires { 0 <= x < uf.size }
156     requires { 0 <= y < uf.size }
157     writes   { uf.link, uf.rank, uf.repr, uf.dist }
158     ensures  { same uf.repr x y }
159     ensures  { forall i j. 0 <= i < uf.size -> 0 <= j < uf.size ->
160                same uf.repr i j <->
161                  same (old uf.repr) i j \/
162                  same (old uf.repr) i x /\ same (old uf.repr) y j \/
163                  same (old uf.repr) i y /\ same (old uf.repr) x j }
164   = let rx = find uf x in
165     let ry = find uf y in
166     if rx <> ry then
167       let oldr = uf.repr in
168       let oldd = uf.dist in
169       if uf.rank[rx] <= uf.rank[ry] then begin
170         uf.link[rx] <- ry;
171         uf.repr <- (fun i -> if oldr i = rx then ry else oldr i);
172         uf.dist <- pure { fun i -> if oldr i = rx then oldd i + 1 else oldd i };
173         assert { forall i. 0 <= i < uf.size ->
174              if oldr i = rx then path uf.size uf.link i (oldd i + 1) ry
175                             else path uf.size uf.link i (oldd i)     (oldr i)};
176         if uf.rank[rx] = uf.rank[ry] then
177           uf.rank[ry] <- uf.rank[ry] + 1
178       end else begin
179         uf.link[ry] <- rx;
180         uf.repr <- (fun i -> if oldr i = ry then rx else oldr i);
181         uf.dist <-
182           pure { fun i -> if oldr i = ry then oldd i + 1 else oldd i };
184       end
186   clone Intf with
187     type t = t,
188     val  create = create,
189     val  find = find,
190     val  union = union