-
-
Notifications
You must be signed in to change notification settings - Fork 20
Add stateful block-wise processing to OctaveFilterBank and WeightingFilter #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
53b7b0f
3175a41
986d17e
b483c21
6886891
5d0110b
972319e
4422a9e
18c6f48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import warnings | ||
| from typing import List, Tuple, cast, overload, Literal | ||
|
|
||
| import numpy as np | ||
|
|
@@ -34,6 +35,9 @@ def __init__( | |
| plot_file: str | None = None, | ||
| calibration_factor: float = 1.0, | ||
| dbfs: bool = False, | ||
| stateful: bool = False, | ||
| steady_ic: bool = False, | ||
| resample: bool = True, | ||
| ) -> None: | ||
| """ | ||
| Initialize the Octave Filter Bank. | ||
|
|
@@ -49,6 +53,9 @@ def __init__( | |
| :param plot_file: Path to save the filter response plot. | ||
| :param calibration_factor: Calibration factor for SPL calculation. | ||
| :param dbfs: If True, calculate SPL in dBFS. | ||
| :param stateful: If True, carry filter state between calls. Useful for block processing. | ||
| :param steady_ic: If True, calculate steady state initial conditions for filter. | ||
| :param resample: If True, resampling is performed. | ||
| """ | ||
| if fs <= 0: | ||
| raise ValueError("Sample rate 'fs' must be positive.") | ||
|
|
@@ -69,6 +76,10 @@ def __init__( | |
| if filter_type not in valid_filters: | ||
| raise ValueError(f"Invalid filter_type. Must be one of {valid_filters}") | ||
|
|
||
| if resample and stateful: | ||
| raise ValueError("Resampling and stateful behaviour (block processing) are not supported.") | ||
| # a stateful resampling algorithm would be required... | ||
|
|
||
| self.fs = fs | ||
| self.fraction = fraction | ||
| self.order = order | ||
|
|
@@ -78,18 +89,39 @@ def __init__( | |
| self.attenuation = attenuation | ||
| self.calibration_factor = calibration_factor | ||
| self.dbfs = dbfs | ||
| self.stateful = stateful | ||
|
|
||
| # Generate frequencies | ||
| self.freq, self.freq_d, self.freq_u = _genfreqs(limits, fraction, fs) | ||
| self.num_bands = len(self.freq) | ||
|
|
||
|
|
||
| # Calculate factors and design SOS | ||
| self.factor = _downsamplingfactor(self.freq_u, fs) | ||
| if resample: | ||
| self.factor = _downsamplingfactor(self.freq_u, fs) | ||
| else: | ||
| self.factor = np.ones(self.num_bands, dtype=int) | ||
|
|
||
| self.sos = _design_sos_filter( | ||
| self.freq, self.freq_d, self.freq_u, fs, order, self.factor, | ||
| filter_type, ripple, attenuation, show, plot_file | ||
| ) | ||
|
|
||
| # Calculate initial conditions for filter state | ||
| if self.stateful: | ||
| self._init_filter_state(steady_ic) | ||
|
|
||
|
|
||
| def _init_filter_state(self, steady_ic: bool) -> None: | ||
| """Initialize filter state (zi) for stateful block-wise processing. | ||
|
|
||
| Uses lazy initialization: zi arrays are allocated on first use in | ||
| _filter_and_resample() so the channel count matches the actual input. | ||
| """ | ||
| self.zi: List[np.ndarray] = [np.array([]) for _ in range(self.num_bands)] | ||
| self._steady_ic = steady_ic | ||
|
|
||
|
|
||
|
sourcery-ai[bot] marked this conversation as resolved.
|
||
| def __repr__(self) -> str: | ||
| return ( | ||
| f"OctaveFilterBank(fs={self.fs}, fraction={self.fraction}, order={self.order}, " | ||
|
|
@@ -115,20 +147,45 @@ def filter( | |
| detrend: bool = True | ||
| ) -> Tuple[np.ndarray, List[float], List[np.ndarray]]: ... | ||
|
|
||
| # New overloads with calculate_level | ||
| @overload | ||
| def filter( | ||
| self, | ||
| x: List[float] | np.ndarray, | ||
| sigbands: Literal[False] = False, | ||
| mode: str = "rms", | ||
| detrend: bool = True, | ||
| calculate_level: Literal[False] = False | ||
| ) -> Tuple[None, List[float]]: | ||
| ... | ||
|
|
||
| @overload | ||
| def filter( | ||
| self, | ||
| x: List[float] | np.ndarray, | ||
| sigbands: Literal[True], | ||
| mode: str = "rms", | ||
| detrend: bool = True, | ||
| calculate_level: Literal[False] = False | ||
| ) -> Tuple[None, List[float], List[np.ndarray]]: | ||
| ... | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| def filter( | ||
| self, | ||
| x: List[float] | np.ndarray, | ||
| sigbands: bool = False, | ||
| mode: str = "rms", | ||
| detrend: bool = True | ||
| ) -> Tuple[np.ndarray, List[float]] | Tuple[np.ndarray, List[float], List[np.ndarray]]: | ||
| detrend: bool = True, | ||
| calculate_level: bool =True | ||
|
sourcery-ai[bot] marked this conversation as resolved.
|
||
| ) -> Tuple[np.ndarray | None, List[float]] | Tuple[np.ndarray | None, List[float], List[np.ndarray]]: | ||
| """ | ||
| Apply the pre-designed filter bank to a signal. | ||
|
|
||
| :param x: Input signal (1D array or 2D array [channels, samples]). | ||
| :param sigbands: If True, also return the signal in the time domain divided into bands. | ||
| :param mode: 'rms' for energy-based level, 'peak' for peak-holding level. | ||
| :param detrend: If True, remove DC offset from signal before filtering (Default: True). | ||
| :param calculate_level: If True, calculate SPL | ||
| :return: A tuple containing (SPL_array, Frequencies_list) or (SPL_array, Frequencies_list, signals). | ||
| """ | ||
|
|
||
|
|
@@ -137,6 +194,13 @@ def filter( | |
|
|
||
| # Handle DC offset removal | ||
| if detrend: | ||
| if self.stateful: | ||
| warnings.warn( | ||
| "Detrending is not recommended during block processing " | ||
| "as it can introduce discontinuities between blocks.", | ||
| UserWarning, | ||
| stacklevel=2, | ||
| ) | ||
| # Axis -1 handles both 1D and 2D arrays correctly | ||
| x_proc = signal.detrend(x_proc, axis=-1, type='constant') | ||
|
|
||
|
|
@@ -148,11 +212,12 @@ def filter( | |
| num_channels = x_proc.shape[0] | ||
|
|
||
| # Process signal across all bands and channels | ||
| spl, xb = self._process_bands(x_proc, num_channels, sigbands, mode=mode) | ||
| spl, xb = self._process_bands(x_proc, num_channels, sigbands, mode=mode, calculate_level=calculate_level) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🐛 Bug: When This affects any user calling with Fix: Guard with a None check on line 209: if not is_multichannel:
if spl is not None:
spl = spl[0]
if sigbands and xb is not None:
xb = [band[0] for band in xb] |
||
|
|
||
| # Format output based on input dimensionality | ||
| if not is_multichannel: | ||
| spl = spl[0] | ||
| if spl is not None: | ||
| spl = spl[0] | ||
| if sigbands and xb is not None: | ||
| xb = [band[0] for band in xb] | ||
|
|
||
|
|
@@ -166,45 +231,66 @@ def _process_bands( | |
| x_proc: np.ndarray, | ||
| num_channels: int, | ||
| sigbands: bool, | ||
| mode: str = "rms" | ||
| ) -> Tuple[np.ndarray, List[np.ndarray] | None]: | ||
| mode: str = "rms", | ||
| calculate_level: bool = True | ||
| ) -> Tuple[np.ndarray | None, List[np.ndarray] | None]: | ||
| """ | ||
| Process signal through each frequency band. | ||
|
|
||
| :param x_proc: Standardized 2D input signal [channels, samples]. | ||
| :param num_channels: Number of channels. | ||
| :param sigbands: If True, return filtered bands. | ||
| :param mode: 'rms' or 'peak'. | ||
| :param calculate_level: If True, calculate SPL | ||
| :return: A tuple containing (SPL_array, Optional_List_of_filtered_signals). | ||
| """ | ||
| spl = np.zeros([num_channels, self.num_bands]) | ||
| if calculate_level: | ||
| spl = np.zeros([num_channels, self.num_bands]) | ||
| else: | ||
| spl = None | ||
| xb: List[np.ndarray] | None = [np.array([]) for _ in range(self.num_bands)] if sigbands else None | ||
|
|
||
| for idx in range(self.num_bands): | ||
| # Vectorized processing for all channels | ||
| filtered_signal = self._filter_and_resample(x_proc, idx) | ||
|
|
||
| # Sound Level Calculation (returns array of shape [num_channels]) | ||
| spl[:, idx] = self._calculate_level(filtered_signal, mode) | ||
| if calculate_level and spl is not None: | ||
| # Sound Level Calculation (returns array of shape [num_channels]) | ||
| spl[:, idx] = self._calculate_level(filtered_signal, mode) | ||
|
|
||
| if sigbands and xb is not None: | ||
| # Restore original length | ||
| # filtered_signal is [channels, downsampled_samples] | ||
| y_resampled = _resample_to_length(filtered_signal, int(self.factor[idx]), x_proc.shape[1]) | ||
| xb[idx] = y_resampled | ||
|
|
||
| return spl, xb | ||
|
|
||
|
|
||
| def _filter_and_resample(self, x: np.ndarray, idx: int) -> np.ndarray: | ||
| """Resample and filter for a specific band (vectorized).""" | ||
| if self.factor[idx] > 1: | ||
| # axis=-1 is default for resample_poly, but being explicit is good | ||
| sd = signal.resample_poly(x, 1, self.factor[idx], axis=-1) | ||
| else: | ||
| sd = x | ||
|
|
||
|
|
||
| if self.stateful: | ||
| n_channels = sd.shape[0] | ||
| # Lazy init: allocate zi with correct channel count on first use | ||
| if self.zi[idx].ndim < 3 or self.zi[idx].shape[1] != n_channels: | ||
| n_sections = self.sos[idx].shape[0] | ||
| if not self._steady_ic: | ||
| self.zi[idx] = np.zeros((n_sections, n_channels, 2)) | ||
| else: | ||
| zi_base = signal.sosfilt_zi(self.sos[idx]) | ||
| self.zi[idx] = np.tile(zi_base[:, np.newaxis, :], (1, n_channels, 1)) | ||
| y, self.zi[idx] = signal.sosfilt(self.sos[idx], sd, axis=-1, zi=self.zi[idx]) | ||
| else: | ||
| y = signal.sosfilt(self.sos[idx], sd, axis=-1) | ||
|
|
||
| # sosfilt supports axis=-1 by default | ||
| return cast(np.ndarray, signal.sosfilt(self.sos[idx], sd, axis=-1)) | ||
| return cast(np.ndarray, y) | ||
|
|
||
| def _calculate_level(self, y: np.ndarray, mode: str) -> float | np.ndarray: | ||
| """Calculate the level (RMS or Peak) in dB.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,21 +21,27 @@ class WeightingFilter: | |
| Allows pre-calculating and reusing filter coefficients. | ||
| """ | ||
|
|
||
| def __init__(self, fs: int, curve: str = "A") -> None: | ||
| def __init__(self, fs: int, curve: str = "A", | ||
| stateful: bool = False, steady_ic: bool = False) -> None: | ||
| """ | ||
| Initialize the weighting filter. | ||
|
|
||
| :param fs: Sample rate in Hz. | ||
| :param curve: 'A', 'C' or 'Z'. | ||
| :param stateful: If True, the weighting filter is stateful. Useful for block processing. | ||
| :param steady_ic: If True, calculate steady state initial conditions for filter. | ||
| """ | ||
| if fs <= 0: | ||
| raise ValueError("Sample rate 'fs' must be positive.") | ||
|
|
||
| self.fs = fs | ||
| self.curve = curve.upper() | ||
| self.stateful = stateful | ||
|
|
||
| if self.curve == "Z": | ||
| self.sos = np.array([]) | ||
| if self.stateful: | ||
| self.zi = np.array([]) | ||
| return | ||
|
|
||
| if self.curve not in ["A", "C"]: | ||
|
Comment on lines
+43
to
47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Stateful
|
||
|
|
@@ -78,6 +84,13 @@ def __init__(self, fs: int, curve: str = "A") -> None: | |
| zd, pd, kd = signal.bilinear_zpk(z, p, k, fs) | ||
| self.sos = signal.zpk2sos(zd, pd, kd) | ||
|
|
||
| # Calculate initial conditions for filter state | ||
| if self.stateful: | ||
| if not steady_ic: | ||
| self.zi = np.zeros((self.sos.shape[0], 2)) | ||
| else: | ||
| self.zi = signal.sosfilt_zi(self.sos) | ||
|
|
||
| def filter(self, x: List[float] | np.ndarray) -> np.ndarray: | ||
| """ | ||
| Apply the weighting filter to a signal. | ||
|
|
@@ -88,7 +101,13 @@ def filter(self, x: List[float] | np.ndarray) -> np.ndarray: | |
| x_proc = _typesignal(x) | ||
| if self.curve == "Z": | ||
| return x_proc | ||
| return cast(np.ndarray, signal.sosfilt(self.sos, x_proc, axis=-1)) | ||
|
|
||
| if self.stateful: | ||
| y, self.zi = signal.sosfilt(self.sos, x_proc, axis=-1, zi=self.zi) | ||
| else: | ||
| y = signal.sosfilt(self.sos, x_proc, axis=-1) | ||
|
|
||
| return cast(np.ndarray, y) | ||
|
|
||
|
|
||
| def weighting_filter(x: List[float] | np.ndarray, fs: int, curve: str = "A") -> np.ndarray: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -148,3 +148,11 @@ def test_calculate_level_invalid(): | |
| # This should hit core.py:218 | ||
| with pytest.raises(ValueError): | ||
| bank._calculate_level(np.array([1.0]), "invalid_mode") | ||
|
|
||
| def test_dont_calculate_level(): | ||
| from pyoctaveband.core import OctaveFilterBank | ||
| bank = OctaveFilterBank(48000) | ||
| x = np.zeros((bank.num_bands, 100)) | ||
| spl, y = bank._process_bands(x, num_channels=bank.num_bands, calculate_level=False, sigbands=True) | ||
| assert spl is None | ||
| assert y is not None | ||
|
Comment on lines
+152
to
+158
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Add tests for the public calculate_level=False API and strengthen assertions on band outputs This currently only exercises the private
It would also help to strengthen the assertions on |
||
Uh oh!
There was an error while loading. Please reload this page.