Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ needle [command]
│ --checkpoint PATH Resume from checkpoint │
│ --checkpoint-dir DIR Checkpoint directory │
│ --seed INT Random seed (default: 42) │
│ --no-speech Disable speech (text-only training) │
│ --max-mel-len INT Max mel frames (default: 1024) │
│ --n-mels INT Mel frequency bins (default: 80) │
│ --max-speech-samples INT Max voice-tool-call samples │
│ --audio-aug-mode STR none|white|person|full (default:white)│
│ --white-noise-p FLOAT White-noise apply prob (default: 0.5) │
│ --white-noise-min-snr-db FLOAT Min SNR dB (default: 8.0) │
│ --white-noise-max-snr-db FLOAT Max SNR dB (default: 30.0) │
│ --person-noise-n INT Bg speaker clips per sample (def: 10) │
│ --person-noise-r1 FLOAT Min distance for person noise (3.0) │
│ --person-noise-r2 FLOAT Max distance for person noise (10.0) │
│ --person-noise-r-ref FL Reference distance for gain (1.0) │
│ --person-noise-min-snr-db FL Min target SNR dB (default: 15.0) │
│ --person-noise-max-snr-db FL Max target SNR dB (default: 40.0) │
│ --curriculum Sort batches easy->hard each epoch │
│ --contrastive-weight FL Contrastive loss weight (default: 0.1)│
│ --contrastive-dim INT Contrastive head dim (default: 128) │
│ │
│ tokenize │
│ --max-samples INT Limit samples per split (dev/test) │
Expand All @@ -177,11 +194,6 @@ needle [command]
│ --throughput-runs INT Throughput runs (default: 10) │
│ --tool-call-samples INT Tool-call eval samples (default: 200) │
│ │
│ evaluate │
│ --checkpoint PATH Path to model checkpoint (required) │
│ --benchmarks [...] wikitext2 lambada hellaswag arc_easy │
│ --max-samples INT Samples per benchmark (default: 500) │
│ │
│ tpu │
│ create NAME Create TPU (auto-finds zone) │
│ --type STR Accelerator (default: v6e-8) │
Expand Down
37 changes: 27 additions & 10 deletions src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,39 @@ def main():
help="Matryoshka FFN shrink factors, e.g. 2=half width (default: 2 4)")
p.add_argument("--dropout", type=float, default=0.1,
help="Dropout rate for residual connections (default: 0.1)")
p.add_argument("--no-speech", action="store_true", help="Disable speech training (text-only)")
p.add_argument("--max-mel-len", type=int, default=1024,
help="Max mel spectrogram frames (default: 1024)")
p.add_argument("--n-mels", type=int, default=80,
help="Number of mel frequency bins (default: 80)")
p.add_argument("--max-speech-samples", type=int, default=None,
help="Max voice-tool-call training samples (default: all)")
p.add_argument("--audio-aug-mode", type=str, default="white", choices=["none", "white", "person", "full"],
help="Speech augmentation mode for precomputed mels: none, white, person, or full (default: white)")
p.add_argument("--white-noise-p", type=float, default=0.5,
help="Probability of applying mel-white-noise per sample (default: 0.5)")
p.add_argument("--white-noise-min-snr-db", type=float, default=8.0,
help="Minimum white-noise SNR in dB (default: 8.0)")
p.add_argument("--white-noise-max-snr-db", type=float, default=30.0,
help="Maximum white-noise SNR in dB (default: 30.0)")
p.add_argument("--person-noise-n", type=int, default=10,
help="Number of background speaker clips to mix per sample (default: 10)")
p.add_argument("--person-noise-r1", type=float, default=3.0,
help="Minimum distance for person noise sampling (default: 3.0)")
p.add_argument("--person-noise-r2", type=float, default=10.0,
help="Maximum distance for person noise sampling (default: 10.0)")
p.add_argument("--person-noise-r-ref", type=float, default=1.0,
help="Reference distance used in distance gain computation (default: 1.0)")
p.add_argument("--person-noise-min-snr-db", type=float, default=15.0,
help="Minimum target SNR for person noise mixing (default: 15.0)")
p.add_argument("--person-noise-max-snr-db", type=float, default=40.0,
help="Maximum target SNR for person noise mixing (default: 40.0)")
p.add_argument("--curriculum", action="store_true",
help="Sort batches easy→hard by tool count each epoch")
p.add_argument("--contrastive-weight", type=float, default=0.1,
help="Weight for CLIP-style contrastive loss (default: 0.1)")
p.add_argument("--contrastive-dim", type=int, default=128,
help="Dimension of contrastive projection head (default: 128)")

