ease the proof of coincidence count
[why3.git] / examples / koda_ruskey.mlw
blobce1c1983c409bd1561da6bd5649741077807d56f
2 (** Koda-Ruskey's algorithm
4     Authors: Mário Pereira (Université Paris Sud)
5              Jean-Christophe Filliâtre (CNRS)
6 *)
8 module KodaRuskey_Spec
10   use map.Map
11   use list.List
12   use list.Append
13   use int.Int
15   type color = White | Black
17   let eq_color (c1 c2:color) : bool
18     ensures { result <-> c1 = c2 }
19   = match c1,c2 with
20     | White,White | Black,Black -> True
21     | _ -> False
22     end
24   type forest =
25     | E
26     | N int forest forest
28   type coloring = map int color
30   function size_forest (f: forest) : int = match f with
31     | E -> 0
32     | N _ f1 f2 -> 1 + size_forest f1 + size_forest f2
33     end
35   lemma size_forest_nonneg : forall f.
36     size_forest f >= 0
38   predicate mem_forest (n: int) (f: forest) = match f with
39     | E -> false
40     | N i f1 f2 -> i = n || mem_forest n f1 || mem_forest n f2
41     end
43   predicate between_range_forest (i j: int) (f: forest) =
44     forall n. mem_forest n f -> i <= n < j
46   predicate disjoint (f1 f2: forest) =
47     forall x. mem_forest x f1 -> mem_forest x f2 -> false
49   predicate no_repeated_forest (f: forest) = match f with
50     | E -> true
51     | N i f1 f2 ->
52       no_repeated_forest f1 && no_repeated_forest f2 &&
53       not (mem_forest i f1) && not (mem_forest i f2) &&
54       disjoint f1 f2
55     end
57   predicate valid_nums_forest (f: forest) (n: int) =
58     between_range_forest 0 n f &&
59     no_repeated_forest f
61   predicate white_forest (f: forest) (c: coloring) = match f with
62     | E -> true
63     | N i f1 f2 ->
64       c[i] = White && white_forest f1 c && white_forest f2 c
65   end
67   predicate valid_coloring (f: forest) (c: coloring) =
68     match f with
69     | E -> true
70     | N i f1 f2 ->
71       valid_coloring f2 c &&
72       match c[i] with
73       | White -> white_forest f1 c
74       | Black -> valid_coloring f1 c
75       end
76     end
78   function count_forest (f: forest) : int = match f with
79     | E         -> 1
80     | N _ f1 f2 -> (1 + count_forest f1) * count_forest f2
81     end
83   lemma count_forest_nonneg:
84     forall f. count_forest f >= 1
86   predicate eq_coloring (n: int) (c1 c2: coloring) =
87     forall i. 0 <= i < n -> c1[i] = c2[i]
89 end
91 module Lemmas
93   use map.Map
94   use list.List
95   use list.Append
96   use int.Int
97   use KodaRuskey_Spec
99   type stack = list forest
101   predicate mem_stack (n: int) (st: stack) = match st with
102     | Nil       -> false
103     | Cons f tl -> mem_forest n f || mem_stack n tl
104     end
106   lemma mem_app: forall n st1 [@induction] st2.
107     mem_stack n (st1 ++ st2) -> mem_stack n st1 || mem_stack n st2
109   function size_stack (st: stack) : int = match st with
110     | Nil -> 0
111     | Cons f st -> size_forest f + size_stack st
112     end
114   lemma size_stack_nonneg : forall st.
115     size_stack st >= 0
117   lemma white_forest_equiv:
118     forall f c.
119     white_forest f c <-> (forall i. mem_forest i f -> c[i] = White)
121   predicate even_forest (f: forest) = match f with
122     | E -> false
123     | N _ f1 f2 -> not (even_forest f1) || even_forest f2
124     end
126   predicate final_forest (f: forest) (c: coloring) = match f with
127     | E -> true
128     | N i f1 f2 ->
129       c[i] = Black && final_forest f1 c &&
130       if not (even_forest f1) then white_forest f2 c
131       else final_forest f2 c
132     end
134   predicate any_forest (f: forest) (c: coloring) =
135     white_forest f c || final_forest f c
137   lemma any_forest_frame:
138     forall f c1 c2.
139     (forall i. mem_forest i f -> c1[i] = c2[i]) ->
140     (final_forest f c1 -> final_forest f c2) &&
141     (white_forest f c1 -> white_forest f c2)
143   predicate unchanged (st: stack) (c1 c2: coloring) =
144     forall i. mem_stack i st -> c1[i] = c2[i]
146   predicate inverse (st: stack) (c1 c2: coloring) =
147     match st with
148     | Nil -> true
149     | Cons f st' ->
150       (white_forest f c1 && final_forest f c2 ||
151        final_forest f c1 && white_forest f c2) &&
152       if even_forest f then
153         unchanged st' c1 c2
154       else
155         inverse st' c1 c2
156     end
158   predicate any_stack (st: stack) (c: coloring) = match st with
159     | Nil -> true
160     | Cons f st -> (white_forest f c || final_forest f c) && any_stack st c
161     end
163   lemma any_stack_frame:
164     forall st c1 c2.
165     unchanged st c1 c2 ->
166     any_stack st c1 -> any_stack st c2
168   lemma inverse_frame:
169     forall st c1 c2 c3.
170     inverse   st c1 c2 ->
171     unchanged st c2 c3 ->
172     inverse   st c1 c3
174   lemma inverse_frame2:
175     forall st c1 c2 c3.
176     unchanged st c1 c2 ->
177     inverse   st c2 c3 ->
178     inverse   st c1 c3
180   let lemma inverse_any (st: stack) (c1 c2: coloring)
181     requires { any_stack st c1 }
182     requires { inverse st c1 c2 }
183     ensures  { any_stack st c2 }
184   = ()
186   lemma inverse_final:
187     forall f st c1 c2.
188     final_forest f c1 ->
189     inverse (Cons f st) c1 c2 ->
190     white_forest f c2
192   lemma inverse_white:
193     forall f st c1 c2.
194     white_forest f c1 ->
195     inverse (Cons f st) c1 c2 ->
196     final_forest f c2
198   let lemma white_final_exclusive (f: forest) (c: coloring)
199     requires { f <> E }
200     ensures  { white_forest f c -> final_forest f c -> false }
201   = match f with E -> () | N _ _ _ -> () end
203   lemma final_unique:
204     forall f c1 c2.
205     final_forest f c1 ->
206     final_forest f c2 ->
207     forall i. mem_forest i f -> c1[i] = c2[i]
209   let rec lemma inverse_inverse
210     (st: stack) (c1 c2 c3: coloring)
211     requires { inverse st c1 c2 }
212     requires { inverse st c2 c3 }
213     ensures  { unchanged st c1 c3 }
214     variant  { st }
215   = match st with
216     | Nil -> ()
217     | Cons E st' -> inverse_inverse st' c1 c2 c3
218     | Cons f st' -> if even_forest f then () else inverse_inverse st' c1 c2 c3
219     end
221   inductive sub stack forest coloring =
222   | Sub_reflex:
223       forall f, c. sub (Cons f Nil) f c
224   | Sub_brother:
225       forall st i f1 f2 c.
226       sub st f2 c -> sub st (N i f1 f2) c
227   | Sub_append:
228       forall st i f1 f2 c.
229       sub st f1 c -> c[i] = Black ->
230       sub (st ++ Cons f2 Nil) (N i f1 f2) c
232   lemma sub_not_nil:
233     forall st f c. sub st f c -> st <> Nil
235   lemma sub_empty:
236     forall st f0 c. st <> Nil -> sub (Cons E st) f0 c ->
237     sub st f0 c
239   lemma sub_mem:
240     forall n st f c.
241     mem_stack n st -> sub st f c -> mem_forest n f
243   lemma sub_weaken1:
244     forall st i f1 f2 f0 c.
245     sub (Cons (N i f1 f2) st) f0 c ->
246     sub (Cons         f2  st) f0 c
248   lemma sub_weaken2:
249     forall st i f1 f2 f0 c.
250     sub (Cons (N i f1 f2) st) f0 c ->
251     c[i] = Black ->
252     sub (Cons f1 (Cons f2 st)) f0 c
254   lemma not_mem_st: forall i f st c.
255     not (mem_forest i f) -> sub st f c -> not (mem_stack i st)
257   lemma sub_frame:
258     forall st f0 c c'.
259     no_repeated_forest f0 ->
260     (forall i. not (mem_stack i st) -> mem_forest i f0 -> c'[i] = c[i]) ->
261     sub st f0 c ->
262     sub st f0 c'
264   predicate disjoint_stack (f: forest) (st: stack) =
265     forall i. mem_forest i f -> mem_stack i st -> false
267   lemma sub_no_rep: forall f st' f0 c.
268     sub (Cons f st') f0 c ->
269     no_repeated_forest f0 ->
270     no_repeated_forest f
272   lemma sub_no_rep2: forall f st' f0 c.
273     sub (Cons f st') f0 c ->
274     no_repeated_forest f0 ->
275     disjoint_stack f st'
277   lemma white_valid: forall f c.
278     white_forest f c -> valid_coloring f c
280   lemma final_valid: forall f c.
281     final_forest f c -> valid_coloring f c
283   lemma valid_coloring_frame:
284     forall f c1 c2.
285     valid_coloring f c1 ->
286     (forall i. mem_forest i f -> c2[i] = c1[i]) ->
287     valid_coloring f c2
289   lemma valid_coloring_set:
290     forall f i c.
291     valid_coloring f c ->
292     not (mem_forest i f) ->
293     valid_coloring f c[i <- Black]
295   lemma head_and_tail:
296     forall f1 f2: 'a, st1 st2: list 'a.
297     Cons f1 st1 = st2 ++ Cons f2 Nil ->
298     st2 <> Nil ->
299     exists st. st1 = st ++ Cons f2 Nil /\ st2 = Cons f1 st
301   lemma sub_valid_coloring_f1:
302     forall i f1 f2 c i1.
303     no_repeated_forest (N i f1 f2) ->
304     valid_coloring (N i f1 f2) c ->
305     c[i] = Black ->
306     mem_forest i1 f1 ->
307     valid_coloring f1 c[i1 <- Black] ->
308     valid_coloring (N i f1 f2) c[i1 <- Black]
310   lemma sub_valid_coloring:
311     forall f0 i f1 f2 st c1.
312     no_repeated_forest f0 ->
313     valid_coloring f0 c1 ->
314     sub (Cons (N i f1 f2) st) f0 c1 ->
315     valid_coloring f0 c1[i <- Black]
317   lemma sub_Cons_N:
318     forall f st i f1 f2 c.
319     sub (Cons f st) (N i f1 f2) c ->
320     f = N i f1 f2 || (exists st'. sub (Cons f st') f1 c) || sub (Cons f st) f2 c
322   lemma white_white:
323     forall f c i.
324     white_forest f c ->
325     white_forest f c[i <- White]
327   let rec lemma sub_valid_coloring_white
328     (f0: forest) (i: int) (f1 f2: forest) (c1: coloring)
329     requires { no_repeated_forest f0 }
330     requires { valid_coloring f0 c1 }
331     requires { white_forest f1 c1 }
332     ensures  { forall st. sub (Cons (N i f1 f2) st) f0 c1 ->
333                valid_coloring f0 c1[i <- White] }
334     variant  { f0 }
335   = match f0 with
336     | E -> ()
337     | N _ f10 f20 ->
338        sub_valid_coloring_white f10 i f1 f2 c1;
339        sub_valid_coloring_white f20 i f1 f2 c1
340     end
342   function count_stack (st: stack) : int = match st with
343     | Nil        -> 1
344     | Cons f st' -> count_forest f * count_stack st'
345     end
347   lemma count_stack_nonneg:
348     forall st. count_stack st >= 1
350   use seq.Seq as S
352   type visited = S.seq coloring
354   predicate stored_solutions
355     (f0: forest) (bits: coloring) (st: stack) (v1 v2: visited) =
356     let n = size_forest f0 in
357     let start = S.length v1 in
358     let stop  = S.length v2 in
359     stop - start = count_stack st &&
360     (forall j. 0 <= j < start ->
361       eq_coloring n (S.get v2 j) (S.get v1 j)) &&
362     forall j. start <= j < stop ->
363       valid_coloring f0 (S.get v2 j) &&
364       (forall i. 0 <= i < n -> not (mem_stack i st) ->
365         (S.get v2 j)[i] = bits[i]) &&
366       forall k. start <= k < stop -> j <> k ->
367         not (eq_coloring n (S.get v2 j) (S.get v2 k))
369   let lemma stored_trans1
370       (f0: forest) (bits1 bits2: coloring) (i: int) (f1 f2: forest) (st: stack)
371       (v1 v2 v3: visited)
372     requires { valid_nums_forest f0 (size_forest f0) }
373     requires { 0 <= i < size_forest f0 }
374     requires { forall j. 0 <= j < size_forest f0 ->
375                not (mem_stack j (Cons (N i f1 f2) st)) -> bits2[j] = bits1[j] }
376     requires { forall j. S.length v1 <= j < S.length v2 ->
377                (S.get v2 j)[i] = White }
378     requires { forall j. S.length v2 <= j < S.length v3 ->
379                (S.get v3 j)[i] = Black }
380     requires { stored_solutions f0 bits1 (Cons f2 st) v1 v2 }
381     requires { stored_solutions f0 bits2 (Cons f1 (Cons f2 st)) v2 v3 }
382     ensures  { stored_solutions f0 bits2 (Cons (N i f1 f2) st) v1 v3 }
383   = ()
386   let lemma stored_trans2
387       (f0: forest) (bits1 bits2: coloring) (i: int) (f1 f2: forest) (st: stack)
388       (v1 v2 v3: visited)
389     requires { valid_nums_forest f0 (size_forest f0) }
390     requires { 0 <= i < size_forest f0 }
391     requires { forall j. 0 <= j < size_forest f0 ->
392        not (mem_stack j (Cons (N i f1 f2) st)) -> bits2[j] = bits1[j] }
393     requires { forall j. S.length v1 <= j < S.length v2 ->
394                (S.get v2 j)[i] = Black }
395     requires { forall j. S.length v2 <= j < S.length v3 ->
396                (S.get v3 j)[i] = White }
397     requires { stored_solutions f0 bits1 (Cons f1 (Cons f2 st)) v1 v2 }
398     requires { stored_solutions f0 bits2 (Cons f2 st) v2 v3 }
399     ensures  { stored_solutions f0 bits2 (Cons (N i f1 f2) st) v1 v3 }
400   = ()
404 module KodaRuskey
406   use seq.Seq as S
407   use list.List
408   use KodaRuskey_Spec
409   use Lemmas
410   use map.Map as M
411   use array.Array
412   use int.Int
413   use ref.Ref
415   val ghost map_of_array (a: array 'a) : M.map int 'a
416     ensures { result = a.elts }
418   val ghost visited: ref visited
420   let rec enum (bits: array color) (ghost f0: forest) (st: list forest) : unit
421     requires { size_forest f0 = length bits }
422     requires { valid_nums_forest f0 (length bits) }
423     requires { sub st f0 bits.elts }
424     requires { st <> Nil }
425     requires { any_stack st bits.elts }
426     requires { valid_coloring f0 bits.elts }
427     variant  { size_stack st, st }
428     ensures  { forall i.
429                  not (mem_stack i st) -> bits[i] = (old bits)[i] }
430     ensures  { inverse st (old bits).elts bits.elts }
431     ensures  { valid_coloring f0 bits.elts }
432     ensures  { stored_solutions f0 bits.elts st (old !visited) !visited }
433   = match st with
434     | Nil ->
435         absurd
436     | Cons E st' ->
437        match st' with
438        | Nil ->
439            (* that's where we visit the next coloring *)
440            assert { valid_coloring f0 bits.elts };
441            ghost visited := S.snoc !visited (map_of_array bits);
442            ()
443        | _ ->
444            enum bits f0 st'
445        end
446     | Cons (N i f1 f2 as f) st' ->
447         assert { disjoint_stack f1 st' };
448         assert { not (mem_stack i st') };
449         let ghost visited1 = !visited in
450         if eq_color bits[i] White then begin
451           label A in
452           enum bits f0 (Cons f2 st');
453           assert { sub st f0 bits.elts };
454           let ghost bits1 = map_of_array bits in
455           let ghost visited2 = !visited in
456           label B in
457           bits[i] <- Black;
458           assert { sub st f0 bits.elts };
459           assert { white_forest f1 bits.elts };
460           assert { unchanged (Cons f2 st') (bits at B).elts bits.elts};
461           assert { inverse (Cons f2 st') (bits at A).elts bits.elts };
462           label C in
463           enum bits f0 (Cons f1 (Cons f2 st'));
464           assert { bits[i] = Black };
465           assert { final_forest f1 bits.elts };
466           assert { if not (even_forest f1)
467                    then inverse (Cons f2 st') (bits at C).elts bits.elts &&
468                         white_forest f2 bits.elts
469                    else unchanged (Cons f2 st') (bits at C).elts bits.elts &&
470                         final_forest f2 bits.elts };
471           ghost stored_trans1 f0 bits1 (map_of_array bits)
472     i f1 f2 st' visited1 visited2 !visited
473         end else begin
474           assert { if not (even_forest f1) then white_forest f2 bits.elts
475                    else final_forest f2 bits.elts };
476           label A in
477           enum bits f0 (Cons f1 (Cons f2 st'));
478           assert { sub st f0 bits.elts };
479           let ghost bits1 = map_of_array bits in
480           let ghost visited2 = !visited in
481           label B in
482           bits[i] <- White;
483           assert { unchanged (Cons f1 (Cons f2 st'))
484                      (bits at B).elts bits.elts };
485           assert { inverse (Cons f1 (Cons f2 st'))
486                      (bits at A).elts bits.elts };
487           assert { sub st f0 bits.elts };
488           assert { if even_forest f1 || even_forest f2
489                    then unchanged st' (bits at A).elts bits.elts
490                    else inverse st' (bits at A).elts bits.elts };
491           enum bits f0 (Cons f2 st');
492           assert { bits[i] = White };
493           assert { white_forest f  bits.elts };
494           ghost stored_trans2 f0 bits1 (map_of_array bits)
495     i f1 f2 st' visited1 visited2 !visited
496        end
497     end
499   let main (bits: array color) (f0: forest)
500     requires { white_forest f0 bits.elts }
501     requires { size_forest f0 = length bits }
502     requires { valid_nums_forest f0 (length bits) }
503     ensures  { S.length !visited = count_forest f0 }
504     ensures  { let n = S.length !visited  in
505                forall j. 0 <= j < n ->
506                  valid_coloring f0 (S.get !visited j) &&
507                  forall k. 0 <= k < n -> j <> k ->
508                    not (eq_coloring (length bits)
509                          (S.get !visited j) (S.get !visited k)) }
510   = visited := S.empty;
511     enum bits f0 (Cons f0 Nil)
515 (** Independently, let's prove that count_forest is indeed the number
516     of colorings. *)
518 (* wip
519 module CountCorrect
521   use seq.Seq as S
522   use map.Map as M
523   use map.Const
524   use list.List
525   use int.Int
526   use KodaRuskey_Spec
527   use Lemmas
528   use ref.Ref
530   predicate id_forest (f: forest) (c1 c2: coloring) =
531     forall j. mem_forest j f -> M.get c1 j = M.get c2 j
533   (* valid coloring, all white outside of f *)
534   predicate solution (f: forest) (c: coloring) =
535     valid_coloring f c &&
536     forall j. not (mem_forest j f) -> M.get c j = White
538   lemma solution_eq:
539     forall n f c1 c2.
540     valid_nums_forest f n ->
541     solution f c1 -> eq_coloring n c1 c2 -> solution f c2
543   predicate is_product (i: int) (f1 f2: forest) (c1 c2 r: coloring) =
544     solution (N i f1 f2) r &&
545     M.get r i = Black &&
546     id_forest f1 r c1 &&
547     id_forest f2 r c2
549   let product (n: int) (i: int) (f1 f2: forest) (c1 c2: coloring) : coloring
550     requires { valid_nums_forest (N i f1 f2) n }
551     requires { solution f1 c1 }
552     requires { solution f2 c2 }
553     ensures  { is_product i f1 f2 c1 c2 result }
554   = let rec copy (acc: coloring) (f: forest)
555       variant { f }
556       ensures { forall i. M.get result i =
557                   if mem_forest i f then M.get c2 i else M.get acc i }
558     = match f with
559       | E -> acc
560       | N i2 left2 right2 ->
561           M.set (copy (copy acc left2) right2) i2 (M.get c2 i2)
562       end
563     in
564     let c = copy c1 f2 in
565     M.set c i Black
567   lemma solution_product:
568     forall n i f1 f2 c1 c2 c.
569     valid_nums_forest (N i f1 f2) n ->
570     solution f1 c1 -> solution f2 c2 ->
571     is_product i f1 f2 c1 c2 c -> solution (N i f1 f2) c
573   predicate solutions (n: int) (f: forest) (s: seq coloring) =
574      (forall j. 0 <= j < length s -> solution f s[j]) &&
575      (* colorings are disjoint *)
576      (forall j. 0 <= j < length s ->
577         forall k. 0 <= k < length s -> j <> k ->
578         not (eq_coloring n s[j] s[k]))
580   let product_all (n: int) (i: int) (f1 f2: forest) (s1 s2: seq coloring)
581     : seq coloring
582     requires { valid_nums_forest (N i f1 f2) n }
583     requires { solutions n f1 s1 }
584     requires { solutions n f2 s2 }
585     ensures  { solutions n (N i f1 f2) result }
586     ensures  { forall j. 0 <= j < length s1 ->
587                forall k. 0 <= k < length s2 ->
588                is_product i f1 f2 s1[j] s2[k] result[j * length s2 + k] }
589     ensures  { length result = length s1 * length s2 }
590   = let s = ref empty in
591     for j = 0 to length s1 - 1 do
592       invariant { length !s = j * length s2 }
593       invariant { solutions n (N i f1 f2) !s }
594       invariant { forall j' k'. 0 <= j' < j -> 0 <= k' < length s2 ->
595                   let c = !s[j' * length s2 + k'] in
596                   is_product i f1 f2 s1[j'] s2[k'] c }
597       for k = 0 to length s2 - 1 do
598         invariant { length !s = j * length s2 + k }
599         invariant { solutions n (N i f1 f2) !s }
600         invariant { forall j' k'. 0 <= j' < j && 0 <= k' < length s2
601                                || j' = j && 0 <= k' < k ->
602                     let c = !s[j' * length s2 + k'] in
603                     is_product i f1 f2 s1[j'] s2[k'] c }
604         let p = product n i f1 f2 s1[j] s2[k] in
605         assert { forall l. 0 <= l < length !s ->
606                   not (eq_coloring n p !s[l]) };
607         s := snoc !s p
608       done
609     done;
610     !s
612   lemma solution_white_or_black:
613     forall n i f1 f2 c.
614     valid_nums_forest (N i f1 f2) n ->
615     solution (N i f1 f2) c ->
616     match M.get c i with
617     | White -> solution f2 c
618     | Black -> exists c1 c2. is_product i f1 f2 c1 c2 c &&
619                              solution f1 c1 && solution f2 c2
620     end
622   let rec enum (n: int) (f: forest) : seq coloring
623     requires { valid_nums_forest f n }
624     ensures  { length result = count_forest f }
625     ensures  { solutions n f result }
626     ensures  { forall c. solution f c <->
627                  exists j. 0 <= j < length result && eq_coloring n c result[j] }
628     variant  { f }
629   = match f with
630     | E ->
631         cons (const White) empty
632     | N i f1 f2 ->
633         let s1 = enum n f1 in
634         let s2 = enum n f2 in
635         s2 ++ product_all n i f1 f2 s1 s2
636    end