diff --git a/README.md b/README.md index 1b4aa48..4ddc320 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,23 @@ needle [command] │ --checkpoint PATH Resume from checkpoint │ │ --checkpoint-dir DIR Checkpoint directory │ │ --seed INT Random seed (default: 42) │ + │ --no-speech Disable speech (text-only training) │ + │ --max-mel-len INT Max mel frames (default: 1024) │ + │ --n-mels INT Mel frequency bins (default: 80) │ + │ --max-speech-samples INT Max voice-tool-call samples │ + │ --audio-aug-mode STR none|white|person|full (default:white)│ + │ --white-noise-p FLOAT White-noise apply prob (default: 0.5) │ + │ --white-noise-min-snr-db FLOAT Min SNR dB (default: 8.0) │ + │ --white-noise-max-snr-db FLOAT Max SNR dB (default: 30.0) │ + │ --person-noise-n INT Bg speaker clips per sample (def: 10) │ + │ --person-noise-r1 FLOAT Min distance for person noise (3.0) │ + │ --person-noise-r2 FLOAT Max distance for person noise (10.0) │ + │ --person-noise-r-ref FL Reference distance for gain (1.0) │ + │ --person-noise-min-snr-db FL Min target SNR dB (default: 15.0) │ + │ --person-noise-max-snr-db FL Max target SNR dB (default: 40.0) │ + │ --curriculum Sort batches easy->hard each epoch │ + │ --contrastive-weight FL Contrastive loss weight (default: 0.1)│ + │ --contrastive-dim INT Contrastive head dim (default: 128) │ │ │ │ tokenize │ │ --max-samples INT Limit samples per split (dev/test) │ @@ -177,11 +194,6 @@ needle [command] │ --throughput-runs INT Throughput runs (default: 10) │ │ --tool-call-samples INT Tool-call eval samples (default: 200) │ │ │ - │ evaluate │ - │ --checkpoint PATH Path to model checkpoint (required) │ - │ --benchmarks [...] wikitext2 lambada hellaswag arc_easy │ - │ --max-samples INT Samples per benchmark (default: 500) │ - │ │ │ tpu │ │ create NAME Create TPU (auto-finds zone) │ │ --type STR Accelerator (default: v6e-8) │ diff --git a/src/cli.py b/src/cli.py index 0943516..a43833c 100644 --- a/src/cli.py +++ b/src/cli.py @@ -50,13 +50,39 @@ def main(): help="Matryoshka FFN shrink factors, e.g. 2=half width (default: 2 4)") p.add_argument("--dropout", type=float, default=0.1, help="Dropout rate for residual connections (default: 0.1)") + p.add_argument("--no-speech", action="store_true", help="Disable speech training (text-only)") + p.add_argument("--max-mel-len", type=int, default=1024, + help="Max mel spectrogram frames (default: 1024)") + p.add_argument("--n-mels", type=int, default=80, + help="Number of mel frequency bins (default: 80)") + p.add_argument("--max-speech-samples", type=int, default=None, + help="Max voice-tool-call training samples (default: all)") + p.add_argument("--audio-aug-mode", type=str, default="white", choices=["none", "white", "person", "full"], + help="Speech augmentation mode for precomputed mels: none, white, person, or full (default: white)") + p.add_argument("--white-noise-p", type=float, default=0.5, + help="Probability of applying mel-white-noise per sample (default: 0.5)") + p.add_argument("--white-noise-min-snr-db", type=float, default=8.0, + help="Minimum white-noise SNR in dB (default: 8.0)") + p.add_argument("--white-noise-max-snr-db", type=float, default=30.0, + help="Maximum white-noise SNR in dB (default: 30.0)") + p.add_argument("--person-noise-n", type=int, default=10, + help="Number of background speaker clips to mix per sample (default: 10)") + p.add_argument("--person-noise-r1", type=float, default=3.0, + help="Minimum distance for person noise sampling (default: 3.0)") + p.add_argument("--person-noise-r2", type=float, default=10.0, + help="Maximum distance for person noise sampling (default: 10.0)") + p.add_argument("--person-noise-r-ref", type=float, default=1.0, + help="Reference distance used in distance gain computation (default: 1.0)") + p.add_argument("--person-noise-min-snr-db", type=float, default=15.0, + help="Minimum target SNR for person noise mixing (default: 15.0)") + p.add_argument("--person-noise-max-snr-db", type=float, default=40.0, + help="Maximum target SNR for person noise mixing (default: 40.0)") p.add_argument("--curriculum", action="store_true", help="Sort batches easy→hard by tool count each epoch") p.add_argument("--contrastive-weight", type=float, default=0.1, help="Weight for CLIP-style contrastive loss (default: 0.1)") p.add_argument("--contrastive-dim", type=int, default=128, help="Dimension of contrastive projection head (default: 128)") - p = sub.add_parser("tokenize", add_help=False) p.add_argument("--max-samples", type=int, default=None, help="Limit samples per split (for dev/test)") @@ -94,12 +120,6 @@ def main(): help="Samples for tool-call accuracy eval (default: 200)") p.add_argument("--throughput-runs", type=int, default=10) - p = sub.add_parser("evaluate", add_help=False) - p.add_argument("--checkpoint", type=str, required=True) - p.add_argument("--benchmarks", type=str, nargs="*", - choices=["wikitext2", "lambada", "hellaswag", "arc_easy"]) - p.add_argument("--max-samples", type=int, default=500) - p = sub.add_parser("tpu", add_help=False) tpu_sub = p.add_subparsers(dest="tpu_action") @@ -149,9 +169,6 @@ def main(): elif args.command == "eval": from .eval import main as eval_main_fn eval_main_fn(args) - elif args.command == "evaluate": - from .evaluate import main as eval_main - eval_main(args) elif args.command == "tpu": from .tpu import tpu_dispatch tpu_dispatch(args) diff --git a/src/data.py b/src/data.py index 9c40fd2..e69a31e 100644 --- a/src/data.py +++ b/src/data.py @@ -117,7 +117,7 @@ def _mark_json_value(s, char_w, key, value_str, weight): val_end = val_start + len(value_str) char_w[val_start:val_end] = np.maximum(char_w[val_start:val_end], weight) return - + pattern_ns = f'"{_re.escape(key)}"\\s*:\\s*{_re.escape(value_str)}' for m in _re.finditer(pattern_ns, s): colon_offset = s[m.start():m.end()].index(':') @@ -976,6 +976,7 @@ def _load_optional(suffix): "dec_targets": np.load(cache_path + "_dec_tgt.npy", mmap_mode=mmap_mode), "loss_mask": np.load(cache_path + "_loss_mask.npy", mmap_mode=mmap_mode), "kept_indices": np.load(cache_path + "_kept_idx.npy", mmap_mode=mmap_mode), + "mel_cache_id": meta.get("mel_cache_id"), "tool_counts": tc, "query_only": _load_optional("_query_only.npy"), "tool_individual": _load_optional("_tool_individual.npy"), @@ -984,6 +985,242 @@ def _load_optional(suffix): } +def load_prepared_mels(mel_cache_id, mmap=False): + """Load precomputed mel .npy file(s), optionally memory-mapped. + + Supports both single-file and sharded (manifest-based) caches. + """ + cache_path = os.path.join(CACHE_DIR, mel_cache_id) + manifest_path = cache_path + "_mel_manifest.json" + + if not os.path.exists(manifest_path): + _gcs_cache_download(mel_cache_id, ["_mel_manifest.json"]) + + if os.path.exists(manifest_path): + with open(manifest_path) as f: + manifest = _json.load(f) + n_shards = manifest["n_shards"] + + _gcs_download_shards(mel_cache_id, n_shards, ["_mels"]) + + shard_paths = [os.path.join(CACHE_DIR, f"{mel_cache_id}_mels_{i:05d}.npy") + for i in range(n_shards)] + + if mmap: + return ShardedMmapArray(shard_paths, mmap_mode="r") + else: + return np.concatenate([np.load(p) for p in shard_paths]) + + mel_file = cache_path + "_mels.npy" + if not os.path.exists(mel_file): + if not _gcs_cache_download(mel_cache_id, ["_mels.npy"]): + raise FileNotFoundError( + f"Mel cache '{mel_cache_id}' not found. Run 'needle tokenize' first." + ) + + mmap_mode = "r" if mmap else None + return np.load(mel_file, mmap_mode=mmap_mode) + + +def _fit_mel_frames_to_length(mel_frames, target_len, rng): + """Crop or tile valid mel frames to target length.""" + mel_frames = np.asarray(mel_frames, dtype=np.float32) + if target_len <= 0: + return np.zeros((0, mel_frames.shape[-1]), dtype=np.float32) + if mel_frames.shape[0] == 0: + return np.zeros((target_len, mel_frames.shape[-1]), dtype=np.float32) + if mel_frames.shape[0] == target_len: + return mel_frames + if mel_frames.shape[0] > target_len: + start = int(rng.integers(0, mel_frames.shape[0] - target_len + 1)) + return mel_frames[start:start + target_len] + reps = int(np.ceil(target_len / mel_frames.shape[0])) + tiled = np.tile(mel_frames, (reps, 1)) + return tiled[:target_len].astype(np.float32) + + +def _split_valid_mel_frames(mel): + """Return (valid_frames, frame_mask) where mask marks non-padding rows.""" + mel = np.asarray(mel, dtype=np.float32) + frame_mask = np.any(mel != 0.0, axis=1) + return mel[frame_mask], frame_mask + + +def _mix_log_mel_with_noise_power(clean_log_mel, noise_power, target_snr_db): + """Mix additive noise power into log-mel features at a target SNR. + + Assumes the STFT cross-term is zero in expectation: + |X + N|^2 ~= |X|^2 + |N|^2 + so we add noise in mel-power domain and map back to log-mel. + """ + clean_log_mel = np.asarray(clean_log_mel, dtype=np.float32) + noise_power = np.asarray(noise_power, dtype=np.float32) + if clean_log_mel.size == 0 or noise_power.size == 0: + return clean_log_mel + + clean_power = np.exp(np.clip(clean_log_mel, -30.0, 30.0)) + noise_power = np.maximum(noise_power, 0.0) + + signal_power = float(np.mean(clean_power)) + noise_mean = float(np.mean(noise_power)) + if signal_power <= 1e-12 or noise_mean <= 1e-12: + return clean_log_mel + + desired_noise_power = signal_power / (10.0 ** (target_snr_db / 10.0)) + noise_scale = desired_noise_power / noise_mean + mixed_power = clean_power + noise_power * noise_scale + return np.log(np.maximum(mixed_power, 1e-10)).astype(np.float32) + + +class MelWhiteNoiseAugmenter: + """White-noise augmenter operating directly on log-mel batches. + + Uses an independent random-phase approximation (zero expected cross-term) + and mixes in mel-power domain at a sampled SNR. + """ + + def __init__(self, p=0.5, min_snr_db=8.0, max_snr_db=30.0, seed=None): + self.p = float(np.clip(p, 0.0, 1.0)) + self.min_snr_db = float(min(min_snr_db, max_snr_db)) + self.max_snr_db = float(max(min_snr_db, max_snr_db)) + self._rng = np.random.default_rng(seed) + self.name = "white" + self.transforms = ("white_noise_mel",) + + def __call__(self, mel_batch): + mel_batch = np.array(mel_batch, dtype=np.float32, copy=True) + for i in range(mel_batch.shape[0]): + if self._rng.random() > self.p: + continue + clean_valid, frame_mask = _split_valid_mel_frames(mel_batch[i]) + if clean_valid.shape[0] == 0: + continue + snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db)) + # Complex-white-noise power proxy: N_re^2 + N_im^2 (chi-square with 2 dof). + n_re = np.asarray(self._rng.standard_normal(clean_valid.shape), dtype=np.float32) + n_im = np.asarray(self._rng.standard_normal(clean_valid.shape), dtype=np.float32) + white_power = n_re * n_re + n_im * n_im + mel_batch[i, frame_mask, :] = _mix_log_mel_with_noise_power(clean_valid, white_power, snr_db) + return mel_batch + + +class MelPersonNoiseAugmenter: + """Background person-noise augmenter operating on precomputed log-mels. + + Mixes in mel-power domain with distance-weighted gains under the same + zero-cross-term approximation used for white noise. + """ + + def __init__(self, mel_pool, n=10, r1=3.0, r2=10.0, r_ref=1.0, + min_snr_db=15.0, max_snr_db=40.0, seed=None): + if r1 <= 0 or r2 <= r1 or r_ref <= 0: + raise ValueError("Require r1 > 0, r2 > r1, and r_ref > 0 for person noise augmentation.") + self.pool = mel_pool + self.n = max(1, int(n)) + self.r1 = float(r1) + self.r2 = float(r2) + self.r_ref = float(r_ref) + self.min_snr_db = float(min(min_snr_db, max_snr_db)) + self.max_snr_db = float(max(min_snr_db, max_snr_db)) + self._rng = np.random.default_rng(seed) + self.name = "person" + self.transforms = ("person_noise_mel",) + + def __call__(self, mel_batch): + mel_batch = np.array(mel_batch, dtype=np.float32, copy=True) + pool_len = len(self.pool) + if pool_len == 0: + return mel_batch + + for i in range(mel_batch.shape[0]): + clean_valid, frame_mask = _split_valid_mel_frames(mel_batch[i]) + target_len = clean_valid.shape[0] + if target_len == 0: + continue + + replace = pool_len < self.n + clip_ids = self._rng.choice(pool_len, size=self.n, replace=replace) + z = self._rng.uniform(0.0, 1.0, size=self.n) + r_z = np.sqrt(self.r1**2 + z * (self.r2**2 - self.r1**2)) + gain_power = (self.r_ref / np.maximum(r_z, 1e-6)) ** 2 + + noise_power_sum = np.zeros_like(clean_valid, dtype=np.float32) + for clip_id, g_pow in zip(clip_ids, gain_power): + noise_mel = np.asarray(self.pool[int(clip_id)], dtype=np.float32) + noise_valid, _ = _split_valid_mel_frames(noise_mel) + if noise_valid.shape[0] == 0: + continue + noise_seg = _fit_mel_frames_to_length(noise_valid, target_len, self._rng) + noise_power_sum += float(g_pow) * np.exp(np.clip(noise_seg, -30.0, 30.0)) + + if np.mean(noise_power_sum) <= 1e-12: + continue + + snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db)) + mel_batch[i, frame_mask, :] = _mix_log_mel_with_noise_power(clean_valid, noise_power_sum, snr_db) + + return mel_batch + + +def build_audio_augmenter(sr=16000, mode="white", white_noise_p=0.5, + white_noise_min_snr_db=8.0, white_noise_max_snr_db=30.0, + person_noise_n=10, person_noise_r1=3.0, person_noise_r2=10.0, + person_noise_r_ref=1.0, person_noise_min_snr_db=15.0, + person_noise_max_snr_db=40.0, person_noise_pool=None): + """Build a mel-domain augmentation pipeline for training.""" + del sr + mode = (mode or "none").lower() + if mode == "none": + return None + + white_noise = MelWhiteNoiseAugmenter( + p=white_noise_p, + min_snr_db=white_noise_min_snr_db, + max_snr_db=white_noise_max_snr_db, + ) + if mode == "white": + return white_noise + + if mode == "person": + if person_noise_pool is None or len(person_noise_pool) == 0: + print(" WARNING: person noise requested, but no mel pool was provided") + return None + return MelPersonNoiseAugmenter( + mel_pool=person_noise_pool, + n=person_noise_n, + r1=person_noise_r1, + r2=person_noise_r2, + r_ref=person_noise_r_ref, + min_snr_db=person_noise_min_snr_db, + max_snr_db=person_noise_max_snr_db, + ) + + if mode == "full": + print(" WARNING: full waveform augmentation is not used on precomputed mels — using white mel noise") + return white_noise + + raise ValueError(f"Unknown audio augmentation mode: {mode}") + + +def get_speech_batches(mel_data, dec_inputs, dec_targets, batch_size, + shuffle=True, loss_mask=None, augmenter=None): + """Yield speech batches from precomputed mel data. + + mel_data: array of shape (N, max_mel_len, n_mels), possibly memory-mapped. + augmenter: callable applied to each mel batch on-the-fly. + """ + n = len(mel_data) + indices = np.random.permutation(n) if shuffle else np.arange(n) + + for i in range(0, n - batch_size + 1, batch_size): + idx = indices[i : i + batch_size] + batch_mel = np.array(mel_data[idx], dtype=np.float32) + if augmenter is not None: + batch_mel = augmenter(batch_mel) + batch = (batch_mel, np.array(dec_inputs[idx]), np.array(dec_targets[idx])) + if loss_mask is not None: + batch = batch + (np.array(loss_mask[idx]),) + yield batch class PrefetchIterator: """Generic prefetch wrapper: runs any batch-generating callable in a background thread.""" diff --git a/src/train.py b/src/train.py index 9eefc6f..a137301 100644 --- a/src/train.py +++ b/src/train.py @@ -13,8 +13,9 @@ from flax.training import train_state from .data import ( - get_batches, get_tokenizer, - load_prepared_data, load_tool_calls, + get_batches, get_tokenizer, get_speech_batches, + build_audio_augmenter, + load_prepared_data, load_prepared_mels, load_tool_calls, PrefetchIterator, count_batches, get_contrastive_batches, ) @@ -23,6 +24,7 @@ TransformerConfig, make_causal_mask, make_padding_mask, + make_mel_padding_mask, ) def _newton_schulz(G, steps=5): @@ -261,6 +263,29 @@ def _text_loss_fn(state, params, src, tgt_in, tgt_out, causal_mask, ffn_mask, rn return ce_loss + z_loss +def _speech_loss_fn(state, params, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask): + pad_id = 0 + q_params = _quantize_params(params, group_size=_GROUP_SIZE) + src_mask = make_mel_padding_mask(mel) + tgt_mask = causal_mask & make_padding_mask(tgt_in, pad_id) + spec_rng, drop_rng = jax.random.split(rng) + logits, _ = state.apply_fn( + {"params": q_params}, + mel, tgt_in, src_mask=src_mask, tgt_mask=tgt_mask, + ffn_mask=ffn_mask, + deterministic=False, + method="forward_speech_masked", + rngs={"specaugment": spec_rng, "dropout": drop_rng}, + ) + logits_f32 = logits.astype(jnp.float32) + mask = loss_mask + ce_loss = jnp.sum( + optax.softmax_cross_entropy_with_integer_labels(logits_f32, tgt_out) * mask + ) / jnp.maximum(jnp.sum(mask), 1.0) + z_loss = 1e-4 * jnp.mean(jax.nn.logsumexp(logits_f32, axis=-1) ** 2) + return ce_loss + z_loss + + def _contrastive_loss_fn(state, params, query_tokens, tool_tokens, rng): """Compute CLIP contrastive loss on query/tool pairs.""" q_params = _quantize_params(params, group_size=_GROUP_SIZE) @@ -322,6 +347,35 @@ def _train_step_text_masked(state, ema_params, src, tgt_in, tgt_out, causal_mask return state, ema_params, loss, grad_norm +def _train_step_speech(state, ema_params, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask): + ema_decay = 0.999 + loss, grads = jax.value_and_grad( + lambda p: _speech_loss_fn(state, p, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask) + )(state.params) + grads = jax.lax.pmean(grads, axis_name="batch") + loss = jax.lax.pmean(loss, axis_name="batch") + grad_norm = optax.global_norm(grads) + state = state.apply_gradients(grads=grads) + ema_params = jax.tree.map(lambda e, p: ema_decay * e + (1 - ema_decay) * p, ema_params, state.params) + return state, ema_params, loss, grad_norm + + +def _train_step_speech_masked(state, ema_params, mel, tgt_in, tgt_out, causal_mask, prune_mask, ffn_mask, rng, loss_mask): + """Speech training step with fused prune mask application.""" + ema_decay = 0.999 + loss, grads = jax.value_and_grad( + lambda p: _speech_loss_fn(state, p, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask) + )(state.params) + grads = jax.lax.pmean(grads, axis_name="batch") + loss = jax.lax.pmean(loss, axis_name="batch") + grad_norm = optax.global_norm(grads) + state = state.apply_gradients(grads=grads) + masked_params = jax.tree.map(lambda w, m: w * m, state.params, prune_mask) + state = state.replace(params=masked_params) + ema_params = jax.tree.map(lambda e, p: ema_decay * e + (1 - ema_decay) * p, ema_params, masked_params) + return state, ema_params, loss, grad_norm + + def _train_step_contrastive(state, ema_params, query_tokens, tool_tokens, rng): """Separate contrastive training step using CLIP loss.""" ema_decay = 0.999 @@ -346,6 +400,14 @@ def _make_p_train_step_masked(): return jax.pmap(_train_step_text_masked, axis_name="batch", donate_argnums=(0, 1)) +def _make_p_train_step_speech(): + return jax.pmap(_train_step_speech, axis_name="batch", donate_argnums=(0, 1)) + + +def _make_p_train_step_speech_masked(): + return jax.pmap(_train_step_speech_masked, axis_name="batch", donate_argnums=(0, 1)) + + def _make_p_train_step_contrastive(): return jax.pmap(_train_step_contrastive, axis_name="batch", donate_argnums=(0, 1)) @@ -384,6 +446,23 @@ def val_loss_batch(params, src, tgt_in, tgt_out, causal_mask, loss_mask): return val_loss_batch +def _make_speech_val_loss_fn(apply_fn): + @jax.jit + def val_loss_batch(params, mel, tgt_in, tgt_out, causal_mask, loss_mask): + pad_id = 0 + src_mask = make_mel_padding_mask(mel) + tgt_mask = causal_mask & make_padding_mask(tgt_in, pad_id) + logits, _, _ = apply_fn( + {"params": params}, mel, tgt_in, + src_mask=src_mask, tgt_mask=tgt_mask, + deterministic=True, + method="forward_speech_with_aux", + ) + loss = optax.softmax_cross_entropy_with_integer_labels(logits.astype(jnp.float32), tgt_out) + return jnp.sum(loss * loss_mask), jnp.sum(loss_mask) + return val_loss_batch + + def _estimate_mat_params(config, matryoshka_factor): """Estimate parameter count of a sub-model at a given matryoshka factor. @@ -415,6 +494,9 @@ def shard_batch(batch, num_devices): def train(args): num_devices = jax.local_device_count() + no_speech = getattr(args, "no_speech", False) + n_mels = getattr(args, "n_mels", 80) + max_mel_len = getattr(args, "max_mel_len", 1024) use_wandb = getattr(args, "wandb", False) if use_wandb: @@ -422,13 +504,19 @@ def train(args): if wandb.run is None: wandb.init(project="needle-v1", config=vars(args)) - print(f"\n[1/3] Detecting devices...") + total_data_steps = 4 if not no_speech else 3 + step_idx = 0 + + step_idx += 1 + print(f"\n[{step_idx}/{total_data_steps}] Detecting devices...") print(f" {num_devices} device(s) for data-parallel training") - print(f"\n[2/3] Loading tokenizer...") + step_idx += 1 + print(f"\n[{step_idx}/{total_data_steps}] Loading tokenizer...") tokenizer = get_tokenizer(max_samples=args.max_samples) - print(f"\n[3/3] Loading prepared data from disk (mmap)...") + step_idx += 1 + print(f"\n[{step_idx}/{total_data_steps}] Loading prepared data from disk (mmap)...") train_data = load_prepared_data("train", mmap=True) val_data = load_prepared_data("val", mmap=True) enc_inputs = train_data["enc_inputs"] @@ -442,6 +530,47 @@ def train(args): val_loss_mask = val_data["loss_mask"] print(f" {len(enc_inputs):,} train / {len(val_enc):,} val tool-call pairs (memory-mapped)") + train_mels = None + val_mels = None + speech_augmenter = None + if not no_speech: + step_idx += 1 + print(f"\n[{step_idx}/{total_data_steps}] Loading precomputed mel spectrograms (mmap)...") + train_mels = load_prepared_mels(train_data["mel_cache_id"], mmap=True) + val_mels = load_prepared_mels(val_data["mel_cache_id"], mmap=True) + print(f" {len(train_mels):,} train / {len(val_mels):,} val mel spectrograms (memory-mapped)") + + speech_augmenter = build_audio_augmenter( + sr=16000, + mode=getattr(args, "audio_aug_mode", "white"), + white_noise_p=getattr(args, "white_noise_p", 0.5), + white_noise_min_snr_db=getattr(args, "white_noise_min_snr_db", 8.0), + white_noise_max_snr_db=getattr(args, "white_noise_max_snr_db", 30.0), + person_noise_n=getattr(args, "person_noise_n", 10), + person_noise_r1=getattr(args, "person_noise_r1", 3.0), + person_noise_r2=getattr(args, "person_noise_r2", 10.0), + person_noise_r_ref=getattr(args, "person_noise_r_ref", 1.0), + person_noise_min_snr_db=getattr(args, "person_noise_min_snr_db", 15.0), + person_noise_max_snr_db=getattr(args, "person_noise_max_snr_db", 40.0), + person_noise_pool=train_mels, + ) + if speech_augmenter is not None: + aug_name = getattr(speech_augmenter, "name", getattr(args, "audio_aug_mode", "white")) + num_transforms = len(getattr(speech_augmenter, "transforms", ())) + if num_transforms > 0: + print(f" Speech augmentation: {aug_name} ({num_transforms} transforms)") + else: + print(f" Speech augmentation: {aug_name}") + + # Contrastive data (optional — graceful if missing) + cl_query_tokens = train_data.get("query_only") + cl_tool_tokens = train_data.get("tool_individual") + cl_tool_ex_idx = train_data.get("tool_ex_idx") + cl_tool_is_pos = train_data.get("tool_is_pos") + has_contrastive = all(x is not None for x in [cl_query_tokens, cl_tool_tokens, cl_tool_ex_idx, cl_tool_is_pos]) + if has_contrastive: + print(f" Contrastive: {len(cl_query_tokens):,} queries, {len(cl_tool_tokens):,} tools") + # Contrastive data (optional — graceful if missing) cl_query_tokens = train_data.get("query_only") cl_tool_tokens = train_data.get("tool_individual") @@ -473,6 +602,8 @@ def train(args): dtype=args.dtype, activation=getattr(args, "activation", "drelu"), num_memory_slots=getattr(args, "num_memory_slots", 64), + n_mels=n_mels, + dropout_rate=getattr(args, "dropout", 0.1), contrastive_dim=getattr(args, "contrastive_dim", 128), ) @@ -490,6 +621,8 @@ def train(args): n_widths = 1 + len(_MAT_FF_WIDTHS) if _MAT_FF_WIDTHS else 1 p_train_step = _make_p_train_step() p_train_step_masked = _make_p_train_step_masked() + p_train_step_speech = _make_p_train_step_speech() + p_train_step_speech_masked = _make_p_train_step_speech_masked() p_train_step_contrastive = _make_p_train_step_contrastive() np.random.seed(args.seed) @@ -498,7 +631,11 @@ def train(args): unique_batch_size = (effective_batch_size // num_devices) * num_devices text_batches_per_epoch = count_batches(len(enc_inputs), unique_batch_size) - num_batches = text_batches_per_epoch + if not no_speech and train_mels is not None: + speech_batches_per_epoch = text_batches_per_epoch + else: + speech_batches_per_epoch = 0 + num_batches = text_batches_per_epoch + speech_batches_per_epoch total_steps = num_batches * args.epochs warmup_steps = max(1, int(total_steps * args.warmup_ratio)) @@ -507,6 +644,7 @@ def train(args): decay_ratio = getattr(args, "decay_ratio", 0.40) state = create_train_state(init_rng, config, scaled_lr, muon_lr, total_steps, warmup_steps, decay_ratio) val_loss_fn = _make_val_loss_fn(state.apply_fn) + speech_vl_fn = _make_speech_val_loss_fn(state.apply_fn) if speech_batches_per_epoch > 0 else None if resume_checkpoint: state = state.replace(params=ckpt_params) @@ -529,6 +667,13 @@ def train(args): print(f" Layers {config.num_encoder_layers:>7} enc / {config.num_decoder_layers} dec") print(f" Activation {config.activation:>12}") print(f" Dtype {config.dtype:>12}") + print(f" Dropout {config.dropout_rate:>12}") + if speech_batches_per_epoch > 0: + print(f" Speech {speech_batches_per_epoch} batches/epoch") + print(f" n_mels {n_mels:>12}") + print(f" max_mel_len {max_mel_len:>12}") + else: + print(f" Speech disabled") print(f" ─────────────────────────────────────") print(f" Devices {num_devices:>12}") print(f" Batch {args.batch_size:>7} x {num_devices} = {effective_batch_size}") @@ -555,8 +700,8 @@ def train(args): print(f" Mat factors {n_widths} (full + {', '.join(str(f)+'x' for f in _MAT_FACTORS)})") print(f" Mat mode unique input ({args.batch_size}/dev, split by width)") - adam_schedule = _wsd_schedule(scaled_lr, total_steps, warmup_steps) - muon_schedule = _wsd_schedule(muon_lr, total_steps, warmup_steps) + adam_schedule = _wsd_schedule(scaled_lr, total_steps, warmup_steps, decay_ratio) + muon_schedule = _wsd_schedule(muon_lr, total_steps, warmup_steps, decay_ratio) tokens_per_batch = effective_batch_size * (args.max_enc_len + args.max_dec_len) eval_model = EncoderDecoderTransformer(config) @@ -581,6 +726,7 @@ def train(args): epoch_step = 0 text_losses = [] + speech_losses = [] _curriculum_tc = tool_counts if getattr(args, "curriculum", False) else None text_batch_iter = PrefetchIterator( lambda: get_batches(enc_inputs, dec_inputs, dec_targets, unique_batch_size, @@ -589,7 +735,14 @@ def train(args): prefetch=4, ) - # Contrastive batch iterator (cycles through contrastive data alongside text) + speech_batch_iter = None + if speech_batches_per_epoch > 0: + speech_batch_iter = PrefetchIterator( + lambda: get_speech_batches(train_mels, dec_inputs, dec_targets, unique_batch_size, + loss_mask=train_loss_mask, augmenter=speech_augmenter), + prefetch=4, + ) + cl_batch_iter = None if has_contrastive and _CONTRASTIVE_WEIGHT > 0: cl_batch_iter = PrefetchIterator( @@ -599,53 +752,95 @@ def train(args): prefetch=4, ) - pbar = tqdm(range(text_batches_per_epoch), desc=f"Epoch {epoch + 1}/{args.epochs}") + steps_this_epoch = text_batches_per_epoch + speech_batches_per_epoch + text_idx = 0 + speech_idx = 0 + speech_loss_val = None + pbar = tqdm(range(steps_this_epoch), desc=f"Epoch {epoch + 1}/{args.epochs}") for step_i in pbar: t0 = time.perf_counter() - src, tgt_in, tgt_out, lm = next(text_batch_iter) - - # Get contrastive batch (may be exhausted before text batches) - cl_q_b = None - cl_t_b = None - if cl_batch_iter is not None: - try: - cl_q, cl_t = next(cl_batch_iter) - cl_q_b = shard_batch(cl_q, num_devices) - cl_t_b = shard_batch(cl_t, num_devices) - except StopIteration: - cl_batch_iter = None - - src_b = shard_batch(src, num_devices) - tgt_in_b = shard_batch(tgt_in, num_devices) - tgt_out_b = shard_batch(tgt_out, num_devices) - lm_b = shard_batch(lm, num_devices) - - rng, text_rng = jax.random.split(rng) - text_rngs = jax.random.split(text_rng, num_devices) - - if prune_mask is not None: - state, ema_params, loss, grad_norm = p_train_step_masked( - state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, prune_mask, text_ffn_mask, text_rngs, lm_b, - ) + do_speech = (step_i % 2 == 1) and speech_idx < speech_batches_per_epoch + do_text = not do_speech and text_idx < text_batches_per_epoch + if not do_speech and not do_text: + if text_idx < text_batches_per_epoch: + do_text = True + elif speech_idx < speech_batches_per_epoch: + do_speech = True + else: + break + + step_grad_norm = None + + if do_text: + src, tgt_in, tgt_out, lm = next(text_batch_iter) + text_idx += 1 + + cl_q_b = None + cl_t_b = None + if cl_batch_iter is not None: + try: + cl_q, cl_t = next(cl_batch_iter) + cl_q_b = shard_batch(cl_q, num_devices) + cl_t_b = shard_batch(cl_t, num_devices) + except StopIteration: + cl_batch_iter = None + + src_b = shard_batch(src, num_devices) + tgt_in_b = shard_batch(tgt_in, num_devices) + tgt_out_b = shard_batch(tgt_out, num_devices) + lm_b = shard_batch(lm, num_devices) + + rng, text_rng = jax.random.split(rng) + text_rngs = jax.random.split(text_rng, num_devices) + + if prune_mask is not None: + state, ema_params, loss, grad_norm = p_train_step_masked( + state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, prune_mask, text_ffn_mask, text_rngs, lm_b, + ) + else: + state, ema_params, loss, grad_norm = p_train_step( + state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, text_ffn_mask, text_rngs, lm_b, + ) + + if cl_q_b is not None and cl_t_b is not None: + rng, cl_rng = jax.random.split(rng) + cl_rngs = jax.random.split(cl_rng, num_devices) + state, ema_params, cl_loss = p_train_step_contrastive( + state, ema_params, cl_q_b, cl_t_b, cl_rngs, + ) + + text_loss_val = float(loss[0]) + text_losses.append(text_loss_val) + step_grad_norm = float(grad_norm[0]) + global_step += 1 else: - state, ema_params, loss, grad_norm = p_train_step( - state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, text_ffn_mask, text_rngs, lm_b, - ) - - # Separate contrastive step (if data available) - if cl_q_b is not None and cl_t_b is not None: - rng, cl_rng = jax.random.split(rng) - cl_rngs = jax.random.split(cl_rng, num_devices) - state, ema_params, cl_loss = p_train_step_contrastive( - state, ema_params, cl_q_b, cl_t_b, cl_rngs, - ) - - text_loss_val = float(loss[0]) - text_losses.append(text_loss_val) - step_grad_norm = float(grad_norm[0]) - global_step += 1 + mel_batch, sp_tgt_in, sp_tgt_out, sp_lm = next(speech_batch_iter) + speech_idx += 1 + + mel_b = shard_batch(mel_batch, num_devices) + sp_tgt_in_b = shard_batch(sp_tgt_in, num_devices) + sp_tgt_out_b = shard_batch(sp_tgt_out, num_devices) + sp_lm_b = shard_batch(sp_lm, num_devices) + + rng, spec_rng = jax.random.split(rng) + spec_rngs = jax.random.split(spec_rng, num_devices) + + if prune_mask is not None: + state, ema_params, sp_loss, sp_grad_norm = p_train_step_speech_masked( + state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, prune_mask, text_ffn_mask, spec_rngs, sp_lm_b, + ) + else: + state, ema_params, sp_loss, sp_grad_norm = p_train_step_speech( + state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, text_ffn_mask, spec_rngs, sp_lm_b, + ) + + speech_loss_val = float(sp_loss[0]) + speech_losses.append(speech_loss_val) + step_grad_norm = float(sp_grad_norm[0]) + text_loss_val = text_losses[-1] if text_losses else float("nan") + global_step += 1 if epoch == weight_prune_epoch and not gradual_sparsify_done: epoch_step += 1 @@ -672,6 +867,7 @@ def train(args): del _eval_params postfix = { + "speech_loss": f"{speech_loss_val:.4f}" if speech_loss_val is not None else "-", "text_loss": f"{text_loss_val:.4f}", "text_ppl": f"{last_val_ppl:.2f}" if last_val_ppl is not None else "?", } @@ -691,6 +887,8 @@ def train(args): "train/tokens_per_sec": tokens_per_batch / dt, "train/step": global_step, } + if speech_loss_val is not None: + log_dict["train/speech_loss"] = speech_loss_val if epoch == weight_prune_epoch and not gradual_sparsify_done: log_dict["train/scheduled_sparsity"] = current_sparsity if global_step % eval_every == 0 or global_step == total_steps: @@ -698,6 +896,8 @@ def train(args): wandb.log(log_dict) text_batch_iter.close() + if speech_batch_iter is not None: + speech_batch_iter.close() if cl_batch_iter is not None: cl_batch_iter.close() @@ -760,6 +960,17 @@ def train(args): mat_results[f] = (float(math.exp(min(avg, 20))), _estimate_mat_params(config, f), config.d_ff // f) + speech_val_ppl = None + if speech_vl_fn is not None and val_mels is not None: + sp_total_loss, sp_total_toks = 0.0, 0.0 + for sp_batch in get_speech_batches(val_mels, val_dec_in, val_dec_tgt, args.batch_size, + shuffle=False, loss_mask=val_loss_mask): + vl, vt = speech_vl_fn(eval_params, sp_batch[0], sp_batch[1], sp_batch[2], val_causal, sp_batch[3]) + sp_total_loss += float(vl) + sp_total_toks += float(vt) + speech_val_loss = sp_total_loss / max(sp_total_toks, 1) + speech_val_ppl = float(math.exp(min(speech_val_loss, 20))) + params_np = jax.tree.map(np.array, eval_params) total_params = sum(x.size for x in jax.tree.leaves(params_np)) near_zero = sum(int(np.sum(np.abs(x) < 1e-6)) for x in jax.tree.leaves(params_np)) @@ -951,12 +1162,17 @@ def _call_key(c): ) del eval_params + final_speech_loss = speech_losses[-1] if speech_losses else None print(f"\n ─────────────────────────────────────") print(f" Epoch {epoch + 1}/{args.epochs}") print(f" ─────────────────────────────────────") print(f" Text loss {final_loss:>12.4f}") print(f" Text val ppl {last_val_ppl:>12.2f}") + if final_speech_loss is not None: + print(f" Speech loss {final_speech_loss:>12.4f}") + if speech_val_ppl is not None: + print(f" Speech val ppl {speech_val_ppl:>12.2f}") print(f" Quant val ppl {quant_val_ppl:>12.2f} (INT4 g{_GROUP_SIZE})") print(f" Sparsity {sparsity:>11.2f}% ({near_zero:,}/{total_params:,})") if mat_results: @@ -1010,6 +1226,10 @@ def _call_key(c): "epoch/weight_sparsity": sparsity, "epoch": epoch + 1, } + if final_speech_loss is not None: + log_dict["epoch/speech_loss"] = final_speech_loss + if speech_val_ppl is not None: + log_dict["epoch/speech_val_ppl"] = speech_val_ppl for factor, (mat_ppl, mat_params, _) in mat_results.items(): log_dict[f"epoch/mat_ppl_{factor}x"] = mat_ppl log_dict[f"epoch/mat_params_{factor}x"] = mat_params @@ -1033,5 +1253,3 @@ def _call_key(c): if best_ckpt_path: print(f"\nBest checkpoint (call_f1={best_call_f1:.1%}): {best_ckpt_path}") print("\nTraining complete.") - -