p = sub.add_parser("tokenize", add_help=False)
p.add_argument("--max-samples", type=int, default=None,
help="Limit samples per split (for dev/test)")
Expand Down Expand Up @@ -94,12 +120,6 @@ def main():
help="Samples for tool-call accuracy eval (default: 200)")
p.add_argument("--throughput-runs", type=int, default=10)

p = sub.add_parser("evaluate", add_help=False)
p.add_argument("--checkpoint", type=str, required=True)
p.add_argument("--benchmarks", type=str, nargs="*",
choices=["wikitext2", "lambada", "hellaswag", "arc_easy"])
p.add_argument("--max-samples", type=int, default=500)

p = sub.add_parser("tpu", add_help=False)
tpu_sub = p.add_subparsers(dest="tpu_action")

Expand Down Expand Up @@ -149,9 +169,6 @@ def main():
elif args.command == "eval":
from .eval import main as eval_main_fn
eval_main_fn(args)
elif args.command == "evaluate":
from .evaluate import main as eval_main
eval_main(args)
elif args.command == "tpu":
from .tpu import tpu_dispatch
tpu_dispatch(args)
239 changes: 238 additions & 1 deletion src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _mark_json_value(s, char_w, key, value_str, weight):
val_end = val_start + len(value_str)
char_w[val_start:val_end] = np.maximum(char_w[val_start:val_end], weight)
return

pattern_ns = f'"{_re.escape(key)}"\\s*:\\s*{_re.escape(value_str)}'
for m in _re.finditer(pattern_ns, s):
colon_offset = s[m.start():m.end()].index(':')
Expand Down Expand Up @@ -976,6 +976,7 @@ def _load_optional(suffix):
"dec_targets": np.load(cache_path + "_dec_tgt.npy", mmap_mode=mmap_mode),
"loss_mask": np.load(cache_path + "_loss_mask.npy", mmap_mode=mmap_mode),
"kept_indices": np.load(cache_path + "_kept_idx.npy", mmap_mode=mmap_mode),
"mel_cache_id": meta.get("mel_cache_id"),
"tool_counts": tc,
"query_only": _load_optional("_query_only.npy"),
"tool_individual": _load_optional("_tool_individual.npy"),
Expand All @@ -984,6 +985,242 @@ def _load_optional(suffix):
}


def load_prepared_mels(mel_cache_id, mmap=False):
"""Load precomputed mel .npy file(s), optionally memory-mapped.

Supports both single-file and sharded (manifest-based) caches.
"""
cache_path = os.path.join(CACHE_DIR, mel_cache_id)
manifest_path = cache_path + "_mel_manifest.json"

if not os.path.exists(manifest_path):
_gcs_cache_download(mel_cache_id, ["_mel_manifest.json"])

if os.path.exists(manifest_path):
with open(manifest_path) as f:
manifest = _json.load(f)
n_shards = manifest["n_shards"]

_gcs_download_shards(mel_cache_id, n_shards, ["_mels"])

shard_paths = [os.path.join(CACHE_DIR, f"{mel_cache_id}_mels_{i:05d}.npy")
for i in range(n_shards)]

if mmap:
return ShardedMmapArray(shard_paths, mmap_mode="r")
else:
return np.concatenate([np.load(p) for p in shard_paths])

mel_file = cache_path + "_mels.npy"
if not os.path.exists(mel_file):
if not _gcs_cache_download(mel_cache_id, ["_mels.npy"]):
raise FileNotFoundError(
f"Mel cache '{mel_cache_id}' not found. Run 'needle tokenize' first."
)

mmap_mode = "r" if mmap else None
return np.load(mel_file, mmap_mode=mmap_mode)


