3 {1 VerifyThis @ ETAPS 2021 competition
4 Challenge 3: Shearsort}
5 See https://www.pm.inf.ethz.ch/research/verifythis.html
7 Author: Martin Clochard (ETH Zurich)
16 use int.ComputerDivision
20 use matrix.Matrix as M
24 function column (m: M.matrix 'a) (j:int) : int -> 'a = fun i -> m.M.elts i j
26 function moccf (x:'a) (e:int -> int -> 'a) (c:int) : int -> int =
27 fun (i:int) -> occ x (e i) 0 c
28 function mocc (x:'a) (e:int -> int -> 'a) (r c:int) : int =
31 let rec ghost transpose_count (x:'a) (e1 e2:int -> int -> 'a) (r c:int) : unit
32 requires { 0 <= r && 0 <= c }
33 requires { forall i j:int. 0 <= i < r /\ 0 <= j < c ->
35 ensures { mocc x e1 r c = mocc x e2 c r }
38 assert { mocc x e2 c r = 0
39 by forall j:int. 0 <= j && j < c ->
44 let f = pure{moccf x e2 r} in
45 let g = pure{moccf x e2 rm} in
46 let rec ghost scan (j:int) : unit
47 requires { 0 <= j <= c }
48 ensures { sum f 0 j = sum g 0 j + occ x (e1 rm) 0 j }
50 = if j <> 0 then scan (j-1)
53 transpose_count x e1 e2 rm c
56 val sort(a: array int) : unit
58 ensures { forall i j:int. 0 <= i <= j < a.length ->
60 ensures { permut a.elts (old a).elts 0 a.length }
62 let ghost permut_swap (a b:int -> 'a)(x y l u:int)
63 requires { l <= x < u && l <= y < u }
64 requires { forall i:int. l <= i <= u && i <> x && i <> y ->
66 requires { a x = b y && a y = b x }
67 ensures { permut a b l u }
68 = let c = MP.(a[x <- a y][y <- a x]) in
69 assert { MP.(a == a[x <- a x][y <- a y]) };
70 assert { permut a c l u };
71 assert { forall i: int. l <= i < u -> b i = c i };
72 assert { permut b c l u }
74 function compose (g:'b -> 'c) (f:'a -> 'b) : 'a -> 'c =
77 let rec ghost numoff_occ (a:int -> 'a) (l u:int) (p q:'a -> bool) (x:'a)
80 requires { forall y:'a. q y <-> p y /\ y <> x }
81 ensures { numof (compose p a) l u = numof (compose q a) l u + occ x a l u }
83 = if l <> u then numoff_occ a (l+1) u p q x
85 let rec ghost permut_numoff (a b:int -> 'a) (l u:int) (p:'a -> bool)
87 requires { permut a b l u }
88 ensures { numof (compose p a) l u = numof (compose p b) l u }
89 variant { numof (compose p b) l u, numof (compose p a) l u}
90 = if pure {numof (compose p a) l u } = 0 then begin
91 if pure {numof (compose p b) l u} <> 0 then permut_numoff b a l u p
93 let rec find (i:int) : int
95 requires { numof (compose p a) i u > 0 }
96 ensures { i <= result <= u }
97 ensures { compose p a result }
98 ensures { occ (a result) a i u > 0 }
100 = if compose p a i then i else find (i+1)
104 let q = MP.(p[v <- false]) in
105 numoff_occ a l u p q v;
106 numoff_occ b l u p q v;
107 permut_numoff a b l u q
110 let sort_row(m : M.matrix int) (i: int) (ascending: bool) : unit
111 requires { 0 <= i < m.M.rows }
113 ensures { forall j k:int.
114 0 <= j /\ j < m.M.rows /\ 0 <= k < m.M.columns /\ j <> i ->
115 m.M.elts j k = (old m).M.elts j k }
116 ensures { forall j k: int. 0 <= j <= k < m.M.columns ->
117 let a = m.M.elts i j in let b = m.M.elts i k in
118 if ascending then a <= b else b <= a }
119 ensures { permut (m.M.elts i) ((old m).M.elts i) 0 m.M.columns }
121 let a = make m.M.columns 0 in
122 for j = 0 to m.M.columns - 1 do
123 invariant { forall k:int. 0 <= k < j ->
124 a.elts k = m.M.elts i k }
127 assert { permut a.elts (m.M.elts i) 0 a.length };
130 if not(ascending) then begin
132 let ref v = a.length - 1 in
134 invariant { 0 <= u <= v + 1 <= a.length }
135 invariant { u + v = a.length - 1 }
136 invariant { forall j:int. u <= j <= v ->
137 a.elts j = (a.elts at L) j
139 invariant { forall j:int. 0 <= j < u \/ v < j < a.length ->
140 a.elts j = (a.elts at L) (a.length - 1 - j)
142 invariant { permut (a.elts at L) a.elts 0 a.length }
148 permut_swap a.elts e u v 0 a.length;
153 for j = 0 to m.M.columns - 1 do
154 invariant { forall k:int. 0 <= k < j ->
155 a.elts k = m.M.elts i k }
156 invariant { forall k l:int.
157 0 <= k < m.M.rows /\ 0 <= l < m.M.columns /\ k <> i ->
158 m.M.elts k l = (m at L).M.elts k l }
161 assert { permut a.elts (m.M.elts i) 0 a.length }
163 let sort_column(m : M.matrix int) (j: int) : unit
164 requires { 0 <= j < m.M.columns }
166 ensures { forall i k:int.
167 0 <= i /\ i < m.M.rows /\ 0 <= k < m.M.columns /\ k <> j ->
168 m.M.elts i k = (old m).M.elts i k }
169 ensures { forall i k: int. 0 <= i <= k < m.M.rows ->
170 m.M.elts i j <= m.M.elts k j }
171 ensures { permut (column m j) (column (old m) j) 0 m.M.rows }
172 = let a = make m.M.rows 0 in
173 for i = 0 to m.M.rows - 1 do
174 invariant { forall k:int. 0 <= k < i ->
175 a.elts k = m.M.elts k j }
178 assert { permut a.elts (column m j) 0 a.length };
181 for i = 0 to m.M.rows - 1 do
182 invariant { forall k:int. 0 <= k < i ->
183 a.elts k = m.M.elts k j }
184 invariant { forall k l:int.
185 0 <= k < m.M.rows /\ 0 <= l < m.M.columns /\ l <> j ->
186 m.M.elts k l = (m at L).M.elts k l }
189 assert { permut a.elts (column m j) 0 a.length }
191 (* Isolate the loops sorting all rows/columns in isolated functions
192 for modular verification (also makes easier to verify alternative
193 implementation with transpose) *)
194 let sort_all_rows(m : M.matrix int) : unit
196 ensures { forall i:int. 0 <= i < m.M.rows ->
197 permut (m.M.elts i) ((old m).M.elts i) 0 m.M.columns }
198 ensures { forall i j k:int. 0 <= i < m.M.rows /\
199 0 <= j <= k < m.M.columns ->
200 let a = m.M.elts i j in let b = m.M.elts i k in
201 if mod i 2 = 0 then a <= b else b <= a }
203 for tid = 0 to m.M.rows - 1 do
204 invariant { forall i:int. 0 <= i < m.M.rows ->
205 permut (m.M.elts i) ((m at L).M.elts i) 0 m.M.columns }
206 invariant { forall i j k:int. 0 <= i < tid /\
207 0 <= j <= k < m.M.columns ->
208 let a = m.M.elts i j in let b = m.M.elts i k in
209 if mod i 2 = 0 then a <= b else b <= a }
211 let ascending = mod tid 2 = 0 in
212 sort_row m tid ascending;
213 assert { forall i:int. 0 <= i < m.M.rows ->
214 permut (m.M.elts i) ((m at L2).M.elts i) 0 m.M.columns }
217 val transpose (m: M.matrix int) : unit
218 requires { m.M.rows = m.M.columns }
220 ensures { forall i j:int. 0 <= i < m.M.rows /\
221 0 <= j < m.M.columns -> (old m).M.elts i j = m.M.elts j i }
223 let sort_all_columns (m : M.matrix int) : unit
225 ensures { forall j:int. 0 <= j < m.M.columns ->
226 permut (column m j) (column (old m) j) 0 m.M.rows }
227 ensures { forall i j k:int. 0 <= i <= k < m.M.rows /\
228 0 <= j < m.M.columns ->
229 m.M.elts i j <= m.M.elts k j }
231 if any bool ensures { result \/ m.M.rows = m.M.columns } then begin
232 for tid = 0 to m.M.columns - 1 do
233 invariant { forall j:int. 0 <= j < m.M.columns ->
234 permut (column m j) (column (old m) j) 0 m.M.rows }
235 invariant { forall i j k:int. 0 <= i <= k < m.M.rows /\
236 0 <= j < tid -> m.M.elts i j <= m.M.elts k j }
239 assert { forall j:int. 0 <= j < m.M.columns ->
240 permut (column m j) (column (m at L2) j) 0 m.M.rows }
245 assert { forall j:int. 0 <= j < n ->
246 permut (column (old m) j) (m.M.elts j) 0 n
247 by forall i:int. 0 <= i < n ->
248 column (old m) j i = m.M.elts j i
250 for tid = 0 to m.M.columns - 1 do
251 invariant { forall j:int. 0 <= j < n ->
252 permut (m.M.elts j) (column (old m) j) 0 n }
253 invariant { forall i j k:int. 0 <= i <= k < n /\
254 0 <= j < tid -> m.M.elts j i <= m.M.elts j k }
257 assert { forall j:int. 0 <= j < n ->
258 permut (m.M.elts j) ((m.M.elts at L2) j) 0 n }
260 let et2 = pure{m.M.elts} in
263 forall j:int. 0 <= j < n ->
264 permut (column m j) (et2 j) 0 n
265 by forall i:int. 0 <= i < n ->
266 column m j i = et2 j i
270 predicate below_column (e:int -> int -> int) (col v row:int) =
272 predicate above_column (e:int -> int -> int) (col v row:int) =
275 let shear_sort(m: M.matrix int) : unit
277 forall i j1 j2 k:int.
278 0 <= i < k < m.M.rows &&
279 0 <= j1 < m.M.columns && 0 <= j2 < m.M.columns ->
280 m.M.elts i j1 <= m.M.elts k j2
284 0 <= i < m.M.rows && 0 <= j <= k < m.M.columns ->
285 if mod i 2 = 0 then m.M.elts i j <= m.M.elts i k else
286 m.M.elts i k <= m.M.elts i j
290 mocc x (old m.M.elts) m.M.rows m.M.columns
291 = mocc x m.M.elts m.M.rows m.M.columns
296 let ghost c = m.M.columns in
297 (* FIX: n need to be non-zero for the log to be computable ! *)
299 (* Compute log_2(n). *)
303 invariant { l >= 0 && p >= 0 }
304 invariant { l <> 0 -> power 2 (l-1) <= n }
305 invariant { p * power 2 l < n <= power 2 l * (p+1) }
310 assert { 2 * p <= q <= 2 * p + 1 };
311 assert { p * power 2 l
312 = (2 * p) * power 2 (l-1)
313 <= q * power 2 (l-1) < n };
314 assert { (p+1) * power 2 l
315 = (2 * p + 2) * power 2 (l-1)
316 >= (q + 1) * power 2 (l-1) >= n }
318 (* Check against defining property of ceil(log2(n)) *)
319 assert { power 2 l >= n };
320 assert { l <> 0 -> power 2 (l-1) <= n };
321 (* Maximum width of the gap between zero rows and one rows
322 under 0-1 abstractions. *)
323 let ghost ref k = n in
324 let ghost ref zeros = pure { fun (_:int) -> 0 } in
325 let ghost ref ones = pure { fun (_:int) -> n } in
326 let ghost column_sorted () : unit
327 requires { forall v:int. 0 <= zeros v <= ones v <= n }
328 requires { forall v:int. ones v <= zeros v + 1 }
329 requires { forall v i j:int.
330 0 <= i < zeros v /\ 0 <= j < c ->
332 requires { forall v i j:int.
333 ones v <= i < n /\ 0 <= j < c ->
335 ensures { forall i j1 j2 k:int.
336 0 <= i < k < n && 0 <= j1 < c && 0 <= j2 < c ->
337 m.M.elts i j1 <= m.M.elts k j2 }
339 let rec lemma aux (i j1 j2 k:int)
340 requires { 0 <= i < k < n && 0 <= j1 < c && 0 <= j2 < c }
341 ensures { m.M.elts i j1 <= m.M.elts k j2 }
342 = let v = m.M.elts k j2 in
343 if m.M.elts i j1 > v then begin
344 assert { ones v >= i &&
350 (* Repeat l+1 times. *)
352 invariant { forall v:int. 0 <= zeros v <= ones v <= n }
353 invariant { forall v:int. ones v <= zeros v + k }
354 invariant { forall v i j:int.
355 0 <= i < zeros v /\ 0 <= j < c ->
356 m.M.elts i j <= v (* p(v)(m.M.elts i j) = 0 *) }
357 invariant { forall v i j:int.
358 ones v <= i < n /\ 0 <= j < c ->
359 m.M.elts i j > v (* p(v)(m.M.elts i j) = 1 *) }
361 invariant { (k-1) * power 2 ind < n <= k * power 2 ind }
362 invariant { ind > l ->
364 0 <= i < n /\ 0 <= j <= k < c ->
365 let a = m.M.elts i j in
366 let b = m.M.elts i k in
367 if mod i 2 = 0 then a <= b else b <= a }
368 invariant { forall x:int.
369 mocc x (m.M.elts at L0) n c
370 = mocc x m.M.elts n c }
374 assert { forall v i j:int. 0 <= i < zeros v /\ 0 <= j < c ->
376 by occ (e i j) (e i) 0 c > 0
377 so occ (e i j) (e0 i) 0 c > 0
378 so exists k. 0 <= k /\ k < c /\ e0 i k = e i j
381 assert { forall v i j:int.
382 ones v <= i < n /\ 0 <= j < c ->
384 by occ (e i j) (e i) 0 c > 0
385 so occ (e i j) (e0 i) 0 c > 0
386 so exists k. 0 <= k /\ k < c /\ e0 i k = e i j
389 assert { forall x:int.
390 mocc x e0 n c = mocc x e n c
391 by forall i:int. 0 <= i < n ->
392 moccf x e0 c i = moccf x e c i
396 let kd = div (k+1) 2 in
397 let ghost function nzo (v:int) : (int,int)
398 ensures { match result with nz, no ->
399 0 <= nz <= no <= n &&
401 forall j:int. 0 <= j < c ->
402 numof (below_column e j v) 0 n >= nz /\
403 numof (above_column e j v) 0 n >= n - no
408 let rec lemma fillz (i:int) (j:int)
409 requires { 0 <= i <= z /\ 0 <= j < c }
410 ensures { numof (below_column e j v) 0 i >= i }
412 = if i <> 0 then begin
413 assert { below_column e j v (i-1) };
414 assert { numof (below_column e j v) (i-1) i = 1 };
419 let rec lemma fillo (i:int) (j:int)
420 requires { o <= i <= n /\ 0 <= j < c }
421 ensures { numof (above_column e j v) i n >= n - i }
423 = if i <> n then begin
424 assert { above_column e j v i };
425 assert { numof (above_column e j v) i (i+1) = 1 };
432 while index + 1 < o do
433 invariant { z <= index /\ index <= o }
434 invariant { nz >= z /\ no <= o }
435 invariant { index - z = 2 * (nz - z + o - no) }
436 invariant { forall j:int. 0 <= j < c ->
437 numof (below_column e j v) 0 index >= nz /\
438 numof (above_column e j v) 0 index
439 + numof (above_column e j v) o n >= n - no }
440 variant { o - index }
441 let rec select (r1 r2:int -> int) (b:bool) : bool
442 requires { forall j k:int. 0 <= j <= k < c ->
443 if b then r1 j <= r1 k else r1 k <= r1 j }
444 requires { forall j k:int. 0 <= j <= k < c ->
445 if b then r2 k <= r2 j else r2 j <= r2 k }
446 ensures { result -> forall i:int. 0 <= i < c ->
447 r1 i <= v || r2 i <= v }
448 ensures { not(result) -> forall i:int. 0 <= i < c ->
449 r1 i > v || r2 i > v }
450 variant { if b then 0 else 1 }
452 if not(b) then select r2 r1 true else
454 while i <> c && r1 i <= v && r2 i > v do
455 invariant { 0 <= i <= c }
456 invariant { forall j:int. 0 <= j < i ->
457 r1 j <= v && r2 j > v }
463 if select (e index) (e (index + 1)) (mod index 2 = 0) then begin
465 assert { forall j:int. 0 <= j < c ->
466 numof (below_column e j v) index (index+2) > 0 }
469 assert { forall j:int. 0 <= j < c ->
470 numof (above_column e j v) index (index+2) > 0 }
476 let ghost function newz (v:int) : int
477 ensures { result = match nzo v with (x,_) -> x end }
478 = match nzo v with (x,_) -> x end
480 let ghost function newo (v:int) : int
481 ensures { result = match nzo v with (_,y) -> y end }
482 = match nzo v with (_,y) -> y end
484 let lemma newzo (v:int)
485 ensures { let nz = newz v in let no = newo v in
486 0 <= nz <= no <= n &&
488 forall j:int. 0 <= j < c ->
489 numof (below_column e j v) 0 n >= nz /\
490 numof (above_column e j v) 0 n >= n - no
492 = let (x,y) = nzo v in assert { newz v = x /\ newo v = y }
494 let cl1 = pure {column m} in
495 assert { forall v j:int. 0 <= j < c ->
496 MP.(below_column e j v == compose ((>=) v) (cl1 j))
497 /\ MP.(above_column e j v == compose ((<) v) (cl1 j)) };
498 if ind = l then column_sorted ();
503 let cl2 = pure {column m} in
504 let rec lemma column_permutation (x:int) : unit
505 ensures { mocc x (m.M.elts at L) n c = mocc x e n c }
506 = transpose_count x e cl1 n c;
507 transpose_count x (pure{m.M.elts at L}) cl2 n c;
508 assert { mocc x cl1 c n = mocc x cl2 c n
509 by forall j:int. 0 <= j < c ->
510 moccf x cl1 n j = moccf x cl2 n j }
512 if ind = l then begin
513 let lemma column_preserved (j:int)
514 requires { 0 <= j < c }
515 ensures { forall i:int. 0 <= i < n ->
516 m.M.elts i j = e i j }
517 = let rec ghost no_occ (m:int -> int) (i:int) (x:int) : unit
518 requires { 0 <= i <= n }
519 requires { forall u v:int. i <= u <= v < n ->
521 requires { forall u:int. i <= u < n -> m u > x }
522 ensures { occ x m i n = 0 }
524 = if i <> n then no_occ m (i+1) x
526 let rec ghost scan (i:int) : unit
527 requires { 0 <= i <= n }
528 requires { permut (cl1 j) (cl2 j) i n }
529 ensures { forall k:int. i <= k < n ->
533 let u = cl1 j i in let v = cl2 j i in
534 if u < v then no_occ (cl2 j) i u else
535 if v < u then no_occ (cl1 j) i v else
543 assert { forall v j:int. 0 <= j < c ->
544 MP.(below_column e j v == compose ((>=) v) (cl2 j))
545 /\ MP.(above_column e j v == compose ((<) v) (cl2 j)) };
546 let rec ghost auxT (v i j:int) : unit
547 requires { 0 <= i < n }
548 requires { numof (below_column e j v) (i+1) n > 0 /\ 0 <= j < c }
549 requires { cl2 j i > v }
555 let lemma auxT1 (v i j:int) : unit
556 requires { 0 <= i < zeros v }
557 requires { 0 <= j < c }
558 ensures { cl2 j i <= v }
559 = if cl2 j i > v then begin
560 permut_numoff (cl1 j) (cl2 j) 0 n ((>=) v);
562 let a = numof (below_column e j v) 0 i in
563 let b = numof (below_column e j v) i n in
564 numof (below_column e j v) 0 n = a + b
571 let rec ghost auxB (v i j:int) : unit
572 requires { 0 <= i < n }
573 requires { numof (above_column e j v) 0 i > 0 /\ 0 <= j < c }
574 requires { cl2 j i <= v }
580 let lemma auxB1 (v i j:int) : unit
581 requires { ones v <= i < n }
582 requires { 0 <= j < c }
583 ensures { cl2 j i > v }
584 = if cl2 j i <= v then begin
585 permut_numoff (cl1 j) (cl2 j) 0 n ((<) v);
587 let a = numof (above_column e j v) 0 (i+1) in
588 let b = numof (above_column e j v) (i+1) n in
589 numof (above_column e j v) 0 n = a + b
596 assert { 2 * kd <= k+1 <= 2 * kd + 1 };
597 assert { (kd-1) * power 2 (ind+1)
598 = (2 * kd - 2) * power 2 ind
599 <= (k-1) * power 2 ind < n };
600 assert { kd * power 2 (ind+1)
601 = 2 * kd * power 2 ind
602 >= k * power 2 ind >= n };
611 (* Duplicated from Why3's gallery. *)
617 use array.IntArraySorted
619 use array.ArrayPermut
622 predicate qs_partition (a1 a2: array int) (l m r: int) (v: int) =
623 permut_sub a1 a2 l r /\
624 (forall j: int. l <= j < m -> a2[j] < v) /\
625 (forall j: int. m < j < r -> v <= a2[j]) /\
628 (* partitioning à la Bentley, that is
631 +-+----------+----------+----------+
633 +-+----------+----------+----------+ *)
635 let rec quick_rec (a: array int) (l: int) (r: int) : unit
636 requires { 0 <= l <= r <= length a }
637 ensures { sorted_sub a l r }
638 ensures { permut_sub (old a) a l r }
640 = if l + 1 < r then begin
644 for i = l + 1 to r - 1 do
645 invariant { a[l] = v /\ l <= !m < i }
646 invariant { forall j:int. l < j <= !m -> a[j] < v }
647 invariant { forall j:int. !m < j < i -> a[j] >= v }
648 invariant { permut_sub (a at L) a l r }
650 if a[i] < v then begin
653 assert { permut_sub (a at K) a l r }
658 assert { qs_partition (a at M) a l !m r v };
661 assert { qs_partition (a at N) a l !m r v };
663 quick_rec a (!m + 1) r;
664 assert { qs_partition (a at O) a l !m r v };
665 assert { qs_partition (a at N) a l !m r v };
668 let quicksort (a: array int) =
670 ensures { permut_all (old a) a }
671 quick_rec a 0 (length a)
675 (* Instantiate leftover sort routine. *)
676 module ShearSortComplete
679 clone ShearSort with val sort = quicksort