Restore previous display for sum and product signs in ASCII art mode;
[maxima.git] / share / nelder_mead / neldermead.lisp
blob94199eb5d2203514c0a388da50a8f5f5d9bc59dc
1 (in-package :neldermead)
3 (defclass cached-simplex-data ()
4 ((pseudopivot :initform nil)
5 (q-factor :initform nil)
6 (side-vectors :initform nil)
7 (r-factor :initform nil)
8 (dv :initform nil
9 :initarg :dv
10 :accessor dv)
11 (dmin :initform nil
12 :initarg :dmin
13 :accessor dmin)
14 (best-before-reshape :initform nil
15 :initarg :best-before-reshape
16 :accessor best-before-reshape)))
18 (defclass nm-simplex ()
19 ((x :accessor x
20 :initarg :x)
21 (fx :accessor fx
22 :initarg :fx)
23 (pmap :accessor pmap
24 :initarg :pmap)
25 (data :accessor data
26 :initform nil)))
28 (defclass grid ()
29 ((z :accessor grid-z :initarg :z)
30 (delta :accessor delta :initarg :delta)))
32 (defvar *verbose-level* 1)
34 (defmethod print-object ((s nm-simplex) stream)
35 (format stream "#<NM-SIMPLEX")
36 (case *verbose-level*
37 (2 (let* ((fx (fx s))
38 (pmap (pmap s))
39 (fmt "~12,3,3E")
40 (n (length pmap)))
42 (dotimes (i n)
43 (terpri stream)
44 (if fx (format stream fmt (fk s i))
45 (format stream " -- "))
47 (format stream " :")
48 (dotimes (j (- n 1))
49 (format stream fmt (aref (xk s i) j))))))
50 (1 (format stream " D=~A Best=~A " (dimension s)
51 (if (fx s)
52 (format nil "~16,7,3E" (fk s 0))
53 "--"))))
54 (format stream ">"))
56 (defmethod xk ((s nm-simplex) k)
57 (aref (x s) (aref (pmap s) k)))
59 (defmethod (setf xk) (nv (s nm-simplex) k)
60 (setf (aref (x s) (aref (pmap s) k)) nv))
62 (defmethod fk ((s nm-simplex) k)
63 (aref (fx s) (aref (pmap s) k)))
65 (defmethod (setf fk) (nv (s nm-simplex) k)
66 (setf (aref (fx s) (aref (pmap s) k)) nv))
68 (defmethod sort-simplex ((s nm-simplex))
69 (sort (pmap s) #'< :key #'(lambda (k) (aref (fx s) k)))
72 (defmethod dimension ((s nm-simplex))
73 (- (length (pmap s)) 1))
75 ;; simple additive simplex generator
76 (defun initial-simplex (x0
77 &key (displace 0.1d0))
78 (let* ((n (length x0))
79 (x (make-array (+ n 1)))
80 (pmap (make-array (+ n 1) :element-type 'fixnum)))
82 (setf (aref x 0) x0
83 (aref pmap 0) 0)
85 (dotimes (k n)
86 (let ((xk (copy-seq x0)))
88 (incf (aref xk k)
89 (if (numberp displace) displace
90 (aref displace k)))
91 (setf (aref x (+ k 1)) xk
92 (aref pmap (+ k 1)) (+ k 1))))
94 (make-instance 'nm-simplex
95 :x x :fx nil :pmap pmap)))
97 ;; The simplex generator from the article, more or less.
98 (defun default-initial-simplex (x0)
99 (initial-simplex x0
100 :displace
101 (map 'vector
102 #'(lambda (v)
103 (max 0.00025d0
104 (abs (* v 0.05d0))))
105 x0)))
107 (defmethod maybe-fill-simplex ((s nm-simplex) f)
108 (if (fx s) s
109 (let* ((n (length (x s)))
110 (fx (make-array n :element-type 'double-float)))
111 (loop :for i :from 0 :below n :do
112 (setf (aref fx i) (funcall f (xk s i))))
114 (setf (fx s) fx)
115 s)))
117 ;; substitutes the worst point with x/fx. Assumes simplex is sorted.
118 (defmethod improve ((s nm-simplex) x fx)
119 (let ((last (length x)))
120 (setf (xk s last) x
121 (fk s last) fx
123 (data s) nil)
125 (sort-simplex s)))
127 (defmethod cached-slot ((s nm-simplex) slot computer)
128 (if (and (data s) (slot-value (data s) slot))
129 (slot-value (data s) slot)
130 (setf (data s) (if (data s) (data s)
131 (make-instance 'cached-simplex-data))
132 (slot-value (data s) slot) (funcall computer))))
134 (defmethod pseudopivot ((s nm-simplex))
135 (cached-slot
136 s 'pseudopivot
138 #'(lambda ()
139 (let* ((n (- (length (pmap s)) 1))
140 (xbar (make-array n
141 :element-type 'double-float
142 :initial-element 0.0d0)))
144 (dotimes (i n)
145 (setf xbar (v+w*c xbar (xk s (+ i 1)) (/ 1.0d0 n))))
147 xbar))))
149 (defmethod side-vectors ((s nm-simplex))
150 (cached-slot
151 s 'side-vectors
153 #'(lambda ()
154 (let* ((n (- (length (pmap s)) 1))
155 (sv (make-array n)))
157 (dotimes (i n)
158 (setf (aref sv i) (v+w*c (xk s (+ i 1)) (xk s 0) -1)))
160 sv))))
162 (defun simplex-qr-thing (sidev)
163 (let* ((n (length sidev))
164 (norms (make-array n :element-type 'double-float))
165 (pmap (make-array n :element-type 'fixnum))
166 (mat (make-array (list n n) :element-type 'double-float)))
168 (dotimes (i n)
169 (setf (aref pmap i) i
170 (aref norms i) (norm (aref sidev i))))
172 (sort pmap #'> :key #'(lambda (k) (aref norms k)))
174 (dotimes (i n)
175 (dotimes (j n)
176 (setf (aref mat i j)
177 (aref (aref sidev (aref pmap j)) i))))
179 (multiple-value-bind (r q) (qr-factorization mat)
180 (values q r pmap))))
182 (defun qrthing-closure (s n)
183 #'(lambda ()
184 (multiple-value-bind (q r p)
185 (simplex-qr-thing (side-vectors s))
186 (setf (slot-value (data s) 'q-factor) q
187 (slot-value (data s) 'r-factor) r)
189 (elt (list q r p) n))))
191 (defmethod q-factor ((s nm-simplex))
192 (cached-slot s 'q-factor
194 (qrthing-closure s 0)))
196 (defmethod r-factor ((s nm-simplex))
197 (cached-slot s 'r-factor
198 (qrthing-closure s 1)))
200 ;; A Nelder-Mead iteration.
201 (defun nm-iteration
202 (simplex f &key
203 verbose
204 (gamma_reflect 1.0d0)
205 (gamma_expand 2.0d0)
206 (gamma_outer_contraction 0.5d0)
207 (gamma_inner_contraction -0.5d0)
208 (gamma_shrink 0.5d0))
211 (let* ((n (- (length (x simplex)) 1))
213 (x_cb (make-array n
214 :element-type 'double-float
215 :initial-element 0.0d0))
216 (x_cb-x_n (make-array n
217 :element-type 'double-float
218 :initial-element 0.0d0)))
220 (labels ((newpoint (gamma)
221 (let ((np (v+w*c x_cb x_cb-x_n gamma)))
222 (values np (funcall f np))))
224 (accept (xx ff)
225 (improve simplex xx ff))
227 (shrink ()
228 (let ((x0 (xk simplex 0)))
229 (loop for i from 1 to n do
230 (let* ((newx (v+w*c (v*c x0
231 (- 1.0d0 gamma_shrink))
232 (xk simplex i) gamma_shrink))
233 (newf (funcall f newx)))
235 (setf (xk simplex i) newx
236 (fk simplex i) newf)
238 (sort-simplex simplex))))))
240 ;; compute centroid
241 (dotimes (i n)
242 (setf x_cb (v+w*c x_cb (xk simplex i) (/ 1.0d0 n))))
244 (setf x_cb-x_n (v+w*c x_cb (xk simplex n) -1.0d0))
246 ;; 2. Reflect
247 (multiple-value-bind (xr fr)
248 (newpoint gamma_reflect)
249 (if (and (<= (fk simplex 0) fr) (< fr (fk simplex (- n 1))))
250 (accept xr fr)
251 ;; 3. expand
252 (if (< fr (fk simplex 0))
253 (multiple-value-bind (xe fe)
254 (newpoint gamma_expand)
255 (if (< fe fr)
256 (accept xe fe)
257 (accept xr fr)))
258 ;; 4. contract or shrink
259 (if (<= (fk simplex (- n 1)) fr)
260 (if (< fr (fk simplex n))
261 ;;outer contraction
262 (multiple-value-bind (xc fc)
263 (newpoint gamma_outer_contraction)
264 (if (or (= n 2)
265 (<= fc fr)) ;; 1D shrink is
266 ;; equivalent to
267 ;; acceptance of this
268 ;; point
269 (accept xc fc)
270 (shrink)))
271 ;; inner contraction
272 (multiple-value-bind (xcc fcc)
273 (newpoint gamma_inner_contraction)
274 (if (< fcc (fk simplex n)) (accept xcc fcc)
275 (shrink))))
277 (shrink)))))
279 (when verbose (format t "~S~%" simplex))
280 simplex)))
282 ;; The test returns true if the volume of the paralelepiped spanned by
283 ;; the vertices of the simplex is smaller than that of an n-cube of
284 ;; side cside.
285 (defun pp-volume-test (cside)
286 #'(lambda (simplex)
287 (let* ((mat (r-factor simplex)))
289 (let ((det 1.0d0))
290 (dotimes (i (dimension simplex))
291 (setf det (* det (aref mat i i))))
293 (< (abs det) (expt cside (dimension simplex)))))))
295 (defun nm-optimize (objective-function initial-guess &key
296 (max-function-calls 100000)
297 (convergence-p (burmen-et-al-convergence-test))
298 verbose)
300 (let ((simplex (if (typep initial-guess 'nm-simplex)
301 initial-guess
302 (default-initial-simplex initial-guess)))
303 (fvcount 0))
304 (labels ((rigged-f (v)
305 (incf fvcount)
306 (funcall objective-function v))
307 (converged-p (s)
308 (or (funcall convergence-p s)
309 (> fvcount max-function-calls))))
311 (when verbose
312 (format t "Initial simplex: ~%~A~%---~%" simplex))
314 (maybe-fill-simplex simplex #'rigged-f)
316 (loop :until (converged-p simplex)
318 (nm-iteration simplex #'rigged-f :verbose verbose))
320 (values (xk simplex 0) (fk simplex 0) simplex fvcount))))
322 (defmethod restrict ((grid grid) point)
323 (let ((new (copy-seq point))
324 (n (length point))
325 (delta (delta grid))
326 (z (grid-z grid)))
328 (dotimes (i n)
329 (setf (aref new i)
330 (+ (* (aref delta i)
331 (floor
332 (+ (/ (- (aref new i)
333 (aref z i))
334 (aref delta i))
335 0.5d0)))
336 (aref z i))))
338 new))
341 ;; Some parameters. Look at Burmen et al for further details.
342 (defparameter *psi* 1.0d-6)
343 (defparameter *biglambda* (/ 0.5d0 double-float-epsilon))
344 (defparameter *tau-r* (* 2.0d0 double-float-epsilon))
345 (defparameter *tau-a* (expt least-positive-double-float (/ 1.0d0 3.0d0)))
346 (defparameter *smalllambda* 2)
348 (defvar *breakdown*)
350 (defmethod maybe-reshape ((s nm-simplex) (g grid) ff &key force)
351 (let ((n (length (side-vectors s)))
352 (biglambda *biglambda*)
353 (smalllambda *smalllambda*)
354 (reshaped-p nil))
356 (labels ((degenerate-p ()
357 (let* ((r (r-factor s))
358 (ax (loop for i from 0 below n
359 minimizing (abs (aref r i i)))))
361 (< ax
362 (/ (* *psi*
363 (norm (delta g))
364 (sqrt (float n 1.0d0)))
365 2)))))
367 (when (or force (degenerate-p))
369 (setf reshaped-p t)
371 (let ((r (r-factor s))
372 (q (q-factor s))
373 (|det| 1.0d0)
374 (dv (make-array (+ n 1)))
375 (dmin-norm nil)
376 (dmin nil)
377 (|Delta| (norm (delta g))))
379 (setf (aref dv 0) (xk s 0))
381 (dotimes (i n)
382 (let* ((di (make-array n :element-type 'double-float))
383 (rii (aref r i i))
384 (sgn[rii] (if (>= rii 0) 1 -1))
385 (|rii| (abs rii))
386 (quot (* (sqrt (float n 0.0d0))
387 |Delta|
388 0.5d0))
389 (minc (min |rii|
390 (* biglambda quot)))
391 (maxc (max (* smalllambda quot) minc))
392 (dfkt (* sgn[rii] maxc)))
394 (setf |det| (* |det| |rii|))
396 (dotimes (j n)
397 (setf (aref di j)
398 (* dfkt (aref q j i))))
400 (when (or (not dmin-norm) (< (abs dfkt) dmin-norm))
401 (setf dmin-norm (abs dfkt)
402 dmin di))
404 (setf (aref dv (+ i 1)) di)
406 (let* ((nxi (restrict g (v+w*c (xk s 0) di 1.0d0)))
407 (fxi (funcall ff nxi)))
409 (setf (xk s (+ i 1)) nxi
410 (fk s (+ i 1)) fxi))))
412 ;; Looks like a very extreme situation, but can actually
413 ;; happen.
414 (when (= |det| 0.0d0)
415 (setf *breakdown* t))
417 ;; Most cached data is invalid, but better keep the reshape
418 ;; data, which might be needed for shrinking the simplex.
419 (setf (data s)
420 (make-instance 'cached-simplex-data
421 :dv dv
422 :dmin dmin
423 :best-before-reshape (fk s 0)))
425 (sort-simplex s))))
427 reshaped-p))
429 ;;; Nelder Mead iteration variant from Burmen et al. Does not shrink,
430 ;;; "failing" instead. Acceptance criteria for the contraction points
431 ;;; are stricter.
433 ;;; Iterates are restricted to a grid.
434 (defun nm-iteration-burmen-et-al
435 (simplex f grid &key
436 verbose
437 (gamma_reflect 1.0d0)
438 (gamma_expand 2.0d0)
439 (gamma_outer_contraction 0.5d0)
440 (gamma_inner_contraction -0.5d0))
443 (let ((n (- (length (x simplex)) 1))
444 (failure t))
446 (let ((x_cb (make-array n
447 :element-type 'double-float
448 :initial-element 0.0d0))
449 (x_cb-x_n (make-array n
450 :element-type 'double-float
451 :initial-element 0.0d0)))
453 (labels ((newpoint (gamma)
454 (let ((np (restrict grid
455 (v+w*c x_cb x_cb-x_n gamma))))
456 (values np (funcall f np))))
458 (accept (xx ff)
459 (improve simplex xx ff)
460 (setf failure nil)))
462 ;; compute centroid
463 (dotimes (i n)
464 (setf x_cb (v+w*c x_cb (xk simplex i) (/ 1.0d0 n))))
466 (setf x_cb-x_n (v+w*c x_cb (xk simplex n) -1.0d0))
468 ;; 2. Reflect
469 (multiple-value-bind (xr fr)
470 (newpoint gamma_reflect)
471 (if (and (<= (fk simplex 0) fr) (< fr (fk simplex (- n 1))))
472 (accept xr fr)
473 ;; 3. expand
474 (if (< fr (fk simplex 0))
475 (multiple-value-bind (xe fe)
476 (newpoint gamma_expand)
477 (if (< fe fr)
478 (accept xe fe)
479 (accept xr fr)))
480 ;; 4. contract or fail ABurmen & al have swapped
481 ;; inner and outers - maybe a typo. (?) No, just a
482 ;; different variant.
483 (if (<= (fk simplex (- n 1)) fr)
484 (if (< fr (fk simplex n))
485 ;;outer contraction
486 (multiple-value-bind (xc fc)
487 (newpoint gamma_outer_contraction)
488 (if (<= fc (fk simplex (- n 1)))
489 (accept xc fc)
491 ;; inner contraction
492 (multiple-value-bind (xcc fcc)
493 (newpoint gamma_inner_contraction)
494 (if (< fcc (fk simplex (- n 1)))
495 (accept xcc fcc))))))))))
497 (sort-simplex simplex)
498 (when verbose (format t "~S~%" simplex))
499 (values simplex failure)))
501 ;; Convergence test used in the article
502 (defun burmen-et-al-convergence-test (&key
503 (tol-x 1.0d-8)
504 (tol-f 1.0d-15)
505 (rel 1.0d-15))
507 #'(lambda (s)
508 (let* ((sv (side-vectors s))
509 (fdiff (abs (- (fk s 0) (fk s (dimension s)))))
510 (vijmax (loop for v across sv maximizing
511 (loop for x across v maximizing
512 (abs x))))
513 (xbest (xk s 0))
514 (|xi|max (loop for z across xbest maximizing (abs z))))
516 (and (< fdiff (max tol-f (* rel (fk s 0))))
517 (< vijmax (max tol-x (* rel |xi|max)))))))
521 (defun regrid (g x1 dmin fkt)
522 (let* ((dmin (v*c dmin fkt))
523 (n (length dmin))
524 (|dmin| (norm dmin))
525 (delta (delta g))
526 (newd (copy-seq delta)))
528 (setf (grid-z g) x1)
530 (dotimes (i (length dmin))
531 (setf (aref newd i)
533 (max (min (max (/ (abs (aref dmin i))
534 (* *smalllambda* 250 n))
536 (/ |dmin|
537 (* *smalllambda* 250 (expt n (/ 3.0d0 2.0d0)))))
538 (aref delta i))
539 (* *tau-r* (aref x1 i))
541 *tau-a*)))
543 (setf (delta g) newd)))
545 (defun deep-shrink (f s g gamma_s convergence-p verbose)
546 (let ((bbr (fk s 0)))
547 (unless (dv (data s))
548 (maybe-reshape s g f :force t))
550 (let* ((dv (dv (data s)))
551 (n (dimension s))
552 (dmin (dmin (data s)))
553 (|dmin| (norm dmin))
554 (converged-p nil)
555 (l 1)) ;; This does the same as in the article
557 (loop :until (or (< (fk s 0) bbr)
558 (setf converged-p (funcall convergence-p s)))
559 :do (let ((fkt (* (expt gamma_s (floor l 2))
560 (expt -1 l))))
562 (when verbose
563 (format t "Shrinking by factor: ~A~%" fkt))
565 (when (= 0.0d0 fkt)
566 ;; You've got no simplex anymore
567 (setf *breakdown* t))
569 (when (< (* (abs fkt) |dmin|)
570 (* (/ *smalllambda* 2)
571 (sqrt n)
572 (norm (delta g))))
574 (regrid g (xk s 0) dmin fkt))
576 (setf (xk s 0) (aref dv 0)
577 (fk s 0) bbr)
579 (loop for k from 1 below (length (x s)) do
580 (setf (xk s k) (restrict g
581 (v+w*c (aref dv 0)
582 (aref dv k) fkt))
583 (fk s k) (funcall f (xk s k))))
585 (sort-simplex s)
586 (incf l)))
588 (setf (data s) nil)
590 converged-p)))
592 (defun grnm-optimize (objective-function initial-guess &key
593 max-function-calls
594 (convergence-p (burmen-et-al-convergence-test))
595 verbose)
597 (let ((simplex (if (typep initial-guess 'nm-simplex)
598 initial-guess
599 (default-initial-simplex initial-guess)))
600 (fvcount 0)
601 (gamma_s 0.5d0)
602 (*breakdown* nil))
604 (labels ((rigged-f (v)
605 (incf fvcount)
606 (funcall objective-function v))
607 (converged-p (s)
608 (or (funcall convergence-p s)
609 *breakdown*
610 (and max-function-calls (> fvcount max-function-calls)))))
611 (when verbose
612 (format t "Initial simplex: ~%~A~%---~%" simplex))
614 (maybe-fill-simplex simplex #'rigged-f)
616 (prog (failure xbest fbest pbest reshaped-p
618 (grid (make-instance 'grid
619 :z (xk simplex 0)
620 :delta
621 (make-array (dimension simplex)
622 :initial-element
623 (/ (loop :for i :from 1 :to
624 (dimension simplex)
625 :minimizing
626 (norm
627 (v+w*c
628 (xk simplex 0)
629 (xk simplex i) -1.0d0)))
630 10.0d0)))))
633 iterate ;; 1.
634 (setf (values simplex failure)
635 (nm-iteration-burmen-et-al simplex #'rigged-f grid
636 :verbose verbose))
638 ;; Small variation. We test for convergence here too. Depending on
639 ;; the convergence criterion, this might be spurious, so we have to
640 ;; reshape, etc.
641 (unless (or failure (converged-p simplex))
642 (go iterate))
644 ;; 2.
645 (setf xbest (xk simplex 0)
646 fbest (fk simplex 0)
647 pbest (aref (pmap simplex) 0)
649 reshaped-p (maybe-reshape simplex grid #'rigged-f))
652 ;; 4. Here we look at the pseudo expand point
653 (let* ((pep (pseudopivot simplex))
654 (xx (v+w*c pep
655 (v+w*c pep (xk simplex 0) -1.0d0)
656 1.0d0)) ;; To do: clear this up
657 (fpep (rigged-f xx)))
659 (if (<= fbest (min (fk simplex 0) fpep))
660 (go deep-shrink)
661 (when (< fpep fbest)
662 ;; There is something subtle here. The pseudo-expand point
663 ;; is supposed to substitute the old (xk 0), which might
664 ;; have changed *position* during reshape, because of the
665 ;; simplex being sorted.
667 ;; So it is not
668 ;;(improve simplex xx fpep)
669 ;; here, but instead
671 ;; Step 5.
672 (setf (aref (x simplex) pbest) xx
673 (aref (fx simplex) pbest) fpep
674 (data simplex) nil)
676 (sort-simplex simplex)
678 ;; Step 6.
679 (go iterate))))
681 deep-shrink
683 (unless
684 (deep-shrink #'rigged-f simplex grid
686 ;;(+ 0.1d0 (random 0.8d0))
687 gamma_s
689 #'converged-p verbose)
690 (go iterate))
692 end)
693 (values (xk simplex 0) (fk simplex 0) simplex fvcount))))
696 ;; Some sample functions to play around.
698 (defun standard-quadratic (v)
699 (loop for i from 0 below (length v)
700 summing (expt (aref v i) 2)))
702 (defun rosenbrock (v)
703 (let ((x (aref v 0))
704 (y (aref v 1)))
706 (+ (expt (- 1 x) 2)
707 (* 100 (expt (- y (expt x 2)) 2)))))