diff --git a/src/osekit/public_api/analysis.py b/src/osekit/public_api/analysis.py index 25344b0b..7a96f8d9 100644 --- a/src/osekit/public_api/analysis.py +++ b/src/osekit/public_api/analysis.py @@ -3,12 +3,14 @@ from enum import Flag, auto from typing import TYPE_CHECKING, Literal -from osekit.core_api.frequency_scale import Scale +from scipy.signal import ShortTimeFFT + from osekit.utils.audio_utils import Normalization if TYPE_CHECKING: from pandas import Timedelta, Timestamp - from scipy.signal import ShortTimeFFT + + from osekit.core_api.frequency_scale import Scale class AnalysisType(Flag): @@ -78,6 +80,8 @@ def __init__( colormap: str | None = None, scale: Scale | None = None, nb_ltas_time_bins: int | None = None, + zoom_levels: list[int] | None = None, + zoomed_fft: list[ShortTimeFFT] | None = None, ) -> None: """Initialize an Analysis object. @@ -141,6 +145,19 @@ def __init__( If None, the spectrogram will be computed regularly. If specified, the spectrogram will be computed as LTAS, with the value representing the maximum number of averaged time bins. + zoom_levels: list[int] | None + If specified, additional analyses datasets will be created at the requested + zoom levels. + e.g. with a data_duration of 10s and zoom_levels = [2,4], 3 SpectroDatasets + will be created, with data_duration = 5s and 2.5s. + This will only affect spectral exports, and if AnalysisType.AUDIO is + included in the analysis, zoomed SpectroDatasets will be linked to the + x1 zoom SpectroData. + zoomed_fft: list[ShortTimeFFT | None] + FFT to use for computing the zoomed spectra. + By default, SpectroDatasets with a zoomed factor z will use the + same FFT as the z=1 SpectroDataset, but with a hop that is + divided by z. """ self.analysis_type = analysis_type @@ -153,16 +170,22 @@ def __init__( self.name = name self.normalization = normalization self.subtype = subtype - self.fft = fft self.v_lim = v_lim self.colormap = colormap self.scale = scale self.nb_ltas_time_bins = nb_ltas_time_bins if self.is_spectro and fft is None: - raise ValueError( - "FFT parameter should be given if spectra outputs are selected.", - ) + msg = "FFT parameter should be given if spectra outputs are selected." + raise ValueError(msg) + + self.fft = fft + self.zoom_levels = list({1, *zoom_levels}) if zoom_levels else None + self.zoomed_fft = ( + zoomed_fft + if zoomed_fft + else self._get_zoomed_ffts(x1_fft=fft, zoom_levels=self.zoom_levels) + ) @property def is_spectro(self) -> bool: @@ -175,3 +198,42 @@ def is_spectro(self) -> bool: AnalysisType.WELCH, ) ) + + @staticmethod + def _get_zoomed_ffts( + x1_fft: ShortTimeFFT, + zoom_levels: list[int] | None, + ) -> list[ShortTimeFFT]: + """Compute the default FFTs to use for computing the zoomed spectra. + + By default, SpectroDatasets with a zoomed factor z will use the + same FFT as the z=1 SpectroDataset, but with a hop that is + divided by z. + + Parameters + ---------- + x1_fft: ShortTimeFFT + FFT used for computing the unzoomed spectra. + zoom_levels: list[int] | None + Additional zoom levels used for computing the spectra. + + Returns + ------- + list[ShortTimeFFT] + FFTs used for computing the zoomed spectra. + + """ + if not zoom_levels: + return [] + zoomed_ffts = [] + for zoom_level in zoom_levels: + if zoom_level == 1: + continue + zoomed_ffts.append( + ShortTimeFFT( + win=x1_fft.win, + hop=x1_fft.hop // zoom_level, + fs=x1_fft.fs, + ), + ) + return zoomed_ffts diff --git a/tests/test_public_api.py b/tests/test_public_api.py index 1add226e..70c72e1a 100644 --- a/tests/test_public_api.py +++ b/tests/test_public_api.py @@ -1411,3 +1411,61 @@ def test_spectro_analysis_with_existing_ads( assert ad.begin == sd.begin assert ad.end == sd.end assert sd.audio_data == ad + + +@pytest.mark.parametrize( + ("fft", "zoomed_levels", "expected"), + [ + pytest.param( + ShortTimeFFT(hamming(1024), hop=1024, fs=24_000), + None, + [], + id="no_zoom", + ), + pytest.param( + ShortTimeFFT(hamming(1024), hop=1024, fs=24_000), + [1], + [], + id="x1_zoom_only_equals_no_zoom", + ), + pytest.param( + ShortTimeFFT(hamming(1024), hop=1024, fs=24_000), + [2], + [ + ShortTimeFFT(hamming(1024), hop=512, fs=24_000), + ], + id="x2_zoom_only", + ), + pytest.param( + ShortTimeFFT(hamming(1024), hop=1024, fs=24_000), + [2, 4, 8], + [ + ShortTimeFFT(hamming(1024), hop=512, fs=24_000), + ShortTimeFFT(hamming(1024), hop=256, fs=24_000), + ShortTimeFFT(hamming(1024), hop=128, fs=24_000), + ], + id="multiple_zoom_levels", + ), + pytest.param( + ShortTimeFFT(hamming(1024), hop=1024, fs=24_000), + [3], + [ + ShortTimeFFT(hamming(1024), hop=341, fs=24_000), + ], + id="hop_is_rounded_down", + ), + ], +) +def test_default_zoomed_sft( + fft: ShortTimeFFT, + zoomed_levels: list[int] | None, + expected: list[ShortTimeFFT], +) -> None: + for sft, expected_sft in zip( + Analysis._get_zoomed_ffts(fft, zoomed_levels), + expected, + strict=True, + ): + assert np.array_equal(sft.win, expected_sft.win) + assert sft.hop == expected_sft.hop + assert sft.fs == expected_sft.fs