def _fit_mel_frames_to_length(mel_frames, target_len, rng):
"""Crop or tile valid mel frames to target length."""
mel_frames = np.asarray(mel_frames, dtype=np.float32)
if target_len <= 0:
return np.zeros((0, mel_frames.shape[-1]), dtype=np.float32)
if mel_frames.shape[0] == 0:
return np.zeros((target_len, mel_frames.shape[-1]), dtype=np.float32)
if mel_frames.shape[0] == target_len:
return mel_frames
if mel_frames.shape[0] > target_len:
start = int(rng.integers(0, mel_frames.shape[0] - target_len + 1))
return mel_frames[start:start + target_len]
reps = int(np.ceil(target_len / mel_frames.shape[0]))
tiled = np.tile(mel_frames, (reps, 1))
return tiled[:target_len].astype(np.float32)


def _split_valid_mel_frames(mel):
"""Return (valid_frames, frame_mask) where mask marks non-padding rows."""
mel = np.asarray(mel, dtype=np.float32)
frame_mask = np.any(mel != 0.0, axis=1)
return mel[frame_mask], frame_mask


def _mix_log_mel_with_noise_power(clean_log_mel, noise_power, target_snr_db):
"""Mix additive noise power into log-mel features at a target SNR.

Assumes the STFT cross-term is zero in expectation:
|X + N|^2 ~= |X|^2 + |N|^2
so we add noise in mel-power domain and map back to log-mel.
"""
clean_log_mel = np.asarray(clean_log_mel, dtype=np.float32)
noise_power = np.asarray(noise_power, dtype=np.float32)
if clean_log_mel.size == 0 or noise_power.size == 0:
return clean_log_mel

clean_power = np.exp(np.clip(clean_log_mel, -30.0, 30.0))
noise_power = np.maximum(noise_power, 0.0)

signal_power = float(np.mean(clean_power))
noise_mean = float(np.mean(noise_power))
if signal_power <= 1e-12 or noise_mean <= 1e-12:
return clean_log_mel

desired_noise_power = signal_power / (10.0 ** (target_snr_db / 10.0))
noise_scale = desired_noise_power / noise_mean
mixed_power = clean_power + noise_power * noise_scale
return np.log(np.maximum(mixed_power, 1e-10)).astype(np.float32)


class MelWhiteNoiseAugmenter:
"""White-noise augmenter operating directly on log-mel batches.

Uses an independent random-phase approximation (zero expected cross-term)
and mixes in mel-power domain at a sampled SNR.
"""

def __init__(self, p=0.5, min_snr_db=8.0, max_snr_db=30.0, seed=None):
self.p = float(np.clip(p, 0.0, 1.0))
self.min_snr_db = float(min(min_snr_db, max_snr_db))
self.max_snr_db = float(max(min_snr_db, max_snr_db))
self._rng = np.random.default_rng(seed)
self.name = "white"
self.transforms = ("white_noise_mel",)

def __call__(self, mel_batch):
mel_batch = np.array(mel_batch, dtype=np.float32, copy=True)
for i in range(mel_batch.shape[0]):
if self._rng.random() > self.p:
continue
clean_valid, frame_mask = _split_valid_mel_frames(mel_batch[i])
if clean_valid.shape[0] == 0:
continue
snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db))
# Complex-white-noise power proxy: N_re^2 + N_im^2 (chi-square with 2 dof).
n_re = np.asarray(self._rng.standard_normal(clean_valid.shape), dtype=np.float32)
n_im = np.asarray(self._rng.standard_normal(clean_valid.shape), dtype=np.float32)
white_power = n_re * n_re + n_im * n_im
mel_batch[i, frame_mask, :] = _mix_log_mel_with_noise_power(clean_valid, white_power, snr_db)
return mel_batch


class MelPersonNoiseAugmenter:
"""Background person-noise augmenter operating on precomputed log-mels.

Mixes in mel-power domain with distance-weighted gains under the same
zero-cross-term approximation used for white noise.
"""

def __init__(self, mel_pool, n=10, r1=3.0, r2=10.0, r_ref=1.0,
min_snr_db=15.0, max_snr_db=40.0, seed=None):
if r1 <= 0 or r2 <= r1 or r_ref <= 0:
raise ValueError("Require r1 > 0, r2 > r1, and r_ref > 0 for person noise augmentation.")
self.pool = mel_pool
self.n = max(1, int(n))
self.r1 = float(r1)
self.r2 = float(r2)
self.r_ref = float(r_ref)
self.min_snr_db = float(min(min_snr_db, max_snr_db))
self.max_snr_db = float(max(min_snr_db, max_snr_db))
self._rng = np.random.default_rng(seed)
self.name = "person"
self.transforms = ("person_noise_mel",)

