diff --git a/rfg_adc_plotter/gui/pyqtgraph_backend.py b/rfg_adc_plotter/gui/pyqtgraph_backend.py index 24c6ce1..bc23802 100644 --- a/rfg_adc_plotter/gui/pyqtgraph_backend.py +++ b/rfg_adc_plotter/gui/pyqtgraph_backend.py @@ -649,31 +649,41 @@ def decimate_bscan_rows_for_display( data: np.ndarray, *, max_points: int = BSCAN_MAX_POINTS, -) -> Tuple[Optional[np.ndarray], np.ndarray]: + return_indices: bool = False, +) -> ( + Tuple[Optional[np.ndarray], np.ndarray] + | Tuple[Optional[np.ndarray], np.ndarray, np.ndarray] +): """Reduce B-scan rows to keep waterfall rendering responsive.""" data_arr = np.asarray(data, dtype=np.float32) if data_arr.ndim != 2: - return axis, data_arr + if data_arr.ndim > 0: + row_idx = np.arange(data_arr.shape[0], dtype=np.int64) + else: + row_idx = np.zeros((0,), dtype=np.int64) + return (axis, data_arr, row_idx) if return_indices else (axis, data_arr) row_count = int(data_arr.shape[0]) limit = max(1, int(max_points)) if row_count <= limit: - return axis, data_arr + row_idx = np.arange(row_count, dtype=np.int64) + return (axis, data_arr, row_idx) if return_indices else (axis, data_arr) row_idx = np.linspace(0, row_count - 1, limit, dtype=np.int64) row_idx = np.unique(row_idx) decimated = data_arr[row_idx, :] if axis is None: - return None, decimated + return (None, decimated, row_idx) if return_indices else (None, decimated) axis_arr = np.asarray(axis, dtype=np.float64).reshape(-1) if axis_arr.size <= 0: - return None, decimated + return (None, decimated, row_idx) if return_indices else (None, decimated) take = min(axis_arr.size, row_count) axis_arr = axis_arr[:take] valid_idx = row_idx[row_idx < axis_arr.size] if valid_idx.size != row_idx.size: decimated = data_arr[valid_idx, :] - return axis_arr[valid_idx], decimated + row_idx = valid_idx + return (axis_arr[valid_idx], decimated, row_idx) if return_indices else (axis_arr[valid_idx], decimated) def coalesce_packets_for_ui( @@ -3272,21 +3282,39 @@ def run_pyqtgraph(args) -> None: else: disp_fft_lin = runtime.ring.get_display_fft_linear() disp_fft_axis = runtime.ring.distance_axis + active_background = None + try: + active_background = resolve_active_background(disp_fft_lin.shape[0]) + except Exception: + active_background = None if disp_fft_axis is not None: axis_arr = np.asarray(disp_fft_axis, dtype=np.float64).reshape(-1) row_take = min(axis_arr.size, disp_fft_lin.shape[0]) axis_arr = axis_arr[:row_take] disp_fft_lin = disp_fft_lin[:row_take, :] + if active_background is not None: + active_background = active_background[:row_take] fft_cut_start = _active_distance_cut_start() axis_arr, keep_mask = apply_distance_cut_to_axis(axis_arr, fft_cut_start) if keep_mask.size > 0: disp_fft_lin = disp_fft_lin[keep_mask, :] + if active_background is not None and active_background.size == keep_mask.size: + active_background = active_background[keep_mask] disp_fft_axis = axis_arr - disp_fft_axis, disp_fft_lin = decimate_bscan_rows_for_display( + disp_fft_axis, disp_fft_lin, display_row_idx = decimate_bscan_rows_for_display( disp_fft_axis, disp_fft_lin, max_points=BSCAN_MAX_POINTS, + return_indices=True, ) + if active_background is not None: + if ( + active_background.size >= display_row_idx.size + and np.all(display_row_idx < active_background.size) + ): + active_background = active_background[display_row_idx] + else: + active_background = None if spec_mean_sec > 0.0: disp_times = runtime.ring.get_display_times() if disp_times is not None: @@ -3300,11 +3328,6 @@ def run_pyqtgraph(args) -> None: except Exception: pass - active_background = None - try: - active_background = resolve_active_background(disp_fft_lin.shape[0]) - except Exception: - active_background = None if active_background is not None: try: disp_fft_lin = subtract_fft_background(disp_fft_lin, active_background) diff --git a/tests/test_processing.py b/tests/test_processing.py index 45910a9..5493f4d 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -454,6 +454,27 @@ class ProcessingTests(unittest.TestCase): self.assertIsNone(dec_axis) self.assertEqual(dec_data.shape, (3, 4)) + def test_bscan_background_profile_tracks_decimated_rows(self): + rows = (FFT_LEN // 2) + 1 + axis = np.linspace(0.0, 10.0, rows, dtype=np.float64) + background = np.linspace(1.0, 2.0, rows, dtype=np.float32) + residual = np.linspace(0.1, 0.4, rows, dtype=np.float32) + data = background[:, None] + residual[:, None] + + dec_axis, dec_data, row_idx = decimate_bscan_rows_for_display( + axis, + data, + max_points=512, + return_indices=True, + ) + dec_background = background[row_idx] + subtracted = subtract_fft_background(dec_data, dec_background) + + self.assertEqual(dec_axis.shape, (512,)) + self.assertEqual(dec_data.shape, (512, 1)) + self.assertEqual(row_idx.shape, (512,)) + self.assertTrue(np.allclose(subtracted[:, 0], residual[row_idx], atol=1e-6)) + def test_update_expected_sweep_width_initializes_from_first_valid_sweep(self): self.assertEqual(update_expected_sweep_width(0, 411), 411)