autotest: set parameter value so context pop resets it
[ardupilot.git] / Tools / FilterTestTool / FilterTest.py
bloba9a39fe11de99bb2824224c3aa9be6b7d8425b5b
1 # -*- coding: utf-8 -*-
3 """ ArduPilot IMU Filter Test Class
5 This program is free software: you can redistribute it and/or modify it under
6 the terms of the GNU General Public License as published by the Free Software
7 Foundation, either version 3 of the License, or (at your option) any later
8 version.
9 This program is distributed in the hope that it will be useful, but WITHOUT
10 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11 FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
12 You should have received a copy of the GNU General Public License along with
13 this program. If not, see <http://www.gnu.org/licenses/>.
14 """
16 __author__ = "Guglielmo Cassinelli"
17 __contact__ = "gdguglie@gmail.com"
19 import numpy as np
20 import matplotlib.pyplot as plt
21 from matplotlib.widgets import Slider
22 from matplotlib.animation import FuncAnimation
23 from scipy import signal
24 from BiquadFilter import BiquadFilterType, BiquadFilter
26 sliders = [] # matplotlib sliders must be global
27 anim = None # matplotlib animations must be global
30 class FilterTest:
31 FILTER_DEBOUNCE = 10 # ms
33 FILT_SHAPE_DT_FACTOR = 1 # increase to reduce filter shape size
35 FFT_N = 512
37 filters = {}
39 def __init__(self, acc_t, acc_x, acc_y, acc_z, gyr_t, gyr_x, gyr_y, gyr_z, acc_freq, gyr_freq,
40 acc_lpf_cutoff, gyr_lpf_cutoff,
41 acc_notch_freq, acc_notch_att, acc_notch_band,
42 gyr_notch_freq, gyr_notch_att, gyr_notch_band,
43 log_name, accel_notch=False, second_notch=False):
45 self.filter_color_map = plt.get_cmap('summer')
47 self.filters["acc"] = [
48 BiquadFilter(acc_lpf_cutoff, acc_freq)
51 if accel_notch:
52 self.filters["acc"].append(
53 BiquadFilter(acc_notch_freq, acc_freq, BiquadFilterType.PEAK, acc_notch_att, acc_notch_band),
56 self.filters["gyr"] = [
57 BiquadFilter(gyr_lpf_cutoff, gyr_freq),
58 BiquadFilter(gyr_notch_freq, gyr_freq, BiquadFilterType.PEAK, gyr_notch_att, gyr_notch_band)
61 if second_notch:
62 self.filters["acc"].append(
63 BiquadFilter(acc_notch_freq * 2, acc_freq, BiquadFilterType.PEAK, acc_notch_att, acc_notch_band)
65 self.filters["gyr"].append(
66 BiquadFilter(gyr_notch_freq * 2, gyr_freq, BiquadFilterType.PEAK, gyr_notch_att, gyr_notch_band)
69 self.ACC_t = acc_t
70 self.ACC_x = acc_x
71 self.ACC_y = acc_y
72 self.ACC_z = acc_z
74 self.GYR_t = gyr_t
75 self.GYR_x = gyr_x
76 self.GYR_y = gyr_y
77 self.GYR_z = gyr_z
79 self.GYR_freq = gyr_freq
80 self.ACC_freq = acc_freq
82 self.gyr_dt = 1. / gyr_freq
83 self.acc_dt = 1. / acc_freq
85 self.timer = None
87 self.updated_artists = []
89 # INIT
90 self.init_plot(log_name)
92 def test_acc_filters(self):
93 filt_xs = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_x)
94 filt_ys = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_y)
95 filt_zs = self.test_filters(self.filters["acc"], self.ACC_t, self.ACC_z)
96 return filt_xs, filt_ys, filt_zs
98 def test_gyr_filters(self):
99 filt_xs = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_x)
100 filt_ys = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_y)
101 filt_zs = self.test_filters(self.filters["gyr"], self.GYR_t, self.GYR_z)
102 return filt_xs, filt_ys, filt_zs
104 def test_filters(self, filters, Ts, Xs):
105 for f in filters:
106 f.reset()
108 x_filtered = []
110 for i, t in enumerate(Ts):
111 x = Xs[i]
113 x_f = x
114 for filt in filters:
115 x_f = filt.apply(x_f)
117 x_filtered.append(x_f)
119 return x_filtered
121 def get_filter_shape(self, filter):
122 samples = int(filter.get_sample_freq()) # resolution of filter shape based on sample rate
123 x_space = np.linspace(0.0, samples // 2, samples // int(2 * self.FILT_SHAPE_DT_FACTOR))
124 return x_space, filter.freq_response(x_space)
126 def init_signal_plot(self, ax, Ts, Xs, Ys, Zs, Xs_filtered, Ys_filtered, Zs_filtered, label):
127 ax.plot(Ts, Xs, linewidth=1, label="{}X".format(label), alpha=0.5)
128 ax.plot(Ts, Ys, linewidth=1, label="{}Y".format(label), alpha=0.5)
129 ax.plot(Ts, Zs, linewidth=1, label="{}Z".format(label), alpha=0.5)
130 filtered_x_ax, = ax.plot(Ts, Xs_filtered, linewidth=1, label="{}X filtered".format(label), alpha=1)
131 filtered_y_ax, = ax.plot(Ts, Ys_filtered, linewidth=1, label="{}Y filtered".format(label), alpha=1)
132 filtered_z_ax, = ax.plot(Ts, Zs_filtered, linewidth=1, label="{}Z filtered".format(label), alpha=1)
133 ax.legend(prop={'size': 8})
134 return filtered_x_ax, filtered_y_ax, filtered_z_ax
136 def fft_to_xdata(self, fft):
137 n = len(fft)
138 norm_factor = 2. / n
139 return norm_factor * np.abs(fft[:n // 2])
141 def plot_fft(self, ax, x, fft, label):
142 fft_ax, = ax.plot(x, self.fft_to_xdata(fft), label=label)
143 return fft_ax
145 def init_fft(self, ax, Ts, Xs, Ys, Zs, sample_rate, dt, Xs_filtered, Ys_filtered, Zs_filtered, label):
147 _freqs_raw_x, _times_raw_x, _stft_raw_x = signal.stft(Xs, sample_rate, window='hann', nperseg=self.FFT_N)
148 raw_fft_x = np.average(np.abs(_stft_raw_x), axis=1)
150 _freqs_raw_y, _times_raw_y, _stft_raw_y = signal.stft(Ys, sample_rate, window='hann', nperseg=self.FFT_N)
151 raw_fft_y = np.average(np.abs(_stft_raw_y), axis=1)
153 _freqs_raw_z, _times_raw_z, _stft_raw_z = signal.stft(Zs, sample_rate, window='hann', nperseg=self.FFT_N)
154 raw_fft_z = np.average(np.abs(_stft_raw_z), axis=1)
156 _freqs_x, _times_x, _stft_x = signal.stft(Xs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
157 filtered_fft_x = np.average(np.abs(_stft_x), axis=1)
159 _freqs_y, _times_y, _stft_y = signal.stft(Ys_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
160 filtered_fft_y = np.average(np.abs(_stft_y), axis=1)
162 _freqs_z, _times_z, _stft_z = signal.stft(Zs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
163 filtered_fft_z = np.average(np.abs(_stft_z), axis=1)
165 ax.plot(_freqs_raw_x, raw_fft_x, alpha=0.5, linewidth=1, label="{}x FFT".format(label))
166 ax.plot(_freqs_raw_y, raw_fft_y, alpha=0.5, linewidth=1, label="{}y FFT".format(label))
167 ax.plot(_freqs_raw_z, raw_fft_z, alpha=0.5, linewidth=1, label="{}z FFT".format(label))
169 filtered_fft_ax_x, = ax.plot(_freqs_x, filtered_fft_x, label="filt. {}x FFT".format(label))
170 filtered_fft_ax_y, = ax.plot(_freqs_y, filtered_fft_y, label="filt. {}y FFT".format(label))
171 filtered_fft_ax_z, = ax.plot(_freqs_z, filtered_fft_z, label="filt. {}z FFT".format(label))
173 # FFT
174 # samples = len(Ts)
175 # x_space = np.linspace(0.0, 1.0 / (2.0 * dt), samples // 2)
176 # filtered_data = np.hanning(len(Xs_filtered)) * Xs_filtered
177 # raw_fft = np.fft.fft(np.hanning(len(Xs)) * Xs)
178 # filtered_fft = np.fft.fft(filtered_data, n=self.FFT_N)
179 # self.plot_fft(ax, x_space, raw_fft, "{} FFT".format(label))
180 # fft_freq = np.fft.fftfreq(self.FFT_N, d=dt)
181 # x_space
182 # filtered_fft_ax = self.plot_fft(ax, fft_freq[:self.FFT_N // 2], filtered_fft, "filtered {} FFT".format(label))
184 ax.set_xlabel("frequency")
185 # ax.set_xscale("log")
186 # ax.xaxis.set_major_formatter(ScalarFormatter())
187 ax.legend(prop={'size': 8})
189 return filtered_fft_ax_x, filtered_fft_ax_y, filtered_fft_ax_z
191 def init_filter_shape(self, ax, filter, color):
192 center = filter.get_center_freq()
193 x_space, lpf_shape = self.get_filter_shape(filter)
195 plot_slpf_shape, = ax.plot(x_space, lpf_shape, c=color, label="LPF shape")
196 xvline_lpf_cutoff = ax.axvline(x=center, linestyle="--", c=color) # LPF cutoff freq
198 return plot_slpf_shape, xvline_lpf_cutoff
200 def create_slider(self, name, rect, max, value, color, callback):
201 global sliders
202 ax_slider = self.fig.add_axes(rect, facecolor='lightgoldenrodyellow')
203 slider = Slider(ax_slider, name, 0, max, valinit=np.sqrt(max * value), valstep=1, color=color)
204 slider.valtext.set_text(value)
206 # slider.drawon = False
208 def changed(val, cbk, max, slider):
209 # non linear slider to better control small values
210 val = int(val ** 2 / max)
211 slider.valtext.set_text(val)
212 cbk(val)
214 slider.on_changed(lambda val, cbk=callback, max=max, s=slider: changed(val, cbk, max, s))
215 sliders.append(slider)
217 def delay_update(self, update_cbk):
218 def _delayed_update(self, cbk):
219 self.timer.stop()
220 cbk()
222 # delay actual filtering
223 if self.fig:
224 if self.timer:
225 self.timer.stop()
226 self.timer = self.fig.canvas.new_timer(interval=self.FILTER_DEBOUNCE)
227 self.timer.add_callback(lambda self=self: _delayed_update(self, update_cbk))
228 self.timer.start()
230 def update_filter_shape(self, filter, shape, center_line):
231 x_data, new_shape = self.get_filter_shape(filter)
233 shape.set_ydata(new_shape)
234 center_line.set_xdata(filter.get_center_freq())
236 self.updated_artists.extend([
237 shape,
238 center_line,
241 def update_signal_and_fft_plot(self, filters_key, time_list, sample_lists, signal_shapes, fft_shapes, shape,
242 center_line, sample_rate):
243 # print("update_signal_and_fft_plot", self.filters[filters_key][0].get_center_freq())
244 Xs, Ys, Zs = sample_lists
245 signal_shape_x, signal_shape_y, signal_shape_z = signal_shapes
246 fft_shape_x, fft_shape_y, fft_shape_z = fft_shapes
248 Xs_filtered = self.test_filters(self.filters[filters_key], time_list, Xs)
249 Ys_filtered = self.test_filters(self.filters[filters_key], time_list, Ys)
250 Zs_filtered = self.test_filters(self.filters[filters_key], time_list, Zs)
252 signal_shape_x.set_ydata(Xs_filtered)
253 signal_shape_y.set_ydata(Ys_filtered)
254 signal_shape_z.set_ydata(Zs_filtered)
256 self.updated_artists.extend([signal_shape_x, signal_shape_y, signal_shape_z])
258 _freqs_x, _times_x, _stft_x = signal.stft(Xs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
259 filtered_fft_x = np.average(np.abs(_stft_x), axis=1)
261 _freqs_y, _times_y, _stft_y = signal.stft(Ys_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
262 filtered_fft_y = np.average(np.abs(_stft_y), axis=1)
264 _freqs_z, _times_z, _stft_z = signal.stft(Zs_filtered, sample_rate, window='hann', nperseg=self.FFT_N)
265 filtered_fft_z = np.average(np.abs(_stft_z), axis=1)
267 fft_shape_x.set_ydata(filtered_fft_x)
268 fft_shape_y.set_ydata(filtered_fft_y)
269 fft_shape_z.set_ydata(filtered_fft_z)
271 self.updated_artists.extend([
272 fft_shape_x, fft_shape_y, fft_shape_z,
273 shape, center_line,
276 # self.fig.canvas.draw()
278 def animation_update(self):
279 updated_artists = self.updated_artists.copy()
281 # if updated_artists:
282 # print("animation update")
284 # reset updated artists
285 self.updated_artists = []
287 return updated_artists
289 def update_filter(self, val, cbk, filter, shape, center_line, filters_key, time_list, sample_lists, signal_shapes,
290 fft_shapes):
291 # this callback sets the parameter controlled by the slider
292 cbk(val)
293 # print("filter update",val)
294 # update filter shape and delay fft update
295 self.update_filter_shape(filter, shape, center_line)
296 sample_freq = filter.get_sample_freq()
297 self.delay_update(
298 lambda self=self: self.update_signal_and_fft_plot(filters_key, time_list, sample_lists, signal_shapes,
299 fft_shapes, shape, center_line, sample_freq))
301 def create_filter_control(self, name, filter, rect, max, default, shape, center_line, cbk, filters_key, time_list,
302 sample_lists, signal_shapes, fft_shapes, filt_color):
303 self.create_slider(name, rect, max, default, filt_color, lambda val, cbk=cbk, self=self, filter=filter, shape=shape,
304 center_line=center_line, filters_key=filters_key,
305 time_list=time_list, sample_list=sample_lists,
306 signal_shape=signal_shapes, fft_shape=fft_shapes:
307 self.update_filter(val, cbk, filter, shape, center_line, filters_key,
308 time_list, sample_list, signal_shape, fft_shape))
310 def create_controls(self, filters_key, base_rect, padding, ax_fft, time_list, sample_lists, signal_shapes,
311 fft_shapes):
312 ax_filter = ax_fft.twinx()
313 ax_filter.set_navigate(False)
314 ax_filter.set_yticks([])
316 num_filters = len(self.filters[filters_key])
318 for i, filter in enumerate(self.filters[filters_key]):
319 filt_type = filter.get_type()
320 filt_color = self.filter_color_map(i / num_filters)
321 filt_shape, filt_cutoff = self.init_filter_shape(ax_filter, filter, filt_color)
323 if filt_type == BiquadFilterType.PEAK:
324 name = "Notch"
325 else:
326 name = "LPF"
328 # control for center freq is common to all filters
329 self.create_filter_control("{} freq".format(name), filter, base_rect, 500, filter.get_center_freq(),
330 filt_shape, filt_cutoff,
331 lambda val, filter=filter: filter.set_center_freq(val),
332 filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
333 # move down of control height + padding
334 base_rect[1] -= (base_rect[3] + padding)
336 if filt_type == BiquadFilterType.PEAK:
337 self.create_filter_control("{} att (db)".format(name), filter, base_rect, 100, filter.get_attenuation(),
338 filt_shape, filt_cutoff,
339 lambda val, filter=filter: filter.set_attenuation(val),
340 filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
341 base_rect[1] -= (base_rect[3] + padding)
342 self.create_filter_control("{} band".format(name), filter, base_rect, 300, filter.get_bandwidth(),
343 filt_shape, filt_cutoff,
344 lambda val, filter=filter: filter.set_bandwidth(val),
345 filters_key, time_list, sample_lists, signal_shapes, fft_shapes, filt_color)
346 base_rect[1] -= (base_rect[3] + padding)
348 def create_spectrogram(self, data, name, sample_rate):
349 freqs, times, Sx = signal.spectrogram(np.array(data), fs=sample_rate, window='hanning',
350 nperseg=self.FFT_N, noverlap=self.FFT_N - self.FFT_N // 10,
351 detrend=False, scaling='spectrum')
353 f, ax = plt.subplots(figsize=(4.8, 2.4))
354 ax.pcolormesh(times, freqs, 10 * np.log10(Sx), cmap='viridis')
355 ax.set_title(name)
356 ax.set_ylabel('Frequency (Hz)')
357 ax.set_xlabel('Time (s)')
359 def init_plot(self, log_name):
361 self.fig = plt.figure(figsize=(14, 9))
362 self.fig.canvas.set_window_title("ArduPilot Filter Test Tool - {}".format(log_name))
363 self.fig.canvas.draw()
365 rows = 2
366 cols = 3
367 raw_acc_index = 1
368 fft_acc_index = raw_acc_index + 1
369 raw_gyr_index = cols + 1
370 fft_gyr_index = raw_gyr_index + 1
372 # signal
373 self.ax_acc = self.fig.add_subplot(rows, cols, raw_acc_index)
374 self.ax_gyr = self.fig.add_subplot(rows, cols, raw_gyr_index, sharex=self.ax_acc)
376 accx_filtered, accy_filtered, accz_filtered = self.test_acc_filters()
377 self.ax_filtered_accx, self.ax_filtered_accy, self.ax_filtered_accz = self.init_signal_plot(self.ax_acc,
378 self.ACC_t,
379 self.ACC_x,
380 self.ACC_y,
381 self.ACC_z,
382 accx_filtered,
383 accy_filtered,
384 accz_filtered,
385 "AccX")
387 gyrx_filtered, gyry_filtered, gyrz_filtered = self.test_gyr_filters()
388 self.ax_filtered_gyrx, self.ax_filtered_gyry, self.ax_filtered_gyrz = self.init_signal_plot(self.ax_gyr,
389 self.GYR_t,
390 self.GYR_x,
391 self.GYR_y,
392 self.GYR_z,
393 gyrx_filtered,
394 gyry_filtered,
395 gyrz_filtered,
396 "GyrX")
398 # FFT
399 self.ax_acc_fft = self.fig.add_subplot(rows, cols, fft_acc_index)
400 self.ax_gyr_fft = self.fig.add_subplot(rows, cols, fft_gyr_index)
402 self.acc_filtered_fft_ax_x, self.acc_filtered_fft_ax_y, self.acc_filtered_fft_ax_z = self.init_fft(
403 self.ax_acc_fft, self.ACC_t, self.ACC_x, self.ACC_y, self.ACC_z, self.ACC_freq, self.acc_dt, accx_filtered,
404 accy_filtered, accz_filtered, "AccX")
405 self.gyr_filtered_fft_ax_x, self.gyr_filtered_fft_ax_y, self.gyr_filtered_fft_ax_z = self.init_fft(
406 self.ax_gyr_fft, self.GYR_t, self.GYR_x, self.GYR_y, self.GYR_z, self.GYR_freq, self.gyr_dt, gyrx_filtered,
407 gyry_filtered, gyrz_filtered, "GyrX")
409 self.fig.tight_layout()
411 # TODO add y z
412 self.create_controls("acc", [0.75, 0.95, 0.2, 0.02], 0.01, self.ax_acc_fft, self.ACC_t,
413 (self.ACC_x, self.ACC_y, self.ACC_z),
414 (self.ax_filtered_accx, self.ax_filtered_accy, self.ax_filtered_accz),
415 (self.acc_filtered_fft_ax_x, self.acc_filtered_fft_ax_y, self.acc_filtered_fft_ax_z))
416 self.create_controls("gyr", [0.75, 0.45, 0.2, 0.02], 0.01, self.ax_gyr_fft, self.GYR_t,
417 (self.GYR_x, self.GYR_y, self.GYR_z),
418 (self.ax_filtered_gyrx, self.ax_filtered_gyry, self.ax_filtered_gyrz),
419 (self.gyr_filtered_fft_ax_x, self.gyr_filtered_fft_ax_y, self.gyr_filtered_fft_ax_z))
421 # setup animation for continuous update
422 global anim
423 anim = FuncAnimation(self.fig, lambda frame, self=self: self.animation_update(), interval=1, blit=False)
425 # Work in progress here...
426 # self.create_spectrogram(self.GYR_x, "GyrX", self.GYR_freq)
427 # self.create_spectrogram(gyrx_filtered, "GyrX filtered", self.GYR_freq)
428 # self.create_spectrogram(self.ACC_x, "AccX", self.ACC_freq)
429 # self.create_spectrogram(accx_filtered, "AccX filtered", self.ACC_freq)
431 plt.show()
433 self.print_filter_param_info()
435 def print_filter_param_info(self):
436 if len(self.filters["acc"]) > 2 or len(self.filters["gyr"]) > 2:
437 print("Testing too many filters unsupported from firmware, cannot calculate parameters to set them")
438 return
440 print("To have the last filter settings in the graphs set the following parameters:\n")
442 for f in self.filters["acc"]:
443 filt_type = f.get_type()
445 if filt_type == BiquadFilterType.PEAK: # NOTCH
446 print("INS_NOTCA_ENABLE,", 1)
447 print("INS_NOTCA_FREQ,", f.get_center_freq())
448 print("INS_NOTCA_BW,", f.get_bandwidth())
449 print("INS_NOTCA_ATT,", f.get_attenuation())
450 else: # LPF
451 print("INS_ACCEL_FILTER,", f.get_center_freq())
453 for f in self.filters["gyr"]:
454 filt_type = f.get_type()
456 if filt_type == BiquadFilterType.PEAK: # NOTCH
457 print("INS_HNTC2_ENABLE,", 1)
458 print("INS_HNTC2_FREQ,", f.get_center_freq())
459 print("INS_HNTC2_BW,", f.get_bandwidth())
460 print("INS_HNTC2_ATT,", f.get_attenuation())
461 else: # LPF
462 print("INS_GYRO_FILTER,", f.get_center_freq())
464 print("\n+---------+")
465 print("| WARNING |")
466 print("+---------+")
467 print("Always check the onboard FFT to setup filters, this tool only simulate effects of filtering.")