def __call__(self, mel_batch):
mel_batch = np.array(mel_batch, dtype=np.float32, copy=True)
pool_len = len(self.pool)
if pool_len == 0:
return mel_batch

for i in range(mel_batch.shape[0]):
clean_valid, frame_mask = _split_valid_mel_frames(mel_batch[i])
target_len = clean_valid.shape[0]
if target_len == 0:
continue

replace = pool_len < self.n
clip_ids = self._rng.choice(pool_len, size=self.n, replace=replace)
z = self._rng.uniform(0.0, 1.0, size=self.n)
r_z = np.sqrt(self.r1**2 + z * (self.r2**2 - self.r1**2))
gain_power = (self.r_ref / np.maximum(r_z, 1e-6)) ** 2

noise_power_sum = np.zeros_like(clean_valid, dtype=np.float32)
for clip_id, g_pow in zip(clip_ids, gain_power):
noise_mel = np.asarray(self.pool[int(clip_id)], dtype=np.float32)
noise_valid, _ = _split_valid_mel_frames(noise_mel)
if noise_valid.shape[0] == 0:
continue
noise_seg = _fit_mel_frames_to_length(noise_valid, target_len, self._rng)
noise_power_sum += float(g_pow) * np.exp(np.clip(noise_seg, -30.0, 30.0))

if np.mean(noise_power_sum) <= 1e-12:
continue

snr_db = float(self._rng.uniform(self.min_snr_db, self.max_snr_db))
mel_batch[i, frame_mask, :] = _mix_log_mel_with_noise_power(clean_valid, noise_power_sum, snr_db)

return mel_batch


def build_audio_augmenter(sr=16000, mode="white", white_noise_p=0.5,
white_noise_min_snr_db=8.0, white_noise_max_snr_db=30.0,
person_noise_n=10, person_noise_r1=3.0, person_noise_r2=10.0,
person_noise_r_ref=1.0, person_noise_min_snr_db=15.0,
person_noise_max_snr_db=40.0, person_noise_pool=None):
"""Build a mel-domain augmentation pipeline for training."""
del sr
mode = (mode or "none").lower()
if mode == "none":
return None

white_noise = MelWhiteNoiseAugmenter(
p=white_noise_p,
min_snr_db=white_noise_min_snr_db,
max_snr_db=white_noise_max_snr_db,
)
if mode == "white":
return white_noise

if mode == "person":
if person_noise_pool is None or len(person_noise_pool) == 0:
print(" WARNING: person noise requested, but no mel pool was provided")
return None
return MelPersonNoiseAugmenter(
mel_pool=person_noise_pool,
n=person_noise_n,
r1=person_noise_r1,
r2=person_noise_r2,
r_ref=person_noise_r_ref,
min_snr_db=person_noise_min_snr_db,
max_snr_db=person_noise_max_snr_db,
)

if mode == "full":
print(" WARNING: full waveform augmentation is not used on precomputed mels — using white mel noise")
return white_noise

raise ValueError(f"Unknown audio augmentation mode: {mode}")


def get_speech_batches(mel_data, dec_inputs, dec_targets, batch_size,
shuffle=True, loss_mask=None, augmenter=None):
"""Yield speech batches from precomputed mel data.

mel_data: array of shape (N, max_mel_len, n_mels), possibly memory-mapped.
augmenter: callable applied to each mel batch on-the-fly.
"""
n = len(mel_data)
indices = np.random.permutation(n) if shuffle else np.arange(n)

for i in range(0, n - batch_size + 1, batch_size):
idx = indices[i : i + batch_size]
batch_mel = np.array(mel_data[idx], dtype=np.float32)
if augmenter is not None:
batch_mel = augmenter(batch_mel)
batch = (batch_mel, np.array(dec_inputs[idx]), np.array(dec_targets[idx]))
if loss_mask is not None:
batch = batch + (np.array(loss_mask[idx]),)
yield batch
class PrefetchIterator:
"""Generic prefetch wrapper: runs any batch-generating callable in a background thread."""

Expand Down
Loading