From 2fab85576f4266db44bb13d8662b026ce8ebfb64 Mon Sep 17 00:00:00 2001 From: JasonMH17 Date: Fri, 27 Feb 2026 13:15:52 +1100 Subject: [PATCH 1/5] Updated python CARFAC to receive input in Pascals --- python/src/carfac/jax/carfac.py | 13 ++++++++++--- python/src/carfac/jax/carfac_test.py | 6 ++++-- python/src/carfac/np/carfac.py | 9 ++++++++- python/src/carfac/np/carfac_test.py | 12 ++++++------ 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/python/src/carfac/jax/carfac.py b/python/src/carfac/jax/carfac.py index 65294a6a..6aba4562 100644 --- a/python/src/carfac/jax/carfac.py +++ b/python/src/carfac/jax/carfac.py @@ -2414,13 +2414,16 @@ def run_segment( the input_waves are assumed to be sampled at the same rate as the CARFAC is designed for; a resampling may be needed before calling this. + input_waves are considered as being in Pascals and therefore should equal + 94db SPL where their RMS is 1. + The function works as an outer iteration on time, updating all the filters and AGC states concurrently, so that the different channels can interact easily. The inner loops are over filterbank channels, and this level should be kept efficient. Args: - input_waves: the audio input. + input_waves: the audio input in Pascals. hypers: all the coefficients of the model. It will be passed to all the JIT'ed functions as static variables. weights: all the trainable weights. It will not be changed. @@ -2438,6 +2441,10 @@ def run_segment( """ if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) + + # scale input_waves from 94dB SPL @ RMS=1 to 107dB SPL @ RMS=1 + input_waves = input_waves * 10 ** ((94-107)/20) + n_ears = input_waves.shape[1] n_fibertypes = SynDesignParameters.n_classes @@ -2594,7 +2601,7 @@ def run_segment_jit( the cache when needed. Args: - input_waves: the audio input. + input_waves: the audio input in Pascals. hypers: all the coefficients of the model. It will be passed to all the JIT'ed functions as static variables. weights: all the trainable weights. It will not be changed. @@ -2636,7 +2643,7 @@ def run_segment_jit_in_chunks_notraceable( 20 percent on a regular CPU. Args: - input_waves: The audio input. + input_waves: The audio input in Pascals. hypers: All the coefficients of the model. It will be passed to all the JIT'ed functions as static variables. weights: All the trainable weights. It will not be changed. diff --git a/python/src/carfac/jax/carfac_test.py b/python/src/carfac/jax/carfac_test.py index 145191f0..c6c71b50 100644 --- a/python/src/carfac/jax/carfac_test.py +++ b/python/src/carfac/jax/carfac_test.py @@ -391,7 +391,8 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): n_samp = 200 n_ears = 1 random_generator = jax.random.PRNGKey(random_seed) - run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) + run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) * 10 ** ((107-94)/20) + # scale input to 94 SPL @ RMS=1 # Copy the state first. state_jax_copied = copy.deepcopy(state_jax) @@ -454,7 +455,8 @@ def test_equal_forward_pass( # should be bigger than 64 (i.e. `prod(AgcDesignParameters.decimation)`). n_samp = 200 random_generator = jax.random.PRNGKey(random_seed) - run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) + run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) * 10 ** ((107-94)/20) + # scale input to 94 dB SPL @ RMS=1 # Only tests the JITted version because this is what we will use. naps_jax, _, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( diff --git a/python/src/carfac/np/carfac.py b/python/src/carfac/np/carfac.py index c84ddc80..e562e147 100644 --- a/python/src/carfac/np/carfac.py +++ b/python/src/carfac/np/carfac.py @@ -1600,6 +1600,9 @@ def run_segment( the input_waves are assumed to be sampled at the same rate as the CARFAC is designed for; a resampling may be needed before calling this. + input_waves are considered as being in Pascals and therefore should equal + 94db SPL where their RMS is 1. + The function works as an outer iteration on time, updating all the filters and AGC states concurrently, so that the different channels can interact easily. The inner loops are over filterbank channels, and @@ -1607,7 +1610,7 @@ def run_segment( Args: cfp: a structure that descirbes everything we know about this CARFAC. - input_waves: the audio input + input_waves: the audio input in Pascals. open_loop: whether to run CARFAC without the feedback. linear_car (new over Matlab): use CAR filters without OHC effects. @@ -1623,6 +1626,10 @@ def run_segment( if len(input_waves.shape) < 2: input_waves = np.reshape(input_waves, (-1, 1)) + + # scale input_waves from 94dB SPL @ RMS=1 to 107dB SPL @ RMS=1 + input_waves = input_waves * 10 ** ((94-107)/20) + [n_samp, n_ears] = input_waves.shape if n_ears != cfp.n_ears: diff --git a/python/src/carfac/np/carfac_test.py b/python/src/carfac/np/carfac_test.py index 17bc91db..39f20068 100644 --- a/python/src/carfac/np/carfac_test.py +++ b/python/src/carfac/np/carfac_test.py @@ -594,11 +594,11 @@ def test_whole_carfac(self, ihc_style): fs = 22050.0 fp = 1000.0 # Probe tone t = np.arange(0, 2, 1 / fs) # 2s of tone - sinusoid = 1e-1 * np.sin(2 * np.pi * t * fp) + sinusoid = 1e-1 * np.sin(2 * np.pi * t * fp) * 10 ** ((107-94)/20) # scale tone to 94dB SPL @RMS=1 t = np.arange(0, 0.5, 1 / fs) impulse = np.zeros(t.shape) - impulse[0] = 1e-4 + impulse[0] = 1e-4 * 10 ** ((107-94)/20) # scale impulse to 94dB SPL @RMS=1 cfp = carfac.design_carfac(fs=fs, ihc_style=ihc_style) cfp = carfac.carfac_init(cfp) @@ -815,7 +815,7 @@ def test_delay_buffer(self): fs = 22050.0 t = np.arange(0, 0.1, 1 / fs) # Short impulse input. impulse = np.zeros(t.shape) - impulse[0] = 1e-4 + impulse[0] = 1e-4 * 10 ** ((107-94)/20) # scale impulse to 94dB SPL @ RMS = 1 cfp = carfac.design_carfac(fs=fs) cfp = carfac.carfac_init(cfp) @@ -873,7 +873,7 @@ def test_ohc_health(self): fs = 22050.0 t = np.arange(0, 1, 1 / fs) # A second of noise. amplitude = 1e-4 # -80 dBFS, around 20 or 30 dB SPL - noise = amplitude * np.random.randn(len(t)) + noise = amplitude * np.random.randn(len(t)) * 10 ** ((107-94)/20) # scale noise to 94dB SPL @ RMS=1 cfp = carfac.design_carfac(fs=fs) cfp = carfac.carfac_init(cfp) # Run the healthy case with low-level noise. @@ -914,7 +914,7 @@ def test_multiaural_carfac(self): fs = 22050.0 t = np.arange(0, 1, 1 / fs) # A second of noise. amplitude = 1e-3 # -70 dBFS, around 30 or 40 dB SPL - noise = amplitude * np.random.randn(len(t)) + noise = amplitude * np.random.randn(len(t)) * 10 ** ((107-94)/20) # scale noise to 94 dB SPL @ RMS=1 two_chan_noise = np.zeros((len(t), 2)) two_chan_noise[:, 0] = noise two_chan_noise[:, 1] = noise @@ -950,7 +950,7 @@ def test_multiaural_carfac_with_silent_channel(self): freqs = freqs.reshape(len(freqs), 1) c_major_chord = amplitude * np.sum( np.sin(2 * np.pi * np.matmul(freqs, t_prime)), 0 - ) + ) * 10 ** ((107-94)/20) # scale chord to 94 dB SPL @ RMS=1 two_chan_noise = np.zeros((len(t), 2)) two_chan_noise[:, 0] = c_major_chord From b8710c67f1eeb7814800aa5492f778da04b22128 Mon Sep 17 00:00:00 2001 From: JasonMH17 Date: Mon, 2 Mar 2026 09:04:55 +1100 Subject: [PATCH 2/5] Added parameter in python CARFAC to control input_wave scaling. Also added appropriate change to JAX carfac_test.py --- python/src/carfac/jax/carfac.py | 16 +++++++++++++--- python/src/carfac/jax/carfac_test.py | 1 + python/src/carfac/np/carfac.py | 7 ++++++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/python/src/carfac/jax/carfac.py b/python/src/carfac/jax/carfac.py index 6aba4562..69060113 100644 --- a/python/src/carfac/jax/carfac.py +++ b/python/src/carfac/jax/carfac.py @@ -944,6 +944,7 @@ def tree_unflatten(cls, _, children): class EarHypers: """Hyperparameters (tagged as static in `jax.jit`) of 1 ear.""" + input_scale: int n_ch: int pole_freqs: jnp.ndarray max_channels_per_octave: float @@ -956,6 +957,7 @@ class EarHypers: # Reference: https://jax.readthedocs.io/en/latest/pytrees.html def tree_flatten(self): # pylint: disable=missing-function-docstring children = ( + self.input_scale, self.n_ch, self.pole_freqs, self.max_channels_per_octave, @@ -965,6 +967,7 @@ def tree_flatten(self): # pylint: disable=missing-function-docstring self.syn, ) aux_data = ( + 'input_scale', 'n_ch', 'pole_freqs', 'max_channels_per_octave', @@ -986,6 +989,7 @@ def _eq_key(self): # assigned to a different array with exactly the same value. We think such # case should be very rare in usage. return ( + self.input_scale, self.n_ch, id(self.pole_freqs), self.max_channels_per_octave, @@ -1055,20 +1059,24 @@ class CarfacDesignParameters: """All the parameters set manually for designing CARFAC.""" fs: float = 22050.0 + input_scale: int = 94 ears: List[EarDesignParameters] = dataclasses.field( default_factory=lambda: [EarDesignParameters()] ) - def __init__(self, fs=22050.0, n_ears=1, use_delay_buffer=False): + def __init__(self, fs=22050.0, input_scale=94, n_ears=1, use_delay_buffer=False): """Initialize the Design Parameters dataclass. Args: fs: Samples per second. + input_scale: scale for input waves. By default, input is expected in Pascals i.e. 94 dB SPL @ RMS=1, + while CARFAC input is considered as 107 dB SPL @ RMS=1 n_ears: Number of ears to design for. use_delay_buffer: Whether to use the delay buffer implementation for the car_step. """ self.fs = fs + self.input_scale = input_scale self.ears = [ EarDesignParameters( car=CarDesignParameters(use_delay_buffer=use_delay_buffer) @@ -1763,6 +1771,7 @@ def design_and_init_carfac( state = CarfacState() for ear, ear_params in enumerate(params.ears): + input_scale = params.input_scale # first figure out how many filter stages (PZFC/CARFAC channels): pole_hz = ear_params.car.first_pole_theta * params.fs / (2 * math.pi) n_ch = 0 @@ -1798,6 +1807,7 @@ def design_and_init_carfac( ) ear_hypers = EarHypers( + input_scale=input_scale, n_ch=n_ch, pole_freqs=pole_freqs, max_channels_per_octave=max_channels_per_octave, @@ -2442,8 +2452,8 @@ def run_segment( if len(input_waves.shape) < 2: input_waves = jnp.reshape(input_waves, (-1, 1)) - # scale input_waves from 94dB SPL @ RMS=1 to 107dB SPL @ RMS=1 - input_waves = input_waves * 10 ** ((94-107)/20) + # scale input_waves from specified dB SPL @ RMS=1 to 107dB SPL @ RMS=1 + input_waves = input_waves * 10 ** ((hypers.ears[0].input_scale-107)/20) n_ears = input_waves.shape[1] n_fibertypes = SynDesignParameters.n_classes diff --git a/python/src/carfac/jax/carfac_test.py b/python/src/carfac/jax/carfac_test.py index c6c71b50..9ae731f7 100644 --- a/python/src/carfac/jax/carfac_test.py +++ b/python/src/carfac/jax/carfac_test.py @@ -101,6 +101,7 @@ def test_hypers_hash(self): hypers = carfac_jax.CarfacHypers() hypers.ears = [ carfac_jax.EarHypers( + input_scale=94, n_ch=0, pole_freqs=jnp.array([]), max_channels_per_octave=0.0, diff --git a/python/src/carfac/np/carfac.py b/python/src/carfac/np/carfac.py index e562e147..2ffddc78 100644 --- a/python/src/carfac/np/carfac.py +++ b/python/src/carfac/np/carfac.py @@ -1250,6 +1250,7 @@ class CarfacCoeffs: @dataclasses.dataclass class CarfacParams: fs: float + input_scale: int max_channels_per_octave: float car_params: CarParams agc_params: AgcParams @@ -1262,6 +1263,7 @@ class CarfacParams: def design_carfac( + input_scale: int = 94, n_ears: int = 1, fs: float = 22050, car_params: Optional[CarParams] = None, @@ -1290,6 +1292,8 @@ def design_carfac( make 96 channels at default fs = 22050, 114 channels at 44100. Args: + input_scale: scale for input waves. By default, input is expected in Pascals i.e. 94 dB SPL @ RMS=1, + while CARFAC input is considered as 107 dB SPL @ RMS=1 n_ears: How many ears (1 or 2, in general) in the simulation fs: is sample rate (per second) car_params: bundles all the pole-zero filter cascade parameters @@ -1362,6 +1366,7 @@ def design_carfac( cfp = CarfacParams( fs, + input_scale, max_channels_per_octave, car_params, agc_params, @@ -1628,7 +1633,7 @@ def run_segment( input_waves = np.reshape(input_waves, (-1, 1)) # scale input_waves from 94dB SPL @ RMS=1 to 107dB SPL @ RMS=1 - input_waves = input_waves * 10 ** ((94-107)/20) + input_waves = input_waves * 10 ** ((cfp.input_scale-107)/20) [n_samp, n_ears] = input_waves.shape From c6dc21d90205b72fa14a2c652033fbdd060722ef Mon Sep 17 00:00:00 2001 From: JasonMH17 Date: Wed, 4 Mar 2026 09:33:32 +1100 Subject: [PATCH 3/5] Updated numpy CARFAC and its test to explain input_scale_dbspl more effectively. Also reverted to using 107 dB SPL scaling for carfac_test.py by implementing input_scale_dbspl parameter in carfac_design function --- python/src/carfac/np/carfac.py | 18 +++++++-------- python/src/carfac/np/carfac_test.py | 35 +++++++++++++++-------------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/python/src/carfac/np/carfac.py b/python/src/carfac/np/carfac.py index 2ffddc78..d527e229 100644 --- a/python/src/carfac/np/carfac.py +++ b/python/src/carfac/np/carfac.py @@ -1250,7 +1250,7 @@ class CarfacCoeffs: @dataclasses.dataclass class CarfacParams: fs: float - input_scale: int + input_scale_dbspl: float max_channels_per_octave: float car_params: CarParams agc_params: AgcParams @@ -1263,7 +1263,7 @@ class CarfacParams: def design_carfac( - input_scale: int = 94, + input_scale_dbspl: float = 94, n_ears: int = 1, fs: float = 22050, car_params: Optional[CarParams] = None, @@ -1292,8 +1292,8 @@ def design_carfac( make 96 channels at default fs = 22050, 114 channels at 44100. Args: - input_scale: scale for input waves. By default, input is expected in Pascals i.e. 94 dB SPL @ RMS=1, - while CARFAC input is considered as 107 dB SPL @ RMS=1 + input_scale_dbspl: scale in dB SPL for input waves (default: 94 dB SPL). The default value expects input in + pascals i.e. 94 dB SPL (for RMS=1), while CARFAC v1 and v2 use an input scale of 107 dB SPL (for RMS=1) n_ears: How many ears (1 or 2, in general) in the simulation fs: is sample rate (per second) car_params: bundles all the pole-zero filter cascade parameters @@ -1366,7 +1366,7 @@ def design_carfac( cfp = CarfacParams( fs, - input_scale, + input_scale_dbspl, max_channels_per_octave, car_params, agc_params, @@ -1605,8 +1605,8 @@ def run_segment( the input_waves are assumed to be sampled at the same rate as the CARFAC is designed for; a resampling may be needed before calling this. - input_waves are considered as being in Pascals and therefore should equal - 94db SPL where their RMS is 1. + input_waves are considered as being in pascals i.e. their level is + 94db SPL when their RMS equals 1. The function works as an outer iteration on time, updating all the filters and AGC states concurrently, so that the different channels can @@ -1615,7 +1615,7 @@ def run_segment( Args: cfp: a structure that descirbes everything we know about this CARFAC. - input_waves: the audio input in Pascals. + input_waves: the audio input in pascals. open_loop: whether to run CARFAC without the feedback. linear_car (new over Matlab): use CAR filters without OHC effects. @@ -1633,7 +1633,7 @@ def run_segment( input_waves = np.reshape(input_waves, (-1, 1)) # scale input_waves from 94dB SPL @ RMS=1 to 107dB SPL @ RMS=1 - input_waves = input_waves * 10 ** ((cfp.input_scale-107)/20) + input_waves = input_waves * 10 ** ((cfp.input_scale_dbspl-107.)/20.) [n_samp, n_ears] = input_waves.shape diff --git a/python/src/carfac/np/carfac_test.py b/python/src/carfac/np/carfac_test.py index 39f20068..ec00bd53 100644 --- a/python/src/carfac/np/carfac_test.py +++ b/python/src/carfac/np/carfac_test.py @@ -11,7 +11,8 @@ from carfac.np import carfac # Note some of these tests create plots for easier comparison to the results -# in Dick Lyon's Human and Machine Hearing. The plots are stored in /tmp, and +# in Dick Lyon's Human and Machine Hearing. All stimuli are presented at native CARFAC v1/v2 +# input scaling of 107 dB SPL for RMS = 1. The plots are stored in /tmp, and # the easiest way to see them is to run the test on your machine; or in a # Colab such as google3/third_party/carfac/python/np/CARFAC_Testing.ipynb @@ -167,7 +168,7 @@ def test_design_fir_coeffs(self): carfac.design_fir_coeffs(5, 1, 1, 1) def test_car_freq_response(self): - cfp = carfac.design_carfac() + cfp = carfac.design_carfac(input_scale_dbspl=107.0,) carfac.carfac_init(cfp) # Show impulse response for just the CAR Filter bank. @@ -498,7 +499,7 @@ def test_spatial_smooth( def test_agc_steady_state(self): # Test: Steady state response # Analagous to figure 19.7 - cfp = carfac.design_carfac() + cfp = carfac.design_carfac(input_scale_dbspl=107.0) cf = carfac.carfac_init(cfp) test_channel = 40 @@ -559,7 +560,7 @@ def test_agc_steady_state(self): def test_stage_g_calculation(self): fs = 22050.0 - cfp = carfac.design_carfac(fs=fs) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0) # Set to true to save a large number of figures. do_plots = False # arange goes to just above 1 to ensure 1.0 is tested. @@ -594,13 +595,13 @@ def test_whole_carfac(self, ihc_style): fs = 22050.0 fp = 1000.0 # Probe tone t = np.arange(0, 2, 1 / fs) # 2s of tone - sinusoid = 1e-1 * np.sin(2 * np.pi * t * fp) * 10 ** ((107-94)/20) # scale tone to 94dB SPL @RMS=1 + sinusoid = 1e-1 * np.sin(2 * np.pi * t * fp) t = np.arange(0, 0.5, 1 / fs) impulse = np.zeros(t.shape) - impulse[0] = 1e-4 * 10 ** ((107-94)/20) # scale impulse to 94dB SPL @RMS=1 + impulse[0] = 1e-4 - cfp = carfac.design_carfac(fs=fs, ihc_style=ihc_style) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, ihc_style=ihc_style) cfp = carfac.carfac_init(cfp) _, cfp, bm_initial, _, _ = carfac.run_segment( @@ -815,9 +816,9 @@ def test_delay_buffer(self): fs = 22050.0 t = np.arange(0, 0.1, 1 / fs) # Short impulse input. impulse = np.zeros(t.shape) - impulse[0] = 1e-4 * 10 ** ((107-94)/20) # scale impulse to 94dB SPL @ RMS = 1 + impulse[0] = 1e-4 - cfp = carfac.design_carfac(fs=fs) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0) cfp = carfac.carfac_init(cfp) # Run the linear case with small impulse. _, cfp, bm_initial, _, _ = carfac.run_segment( @@ -847,7 +848,7 @@ def test_delay_buffer(self): self.assertLess(max_max_rel_error, 4e-4) # More tolerance than Matlab. Why? # Run the nonlinear case with a small impulse so not too nonlinear. - cfp = carfac.design_carfac(fs=fs) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0) cfp = carfac.carfac_init(cfp) cfp.ears[0].car_coeffs.use_delay_buffer = False _, cfp, bm_initial, _, _ = carfac.run_segment(cfp, impulse) @@ -873,8 +874,8 @@ def test_ohc_health(self): fs = 22050.0 t = np.arange(0, 1, 1 / fs) # A second of noise. amplitude = 1e-4 # -80 dBFS, around 20 or 30 dB SPL - noise = amplitude * np.random.randn(len(t)) * 10 ** ((107-94)/20) # scale noise to 94dB SPL @ RMS=1 - cfp = carfac.design_carfac(fs=fs) + noise = amplitude * np.random.randn(len(t)) + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0) cfp = carfac.carfac_init(cfp) # Run the healthy case with low-level noise. _, cfp, bm_baseline, _, _ = carfac.run_segment(cfp, noise) @@ -914,11 +915,11 @@ def test_multiaural_carfac(self): fs = 22050.0 t = np.arange(0, 1, 1 / fs) # A second of noise. amplitude = 1e-3 # -70 dBFS, around 30 or 40 dB SPL - noise = amplitude * np.random.randn(len(t)) * 10 ** ((107-94)/20) # scale noise to 94 dB SPL @ RMS=1 + noise = amplitude * np.random.randn(len(t)) two_chan_noise = np.zeros((len(t), 2)) two_chan_noise[:, 0] = noise two_chan_noise[:, 1] = noise - cfp = carfac.design_carfac(fs=fs, n_ears=2, ihc_style='one_cap') + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, n_ears=2, ihc_style='one_cap') cfp = carfac.carfac_init(cfp) naps, _, _, _, _ = carfac.run_segment(cfp, two_chan_noise) max_abs_diff = np.amax(np.abs(naps[:, :, 0] - naps[:, :, 1])) @@ -950,14 +951,14 @@ def test_multiaural_carfac_with_silent_channel(self): freqs = freqs.reshape(len(freqs), 1) c_major_chord = amplitude * np.sum( np.sin(2 * np.pi * np.matmul(freqs, t_prime)), 0 - ) * 10 ** ((107-94)/20) # scale chord to 94 dB SPL @ RMS=1 + ) two_chan_noise = np.zeros((len(t), 2)) two_chan_noise[:, 0] = c_major_chord # Leave the audio in channel 1 as silence. - cfp = carfac.design_carfac(fs=fs, n_ears=2, ihc_style='one_cap') + cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, n_ears=2, ihc_style='one_cap') cfp = carfac.carfac_init(cfp) - mono_cfp = carfac.design_carfac(fs=fs, n_ears=1, ihc_style='one_cap') + mono_cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, n_ears=1, ihc_style='one_cap') mono_cfp = carfac.carfac_init(mono_cfp) _, _, bm_binaural, _, _ = carfac.run_segment(cfp, two_chan_noise) From 22b02c0fdf0f09dd23353e4719aa54302ee22838 Mon Sep 17 00:00:00 2001 From: JasonMH17 Date: Wed, 4 Mar 2026 09:42:37 +1100 Subject: [PATCH 4/5] Updated description of input_waves in numpy carfac.py --- python/src/carfac/np/carfac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/carfac/np/carfac.py b/python/src/carfac/np/carfac.py index d527e229..e92f81e5 100644 --- a/python/src/carfac/np/carfac.py +++ b/python/src/carfac/np/carfac.py @@ -1615,7 +1615,7 @@ def run_segment( Args: cfp: a structure that descirbes everything we know about this CARFAC. - input_waves: the audio input in pascals. + input_waves: the audio input in pascals with default input_scale_dbspl. open_loop: whether to run CARFAC without the feedback. linear_car (new over Matlab): use CAR filters without OHC effects. From a5956f8ab7e7ce6fbe38a6ae6a54d38fbb8ceb86 Mon Sep 17 00:00:00 2001 From: JasonMH17 Date: Wed, 4 Mar 2026 10:01:55 +1100 Subject: [PATCH 5/5] Updated JAX CARFAC and its test to explain input_scale_dbspl more effectively. As in Numpy carfac_test.py, input scaling is reverted to native CARFAC v1/v2 107 dB SPL. Title added for Numpy/JAX carfac_test.py files to more easily differentiate --- python/src/carfac/jax/carfac.py | 34 ++++++++++++++-------------- python/src/carfac/jax/carfac_test.py | 21 +++++++++-------- python/src/carfac/np/carfac_test.py | 2 +- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/python/src/carfac/jax/carfac.py b/python/src/carfac/jax/carfac.py index 69060113..2c582cc3 100644 --- a/python/src/carfac/jax/carfac.py +++ b/python/src/carfac/jax/carfac.py @@ -944,7 +944,7 @@ def tree_unflatten(cls, _, children): class EarHypers: """Hyperparameters (tagged as static in `jax.jit`) of 1 ear.""" - input_scale: int + input_scale_dbspl: float n_ch: int pole_freqs: jnp.ndarray max_channels_per_octave: float @@ -957,7 +957,7 @@ class EarHypers: # Reference: https://jax.readthedocs.io/en/latest/pytrees.html def tree_flatten(self): # pylint: disable=missing-function-docstring children = ( - self.input_scale, + self.input_scale_dbspl, self.n_ch, self.pole_freqs, self.max_channels_per_octave, @@ -967,7 +967,7 @@ def tree_flatten(self): # pylint: disable=missing-function-docstring self.syn, ) aux_data = ( - 'input_scale', + 'input_scale_dbspl', 'n_ch', 'pole_freqs', 'max_channels_per_octave', @@ -989,7 +989,7 @@ def _eq_key(self): # assigned to a different array with exactly the same value. We think such # case should be very rare in usage. return ( - self.input_scale, + self.input_scale_dbspl, self.n_ch, id(self.pole_freqs), self.max_channels_per_octave, @@ -1059,24 +1059,24 @@ class CarfacDesignParameters: """All the parameters set manually for designing CARFAC.""" fs: float = 22050.0 - input_scale: int = 94 + input_scale_dbspl: float = 94 ears: List[EarDesignParameters] = dataclasses.field( default_factory=lambda: [EarDesignParameters()] ) - def __init__(self, fs=22050.0, input_scale=94, n_ears=1, use_delay_buffer=False): + def __init__(self, fs=22050.0, input_scale_dbspl=94, n_ears=1, use_delay_buffer=False): """Initialize the Design Parameters dataclass. Args: fs: Samples per second. - input_scale: scale for input waves. By default, input is expected in Pascals i.e. 94 dB SPL @ RMS=1, - while CARFAC input is considered as 107 dB SPL @ RMS=1 + input_scale_dbspl: scale in dB SPL for input waves (default: 94). The default value expects input in + pascals i.e. 94 dB SPL (for RMS=1), while CARFAC v1 and v2 use an input scale of 107 dB SPL (for RMS=1) n_ears: Number of ears to design for. use_delay_buffer: Whether to use the delay buffer implementation for the car_step. """ self.fs = fs - self.input_scale = input_scale + self.input_scale_dbspl = input_scale_dbspl self.ears = [ EarDesignParameters( car=CarDesignParameters(use_delay_buffer=use_delay_buffer) @@ -1771,7 +1771,7 @@ def design_and_init_carfac( state = CarfacState() for ear, ear_params in enumerate(params.ears): - input_scale = params.input_scale + input_scale_dbspl = params.input_scale_dbspl # first figure out how many filter stages (PZFC/CARFAC channels): pole_hz = ear_params.car.first_pole_theta * params.fs / (2 * math.pi) n_ch = 0 @@ -1807,7 +1807,7 @@ def design_and_init_carfac( ) ear_hypers = EarHypers( - input_scale=input_scale, + input_scale_dbspl=input_scale_dbspl, n_ch=n_ch, pole_freqs=pole_freqs, max_channels_per_octave=max_channels_per_octave, @@ -2424,8 +2424,8 @@ def run_segment( the input_waves are assumed to be sampled at the same rate as the CARFAC is designed for; a resampling may be needed before calling this. - input_waves are considered as being in Pascals and therefore should equal - 94db SPL where their RMS is 1. + input_waves are considered as being in pascals i.e. their level is + 94db SPL when their RMS equals 1. The function works as an outer iteration on time, updating all the filters and AGC states concurrently, so that the different channels can @@ -2433,7 +2433,7 @@ def run_segment( this level should be kept efficient. Args: - input_waves: the audio input in Pascals. + input_waves: the audio input in pascals with default input_scale_dbspl. hypers: all the coefficients of the model. It will be passed to all the JIT'ed functions as static variables. weights: all the trainable weights. It will not be changed. @@ -2453,7 +2453,7 @@ def run_segment( input_waves = jnp.reshape(input_waves, (-1, 1)) # scale input_waves from specified dB SPL @ RMS=1 to 107dB SPL @ RMS=1 - input_waves = input_waves * 10 ** ((hypers.ears[0].input_scale-107)/20) + input_waves = input_waves * 10 ** ((hypers.ears[0].input_scale_dbspl-107.)/20.) n_ears = input_waves.shape[1] n_fibertypes = SynDesignParameters.n_classes @@ -2611,7 +2611,7 @@ def run_segment_jit( the cache when needed. Args: - input_waves: the audio input in Pascals. + input_waves: the audio input in pascals with default input_scale_dbspl. hypers: all the coefficients of the model. It will be passed to all the JIT'ed functions as static variables. weights: all the trainable weights. It will not be changed. @@ -2653,7 +2653,7 @@ def run_segment_jit_in_chunks_notraceable( 20 percent on a regular CPU. Args: - input_waves: The audio input in Pascals. + input_waves: The audio input in pascals with defualt input_scale_dbspl. hypers: All the coefficients of the model. It will be passed to all the JIT'ed functions as static variables. weights: All the trainable weights. It will not be changed. diff --git a/python/src/carfac/jax/carfac_test.py b/python/src/carfac/jax/carfac_test.py index 9ae731f7..292108fd 100644 --- a/python/src/carfac/jax/carfac_test.py +++ b/python/src/carfac/jax/carfac_test.py @@ -1,3 +1,5 @@ +"""Tests for JAX carfac.""" + import collections import copy import numbers @@ -14,6 +16,7 @@ from carfac.jax import carfac as carfac_jax from carfac.np import carfac as carfac_np +# All stimuli are presented at native CARFAC v1/v2 input scaling of 107 dB SPL for RMS = 1. class CarfacJaxTest(parameterized.TestCase): @@ -101,7 +104,7 @@ def test_hypers_hash(self): hypers = carfac_jax.CarfacHypers() hypers.ears = [ carfac_jax.EarHypers( - input_scale=94, + input_scale_dbspl=94., # use default value of 94 dB SPL n_ch=0, pole_freqs=jnp.array([]), max_channels_per_octave=0.0, @@ -197,11 +200,11 @@ def container_comparison(self, left_side, right_side, exclude_keys=None): ) def test_equal_design(self, ihc_style): # Test: the designs are similar. - cfp = carfac_np.design_carfac(ihc_style=ihc_style) + cfp = carfac_np.design_carfac(input_scale_dbspl=107., ihc_style=ihc_style) carfac_np.carfac_init(cfp) cfp.ears[0].car_coeffs.linear = False - params_jax = carfac_jax.CarfacDesignParameters() + params_jax = carfac_jax.CarfacDesignParameters(input_scale_dbspl=107.) params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( @@ -378,7 +381,7 @@ def test_equal_design(self, ihc_style): def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): """Tests whether `run_segment` produces the same results as np version.""" # Inits JAX version - params_jax = carfac_jax.CarfacDesignParameters() + params_jax = carfac_jax.CarfacDesignParameters(input_scale_dbspl=107.) params_jax.ears[0].ihc.ihc_style = ihc_style params_jax.ears[0].car.linear_car = False hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac( @@ -392,8 +395,7 @@ def test_chunked_naps_same_as_jit(self, random_seed, ihc_style): n_samp = 200 n_ears = 1 random_generator = jax.random.PRNGKey(random_seed) - run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) * 10 ** ((107-94)/20) - # scale input to 94 SPL @ RMS=1 + run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) # Copy the state first. state_jax_copied = copy.deepcopy(state_jax) @@ -433,7 +435,7 @@ def test_equal_forward_pass( """Tests whether `run_segment` produces the same results as np version.""" # Inits JAX version params_jax = carfac_jax.CarfacDesignParameters( - n_ears=n_ears, use_delay_buffer=delay_buffer + input_scale_dbspl=107., n_ears=n_ears, use_delay_buffer=delay_buffer ) for ear in range(n_ears): params_jax.ears[ear].ihc.ihc_style = ihc_style @@ -443,7 +445,7 @@ def test_equal_forward_pass( ) # Inits numpy version cfp = carfac_np.design_carfac( - ihc_style=ihc_style, n_ears=n_ears, use_delay_buffer=delay_buffer + input_scale_dbspl=107., ihc_style=ihc_style, n_ears=n_ears, use_delay_buffer=delay_buffer ) carfac_np.carfac_init(cfp) @@ -456,8 +458,7 @@ def test_equal_forward_pass( # should be bigger than 64 (i.e. `prod(AgcDesignParameters.decimation)`). n_samp = 200 random_generator = jax.random.PRNGKey(random_seed) - run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) * 10 ** ((107-94)/20) - # scale input to 94 dB SPL @ RMS=1 + run_seg_input = jax.random.normal(random_generator, (n_samp, n_ears)) # Only tests the JITted version because this is what we will use. naps_jax, _, state_jax, bm_jax, seg_ohc_jax, seg_agc_jax = ( diff --git a/python/src/carfac/np/carfac_test.py b/python/src/carfac/np/carfac_test.py index ec00bd53..9b4ecc95 100644 --- a/python/src/carfac/np/carfac_test.py +++ b/python/src/carfac/np/carfac_test.py @@ -1,4 +1,4 @@ -"""Tests for carfac.""" +"""Tests for Numpy carfac.""" import math from typing import List, Tuple