diff --git a/python/src/carfac/jax/carfac.py b/python/src/carfac/jax/carfac.py index 65294a6a..2c582cc3 100644 --- a/python/src/carfac/jax/carfac.py +++ b/python/src/carfac/jax/carfac.py @@ -944,6 +944,7 @@ def tree_unflatten(cls, _, children): class EarHypers: """Hyperparameters (tagged as static in `jax.jit`) of 1 ear.""" + input_scale_dbspl: float n_ch: int pole_freqs: jnp.ndarray max_channels_per_octave: float @@ -956,6 +957,7 @@ class EarHypers: # Reference: https://jax.readthedocs.io/en/latest/pytrees.html def tree_flatten(self): # pylint: disable=missing-function-docstring children = ( + self.input_scale_dbspl, self.n_ch, self.pole_freqs, self.max_channels_per_octave, @@ -965,6 +967,7 @@ def tree_flatten(self): # pylint: disable=missing-function-docstring self.syn, ) aux_data = ( + 'input_scale_dbspl', 'n_ch', 'pole_freqs', 'max_channels_per_octave', @@ -986,6 +989,7 @@ def _eq_key(self): # assigned to a different array with exactly the same value. We think such # case should be very rare in usage. return ( + self.input_scale_dbspl, self.n_ch, id(self.pole_freqs), self.max_channels_per_octave, @@ -1055,20 +1059,24 @@ class CarfacDesignParameters: """All the parameters set manually for designing CARFAC.""" fs: float = 22050.0 + input_scale_dbspl: float = 94 ears: List[EarDesignParameters] = dataclasses.field( default_factory=lambda: [EarDesignParameters()] ) - def __init__(self, fs=22050.0, n_ears=1, use_delay_buffer=False): + def __init__(self, fs=22050.0, input_scale_dbspl=94, n_ears=1, use_delay_buffer=False): """Initialize the Design Parameters dataclass. Args: fs: Samples per second. + input_scale_dbspl: scale in dB SPL for input waves (default: 94). The default value expects input in + pascals i.e. 94 dB SPL (for RMS=1), while CARFAC v1 and v2 use an input scale of 107 dB SPL (for RMS=1) n_ears: Number of ears to design for. use_delay_buffer: Whether to use the delay buffer implementation for the car_step. """ self.fs = fs + self.input_scale_dbspl = input_scale_dbspl self.ears = [ EarDesignParameters( car=CarDesignParameters(use_delay_buffer=use_delay_buffer) @@ -1763,6 +1771,7 @@ def design_and_init_carfac( state = CarfacState() for ear, ear_params in enumerate(params.ears): + input_scale_dbspl = params.input_scale_dbspl # first figure out how many filter stages (PZFC/CARFAC channels): pole_hz = ear_params.car.first_pole_theta * params.fs / (2 * math.pi) n_ch = 0 @@ -1798,6 +1807,7 @@ def design_and_init_carfac( ) ear_hypers = EarHypers( + input_scale_dbspl=input_scale_dbspl, n_ch=n_ch, pole_freqs=pole_freqs, max_channels_per_octave=max_channels_per_octave, @@ -2414,13 +2424,16 @@ def run_segment( the input_waves are assumed to be sampled at the same rate as the CARFAC is designed for; a resampling may be needed before calling this. + input_waves are considered as being in pascals i.e. their level is + 94db SPL when their RMS equals 1. + The function works as an outer iteration on time, updating all the filters and AGC states concurrently, so that the different channels can interact easily. The inner loops are over filterbank channels, and this level should be kept efficient. Args: - input_waves: the audio input. + input_waves: the audio input in pascals with default input_scale_dbspl. hypers: all the coefficients of the model. It will be passed to all the JIT'ed functions as static variables. weights: all the trainable weights. It will not be changed. @@ -2438,6 +2451,10 @@ def run_segment( """ if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) + + # scale input_waves from specified dB SPL @ RMS=1 to 107dB SPL @ RMS=1 + input_waves = input_waves * 10 ** ((hypers.ears[0].input_scale_dbspl-107.)/20.) + n_ears = input_waves.shape[1] n_fibertypes = SynDesignParameters.n_classes @@ -2594,7 +2611,7 @@ def run_segment_jit( the cache when needed. Args: - input_waves: the audio input. + input_waves: the audio input in pascals with default input_scale_dbspl. hypers: all the coefficients of the model. It will be passed to all the JIT'ed functions as static variables. weights: all the trainable weights. It will not be changed. @@ -2636,7 +2653,7 @@ def run_segment_jit_in_chunks_notraceable( 20 percent on a regular CPU. Args: - input_waves: The audio input. + input_waves: The audio input in pascals with defualt input_scale_dbspl. hypers: All the coefficients of the model. It will be passed to all the JIT'ed functions as static variables. weights: All the trainable weights. It will not be changed. diff --git a/python/src/carfac/jax/carfac_test.py b/python/src/carfac/jax/carfac_test.py index 145191f0..292108fd 100644 --- a/python/src/carfac/jax/carfac_test.py +++ b/python/src/carfac/jax/carfac_test.py @@ -1,3 +1,5 @@ +"""Tests for JAX carfac.""" + import collections import copy import numbers @@ -14,6 +16,7 @@ from carfac.jax import carfac as carfac_jax from carfac.np import carfac as carfac_np +# All stimuli are presented at native CARFAC v1/v2 input scaling of 107 dB SPL for RMS = 1. class CarfacJaxTest(parameterized.TestCase): @@ -101,6 +104,7 @@ def test_hypers_hash(self): hypers = carfac_jax.CarfacHypers() hypers.ears = [ carfac_jax.EarHypers( + input_scale_dbspl=94., # use default value of 94 dB SPL n_ch=0, pole_freqs=jnp.array([]), max_channels_per_octave=0.0, @@ -196,11 +200,11 @@ def container_comparison(self, left_side, right_side, exclude_keys=None): ) def test_equal_design(self, ihc_style): # Test: the designs are similar. - cfp = carfac_np.design_carfac(ihc_style=ihc_style) + cfp = carfac_np.design_carfac(input_scale_dbspl=107., ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False - params_jax = carfac_jax.CarfacDesignParameters() + params_jax = carfac_jax.CarfacDesignParameters(input_scale_dbspl=107.) params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( @@ -377,7 +381,7 @@ def test_equal_design(self, ihc_style): def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): """Tests whether `run_segment` produces the same results as np version.""" # Inits JAX version - params_jax = carfac_jax.CarfacDesignParameters() + params_jax = carfac_jax.CarfacDesignParameters(input_scale_dbspl=107.) params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( @@ -431,7 +435,7 @@ def test_equal_forward_pass( """Tests whether `run_segment` produces the same results as np version.""" # Inits JAX version params_jax = carfac_jax.CarfacDesignParameters( - n_ears=n_ears, use_delay_buffer=delay_buffer + input_scale_dbspl=107., n_ears=n_ears, use_delay_buffer=delay_buffer ) for ear in range(n_ears): params_jax.ears[ear].ihc.ihc_style = ihc_style @@ -441,7 +445,7 @@ def test_equal_forward_pass( ) # Inits numpy version cfp = carfac_np.design_carfac( - ihc_style=ihc_style, n_ears=n_ears, use_delay_buffer=delay_buffer + input_scale_dbspl=107., ihc_style=ihc_style, n_ears=n_ears, use_delay_buffer=delay_buffer ) carfac_np.carfac_init(cfp) diff --git a/python/src/carfac/np/carfac.py b/python/src/carfac/np/carfac.py index c84ddc80..e92f81e5 100644 --- a/python/src/carfac/np/carfac.py +++ b/python/src/carfac/np/carfac.py @@ -1250,6 +1250,7 @@ class CarfacCoeffs: @dataclasses.dataclass class CarfacParams: fs: float + input_scale_dbspl: float max_channels_per_octave: float car_params: CarParams agc_params: AgcParams @@ -1262,6 +1263,7 @@ class CarfacParams: def design_carfac( + input_scale_dbspl: float = 94, n_ears: int = 1, fs: float = 22050, car_params: Optional[CarParams] = None, @@ -1290,6 +1292,8 @@ def design_carfac( make 96 channels at default fs = 22050, 114 channels at 44100. Args: + input_scale_dbspl: scale in dB SPL for input waves (default: 94 dB SPL). The default value expects input in + pascals i.e. 94 dB SPL (for RMS=1), while CARFAC v1 and v2 use an input scale of 107 dB SPL (for RMS=1) n_ears: How many ears (1 or 2, in general) in the simulation fs: is sample rate (per second) car_params: bundles all the pole-zero filter cascade parameters @@ -1362,6 +1366,7 @@ def design_carfac( cfp = CarfacParams( fs, + input_scale_dbspl, max_channels_per_octave, car_params, agc_params, @@ -1600,6 +1605,9 @@ def run_segment( the input_waves are assumed to be sampled at the same rate as the CARFAC is designed for; a resampling may be needed before calling this. + input_waves are considered as being in pascals i.e. their level is + 94db SPL when their RMS equals 1. + The function works as an outer iteration on time, updating all the filters and AGC states concurrently, so that the different channels can interact easily. The inner loops are over filterbank channels, and @@ -1607,7 +1615,7 @@ def run_segment( Args: cfp: a structure that descirbes everything we know about this CARFAC. - input_waves: the audio input + input_waves: the audio input in pascals with default input_scale_dbspl. open_loop: whether to run CARFAC without the feedback. linear_car (new over Matlab): use CAR filters without OHC effects. @@ -1623,6 +1631,10 @@ def run_segment( if len(input_waves.shape) < 2: input_waves = np.reshape(input_waves, (-1, 1)) + + # scale input_waves from 94dB SPL @ RMS=1 to 107dB SPL @ RMS=1 + input_waves = input_waves * 10 ** ((cfp.input_scale_dbspl-107.)/20.) + [n_samp, n_ears] = input_waves.shape if n_ears != cfp.n_ears: diff --git a/python/src/carfac/np/carfac_test.py b/python/src/carfac/np/carfac_test.py index 17bc91db..9b4ecc95 100644 --- a/python/src/carfac/np/carfac_test.py +++ b/python/src/carfac/np/carfac_test.py @@ -1,4 +1,4 @@ -"""Tests for carfac.""" +"""Tests for Numpy carfac.""" import math from typing import List, Tuple @@ -11,7 +11,8 @@ from carfac.np import carfac # Note some of these tests create plots for easier comparison to the results -# in Dick Lyon's Human and Machine Hearing. The plots are stored in /tmp, and +# in Dick Lyon's Human and Machine Hearing. All stimuli are presented at native CARFAC v1/v2 +# input scaling of 107 dB SPL for RMS = 1. The plots are stored in /tmp, and # the easiest way to see them is to run the test on your machine; or in a # Colab such as google3/third_party/carfac/python/np/CARFAC_Testing.ipynb @@ -167,7 +168,7 @@ def test_design_fir_coeffs(self): carfac.design_fir_coeffs(5, 1, 1, 1) def test_car_freq_response(self): - cfp = carfac.design_carfac() + cfp = carfac.design_carfac(input_scale_dbspl=107.0,) carfac.carfac_init(cfp) # Show impulse response for just the CAR Filter bank. @@ -498,7 +499,7 @@ def test_spatial_smooth( def test_agc_steady_state(self): # Test: Steady state response # Analagous to figure 19.7 - cfp = carfac.design_carfac() + cfp = carfac.design_carfac(input_scale_dbspl=107.0) cf = carfac.carfac_init(cfp) test_channel = 40 @@ -559,7 +560,7 @@ def test_agc_steady_state(self): def test_stage_g_calculation(self): fs = 22050.0 - cfp = carfac.design_carfac(fs=fs) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0) # Set to true to save a large number of figures. do_plots = False # arange goes to just above 1 to ensure 1.0 is tested. @@ -600,7 +601,7 @@ def test_whole_carfac(self, ihc_style): impulse = np.zeros(t.shape) impulse[0] = 1e-4 - cfp = carfac.design_carfac(fs=fs, ihc_style=ihc_style) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, ihc_style=ihc_style) cfp = carfac.carfac_init(cfp) _, cfp, bm_initial, _, _ = carfac.run_segment( @@ -817,7 +818,7 @@ def test_delay_buffer(self): impulse = np.zeros(t.shape) impulse[0] = 1e-4 - cfp = carfac.design_carfac(fs=fs) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0) cfp = carfac.carfac_init(cfp) # Run the linear case with small impulse. _, cfp, bm_initial, _, _ = carfac.run_segment( @@ -847,7 +848,7 @@ def test_delay_buffer(self): self.assertLess(max_max_rel_error, 4e-4) # More tolerance than Matlab. Why? # Run the nonlinear case with a small impulse so not too nonlinear. - cfp = carfac.design_carfac(fs=fs) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0) cfp = carfac.carfac_init(cfp) cfp.ears[0].car_coeffs.use_delay_buffer = False _, cfp, bm_initial, _, _ = carfac.run_segment(cfp, impulse) @@ -874,7 +875,7 @@ def test_ohc_health(self): t = np.arange(0, 1, 1 / fs) # A second of noise. amplitude = 1e-4 # -80 dBFS, around 20 or 30 dB SPL noise = amplitude * np.random.randn(len(t)) - cfp = carfac.design_carfac(fs=fs) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0) cfp = carfac.carfac_init(cfp) # Run the healthy case with low-level noise. _, cfp, bm_baseline, _, _ = carfac.run_segment(cfp, noise) @@ -918,7 +919,7 @@ def test_multiaural_carfac(self): two_chan_noise = np.zeros((len(t), 2)) two_chan_noise[:, 0] = noise two_chan_noise[:, 1] = noise - cfp = carfac.design_carfac(fs=fs, n_ears=2, ihc_style='one_cap') + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, n_ears=2, ihc_style='one_cap') cfp = carfac.carfac_init(cfp) naps, _, _, _, _ = carfac.run_segment(cfp, two_chan_noise) max_abs_diff = np.amax(np.abs(naps[:, :, 0] - naps[:, :, 1])) @@ -955,9 +956,9 @@ def test_multiaural_carfac_with_silent_channel(self): two_chan_noise = np.zeros((len(t), 2)) two_chan_noise[:, 0] = c_major_chord # Leave the audio in channel 1 as silence. - cfp = carfac.design_carfac(fs=fs, n_ears=2, ihc_style='one_cap') + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, n_ears=2, ihc_style='one_cap') cfp = carfac.carfac_init(cfp) - mono_cfp = carfac.design_carfac(fs=fs, n_ears=1, ihc_style='one_cap') + mono_cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, n_ears=1, ihc_style='one_cap') mono_cfp = carfac.carfac_init(mono_cfp) _, _, bm_binaural, _, _ = carfac.run_segment(cfp, two_chan_noise)