1 # -*- coding: utf-8 -*-
3 """ ArduPilot IMU Filter Test Class
5 This program is free software: you can redistribute it and/or modify it under
6 the terms of the GNU General Public License as published by the Free Software
7 Foundation, either version 3 of the License, or (at your option) any later
9 This program is distributed in the hope that it will be useful, but WITHOUT
10 ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11 FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
12 You should have received a copy of the GNU General Public License along with
13 this program. If not, see <http://www.gnu.org/licenses/>.
16 __author__
= "Guglielmo Cassinelli"
17 __contact__
= "gdguglie@gmail.com"
20 import matplotlib
.pyplot
as plt
21 from matplotlib
.widgets
import Slider
22 from matplotlib
.animation
import FuncAnimation
23 from scipy
import signal
24 from BiquadFilter
import BiquadFilterType
, BiquadFilter
26 sliders
= [] # matplotlib sliders must be global
27 anim
= None # matplotlib animations must be global
31 FILTER_DEBOUNCE
= 10 # ms
33 FILT_SHAPE_DT_FACTOR
= 1 # increase to reduce filter shape size
39 def __init__(self
, acc_t
, acc_x
, acc_y
, acc_z
, gyr_t
, gyr_x
, gyr_y
, gyr_z
, acc_freq
, gyr_freq
,
40 acc_lpf_cutoff
, gyr_lpf_cutoff
,
41 acc_notch_freq
, acc_notch_att
, acc_notch_band
,
42 gyr_notch_freq
, gyr_notch_att
, gyr_notch_band
,
43 log_name
, accel_notch
=False, second_notch
=False):
45 self
.filter_color_map
= plt
.get_cmap('summer')
47 self
.filters
["acc"] = [
48 BiquadFilter(acc_lpf_cutoff
, acc_freq
)
52 self
.filters
["acc"].append(
53 BiquadFilter(acc_notch_freq
, acc_freq
, BiquadFilterType
.PEAK
, acc_notch_att
, acc_notch_band
),
56 self
.filters
["gyr"] = [
57 BiquadFilter(gyr_lpf_cutoff
, gyr_freq
),
58 BiquadFilter(gyr_notch_freq
, gyr_freq
, BiquadFilterType
.PEAK
, gyr_notch_att
, gyr_notch_band
)
62 self
.filters
["acc"].append(
63 BiquadFilter(acc_notch_freq
* 2, acc_freq
, BiquadFilterType
.PEAK
, acc_notch_att
, acc_notch_band
)
65 self
.filters
["gyr"].append(
66 BiquadFilter(gyr_notch_freq
* 2, gyr_freq
, BiquadFilterType
.PEAK
, gyr_notch_att
, gyr_notch_band
)
79 self
.GYR_freq
= gyr_freq
80 self
.ACC_freq
= acc_freq
82 self
.gyr_dt
= 1. / gyr_freq
83 self
.acc_dt
= 1. / acc_freq
87 self
.updated_artists
= []
90 self
.init_plot(log_name
)
92 def test_acc_filters(self
):
93 filt_xs
= self
.test_filters(self
.filters
["acc"], self
.ACC_t
, self
.ACC_x
)
94 filt_ys
= self
.test_filters(self
.filters
["acc"], self
.ACC_t
, self
.ACC_y
)
95 filt_zs
= self
.test_filters(self
.filters
["acc"], self
.ACC_t
, self
.ACC_z
)
96 return filt_xs
, filt_ys
, filt_zs
98 def test_gyr_filters(self
):
99 filt_xs
= self
.test_filters(self
.filters
["gyr"], self
.GYR_t
, self
.GYR_x
)
100 filt_ys
= self
.test_filters(self
.filters
["gyr"], self
.GYR_t
, self
.GYR_y
)
101 filt_zs
= self
.test_filters(self
.filters
["gyr"], self
.GYR_t
, self
.GYR_z
)
102 return filt_xs
, filt_ys
, filt_zs
104 def test_filters(self
, filters
, Ts
, Xs
):
110 for i
, t
in enumerate(Ts
):
115 x_f
= filt
.apply(x_f
)
117 x_filtered
.append(x_f
)
121 def get_filter_shape(self
, filter):
122 samples
= int(filter.get_sample_freq()) # resolution of filter shape based on sample rate
123 x_space
= np
.linspace(0.0, samples
// 2, samples
// int(2 * self
.FILT_SHAPE_DT_FACTOR
))
124 return x_space
, filter.freq_response(x_space
)
126 def init_signal_plot(self
, ax
, Ts
, Xs
, Ys
, Zs
, Xs_filtered
, Ys_filtered
, Zs_filtered
, label
):
127 ax
.plot(Ts
, Xs
, linewidth
=1, label
="{}X".format(label
), alpha
=0.5)
128 ax
.plot(Ts
, Ys
, linewidth
=1, label
="{}Y".format(label
), alpha
=0.5)
129 ax
.plot(Ts
, Zs
, linewidth
=1, label
="{}Z".format(label
), alpha
=0.5)
130 filtered_x_ax
, = ax
.plot(Ts
, Xs_filtered
, linewidth
=1, label
="{}X filtered".format(label
), alpha
=1)
131 filtered_y_ax
, = ax
.plot(Ts
, Ys_filtered
, linewidth
=1, label
="{}Y filtered".format(label
), alpha
=1)
132 filtered_z_ax
, = ax
.plot(Ts
, Zs_filtered
, linewidth
=1, label
="{}Z filtered".format(label
), alpha
=1)
133 ax
.legend(prop
={'size': 8})
134 return filtered_x_ax
, filtered_y_ax
, filtered_z_ax
136 def fft_to_xdata(self
, fft
):
139 return norm_factor
* np
.abs(fft
[:n
// 2])
141 def plot_fft(self
, ax
, x
, fft
, label
):
142 fft_ax
, = ax
.plot(x
, self
.fft_to_xdata(fft
), label
=label
)
145 def init_fft(self
, ax
, Ts
, Xs
, Ys
, Zs
, sample_rate
, dt
, Xs_filtered
, Ys_filtered
, Zs_filtered
, label
):
147 _freqs_raw_x
, _times_raw_x
, _stft_raw_x
= signal
.stft(Xs
, sample_rate
, window
='hann', nperseg
=self
.FFT_N
)
148 raw_fft_x
= np
.average(np
.abs(_stft_raw_x
), axis
=1)
150 _freqs_raw_y
, _times_raw_y
, _stft_raw_y
= signal
.stft(Ys
, sample_rate
, window
='hann', nperseg
=self
.FFT_N
)
151 raw_fft_y
= np
.average(np
.abs(_stft_raw_y
), axis
=1)
153 _freqs_raw_z
, _times_raw_z
, _stft_raw_z
= signal
.stft(Zs
, sample_rate
, window
='hann', nperseg
=self
.FFT_N
)
154 raw_fft_z
= np
.average(np
.abs(_stft_raw_z
), axis
=1)
156 _freqs_x
, _times_x
, _stft_x
= signal
.stft(Xs_filtered
, sample_rate
, window
='hann', nperseg
=self
.FFT_N
)
157 filtered_fft_x
= np
.average(np
.abs(_stft_x
), axis
=1)
159 _freqs_y
, _times_y
, _stft_y
= signal
.stft(Ys_filtered
, sample_rate
, window
='hann', nperseg
=self
.FFT_N
)
160 filtered_fft_y
= np
.average(np
.abs(_stft_y
), axis
=1)
162 _freqs_z
, _times_z
, _stft_z
= signal
.stft(Zs_filtered
, sample_rate
, window
='hann', nperseg
=self
.FFT_N
)
163 filtered_fft_z
= np
.average(np
.abs(_stft_z
), axis
=1)
165 ax
.plot(_freqs_raw_x
, raw_fft_x
, alpha
=0.5, linewidth
=1, label
="{}x FFT".format(label
))
166 ax
.plot(_freqs_raw_y
, raw_fft_y
, alpha
=0.5, linewidth
=1, label
="{}y FFT".format(label
))
167 ax
.plot(_freqs_raw_z
, raw_fft_z
, alpha
=0.5, linewidth
=1, label
="{}z FFT".format(label
))
169 filtered_fft_ax_x
, = ax
.plot(_freqs_x
, filtered_fft_x
, label
="filt. {}x FFT".format(label
))
170 filtered_fft_ax_y
, = ax
.plot(_freqs_y
, filtered_fft_y
, label
="filt. {}y FFT".format(label
))
171 filtered_fft_ax_z
, = ax
.plot(_freqs_z
, filtered_fft_z
, label
="filt. {}z FFT".format(label
))
175 # x_space = np.linspace(0.0, 1.0 / (2.0 * dt), samples // 2)
176 # filtered_data = np.hanning(len(Xs_filtered)) * Xs_filtered
177 # raw_fft = np.fft.fft(np.hanning(len(Xs)) * Xs)
178 # filtered_fft = np.fft.fft(filtered_data, n=self.FFT_N)
179 # self.plot_fft(ax, x_space, raw_fft, "{} FFT".format(label))
180 # fft_freq = np.fft.fftfreq(self.FFT_N, d=dt)
182 # filtered_fft_ax = self.plot_fft(ax, fft_freq[:self.FFT_N // 2], filtered_fft, "filtered {} FFT".format(label))
184 ax
.set_xlabel("frequency")
185 # ax.set_xscale("log")
186 # ax.xaxis.set_major_formatter(ScalarFormatter())
187 ax
.legend(prop
={'size': 8})
189 return filtered_fft_ax_x
, filtered_fft_ax_y
, filtered_fft_ax_z
191 def init_filter_shape(self
, ax
, filter, color
):
192 center
= filter.get_center_freq()
193 x_space
, lpf_shape
= self
.get_filter_shape(filter)
195 plot_slpf_shape
, = ax
.plot(x_space
, lpf_shape
, c
=color
, label
="LPF shape")
196 xvline_lpf_cutoff
= ax
.axvline(x
=center
, linestyle
="--", c
=color
) # LPF cutoff freq
198 return plot_slpf_shape
, xvline_lpf_cutoff
200 def create_slider(self
, name
, rect
, max, value
, color
, callback
):
202 ax_slider
= self
.fig
.add_axes(rect
, facecolor
='lightgoldenrodyellow')
203 slider
= Slider(ax_slider
, name
, 0, max, valinit
=np
.sqrt(max * value
), valstep
=1, color
=color
)
204 slider
.valtext
.set_text(value
)
206 # slider.drawon = False
208 def changed(val
, cbk
, max, slider
):
209 # non linear slider to better control small values
210 val
= int(val
** 2 / max)
211 slider
.valtext
.set_text(val
)
214 slider
.on_changed(lambda val
, cbk
=callback
, max=max, s
=slider
: changed(val
, cbk
, max, s
))
215 sliders
.append(slider
)
217 def delay_update(self
, update_cbk
):
218 def _delayed_update(self
, cbk
):
222 # delay actual filtering
226 self
.timer
= self
.fig
.canvas
.new_timer(interval
=self
.FILTER_DEBOUNCE
)
227 self
.timer
.add_callback(lambda self
=self
: _delayed_update(self
, update_cbk
))
230 def update_filter_shape(self
, filter, shape
, center_line
):
231 x_data
, new_shape
= self
.get_filter_shape(filter)
233 shape
.set_ydata(new_shape
)
234 center_line
.set_xdata(filter.get_center_freq())
236 self
.updated_artists
.extend([
241 def update_signal_and_fft_plot(self
, filters_key
, time_list
, sample_lists
, signal_shapes
, fft_shapes
, shape
,
242 center_line
, sample_rate
):
243 # print("update_signal_and_fft_plot", self.filters[filters_key][0].get_center_freq())
244 Xs
, Ys
, Zs
= sample_lists
245 signal_shape_x
, signal_shape_y
, signal_shape_z
= signal_shapes
246 fft_shape_x
, fft_shape_y
, fft_shape_z
= fft_shapes
248 Xs_filtered
= self
.test_filters(self
.filters
[filters_key
], time_list
, Xs
)
249 Ys_filtered
= self
.test_filters(self
.filters
[filters_key
], time_list
, Ys
)
250 Zs_filtered
= self
.test_filters(self
.filters
[filters_key
], time_list
, Zs
)
252 signal_shape_x
.set_ydata(Xs_filtered
)
253 signal_shape_y
.set_ydata(Ys_filtered
)
254 signal_shape_z
.set_ydata(Zs_filtered
)
256 self
.updated_artists
.extend([signal_shape_x
, signal_shape_y
, signal_shape_z
])
258 _freqs_x
, _times_x
, _stft_x
= signal
.stft(Xs_filtered
, sample_rate
, window
='hann', nperseg
=self
.FFT_N
)
259 filtered_fft_x
= np
.average(np
.abs(_stft_x
), axis
=1)
261 _freqs_y
, _times_y
, _stft_y
= signal
.stft(Ys_filtered
, sample_rate
, window
='hann', nperseg
=self
.FFT_N
)
262 filtered_fft_y
= np
.average(np
.abs(_stft_y
), axis
=1)
264 _freqs_z
, _times_z
, _stft_z
= signal
.stft(Zs_filtered
, sample_rate
, window
='hann', nperseg
=self
.FFT_N
)
265 filtered_fft_z
= np
.average(np
.abs(_stft_z
), axis
=1)
267 fft_shape_x
.set_ydata(filtered_fft_x
)
268 fft_shape_y
.set_ydata(filtered_fft_y
)
269 fft_shape_z
.set_ydata(filtered_fft_z
)
271 self
.updated_artists
.extend([
272 fft_shape_x
, fft_shape_y
, fft_shape_z
,
276 # self.fig.canvas.draw()
278 def animation_update(self
):
279 updated_artists
= self
.updated_artists
.copy()
281 # if updated_artists:
282 # print("animation update")
284 # reset updated artists
285 self
.updated_artists
= []
287 return updated_artists
289 def update_filter(self
, val
, cbk
, filter, shape
, center_line
, filters_key
, time_list
, sample_lists
, signal_shapes
,
291 # this callback sets the parameter controlled by the slider
293 # print("filter update",val)
294 # update filter shape and delay fft update
295 self
.update_filter_shape(filter, shape
, center_line
)
296 sample_freq
= filter.get_sample_freq()
298 lambda self
=self
: self
.update_signal_and_fft_plot(filters_key
, time_list
, sample_lists
, signal_shapes
,
299 fft_shapes
, shape
, center_line
, sample_freq
))
301 def create_filter_control(self
, name
, filter, rect
, max, default
, shape
, center_line
, cbk
, filters_key
, time_list
,
302 sample_lists
, signal_shapes
, fft_shapes
, filt_color
):
303 self
.create_slider(name
, rect
, max, default
, filt_color
, lambda val
, cbk
=cbk
, self
=self
, filter=filter, shape
=shape
,
304 center_line
=center_line
, filters_key
=filters_key
,
305 time_list
=time_list
, sample_list
=sample_lists
,
306 signal_shape
=signal_shapes
, fft_shape
=fft_shapes
:
307 self
.update_filter(val
, cbk
, filter, shape
, center_line
, filters_key
,
308 time_list
, sample_list
, signal_shape
, fft_shape
))
310 def create_controls(self
, filters_key
, base_rect
, padding
, ax_fft
, time_list
, sample_lists
, signal_shapes
,
312 ax_filter
= ax_fft
.twinx()
313 ax_filter
.set_navigate(False)
314 ax_filter
.set_yticks([])
316 num_filters
= len(self
.filters
[filters_key
])
318 for i
, filter in enumerate(self
.filters
[filters_key
]):
319 filt_type
= filter.get_type()
320 filt_color
= self
.filter_color_map(i
/ num_filters
)
321 filt_shape
, filt_cutoff
= self
.init_filter_shape(ax_filter
, filter, filt_color
)
323 if filt_type
== BiquadFilterType
.PEAK
:
328 # control for center freq is common to all filters
329 self
.create_filter_control("{} freq".format(name
), filter, base_rect
, 500, filter.get_center_freq(),
330 filt_shape
, filt_cutoff
,
331 lambda val
, filter=filter: filter.set_center_freq(val
),
332 filters_key
, time_list
, sample_lists
, signal_shapes
, fft_shapes
, filt_color
)
333 # move down of control height + padding
334 base_rect
[1] -= (base_rect
[3] + padding
)
336 if filt_type
== BiquadFilterType
.PEAK
:
337 self
.create_filter_control("{} att (db)".format(name
), filter, base_rect
, 100, filter.get_attenuation(),
338 filt_shape
, filt_cutoff
,
339 lambda val
, filter=filter: filter.set_attenuation(val
),
340 filters_key
, time_list
, sample_lists
, signal_shapes
, fft_shapes
, filt_color
)
341 base_rect
[1] -= (base_rect
[3] + padding
)
342 self
.create_filter_control("{} band".format(name
), filter, base_rect
, 300, filter.get_bandwidth(),
343 filt_shape
, filt_cutoff
,
344 lambda val
, filter=filter: filter.set_bandwidth(val
),
345 filters_key
, time_list
, sample_lists
, signal_shapes
, fft_shapes
, filt_color
)
346 base_rect
[1] -= (base_rect
[3] + padding
)
348 def create_spectrogram(self
, data
, name
, sample_rate
):
349 freqs
, times
, Sx
= signal
.spectrogram(np
.array(data
), fs
=sample_rate
, window
='hanning',
350 nperseg
=self
.FFT_N
, noverlap
=self
.FFT_N
- self
.FFT_N
// 10,
351 detrend
=False, scaling
='spectrum')
353 f
, ax
= plt
.subplots(figsize
=(4.8, 2.4))
354 ax
.pcolormesh(times
, freqs
, 10 * np
.log10(Sx
), cmap
='viridis')
356 ax
.set_ylabel('Frequency (Hz)')
357 ax
.set_xlabel('Time (s)')
359 def init_plot(self
, log_name
):
361 self
.fig
= plt
.figure(figsize
=(14, 9))
362 self
.fig
.canvas
.set_window_title("ArduPilot Filter Test Tool - {}".format(log_name
))
363 self
.fig
.canvas
.draw()
368 fft_acc_index
= raw_acc_index
+ 1
369 raw_gyr_index
= cols
+ 1
370 fft_gyr_index
= raw_gyr_index
+ 1
373 self
.ax_acc
= self
.fig
.add_subplot(rows
, cols
, raw_acc_index
)
374 self
.ax_gyr
= self
.fig
.add_subplot(rows
, cols
, raw_gyr_index
, sharex
=self
.ax_acc
)
376 accx_filtered
, accy_filtered
, accz_filtered
= self
.test_acc_filters()
377 self
.ax_filtered_accx
, self
.ax_filtered_accy
, self
.ax_filtered_accz
= self
.init_signal_plot(self
.ax_acc
,
387 gyrx_filtered
, gyry_filtered
, gyrz_filtered
= self
.test_gyr_filters()
388 self
.ax_filtered_gyrx
, self
.ax_filtered_gyry
, self
.ax_filtered_gyrz
= self
.init_signal_plot(self
.ax_gyr
,
399 self
.ax_acc_fft
= self
.fig
.add_subplot(rows
, cols
, fft_acc_index
)
400 self
.ax_gyr_fft
= self
.fig
.add_subplot(rows
, cols
, fft_gyr_index
)
402 self
.acc_filtered_fft_ax_x
, self
.acc_filtered_fft_ax_y
, self
.acc_filtered_fft_ax_z
= self
.init_fft(
403 self
.ax_acc_fft
, self
.ACC_t
, self
.ACC_x
, self
.ACC_y
, self
.ACC_z
, self
.ACC_freq
, self
.acc_dt
, accx_filtered
,
404 accy_filtered
, accz_filtered
, "AccX")
405 self
.gyr_filtered_fft_ax_x
, self
.gyr_filtered_fft_ax_y
, self
.gyr_filtered_fft_ax_z
= self
.init_fft(
406 self
.ax_gyr_fft
, self
.GYR_t
, self
.GYR_x
, self
.GYR_y
, self
.GYR_z
, self
.GYR_freq
, self
.gyr_dt
, gyrx_filtered
,
407 gyry_filtered
, gyrz_filtered
, "GyrX")
409 self
.fig
.tight_layout()
412 self
.create_controls("acc", [0.75, 0.95, 0.2, 0.02], 0.01, self
.ax_acc_fft
, self
.ACC_t
,
413 (self
.ACC_x
, self
.ACC_y
, self
.ACC_z
),
414 (self
.ax_filtered_accx
, self
.ax_filtered_accy
, self
.ax_filtered_accz
),
415 (self
.acc_filtered_fft_ax_x
, self
.acc_filtered_fft_ax_y
, self
.acc_filtered_fft_ax_z
))
416 self
.create_controls("gyr", [0.75, 0.45, 0.2, 0.02], 0.01, self
.ax_gyr_fft
, self
.GYR_t
,
417 (self
.GYR_x
, self
.GYR_y
, self
.GYR_z
),
418 (self
.ax_filtered_gyrx
, self
.ax_filtered_gyry
, self
.ax_filtered_gyrz
),
419 (self
.gyr_filtered_fft_ax_x
, self
.gyr_filtered_fft_ax_y
, self
.gyr_filtered_fft_ax_z
))
421 # setup animation for continuous update
423 anim
= FuncAnimation(self
.fig
, lambda frame
, self
=self
: self
.animation_update(), interval
=1, blit
=False)
425 # Work in progress here...
426 # self.create_spectrogram(self.GYR_x, "GyrX", self.GYR_freq)
427 # self.create_spectrogram(gyrx_filtered, "GyrX filtered", self.GYR_freq)
428 # self.create_spectrogram(self.ACC_x, "AccX", self.ACC_freq)
429 # self.create_spectrogram(accx_filtered, "AccX filtered", self.ACC_freq)
433 self
.print_filter_param_info()
435 def print_filter_param_info(self
):
436 if len(self
.filters
["acc"]) > 2 or len(self
.filters
["gyr"]) > 2:
437 print("Testing too many filters unsupported from firmware, cannot calculate parameters to set them")
440 print("To have the last filter settings in the graphs set the following parameters:\n")
442 for f
in self
.filters
["acc"]:
443 filt_type
= f
.get_type()
445 if filt_type
== BiquadFilterType
.PEAK
: # NOTCH
446 print("INS_NOTCA_ENABLE,", 1)
447 print("INS_NOTCA_FREQ,", f
.get_center_freq())
448 print("INS_NOTCA_BW,", f
.get_bandwidth())
449 print("INS_NOTCA_ATT,", f
.get_attenuation())
451 print("INS_ACCEL_FILTER,", f
.get_center_freq())
453 for f
in self
.filters
["gyr"]:
454 filt_type
= f
.get_type()
456 if filt_type
== BiquadFilterType
.PEAK
: # NOTCH
457 print("INS_HNTC2_ENABLE,", 1)
458 print("INS_HNTC2_FREQ,", f
.get_center_freq())
459 print("INS_HNTC2_BW,", f
.get_bandwidth())
460 print("INS_HNTC2_ATT,", f
.get_attenuation())
462 print("INS_GYRO_FILTER,", f
.get_center_freq())
464 print("\n+---------+")
467 print("Always check the onboard FFT to setup filters, this tool only simulate effects of filtering.")