From 838393c35869a4045233a950996168f8e2686a37 Mon Sep 17 00:00:00 2001 From: ammesatyajit Date: Mon, 9 Mar 2026 02:38:08 -0700 Subject: [PATCH 1/4] white noise --- README.md | 4 +++ src/cli.py | 8 ++++++ src/data.py | 78 +++++++++++++++++++++++++++++++++++++++++++++------- src/train.py | 2 -- 4 files changed, 80 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 875a9c3..0f6869f 100644 --- a/README.md +++ b/README.md @@ -172,6 +172,10 @@ needle [command] │ --max-mel-len INT Max mel frames (default: 1024) │ │ --n-mels INT Mel frequency bins (default: 80) │ │ --max-speech-samples INT Max voice-tool-call samples │ + │ --audio-aug-mode STR none|white|full (default: white) │ + │ --white-noise-p FLOAT White-noise apply prob (default: 0.5) │ + │ --white-noise-min-snr-db FLOAT Min SNR dB (default: 8.0) │ + │ --white-noise-max-snr-db FLOAT Max SNR dB (default: 30.0) │ │ │ │ tokenize │ │ --max-samples INT Limit samples per split (dev/test) │ diff --git a/src/cli.py b/src/cli.py index 15a909e..e73659f 100644 --- a/src/cli.py +++ b/src/cli.py @@ -56,6 +56,14 @@ def main(): help="Number of mel frequency bins (default: 80)") p.add_argument("--max-speech-samples", type=int, default=None, help="Max voice-tool-call training samples (default: all)") + p.add_argument("--audio-aug-mode", type=str, default="white", choices=["none", "white", "full"], + help="Waveform augmentation mode: none, white, or full (default: white)") + p.add_argument("--white-noise-p", type=float, default=0.5, + help="Probability of applying white noise per sample (default: 0.5)") + p.add_argument("--white-noise-min-snr-db", type=float, default=8.0, + help="Minimum white-noise SNR in dB (default: 8.0)") + p.add_argument("--white-noise-max-snr-db", type=float, default=30.0, + help="Maximum white-noise SNR in dB (default: 30.0)") p = sub.add_parser("tokenize", add_help=False) p.add_argument("--max-samples", type=int, default=None, diff --git a/src/data.py b/src/data.py index 5d40484..680d4ed 100644 --- a/src/data.py +++ b/src/data.py @@ -988,20 +988,76 @@ def load_prepared_mels(mel_cache_id, mmap=False): return np.load(mel_file, mmap_mode=mmap_mode) -def build_audio_augmenter(sr=16000): - """Build an audiomentations augmentation pipeline for training. - - Returns an augmenter callable or None if audiomentations is unavailable. +class WhiteNoiseAugmenter: + """Simple white Gaussian noise augmenter with random SNR.""" + + def __init__(self, p=0.5, min_snr_db=8.0, max_snr_db=30.0, seed=None): + self.p = float(np.clip(p, 0.0, 1.0)) + self.min_snr_db = float(min(min_snr_db, max_snr_db)) + self.max_snr_db = float(max(min_snr_db, max_snr_db)) + self._rng = np.random.default_rng(seed) + self.name = "white" + self.transforms = ("white_noise",) + + def __call__(self, samples, sample_rate): + del sample_rate + audio = np.asarray(samples, dtype=np.float32) + if audio.size == 0 or self.p <= 0.0: + return audio + if self._rng.random() > self.p: + return audio + + signal_power = float(np.mean(audio * audio)) + if signal_power <= 1e-12: + return audio + + snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db)) + noise = np.asarray(self._rng.standard_normal(audio.shape), dtype=np.float32) + noise_power = float(np.mean(noise * noise)) + if noise_power <= 1e-12: + return audio + + target_noise_power = signal_power / (10.0 ** (snr_db / 10.0)) + scaled_noise = noise * np.sqrt(target_noise_power / noise_power) + mixed = audio + scaled_noise.astype(np.float32) + return np.clip(mixed, -1.0, 1.0).astype(np.float32) + + +def build_audio_augmenter(sr=16000, mode="white", white_noise_p=0.5, + white_noise_min_snr_db=8.0, white_noise_max_snr_db=30.0): + """Build a waveform augmentation pipeline for training. + + Args: + mode: "none", "white", or "full". """ + del sr + mode = (mode or "none").lower() + if mode == "none": + return None + + white_noise = WhiteNoiseAugmenter( + p=white_noise_p, + min_snr_db=white_noise_min_snr_db, + max_snr_db=white_noise_max_snr_db, + ) + if mode == "white": + return white_noise + + if mode != "full": + raise ValueError(f"Unknown audio augmentation mode: {mode}") + try: import audiomentations as A except ImportError: - print(" WARNING: audiomentations not installed — no waveform augmentation") - return None - - return A.Compose([ - A.AddGaussianSNR(min_snr_db=10.0, max_snr_db=35.0, p=0.5), - A.AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.3), + print(" WARNING: audiomentations not installed — falling back to white-noise-only augmentation") + return white_noise + + augmenter = A.Compose([ + A.AddGaussianSNR( + min_snr_db=white_noise.min_snr_db, + max_snr_db=white_noise.max_snr_db, + p=white_noise.p, + ), A.TimeStretch(min_rate=0.9, max_rate=1.1, p=0.3), A.PitchShift(min_semitones=-2, max_semitones=2, p=0.3), A.Gain(min_gain_db=-6, max_gain_db=6, p=0.4), @@ -1009,6 +1065,8 @@ def build_audio_augmenter(sr=16000): A.HighPassFilter(min_cutoff_freq=50, max_cutoff_freq=400, p=0.2), A.ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=5, p=0.1), ]) + augmenter.name = "full" + return augmenter def _load_mel_batch(audio_arrays, n_mels, max_mel_len, augmenter=None, sr=16000): diff --git a/src/train.py b/src/train.py index 0b26e4a..f2f9ec8 100644 --- a/src/train.py +++ b/src/train.py @@ -978,5 +978,3 @@ def _tile_sp(arr): if use_wandb: wandb.finish() print("\nTraining complete.") - - From a8602e97e833049e7e05ced31a8a04dc63498b44 Mon Sep 17 00:00:00 2001 From: ammesatyajit Date: Mon, 9 Mar 2026 04:34:45 -0700 Subject: [PATCH 2/4] add person noise augmentation caveat: it currently does not explicitly exclude the current utterance from being picked as background. --- README.md | 8 ++++- src/cli.py | 16 ++++++++-- src/data.py | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 110 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 0f6869f..38ddc45 100644 --- a/README.md +++ b/README.md @@ -172,10 +172,16 @@ needle [command] │ --max-mel-len INT Max mel frames (default: 1024) │ │ --n-mels INT Mel frequency bins (default: 80) │ │ --max-speech-samples INT Max voice-tool-call samples │ - │ --audio-aug-mode STR none|white|full (default: white) │ + │ --audio-aug-mode STR none|white|person|full (default:white)│ │ --white-noise-p FLOAT White-noise apply prob (default: 0.5) │ │ --white-noise-min-snr-db FLOAT Min SNR dB (default: 8.0) │ │ --white-noise-max-snr-db FLOAT Max SNR dB (default: 30.0) │ + │ --person-noise-n INT Bg speaker clips per sample (def: 10) │ + │ --person-noise-r1 FLOAT Min distance for person noise (3.0) │ + │ --person-noise-r2 FLOAT Max distance for person noise (10.0) │ + │ --person-noise-r-ref FL Reference distance for gain (1.0) │ + │ --person-noise-min-snr-db FL Min target SNR dB (default: 15.0) │ + │ --person-noise-max-snr-db FL Max target SNR dB (default: 40.0) │ │ │ │ tokenize │ │ --max-samples INT Limit samples per split (dev/test) │ diff --git a/src/cli.py b/src/cli.py index e73659f..0bcf392 100644 --- a/src/cli.py +++ b/src/cli.py @@ -56,14 +56,26 @@ def main(): help="Number of mel frequency bins (default: 80)") p.add_argument("--max-speech-samples", type=int, default=None, help="Max voice-tool-call training samples (default: all)") - p.add_argument("--audio-aug-mode", type=str, default="white", choices=["none", "white", "full"], - help="Waveform augmentation mode: none, white, or full (default: white)") + p.add_argument("--audio-aug-mode", type=str, default="white", choices=["none", "white", "person", "full"], + help="Waveform augmentation mode: none, white, person, or full (default: white)") p.add_argument("--white-noise-p", type=float, default=0.5, help="Probability of applying white noise per sample (default: 0.5)") p.add_argument("--white-noise-min-snr-db", type=float, default=8.0, help="Minimum white-noise SNR in dB (default: 8.0)") p.add_argument("--white-noise-max-snr-db", type=float, default=30.0, help="Maximum white-noise SNR in dB (default: 30.0)") + p.add_argument("--person-noise-n", type=int, default=10, + help="Number of background speaker clips to mix per sample (default: 10)") + p.add_argument("--person-noise-r1", type=float, default=3.0, + help="Minimum distance for person noise sampling (default: 3.0)") + p.add_argument("--person-noise-r2", type=float, default=10.0, + help="Maximum distance for person noise sampling (default: 10.0)") + p.add_argument("--person-noise-r-ref", type=float, default=1.0, + help="Reference distance used in distance gain computation (default: 1.0)") + p.add_argument("--person-noise-min-snr-db", type=float, default=15.0, + help="Minimum target SNR for person noise mixing (default: 15.0)") + p.add_argument("--person-noise-max-snr-db", type=float, default=40.0, + help="Maximum target SNR for person noise mixing (default: 40.0)") p = sub.add_parser("tokenize", add_help=False) p.add_argument("--max-samples", type=int, default=None, diff --git a/src/data.py b/src/data.py index 680d4ed..b35e94f 100644 --- a/src/data.py +++ b/src/data.py @@ -988,6 +988,22 @@ def load_prepared_mels(mel_cache_id, mmap=False): return np.load(mel_file, mmap_mode=mmap_mode) +def _fit_audio_to_length(wave, target_len, rng): + """Crop or tile a waveform to target length.""" + wave = np.asarray(wave, dtype=np.float32).flatten() + if target_len <= 0: + return np.zeros(0, dtype=np.float32) + if wave.size == 0: + return np.zeros(target_len, dtype=np.float32) + if wave.size == target_len: + return wave + if wave.size > target_len: + start = int(rng.integers(0, wave.size - target_len + 1)) + return wave[start:start + target_len] + reps = int(np.ceil(target_len / wave.size)) + return np.tile(wave, reps)[:target_len].astype(np.float32) + + class WhiteNoiseAugmenter: """Simple white Gaussian noise augmenter with random SNR.""" @@ -1023,12 +1039,69 @@ def __call__(self, samples, sample_rate): return np.clip(mixed, -1.0, 1.0).astype(np.float32) +class PersonNoiseAugmenter: + """Background speech augmentation with distance sampling and SNR control.""" + + def __init__(self, background_pool, n=10, r1=3.0, r2=10.0, r_ref=1.0, + min_snr_db=15.0, max_snr_db=40.0, seed=None): + if r1 <= 0 or r2 <= r1 or r_ref <= 0: + raise ValueError("Require r1 > 0, r2 > r1, and r_ref > 0 for person noise augmentation.") + pool = [] + for wave in background_pool: + arr = np.asarray(wave, dtype=np.float32).flatten() + if arr.size > 0: + pool.append(arr) + self.pool = pool + self.n = max(1, int(n)) + self.r1 = float(r1) + self.r2 = float(r2) + self.r_ref = float(r_ref) + self.min_snr_db = float(min(min_snr_db, max_snr_db)) + self.max_snr_db = float(max(min_snr_db, max_snr_db)) + self._rng = np.random.default_rng(seed) + self.name = "person" + self.transforms = ("person_noise",) + + def __call__(self, samples, sample_rate): + del sample_rate + audio = np.asarray(samples, dtype=np.float32).flatten() + if audio.size == 0 or len(self.pool) == 0: + return audio + + replace = len(self.pool) < self.n + clip_ids = self._rng.choice(len(self.pool), size=self.n, replace=replace) + + z = self._rng.uniform(0.0, 1.0, size=self.n) + r_z = np.sqrt(self.r1**2 + z * (self.r2**2 - self.r1**2)) + gains = self.r_ref / np.maximum(r_z, 1e-6) + + noise = np.zeros_like(audio, dtype=np.float32) + for clip_id, gain in zip(clip_ids, gains): + clip = _fit_audio_to_length(self.pool[int(clip_id)], len(audio), self._rng) + noise += float(gain) * clip + + clean_rms = float(np.sqrt(np.mean(audio * audio) + 1e-12)) + noise_rms = float(np.sqrt(np.mean(noise * noise) + 1e-12)) + if clean_rms <= 1e-8 or noise_rms <= 1e-8: + return audio + + target_snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db)) + desired_noise_rms = clean_rms / (10.0 ** (target_snr_db / 20.0)) + noise_scaled = noise * (desired_noise_rms / noise_rms) + + mixed = audio + noise_scaled.astype(np.float32) + return np.clip(mixed, -1.0, 1.0).astype(np.float32) + + def build_audio_augmenter(sr=16000, mode="white", white_noise_p=0.5, - white_noise_min_snr_db=8.0, white_noise_max_snr_db=30.0): + white_noise_min_snr_db=8.0, white_noise_max_snr_db=30.0, + person_noise_n=10, person_noise_r1=3.0, person_noise_r2=10.0, + person_noise_r_ref=1.0, person_noise_min_snr_db=15.0, + person_noise_max_snr_db=40.0, person_noise_pool=None): """Build a waveform augmentation pipeline for training. Args: - mode: "none", "white", or "full". + mode: "none", "white", "person", or "full". """ del sr mode = (mode or "none").lower() @@ -1043,6 +1116,20 @@ def build_audio_augmenter(sr=16000, mode="white", white_noise_p=0.5, if mode == "white": return white_noise + if mode == "person": + if person_noise_pool is None or len(person_noise_pool) == 0: + print(" WARNING: person noise requested, but no background speech pool was provided") + return None + return PersonNoiseAugmenter( + background_pool=person_noise_pool, + n=person_noise_n, + r1=person_noise_r1, + r2=person_noise_r2, + r_ref=person_noise_r_ref, + min_snr_db=person_noise_min_snr_db, + max_snr_db=person_noise_max_snr_db, + ) + if mode != "full": raise ValueError(f"Unknown audio augmentation mode: {mode}") From 2ca7d35203abd5121125024ebac90d5740c5cadc Mon Sep 17 00:00:00 2001 From: ammesatyajit Date: Mon, 9 Mar 2026 04:34:45 -0700 Subject: [PATCH 3/4] add person noise augmentation caveat: it currently does not explicitly exclude the current utterance from being picked as background. Signed-off-by: ammesatyajit --- src/cli.py | 1 - src/data.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/cli.py b/src/cli.py index 0bcf392..8e535f6 100644 --- a/src/cli.py +++ b/src/cli.py @@ -76,7 +76,6 @@ def main(): help="Minimum target SNR for person noise mixing (default: 15.0)") p.add_argument("--person-noise-max-snr-db", type=float, default=40.0, help="Maximum target SNR for person noise mixing (default: 40.0)") - p = sub.add_parser("tokenize", add_help=False) p.add_argument("--max-samples", type=int, default=None, help="Limit samples per split (for dev/test)") diff --git a/src/data.py b/src/data.py index b35e94f..61678ea 100644 --- a/src/data.py +++ b/src/data.py @@ -1004,6 +1004,22 @@ def _fit_audio_to_length(wave, target_len, rng): return np.tile(wave, reps)[:target_len].astype(np.float32) +def _fit_audio_to_length(wave, target_len, rng): + """Crop or tile a waveform to target length.""" + wave = np.asarray(wave, dtype=np.float32).flatten() + if target_len <= 0: + return np.zeros(0, dtype=np.float32) + if wave.size == 0: + return np.zeros(target_len, dtype=np.float32) + if wave.size == target_len: + return wave + if wave.size > target_len: + start = int(rng.integers(0, wave.size - target_len + 1)) + return wave[start:start + target_len] + reps = int(np.ceil(target_len / wave.size)) + return np.tile(wave, reps)[:target_len].astype(np.float32) + + class WhiteNoiseAugmenter: """Simple white Gaussian noise augmenter with random SNR.""" From 0510ec99bdca9e294a1a385146cd6f4eb363fcea Mon Sep 17 00:00:00 2001 From: ammesatyajit Date: Tue, 10 Mar 2026 00:38:33 -0700 Subject: [PATCH 4/4] Wire mel augmentation into active train path Signed-off-by: ammesatyajit --- src/cli.py | 4 +- src/data.py | 297 +++++++++++++++++++++------------------------------ src/train.py | 26 ++++- 3 files changed, 149 insertions(+), 178 deletions(-) diff --git a/src/cli.py b/src/cli.py index 8e535f6..3030aa2 100644 --- a/src/cli.py +++ b/src/cli.py @@ -57,9 +57,9 @@ def main(): p.add_argument("--max-speech-samples", type=int, default=None, help="Max voice-tool-call training samples (default: all)") p.add_argument("--audio-aug-mode", type=str, default="white", choices=["none", "white", "person", "full"], - help="Waveform augmentation mode: none, white, person, or full (default: white)") + help="Speech augmentation mode for precomputed mels: none, white, person, or full (default: white)") p.add_argument("--white-noise-p", type=float, default=0.5, - help="Probability of applying white noise per sample (default: 0.5)") + help="Probability of applying mel-white-noise per sample (default: 0.5)") p.add_argument("--white-noise-min-snr-db", type=float, default=8.0, help="Minimum white-noise SNR in dB (default: 8.0)") p.add_argument("--white-noise-max-snr-db", type=float, default=30.0, diff --git a/src/data.py b/src/data.py index 61678ea..d685c20 100644 --- a/src/data.py +++ b/src/data.py @@ -988,40 +988,62 @@ def load_prepared_mels(mel_cache_id, mmap=False): return np.load(mel_file, mmap_mode=mmap_mode) -def _fit_audio_to_length(wave, target_len, rng): - """Crop or tile a waveform to target length.""" - wave = np.asarray(wave, dtype=np.float32).flatten() +def _fit_mel_frames_to_length(mel_frames, target_len, rng): + """Crop or tile valid mel frames to target length.""" + mel_frames = np.asarray(mel_frames, dtype=np.float32) if target_len <= 0: - return np.zeros(0, dtype=np.float32) - if wave.size == 0: - return np.zeros(target_len, dtype=np.float32) - if wave.size == target_len: - return wave - if wave.size > target_len: - start = int(rng.integers(0, wave.size - target_len + 1)) - return wave[start:start + target_len] - reps = int(np.ceil(target_len / wave.size)) - return np.tile(wave, reps)[:target_len].astype(np.float32) - - -def _fit_audio_to_length(wave, target_len, rng): - """Crop or tile a waveform to target length.""" - wave = np.asarray(wave, dtype=np.float32).flatten() - if target_len <= 0: - return np.zeros(0, dtype=np.float32) - if wave.size == 0: - return np.zeros(target_len, dtype=np.float32) - if wave.size == target_len: - return wave - if wave.size > target_len: - start = int(rng.integers(0, wave.size - target_len + 1)) - return wave[start:start + target_len] - reps = int(np.ceil(target_len / wave.size)) - return np.tile(wave, reps)[:target_len].astype(np.float32) + return np.zeros((0, mel_frames.shape[-1]), dtype=np.float32) + if mel_frames.shape[0] == 0: + return np.zeros((target_len, mel_frames.shape[-1]), dtype=np.float32) + if mel_frames.shape[0] == target_len: + return mel_frames + if mel_frames.shape[0] > target_len: + start = int(rng.integers(0, mel_frames.shape[0] - target_len + 1)) + return mel_frames[start:start + target_len] + reps = int(np.ceil(target_len / mel_frames.shape[0])) + tiled = np.tile(mel_frames, (reps, 1)) + return tiled[:target_len].astype(np.float32) + + +def _split_valid_mel_frames(mel): + """Return (valid_frames, frame_mask) where mask marks non-padding rows.""" + mel = np.asarray(mel, dtype=np.float32) + frame_mask = np.any(mel != 0.0, axis=1) + return mel[frame_mask], frame_mask + + +def _mix_log_mel_with_noise_power(clean_log_mel, noise_power, target_snr_db): + """Mix additive noise power into log-mel features at a target SNR. + + Assumes the STFT cross-term is zero in expectation: + |X + N|^2 ~= |X|^2 + |N|^2 + so we add noise in mel-power domain and map back to log-mel. + """ + clean_log_mel = np.asarray(clean_log_mel, dtype=np.float32) + noise_power = np.asarray(noise_power, dtype=np.float32) + if clean_log_mel.size == 0 or noise_power.size == 0: + return clean_log_mel + + clean_power = np.exp(np.clip(clean_log_mel, -30.0, 30.0)) + noise_power = np.maximum(noise_power, 0.0) + + signal_power = float(np.mean(clean_power)) + noise_mean = float(np.mean(noise_power)) + if signal_power <= 1e-12 or noise_mean <= 1e-12: + return clean_log_mel + + desired_noise_power = signal_power / (10.0 ** (target_snr_db / 10.0)) + noise_scale = desired_noise_power / noise_mean + mixed_power = clean_power + noise_power * noise_scale + return np.log(np.maximum(mixed_power, 1e-10)).astype(np.float32) -class WhiteNoiseAugmenter: - """Simple white Gaussian noise augmenter with random SNR.""" +class MelWhiteNoiseAugmenter: + """White-noise augmenter operating directly on log-mel batches. + + Uses an independent random-phase approximation (zero expected cross-term) + and mixes in mel-power domain at a sampled SNR. + """ def __init__(self, p=0.5, min_snr_db=8.0, max_snr_db=30.0, seed=None): self.p = float(np.clip(p, 0.0, 1.0)) @@ -1029,45 +1051,37 @@ def __init__(self, p=0.5, min_snr_db=8.0, max_snr_db=30.0, seed=None): self.max_snr_db = float(max(min_snr_db, max_snr_db)) self._rng = np.random.default_rng(seed) self.name = "white" - self.transforms = ("white_noise",) - - def __call__(self, samples, sample_rate): - del sample_rate - audio = np.asarray(samples, dtype=np.float32) - if audio.size == 0 or self.p <= 0.0: - return audio - if self._rng.random() > self.p: - return audio - - signal_power = float(np.mean(audio * audio)) - if signal_power <= 1e-12: - return audio + self.transforms = ("white_noise_mel",) - snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db)) - noise = np.asarray(self._rng.standard_normal(audio.shape), dtype=np.float32) - noise_power = float(np.mean(noise * noise)) - if noise_power <= 1e-12: - return audio + def __call__(self, mel_batch): + mel_batch = np.array(mel_batch, dtype=np.float32, copy=True) + for i in range(mel_batch.shape[0]): + if self._rng.random() > self.p: + continue + clean_valid, frame_mask = _split_valid_mel_frames(mel_batch[i]) + if clean_valid.shape[0] == 0: + continue + snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db)) + # Complex-white-noise power proxy: N_re^2 + N_im^2 (chi-square with 2 dof). + n_re = np.asarray(self._rng.standard_normal(clean_valid.shape), dtype=np.float32) + n_im = np.asarray(self._rng.standard_normal(clean_valid.shape), dtype=np.float32) + white_power = n_re * n_re + n_im * n_im + mel_batch[i, frame_mask, :] = _mix_log_mel_with_noise_power(clean_valid, white_power, snr_db) + return mel_batch - target_noise_power = signal_power / (10.0 ** (snr_db / 10.0)) - scaled_noise = noise * np.sqrt(target_noise_power / noise_power) - mixed = audio + scaled_noise.astype(np.float32) - return np.clip(mixed, -1.0, 1.0).astype(np.float32) +class MelPersonNoiseAugmenter: + """Background person-noise augmenter operating on precomputed log-mels. -class PersonNoiseAugmenter: - """Background speech augmentation with distance sampling and SNR control.""" + Mixes in mel-power domain with distance-weighted gains under the same + zero-cross-term approximation used for white noise. + """ - def __init__(self, background_pool, n=10, r1=3.0, r2=10.0, r_ref=1.0, + def __init__(self, mel_pool, n=10, r1=3.0, r2=10.0, r_ref=1.0, min_snr_db=15.0, max_snr_db=40.0, seed=None): if r1 <= 0 or r2 <= r1 or r_ref <= 0: raise ValueError("Require r1 > 0, r2 > r1, and r_ref > 0 for person noise augmentation.") - pool = [] - for wave in background_pool: - arr = np.asarray(wave, dtype=np.float32).flatten() - if arr.size > 0: - pool.append(arr) - self.pool = pool + self.pool = mel_pool self.n = max(1, int(n)) self.r1 = float(r1) self.r2 = float(r2) @@ -1076,37 +1090,42 @@ def __init__(self, background_pool, n=10, r1=3.0, r2=10.0, r_ref=1.0, self.max_snr_db = float(max(min_snr_db, max_snr_db)) self._rng = np.random.default_rng(seed) self.name = "person" - self.transforms = ("person_noise",) - - def __call__(self, samples, sample_rate): - del sample_rate - audio = np.asarray(samples, dtype=np.float32).flatten() - if audio.size == 0 or len(self.pool) == 0: - return audio - - replace = len(self.pool) < self.n - clip_ids = self._rng.choice(len(self.pool), size=self.n, replace=replace) - - z = self._rng.uniform(0.0, 1.0, size=self.n) - r_z = np.sqrt(self.r1**2 + z * (self.r2**2 - self.r1**2)) - gains = self.r_ref / np.maximum(r_z, 1e-6) + self.transforms = ("person_noise_mel",) + + def __call__(self, mel_batch): + mel_batch = np.array(mel_batch, dtype=np.float32, copy=True) + pool_len = len(self.pool) + if pool_len == 0: + return mel_batch + + for i in range(mel_batch.shape[0]): + clean_valid, frame_mask = _split_valid_mel_frames(mel_batch[i]) + target_len = clean_valid.shape[0] + if target_len == 0: + continue - noise = np.zeros_like(audio, dtype=np.float32) - for clip_id, gain in zip(clip_ids, gains): - clip = _fit_audio_to_length(self.pool[int(clip_id)], len(audio), self._rng) - noise += float(gain) * clip + replace = pool_len < self.n + clip_ids = self._rng.choice(pool_len, size=self.n, replace=replace) + z = self._rng.uniform(0.0, 1.0, size=self.n) + r_z = np.sqrt(self.r1**2 + z * (self.r2**2 - self.r1**2)) + gain_power = (self.r_ref / np.maximum(r_z, 1e-6)) ** 2 + + noise_power_sum = np.zeros_like(clean_valid, dtype=np.float32) + for clip_id, g_pow in zip(clip_ids, gain_power): + noise_mel = np.asarray(self.pool[int(clip_id)], dtype=np.float32) + noise_valid, _ = _split_valid_mel_frames(noise_mel) + if noise_valid.shape[0] == 0: + continue + noise_seg = _fit_mel_frames_to_length(noise_valid, target_len, self._rng) + noise_power_sum += float(g_pow) * np.exp(np.clip(noise_seg, -30.0, 30.0)) - clean_rms = float(np.sqrt(np.mean(audio * audio) + 1e-12)) - noise_rms = float(np.sqrt(np.mean(noise * noise) + 1e-12)) - if clean_rms <= 1e-8 or noise_rms <= 1e-8: - return audio + if np.mean(noise_power_sum) <= 1e-12: + continue - target_snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db)) - desired_noise_rms = clean_rms / (10.0 ** (target_snr_db / 20.0)) - noise_scaled = noise * (desired_noise_rms / noise_rms) + snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db)) + mel_batch[i, frame_mask, :] = _mix_log_mel_with_noise_power(clean_valid, noise_power_sum, snr_db) - mixed = audio + noise_scaled.astype(np.float32) - return np.clip(mixed, -1.0, 1.0).astype(np.float32) + return mel_batch def build_audio_augmenter(sr=16000, mode="white", white_noise_p=0.5, @@ -1114,17 +1133,13 @@ def build_audio_augmenter(sr=16000, mode="white", white_noise_p=0.5, person_noise_n=10, person_noise_r1=3.0, person_noise_r2=10.0, person_noise_r_ref=1.0, person_noise_min_snr_db=15.0, person_noise_max_snr_db=40.0, person_noise_pool=None): - """Build a waveform augmentation pipeline for training. - - Args: - mode: "none", "white", "person", or "full". - """ + """Build a mel-domain augmentation pipeline for training.""" del sr mode = (mode or "none").lower() if mode == "none": return None - white_noise = WhiteNoiseAugmenter( + white_noise = MelWhiteNoiseAugmenter( p=white_noise_p, min_snr_db=white_noise_min_snr_db, max_snr_db=white_noise_max_snr_db, @@ -1134,10 +1149,10 @@ def build_audio_augmenter(sr=16000, mode="white", white_noise_p=0.5, if mode == "person": if person_noise_pool is None or len(person_noise_pool) == 0: - print(" WARNING: person noise requested, but no background speech pool was provided") + print(" WARNING: person noise requested, but no mel pool was provided") return None - return PersonNoiseAugmenter( - background_pool=person_noise_pool, + return MelPersonNoiseAugmenter( + mel_pool=person_noise_pool, n=person_noise_n, r1=person_noise_r1, r2=person_noise_r2, @@ -1146,97 +1161,29 @@ def build_audio_augmenter(sr=16000, mode="white", white_noise_p=0.5, max_snr_db=person_noise_max_snr_db, ) - if mode != "full": - raise ValueError(f"Unknown audio augmentation mode: {mode}") - - try: - import audiomentations as A - except ImportError: - print(" WARNING: audiomentations not installed — falling back to white-noise-only augmentation") + if mode == "full": + print(" WARNING: full waveform augmentation is not used on precomputed mels — using white mel noise") return white_noise - augmenter = A.Compose([ - A.AddGaussianSNR( - min_snr_db=white_noise.min_snr_db, - max_snr_db=white_noise.max_snr_db, - p=white_noise.p, - ), - A.TimeStretch(min_rate=0.9, max_rate=1.1, p=0.3), - A.PitchShift(min_semitones=-2, max_semitones=2, p=0.3), - A.Gain(min_gain_db=-6, max_gain_db=6, p=0.4), - A.LowPassFilter(min_cutoff_freq=3000, max_cutoff_freq=7500, p=0.2), - A.HighPassFilter(min_cutoff_freq=50, max_cutoff_freq=400, p=0.2), - A.ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=5, p=0.1), - ]) - augmenter.name = "full" - return augmenter - - -def _load_mel_batch(audio_arrays, n_mels, max_mel_len, augmenter=None, sr=16000): - """Compute mel spectrograms for a batch of audio arrays. - - If augmenter is provided, applies waveform augmentation before mel computation. - """ - mels = [] - for audio in audio_arrays: - audio = np.array(audio, dtype=np.float32) - if audio.ndim > 1: - audio = audio.mean(axis=1) - - if augmenter is not None: - audio = augmenter(samples=audio, sample_rate=sr) - - mel = compute_mel_spectrogram(audio, sr=sr, n_mels=n_mels) - - if mel.shape[0] > max_mel_len: - mel = mel[:max_mel_len] - elif mel.shape[0] < max_mel_len: - pad_len = max_mel_len - mel.shape[0] - mel = np.pad(mel, ((0, pad_len), (0, 0))) - - mels.append(mel) - - return np.stack(mels).astype(np.float32) - - -def _load_audio_batch(ds_indices): - """Load and decode audio for a batch of dataset indices.""" - import io - import soundfile as sf - - ds = _load_unified_dataset() - arrays = [] - for idx in ds_indices: - ex = ds[int(idx)] - audio_val = ex.get("audio") - raw_bytes = None - if isinstance(audio_val, dict): - raw_bytes = audio_val.get("bytes") - elif isinstance(audio_val, bytes): - raw_bytes = audio_val - if raw_bytes is None: - arrays.append(np.zeros(16000, dtype=np.float32)) - continue - audio_array, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32") - if audio_array.ndim > 1: - audio_array = audio_array.mean(axis=1) - arrays.append(audio_array.astype(np.float32)) - return arrays + raise ValueError(f"Unknown audio augmentation mode: {mode}") def get_speech_batches(mel_data, dec_inputs, dec_targets, batch_size, - shuffle=True, loss_mask=None): + shuffle=True, loss_mask=None, augmenter=None): """Yield speech batches from precomputed mel data. mel_data: array of shape (N, max_mel_len, n_mels), possibly memory-mapped. - Uses per-batch fancy indexing to avoid copying full arrays. + augmenter: callable applied to each mel batch on-the-fly. """ n = len(mel_data) indices = np.random.permutation(n) if shuffle else np.arange(n) for i in range(0, n - batch_size + 1, batch_size): idx = indices[i : i + batch_size] - batch = (np.array(mel_data[idx]), np.array(dec_inputs[idx]), np.array(dec_targets[idx])) + batch_mel = np.array(mel_data[idx], dtype=np.float32) + if augmenter is not None: + batch_mel = augmenter(batch_mel) + batch = (batch_mel, np.array(dec_inputs[idx]), np.array(dec_targets[idx])) if loss_mask is not None: batch = batch + (np.array(loss_mask[idx]),) yield batch diff --git a/src/train.py b/src/train.py index f2f9ec8..195a022 100644 --- a/src/train.py +++ b/src/train.py @@ -15,6 +15,7 @@ from .data import ( get_batches, get_tokenizer, get_speech_batches, + build_audio_augmenter, load_prepared_data, load_prepared_mels, load_example_with_audio, PrefetchIterator, count_batches, @@ -487,6 +488,7 @@ def train(args): train_mels = None val_mels = None + speech_augmenter = None if not no_speech: step_idx += 1 print(f"\n[{step_idx}/{total_data_steps}] Loading precomputed mel spectrograms (mmap)...") @@ -494,6 +496,28 @@ def train(args): val_mels = load_prepared_mels(val_data["mel_cache_id"], mmap=True) print(f" {len(train_mels):,} train / {len(val_mels):,} val mel spectrograms (memory-mapped)") + speech_augmenter = build_audio_augmenter( + sr=16000, + mode=getattr(args, "audio_aug_mode", "white"), + white_noise_p=getattr(args, "white_noise_p", 0.5), + white_noise_min_snr_db=getattr(args, "white_noise_min_snr_db", 8.0), + white_noise_max_snr_db=getattr(args, "white_noise_max_snr_db", 30.0), + person_noise_n=getattr(args, "person_noise_n", 10), + person_noise_r1=getattr(args, "person_noise_r1", 3.0), + person_noise_r2=getattr(args, "person_noise_r2", 10.0), + person_noise_r_ref=getattr(args, "person_noise_r_ref", 1.0), + person_noise_min_snr_db=getattr(args, "person_noise_min_snr_db", 15.0), + person_noise_max_snr_db=getattr(args, "person_noise_max_snr_db", 40.0), + person_noise_pool=train_mels, + ) + if speech_augmenter is not None: + aug_name = getattr(speech_augmenter, "name", getattr(args, "audio_aug_mode", "white")) + num_transforms = len(getattr(speech_augmenter, "transforms", ())) + if num_transforms > 0: + print(f" Speech augmentation: {aug_name} ({num_transforms} transforms)") + else: + print(f" Speech augmentation: {aug_name}") + effective_batch_size = args.batch_size * num_devices resume_checkpoint = getattr(args, "checkpoint", None) @@ -650,7 +674,7 @@ def train(args): if not no_speech and train_mels is not None: speech_batch_iter = PrefetchIterator( lambda: get_speech_batches(train_mels, dec_inputs, dec_targets, unique_batch_size, - loss_mask=train_loss_mask), + loss_mask=train_loss_mask, augmenter=speech_augmenter), prefetch=4, )