Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions python/src/carfac/jax/carfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ def tree_unflatten(cls, _, children):
class EarHypers:
"""Hyperparameters (tagged as static in `jax.jit`) of 1 ear."""

input_scale_dbspl: float
n_ch: int
pole_freqs: jnp.ndarray
max_channels_per_octave: float
Expand All @@ -956,6 +957,7 @@ class EarHypers:
# Reference: https://jax.readthedocs.io/en/latest/pytrees.html
def tree_flatten(self): # pylint: disable=missing-function-docstring
children = (
self.input_scale_dbspl,
self.n_ch,
self.pole_freqs,
self.max_channels_per_octave,
Expand All @@ -965,6 +967,7 @@ def tree_flatten(self): # pylint: disable=missing-function-docstring
self.syn,
)
aux_data = (
'input_scale_dbspl',
'n_ch',
'pole_freqs',
'max_channels_per_octave',
Expand All @@ -986,6 +989,7 @@ def _eq_key(self):
# assigned to a different array with exactly the same value. We think such
# case should be very rare in usage.
return (
self.input_scale_dbspl,
self.n_ch,
id(self.pole_freqs),
self.max_channels_per_octave,
Expand Down Expand Up @@ -1055,20 +1059,24 @@ class CarfacDesignParameters:
"""All the parameters set manually for designing CARFAC."""

fs: float = 22050.0
input_scale_dbspl: float = 94
ears: List[EarDesignParameters] = dataclasses.field(
default_factory=lambda: [EarDesignParameters()]
)

def __init__(self, fs=22050.0, n_ears=1, use_delay_buffer=False):
def __init__(self, fs=22050.0, input_scale_dbspl=94, n_ears=1, use_delay_buffer=False):
"""Initialize the Design Parameters dataclass.

Args:
fs: Samples per second.
input_scale_dbspl: scale in dB SPL for input waves (default: 94). The default value expects input in
pascals i.e. 94 dB SPL (for RMS=1), while CARFAC v1 and v2 use an input scale of 107 dB SPL (for RMS=1)
n_ears: Number of ears to design for.
use_delay_buffer: Whether to use the delay buffer implementation for the
car_step.
"""
self.fs = fs
self.input_scale_dbspl = input_scale_dbspl
self.ears = [
EarDesignParameters(
car=CarDesignParameters(use_delay_buffer=use_delay_buffer)
Expand Down Expand Up @@ -1763,6 +1771,7 @@ def design_and_init_carfac(
state = CarfacState()

for ear, ear_params in enumerate(params.ears):
input_scale_dbspl = params.input_scale_dbspl
# first figure out how many filter stages (PZFC/CARFAC channels):
pole_hz = ear_params.car.first_pole_theta * params.fs / (2 * math.pi)
n_ch = 0
Expand Down Expand Up @@ -1798,6 +1807,7 @@ def design_and_init_carfac(
)

ear_hypers = EarHypers(
input_scale_dbspl=input_scale_dbspl,
n_ch=n_ch,
pole_freqs=pole_freqs,
max_channels_per_octave=max_channels_per_octave,
Expand Down Expand Up @@ -2414,13 +2424,16 @@ def run_segment(
the input_waves are assumed to be sampled at the same rate as the
CARFAC is designed for; a resampling may be needed before calling this.

input_waves are considered as being in pascals i.e. their level is
94db SPL when their RMS equals 1.

The function works as an outer iteration on time, updating all the
filters and AGC states concurrently, so that the different channels can
interact easily. The inner loops are over filterbank channels, and
this level should be kept efficient.

Args:
input_waves: the audio input.
input_waves: the audio input in pascals with default input_scale_dbspl.
hypers: all the coefficients of the model. It will be passed to all the
JIT'ed functions as static variables.
weights: all the trainable weights. It will not be changed.
Expand All @@ -2438,6 +2451,10 @@ def run_segment(
"""
if len(input_waves.shape) < 2:
input_waves = jnp.reshape(input_waves, (-1, 1))

# scale input_waves from specified dB SPL @ RMS=1 to 107dB SPL @ RMS=1
input_waves = input_waves * 10 ** ((hypers.ears[0].input_scale_dbspl-107.)/20.)

n_ears = input_waves.shape[1]
n_fibertypes = SynDesignParameters.n_classes

Expand Down Expand Up @@ -2594,7 +2611,7 @@ def run_segment_jit(
the cache when needed.

Args:
input_waves: the audio input.
input_waves: the audio input in pascals with default input_scale_dbspl.
hypers: all the coefficients of the model. It will be passed to all the
JIT'ed functions as static variables.
weights: all the trainable weights. It will not be changed.
Expand Down Expand Up @@ -2636,7 +2653,7 @@ def run_segment_jit_in_chunks_notraceable(
20 percent on a regular CPU.

Args:
input_waves: The audio input.
input_waves: The audio input in pascals with defualt input_scale_dbspl.
hypers: All the coefficients of the model. It will be passed to all the
JIT'ed functions as static variables.
weights: All the trainable weights. It will not be changed.
Expand Down
14 changes: 9 additions & 5 deletions python/src/carfac/jax/carfac_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Tests for JAX carfac."""

import collections
import copy
import numbers
Expand All @@ -14,6 +16,7 @@
from carfac.jax import carfac as carfac_jax
from carfac.np import carfac as carfac_np

# All stimuli are presented at native CARFAC v1/v2 input scaling of 107 dB SPL for RMS = 1.

class CarfacJaxTest(parameterized.TestCase):

Expand Down Expand Up @@ -101,6 +104,7 @@ def test_hypers_hash(self):
hypers = carfac_jax.CarfacHypers()
hypers.ears = [
carfac_jax.EarHypers(
input_scale_dbspl=94., # use default value of 94 dB SPL
n_ch=0,
pole_freqs=jnp.array([]),
max_channels_per_octave=0.0,
Expand Down Expand Up @@ -196,11 +200,11 @@ def container_comparison(self, left_side, right_side, exclude_keys=None):
)
def test_equal_design(self, ihc_style):
# Test: the designs are similar.
cfp = carfac_np.design_carfac(ihc_style=ihc_style)
cfp = carfac_np.design_carfac(input_scale_dbspl=107., ihc_style=ihc_style)
carfac_np.carfac_init(cfp)
cfp.ears[0].car_coeffs.linear = False

params_jax = carfac_jax.CarfacDesignParameters()
params_jax = carfac_jax.CarfacDesignParameters(input_scale_dbspl=107.)
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False
hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac(
Expand Down Expand Up @@ -377,7 +381,7 @@ def test_equal_design(self, ihc_style):
def test_chunked_naps_same_as_jit(self, random_seed, ihc_style):
"""Tests whether `run_segment` produces the same results as np version."""
# Inits JAX version
params_jax = carfac_jax.CarfacDesignParameters()
params_jax = carfac_jax.CarfacDesignParameters(input_scale_dbspl=107.)
params_jax.ears[0].ihc.ihc_style = ihc_style
params_jax.ears[0].car.linear_car = False
hypers_jax, weights_jax, state_jax = carfac_jax.design_and_init_carfac(
Expand Down Expand Up @@ -431,7 +435,7 @@ def test_equal_forward_pass(
"""Tests whether `run_segment` produces the same results as np version."""
# Inits JAX version
params_jax = carfac_jax.CarfacDesignParameters(
n_ears=n_ears, use_delay_buffer=delay_buffer
input_scale_dbspl=107., n_ears=n_ears, use_delay_buffer=delay_buffer
)
for ear in range(n_ears):
params_jax.ears[ear].ihc.ihc_style = ihc_style
Expand All @@ -441,7 +445,7 @@ def test_equal_forward_pass(
)
# Inits numpy version
cfp = carfac_np.design_carfac(
ihc_style=ihc_style, n_ears=n_ears, use_delay_buffer=delay_buffer
input_scale_dbspl=107., ihc_style=ihc_style, n_ears=n_ears, use_delay_buffer=delay_buffer
)

carfac_np.carfac_init(cfp)
Expand Down
14 changes: 13 additions & 1 deletion python/src/carfac/np/carfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,7 @@ class CarfacCoeffs:
@dataclasses.dataclass
class CarfacParams:
fs: float
input_scale_dbspl: float
max_channels_per_octave: float
car_params: CarParams
agc_params: AgcParams
Expand All @@ -1262,6 +1263,7 @@ class CarfacParams:


def design_carfac(
input_scale_dbspl: float = 94,
n_ears: int = 1,
fs: float = 22050,
car_params: Optional[CarParams] = None,
Expand Down Expand Up @@ -1290,6 +1292,8 @@ def design_carfac(
make 96 channels at default fs = 22050, 114 channels at 44100.

Args:
input_scale_dbspl: scale in dB SPL for input waves (default: 94 dB SPL). The default value expects input in
pascals i.e. 94 dB SPL (for RMS=1), while CARFAC v1 and v2 use an input scale of 107 dB SPL (for RMS=1)
n_ears: How many ears (1 or 2, in general) in the simulation
fs: is sample rate (per second)
car_params: bundles all the pole-zero filter cascade parameters
Expand Down Expand Up @@ -1362,6 +1366,7 @@ def design_carfac(

cfp = CarfacParams(
fs,
input_scale_dbspl,
max_channels_per_octave,
car_params,
agc_params,
Expand Down Expand Up @@ -1600,14 +1605,17 @@ def run_segment(
the input_waves are assumed to be sampled at the same rate as the
CARFAC is designed for; a resampling may be needed before calling this.

input_waves are considered as being in pascals i.e. their level is
94db SPL when their RMS equals 1.

The function works as an outer iteration on time, updating all the
filters and AGC states concurrently, so that the different channels can
interact easily. The inner loops are over filterbank channels, and
this level should be kept efficient.

Args:
cfp: a structure that descirbes everything we know about this CARFAC.
input_waves: the audio input
input_waves: the audio input in pascals with default input_scale_dbspl.
open_loop: whether to run CARFAC without the feedback.
linear_car (new over Matlab): use CAR filters without OHC effects.

Expand All @@ -1623,6 +1631,10 @@ def run_segment(

if len(input_waves.shape) < 2:
input_waves = np.reshape(input_waves, (-1, 1))

# scale input_waves from 94dB SPL @ RMS=1 to 107dB SPL @ RMS=1
input_waves = input_waves * 10 ** ((cfp.input_scale_dbspl-107.)/20.)

[n_samp, n_ears] = input_waves.shape

if n_ears != cfp.n_ears:
Expand Down
25 changes: 13 additions & 12 deletions python/src/carfac/np/carfac_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests for carfac."""
"""Tests for Numpy carfac."""

import math
from typing import List, Tuple
Expand All @@ -11,7 +11,8 @@
from carfac.np import carfac

# Note some of these tests create plots for easier comparison to the results
# in Dick Lyon's Human and Machine Hearing. The plots are stored in /tmp, and
# in Dick Lyon's Human and Machine Hearing. All stimuli are presented at native CARFAC v1/v2
# input scaling of 107 dB SPL for RMS = 1. The plots are stored in /tmp, and
# the easiest way to see them is to run the test on your machine; or in a
# Colab such as google3/third_party/carfac/python/np/CARFAC_Testing.ipynb

Expand Down Expand Up @@ -167,7 +168,7 @@ def test_design_fir_coeffs(self):
carfac.design_fir_coeffs(5, 1, 1, 1)

def test_car_freq_response(self):
cfp = carfac.design_carfac()
cfp = carfac.design_carfac(input_scale_dbspl=107.0,)
carfac.carfac_init(cfp)

# Show impulse response for just the CAR Filter bank.
Expand Down Expand Up @@ -498,7 +499,7 @@ def test_spatial_smooth(
def test_agc_steady_state(self):
# Test: Steady state response
# Analagous to figure 19.7
cfp = carfac.design_carfac()
cfp = carfac.design_carfac(input_scale_dbspl=107.0)
cf = carfac.carfac_init(cfp)

test_channel = 40
Expand Down Expand Up @@ -559,7 +560,7 @@ def test_agc_steady_state(self):

def test_stage_g_calculation(self):
fs = 22050.0
cfp = carfac.design_carfac(fs=fs)
cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0)
# Set to true to save a large number of figures.
do_plots = False
# arange goes to just above 1 to ensure 1.0 is tested.
Expand Down Expand Up @@ -600,7 +601,7 @@ def test_whole_carfac(self, ihc_style):
impulse = np.zeros(t.shape)
impulse[0] = 1e-4

cfp = carfac.design_carfac(fs=fs, ihc_style=ihc_style)
cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, ihc_style=ihc_style)
cfp = carfac.carfac_init(cfp)

_, cfp, bm_initial, _, _ = carfac.run_segment(
Expand Down Expand Up @@ -817,7 +818,7 @@ def test_delay_buffer(self):
impulse = np.zeros(t.shape)
impulse[0] = 1e-4

cfp = carfac.design_carfac(fs=fs)
cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0)
cfp = carfac.carfac_init(cfp)
# Run the linear case with small impulse.
_, cfp, bm_initial, _, _ = carfac.run_segment(
Expand Down Expand Up @@ -847,7 +848,7 @@ def test_delay_buffer(self):
self.assertLess(max_max_rel_error, 4e-4) # More tolerance than Matlab. Why?

# Run the nonlinear case with a small impulse so not too nonlinear.
cfp = carfac.design_carfac(fs=fs)
cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0)
cfp = carfac.carfac_init(cfp)
cfp.ears[0].car_coeffs.use_delay_buffer = False
_, cfp, bm_initial, _, _ = carfac.run_segment(cfp, impulse)
Expand All @@ -874,7 +875,7 @@ def test_ohc_health(self):
t = np.arange(0, 1, 1 / fs) # A second of noise.
amplitude = 1e-4 # -80 dBFS, around 20 or 30 dB SPL
noise = amplitude * np.random.randn(len(t))
cfp = carfac.design_carfac(fs=fs)
cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0)
cfp = carfac.carfac_init(cfp)
# Run the healthy case with low-level noise.
_, cfp, bm_baseline, _, _ = carfac.run_segment(cfp, noise)
Expand Down Expand Up @@ -918,7 +919,7 @@ def test_multiaural_carfac(self):
two_chan_noise = np.zeros((len(t), 2))
two_chan_noise[:, 0] = noise
two_chan_noise[:, 1] = noise
cfp = carfac.design_carfac(fs=fs, n_ears=2, ihc_style='one_cap')
cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, n_ears=2, ihc_style='one_cap')
cfp = carfac.carfac_init(cfp)
naps, _, _, _, _ = carfac.run_segment(cfp, two_chan_noise)
max_abs_diff = np.amax(np.abs(naps[:, :, 0] - naps[:, :, 1]))
Expand Down Expand Up @@ -955,9 +956,9 @@ def test_multiaural_carfac_with_silent_channel(self):
two_chan_noise = np.zeros((len(t), 2))
two_chan_noise[:, 0] = c_major_chord
# Leave the audio in channel 1 as silence.
cfp = carfac.design_carfac(fs=fs, n_ears=2, ihc_style='one_cap')
cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, n_ears=2, ihc_style='one_cap')
cfp = carfac.carfac_init(cfp)
mono_cfp = carfac.design_carfac(fs=fs, n_ears=1, ihc_style='one_cap')
mono_cfp = carfac.design_carfac(fs=fs, input_scale_dbspl=107.0, n_ears=1, ihc_style='one_cap')
mono_cfp = carfac.carfac_init(mono_cfp)

_, _, bm_binaural, _, _ = carfac.run_segment(cfp, two_chan_noise)
Expand Down