diff --git a/.claude/skills/wer-wolf/SKILL.md b/.claude/skills/wer-wolf/SKILL.md new file mode 100644 index 0000000..1827bf7 --- /dev/null +++ b/.claude/skills/wer-wolf/SKILL.md @@ -0,0 +1,86 @@ +--- +name: wer-wolf +description: Benchmark Zee's saved audio samples against every available STT provider/model (named for Word Error Rate — the canonical STT eval metric). Use when the user wants to compare transcription quality across Groq Whisper, OpenAI, Mistral Voxtral, ElevenLabs Scribe, etc., on their own saved recordings, evaluate which model handles their domain vocabulary best, or audit how well hints.txt biasing works per provider. +--- + +# wer-wolf — STT bake-off for Zee samples + +Run every saved sample under `~/Library/Application Support/zee/samples/` through every STT provider Zee supports (that has an API key) and present a side-by-side comparison. + +## What this does + +1. Lists every sample directory under `~/Library/Application Support/zee/samples/`. +2. For each sample, reads `info.json` (original provider/model, original transcribed text, timestamp) and stats `audio.` for KB size + format. +3. Loops over `(provider, model)` pairs whose API key env var is set, swapping `config.json` for each, calling `./zee -transcribe `. Hints are read automatically from `hints.txt` by all five providers. +4. Restores the original `config.json` at the end (also on error — script uses a trap). +5. Renders the result as one block per sample: metadata header followed by an ASCII table of every model's output. + +## Pre-flight checks + +Before running, verify: + +- **Zee binary**. Find the running process first (most accurate, matches what the user is actually using): + ```bash + lsof -c zee 2>/dev/null | awk '$4=="txt" && $NF ~ /zee$/ {print $NF; exit}' + ``` + Fall back to `/Users/supo/Desktop/p/zee/zee` if no process. If neither exists, ask the user where the binary is. +- **Samples directory** exists and has at least one `2026-*` subdir. If empty, tell the user to enable `ZEE_SAVE_LAST_AUDIO=1` and capture some recordings first. +- **API keys**. Print which of `GROQ_API_KEY`, `OPENAI_API_KEY`, `MISTRAL_API_KEY`, `ELEVENLABS_API_KEY` are set; the script auto-skips providers whose key is missing. Deepgram is skipped entirely (its only model `nova-3` is streaming, not compatible with batch `-transcribe`). +- **Tray app warning**. Tell the user not to interact with the running tray app's menu during the run — it caches `config.json` in memory and will overwrite the file on the next menu interaction. Backup is restored at the end either way, but mid-run interference can corrupt results. + +## Running it + +```bash +ZEE_BIN= bash ~/.claude/skills/zee-transcribe-compare/scripts/compare.sh +``` + +The script: +- Writes raw results to `/tmp/zee-compare-results.txt`. +- Writes machine-readable per-sample/per-model JSON lines to `/tmp/zee-compare-results.jsonl` (each line: `{sample, provider, model, text, error?}`). Use this for the comparison table — easier than re-parsing the human-readable file. + +## How to render the result + +Start with a header showing the active vocabulary so the user can see what biasing the providers received: + +``` +**Hints in effect** (`~/Library/Application Support/zee/hints.txt`): + +``` + +Then for each sample directory (sorted by timestamp), produce one block: + +``` +### KB — recorded +**Originally transcribed by:** / +**Original text:** "" + +| Provider / Model | Transcription | +|--------------------------|------------------------------------------| +| groq / whisper-v3-turbo | ... | +| ... | ... | +``` + +Mark the row matching the original `(provider, model)` with `*` after the model name so the user can quickly see the baseline. + +If a model errored (network, 4xx, etc.), put the error message in the cell instead of the transcription. + +## Models tested per provider + +(Source of truth: `transcriber/*.go` in the zee repo. Update this list if Zee adds models.) + +| Provider | Model(s) | Hint field | +|--------------|--------------------------------------------|-------------------------| +| groq | `whisper-large-v3-turbo`, `whisper-large-v3` | `prompt` | +| openai | `gpt-4o-transcribe` | `prompt` | +| mistral | `voxtral-mini-latest` | `context_bias[]` | +| elevenlabs | `scribe_v2` | `keyterms[]` | +| deepgram | `nova-3` (streaming-only — skipped) | streaming keyterms | + +All five wire `hints.txt` automatically — no flag required. Each provider receives the same hints joined as a single comma-separated string; how aggressively each provider biases varies a lot in practice (Whisper-via-Groq honors it strongly; Mistral and Scribe much less so based on observed runs). + +## Notes / gotchas + +- Each `-transcribe` call is one HTTP round-trip (~1–3s). 4 samples × 4 models ≈ 12–20s wall time plus API latency. Costs are tiny but real. +- `voxtral-mini-latest` is a moving target — if Mistral renames it, update the script. +- The script intentionally does **not** parallelize providers — keeps output ordered and avoids rate-limit surprises. +- `-transcribe` exits non-zero on transcription errors; the script captures stdout+stderr and continues so one failure doesn't abort the matrix. diff --git a/.claude/skills/wer-wolf/scripts/compare.sh b/.claude/skills/wer-wolf/scripts/compare.sh new file mode 100755 index 0000000..69a3596 --- /dev/null +++ b/.claude/skills/wer-wolf/scripts/compare.sh @@ -0,0 +1,142 @@ +#!/usr/bin/env bash +# zee-transcribe-compare — run every saved Zee sample through every available +# STT provider and emit both human-readable and JSONL results. +# +# Required: ZEE_BIN (absolute path to zee binary). Falls back to discovering +# the running process or /Users/supo/Desktop/p/zee/zee. +set -u + +ZEE_BIN="${ZEE_BIN:-}" +if [ -z "$ZEE_BIN" ]; then + ZEE_BIN=$(lsof -c zee 2>/dev/null | awk '$4=="txt" && $NF ~ /zee$/ {print $NF; exit}') +fi +[ -z "$ZEE_BIN" ] && ZEE_BIN="/Users/supo/Desktop/p/zee/zee" +if [ ! -x "$ZEE_BIN" ]; then + echo "ERROR: zee binary not found or not executable: $ZEE_BIN" >&2 + exit 1 +fi + +ZEE_DIR="$HOME/Library/Application Support/zee" +CFG="$ZEE_DIR/config.json" +SAMPLES="$ZEE_DIR/samples" +BACKUP="/tmp/zee-config-backup.$$.json" +HUMAN=/tmp/zee-compare-results.txt +JSONL=/tmp/zee-compare-results.jsonl + +[ -f "$CFG" ] || { echo "ERROR: $CFG missing" >&2; exit 1; } +[ -d "$SAMPLES" ] || { echo "ERROR: $SAMPLES missing" >&2; exit 1; } + +cp "$CFG" "$BACKUP" +restore() { cp "$BACKUP" "$CFG" 2>/dev/null; rm -f "$BACKUP"; } +trap restore EXIT INT TERM + +: > "$HUMAN" +: > "$JSONL" + +# Header: active hints (what each provider receives as biasing). +HINTS_FILE="$ZEE_DIR/hints.txt" +{ + echo "########## hints.txt ##########" + if [ -f "$HINTS_FILE" ]; then + grep -vE '^\s*(#|$)' "$HINTS_FILE" | paste -sd, - + else + echo "(no hints.txt found — providers receive no biasing)" + fi + echo +} | tee -a "$HUMAN" + +# (provider, model, env_var) — keep groq turbo first since it's the fastest baseline. +# Update this list when zee adds providers/models (transcriber/*.go). +COMBOS=( + "groq|whisper-large-v3-turbo|GROQ_API_KEY" + "groq|whisper-large-v3|GROQ_API_KEY" + "openai|gpt-4o-transcribe|OPENAI_API_KEY" + "mistral|voxtral-mini-latest|MISTRAL_API_KEY" + "elevenlabs|scribe_v2|ELEVENLABS_API_KEY" +) + +# JSON string escaper for the JSONL output. Python is on every macOS. +jesc() { python3 -c 'import json,sys; print(json.dumps(sys.stdin.read()), end="")'; } + +write_cfg() { + local prov="$1" model="$2" + cat > "$CFG" <&2 + exit 1 +fi + +for combo in "${COMBOS[@]}"; do + IFS='|' read -r prov model envk <<< "$combo" + if [ -z "${!envk:-}" ]; then + echo "########## SKIP $prov/$model ($envk not set) ##########" | tee -a "$HUMAN" + continue + fi + echo "########## $prov / $model ##########" | tee -a "$HUMAN" + write_cfg "$prov" "$model" + for d in "${SAMPLE_DIRS[@]}"; do + sample=$(basename "$d") + audio=$(find "$d" -maxdepth 1 -type f -name "audio.*" | head -1) + [ -z "$audio" ] && continue + echo "----- $sample -----" | tee -a "$HUMAN" + text=$(timeout 45 "$ZEE_BIN" -transcribe "$audio" 2>&1) + rc=$? + echo "$text" | tee -a "$HUMAN" + # Emit JSONL (one line per sample/model). Trim trailing newline from text first. + text_trimmed=$(printf '%s' "$text") + if [ $rc -ne 0 ]; then + printf '{"sample":%s,"provider":%s,"model":%s,"error":%s}\n' \ + "$(printf '%s' "$sample" | jesc)" \ + "$(printf '%s' "$prov" | jesc)" \ + "$(printf '%s' "$model" | jesc)" \ + "$(printf '%s' "$text_trimmed" | jesc)" >> "$JSONL" + else + printf '{"sample":%s,"provider":%s,"model":%s,"text":%s}\n' \ + "$(printf '%s' "$sample" | jesc)" \ + "$(printf '%s' "$prov" | jesc)" \ + "$(printf '%s' "$model" | jesc)" \ + "$(printf '%s' "$text_trimmed" | jesc)" >> "$JSONL" + fi + done +done + +# Also dump per-sample metadata so the rendering step doesn't need to re-stat files. +META=/tmp/zee-compare-samples.jsonl +: > "$META" +for d in "${SAMPLE_DIRS[@]}"; do + sample=$(basename "$d") + info="$d/info.json" + audio=$(find "$d" -maxdepth 1 -type f -name "audio.*" | head -1) + [ -z "$audio" ] && continue + size_kb=$(awk -v b="$(stat -f%z "$audio")" 'BEGIN{printf "%.1f", b/1024}') + ext="${audio##*.}" + python3 -c " +import json,sys +info=json.load(open('$info')) +info['sample']='$sample' +info['size_kb']=$size_kb +info['ext']='$ext' +print(json.dumps(info)) +" >> "$META" +done + +echo +echo "DONE — config restored." +echo " samples meta: $META" +echo " raw output: $HUMAN" +echo " per-cell: $JSONL" diff --git a/.gitignore b/.gitignore index e86d9f5..00e7e6e 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,3 @@ transcribe_log.txt # OS .DS_Store - -# Claude -.claude/ diff --git a/CLAUDE.md b/CLAUDE.md index 906566c..50b937d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -43,6 +43,11 @@ make benchmark WAV=file.wav RUNS=5 - `-benchmark ` - run benchmark instead of live recording - `-runs N` - benchmark iterations (default: 3) - `-logpath ` - log directory (default: `$ZEE_LOG_PATH` or OS-specific, use `./` for current directory) +- `-hints ` - comma-separated vocabulary hints (overrides `hints.txt`) +- `-transcribe ` - transcribe an audio file (mp3/flac/wav) and exit + +**Environment variables:** +- `ZEE_SAVE_LAST_AUDIO=1` - enables "Save Last Recording" tray button (saves audio + metadata to `config/samples/`) ## Architecture @@ -56,7 +61,7 @@ Ctrl+Shift+Space keydown → record audio → encode (mode-based) → API call - `main.go` - hotkey handling, audio capture, recording logic, panic recovery - `tray/` - system tray icon, menus (devices, providers, languages, auto-paste), dynamic icons - `encoder/` - AudioEncoder interface, FLAC, MP3, and Adaptive implementations -- `transcriber/` - Groq and DeepGram API clients with shared TracedClient for HTTP timing metrics +- `transcriber/` - STT providers (Groq, OpenAI, Deepgram, Mistral, ElevenLabs) with shared TracedClient for HTTP timing metrics - `hotkey/` - global hotkey registration (Ctrl+Shift+Space) with platform-specific backends - `clipboard/` - platform-specific clipboard and paste operations (Cmd+V / Ctrl+V) - `audio/` - platform-specific audio capture (malgo on macOS, PulseAudio on Linux) @@ -66,7 +71,7 @@ Ctrl+Shift+Space keydown → record audio → encode (mode-based) → API call - `device.go` - microphone picker with arrow-key navigation - `vad.go` - voice activity detection using WebRTC VAD with debounced speech confirmation - `silence.go` - silence monitoring with warnings, repeat beeps, and auto-close (toggle mode) -- `settings.go` - persistent settings (language, device, provider/model, auto-paste, auto-start) with JSON config file +- `config/` - persistent settings (`config.json`) and vocabulary hints (`hints.txt`) - `log.go` - diagnostic logging and panic capture to `diagnostics_log.txt` ## Design Philosophy diff --git a/settings.go b/config/config.go similarity index 73% rename from settings.go rename to config/config.go index 01323c6..c3dd740 100644 --- a/settings.go +++ b/config/config.go @@ -1,4 +1,4 @@ -package main +package config import ( "encoding/json" @@ -22,19 +22,20 @@ type Settings struct { const settingsFile = "config.json" var ( - settingsMu sync.Mutex - current Settings - cfgDir string + mu sync.Mutex + current Settings + dir string + defaults = Settings{ + Language: "en", + AutoPaste: true, + } ) -var settingsDefaults = Settings{ - Language: "en", - AutoPaste: true, -} +func SetDir(d string) { dir = d } -func settingsDir() string { - if cfgDir != "" { - return cfgDir +func Dir() string { + if dir != "" { + return dir } home, err := os.UserHomeDir() if err != nil { @@ -58,12 +59,12 @@ func settingsDir() string { } func settingsPath() string { - return filepath.Join(settingsDir(), settingsFile) + return filepath.Join(Dir(), settingsFile) } -func loadSettings() error { - cfgDir = settingsDir() - current = settingsDefaults +func Load() error { + dir = Dir() + current = defaults data, err := os.ReadFile(settingsPath()) if err != nil { @@ -81,30 +82,30 @@ func loadSettings() error { current = s if current.Language == "" { - current.Language = settingsDefaults.Language + current.Language = defaults.Language } return nil } -func getSettings() Settings { - settingsMu.Lock() +func Get() Settings { + mu.Lock() s := current - settingsMu.Unlock() + mu.Unlock() return s } -func updateSettings(fn func(*Settings)) { - settingsMu.Lock() +func Update(fn func(*Settings)) { + mu.Lock() fn(¤t) s := current - settingsMu.Unlock() + mu.Unlock() - saveSettings(s) + save(s) } -func saveSettings(s Settings) { - dir := cfgDir - if err := os.MkdirAll(dir, 0755); err != nil { +func save(s Settings) { + d := dir + if err := os.MkdirAll(d, 0755); err != nil { log.Warnf("settings: create dir: %v", err) return } @@ -116,7 +117,7 @@ func saveSettings(s Settings) { } data = append(data, '\n') - tmp, err := os.CreateTemp(dir, ".config-*.json") + tmp, err := os.CreateTemp(d, ".config-*.json") if err != nil { log.Warnf("settings: create temp: %v", err) return diff --git a/settings_test.go b/config/config_test.go similarity index 63% rename from settings_test.go rename to config/config_test.go index 8536a45..0546398 100644 --- a/settings_test.go +++ b/config/config_test.go @@ -1,4 +1,4 @@ -package main +package config import ( "os" @@ -8,13 +8,12 @@ import ( ) func TestSettingsDefaults(t *testing.T) { - cfgDir = t.TempDir() - current = Settings{} + SetDir(t.TempDir()) - if err := loadSettings(); err != nil { - t.Fatalf("loadSettings: %v", err) + if err := Load(); err != nil { + t.Fatalf("Load: %v", err) } - s := getSettings() + s := Get() if s.Language != "en" { t.Errorf("Language = %q, want %q", s.Language, "en") } @@ -27,14 +26,13 @@ func TestSettingsDefaults(t *testing.T) { } func TestSettingsRoundTrip(t *testing.T) { - cfgDir = t.TempDir() - current = Settings{} + SetDir(t.TempDir()) - if err := loadSettings(); err != nil { - t.Fatalf("loadSettings: %v", err) + if err := Load(); err != nil { + t.Fatalf("Load: %v", err) } - updateSettings(func(s *Settings) { + Update(func(s *Settings) { s.Language = "fr" s.Device = "Blue Yeti" s.Provider = "groq" @@ -43,13 +41,11 @@ func TestSettingsRoundTrip(t *testing.T) { s.AutoStart = true }) - // Re-load from disk - current = Settings{} - if err := loadSettings(); err != nil { - t.Fatalf("loadSettings after update: %v", err) + if err := Load(); err != nil { + t.Fatalf("Load after update: %v", err) } - s := getSettings() + s := Get() if s.Language != "fr" { t.Errorf("Language = %q, want %q", s.Language, "fr") } @@ -71,47 +67,47 @@ func TestSettingsRoundTrip(t *testing.T) { } func TestSettingsCopySafety(t *testing.T) { - cfgDir = t.TempDir() - current = settingsDefaults + SetDir(t.TempDir()) + Load() - s := getSettings() + s := Get() s.Language = "xx" - s2 := getSettings() + s2 := Get() if s2.Language == "xx" { t.Error("mutating returned Settings affected internal state") } } func TestSettingsCorruptFile(t *testing.T) { - cfgDir = t.TempDir() - current = Settings{} + d := t.TempDir() + SetDir(d) - os.WriteFile(filepath.Join(cfgDir, settingsFile), []byte("not json{{{"), 0644) + os.WriteFile(filepath.Join(d, "config.json"), []byte("not json{{{"), 0644) - if err := loadSettings(); err != nil { - t.Fatalf("loadSettings should not error on corrupt file: %v", err) + if err := Load(); err != nil { + t.Fatalf("Load should not error on corrupt file: %v", err) } - s := getSettings() + s := Get() if s.Language != "en" { t.Errorf("Language = %q, want default %q after corrupt file", s.Language, "en") } } func TestSettingsConcurrent(t *testing.T) { - cfgDir = t.TempDir() - current = settingsDefaults + SetDir(t.TempDir()) + Load() var wg sync.WaitGroup for i := 0; i < 50; i++ { wg.Add(2) go func() { defer wg.Done() - updateSettings(func(s *Settings) { s.Language = "es" }) + Update(func(s *Settings) { s.Language = "es" }) }() go func() { defer wg.Done() - _ = getSettings() + _ = Get() }() } wg.Wait() diff --git a/config/hints.go b/config/hints.go new file mode 100644 index 0000000..f73ec91 --- /dev/null +++ b/config/hints.go @@ -0,0 +1,67 @@ +package config + +import ( + "bufio" + "os" + "path/filepath" + "strings" + "time" +) + +const hintsFile = "hints.txt" + +const hintsHeader = `# Vocabulary hints for transcription (one per line) +# These help the model recognize domain-specific terms +# Empty lines and lines starting with # are ignored +` + +func HintsPath() string { + return filepath.Join(Dir(), hintsFile) +} + +var ( + hintsCache string + hintsModTime time.Time + hintsFixed bool +) + +func SetHints(s string) { + hintsCache = s + hintsFixed = true +} + +func GetHints() string { + if hintsFixed { + return hintsCache + } + info, err := os.Stat(HintsPath()) + if err != nil { + if os.IsNotExist(err) { + os.MkdirAll(Dir(), 0755) + os.WriteFile(HintsPath(), []byte(hintsHeader), 0644) + } + return hintsCache + } + if info.ModTime().Equal(hintsModTime) { + return hintsCache + } + + f, err := os.Open(HintsPath()) + if err != nil { + return hintsCache + } + defer f.Close() + + var hints []string + sc := bufio.NewScanner(f) + for sc.Scan() { + line := strings.TrimSpace(sc.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + hints = append(hints, line) + } + hintsCache = strings.Join(hints, ", ") + hintsModTime = info.ModTime() + return hintsCache +} diff --git a/main.go b/main.go index 3b4b133..0c551d5 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "zee/alert" "zee/audio" + "zee/config" "zee/beep" "zee/clipboard" "zee/doctor" @@ -81,6 +82,7 @@ type recordingConfig struct { stream bool format string lang string + hints string autoPaste bool } @@ -171,6 +173,8 @@ func run() { profileFlag := flag.String("profile", "", "Enable pprof profiling server (e.g., :6060 or localhost:6060)") testFlag := flag.Bool("test", false, "Test mode (headless, stdin-driven)") longPressFlag := flag.Duration("longpress", 350*time.Millisecond, "Long-press threshold for PTT vs tap (e.g., 350ms)") + hintsFlag := flag.String("hints", "", "Vocabulary hints for transcription (comma-separated)") + transcribeFlag := flag.String("transcribe", "", "Transcribe an audio file and exit") flag.Parse() // Resolve log directory early @@ -217,10 +221,10 @@ func run() { os.Exit(doctor.Run(wavFile)) } // Load persistent settings, merge with CLI flags - if err := loadSettings(); err != nil { + if err := config.Load(); err != nil { log.Warnf("settings: %v", err) } - cfg := getSettings() + cfg := config.Get() flagSet := map[string]bool{} flag.Visit(func(f *flag.Flag) { flagSet[f.Name] = true }) if !flagSet["lang"] && cfg.Language != "" { @@ -240,6 +244,9 @@ func run() { switch *formatFlag { case "mp3@16", "mp3@64", "flac": activeFormat = *formatFlag + if *hintsFlag != "" { + config.SetHints(*hintsFlag) + } default: fatal("Unknown format %q (use mp3@16, mp3@64, or flac)", *formatFlag) } @@ -309,6 +316,11 @@ func run() { return } + if *transcribeFlag != "" { + runTranscribeFile(*transcribeFlag) + return + } + if autoPaste { if err := clipboard.Init(); err != nil { log.Warnf("paste init failed: %v", err) @@ -374,7 +386,7 @@ func run() { } tray.SetDevices(names, preferredDevice, func(name string) { preferredDevice = name - updateSettings(func(s *Settings) { s.Device = name }) + config.Update(func(s *config.Settings) { s.Device = name }) if name == "" { applyDeviceSwitch(ctx, captureConfig, &captureDevice, &selectedDevice, nil) } else { @@ -430,7 +442,7 @@ func run() { activeFormat = *formatFlag } - updateSettings(func(s *Settings) { s.Provider = provider; s.Model = model }) + config.Update(func(s *config.Settings) { s.Provider = provider; s.Model = model }) tray.SetLanguages(newTr.SupportedLanguages()) }) @@ -438,18 +450,21 @@ func run() { configMu.Lock() activeTranscriber.SetLanguage(code) configMu.Unlock() - updateSettings(func(s *Settings) { s.Language = code }) + config.Update(func(s *config.Settings) { s.Language = code }) }) tray.SetLogin(login.Enabled()) tray.SetVersion(version) tray.OnSaveAudio(saveLastRecording) + tray.OnEditHints(func() { + exec.Command("open", config.HintsPath()).Run() + }) trayQuit := tray.Init() tray.OnAutoPaste(func(on bool) { configMu.Lock() autoPaste = on configMu.Unlock() - updateSettings(func(s *Settings) { s.AutoPaste = on }) + config.Update(func(s *config.Settings) { s.AutoPaste = on }) }) tray.OnLogin(func(on bool) error { var err error @@ -462,7 +477,7 @@ func run() { log.Errorf("login toggle: %v", err) tray.SetError(err.Error()) } else { - updateSettings(func(s *Settings) { s.AutoStart = on }) + config.Update(func(s *config.Settings) { s.AutoStart = on }) } return err }) @@ -647,6 +662,7 @@ func handleRecording(capture audio.CaptureDevice, sess recSession) (<-chan struc stream: streamEnabled, format: activeFormat, lang: activeTranscriber.GetLanguage(), + hints: config.GetHints(), autoPaste: autoPaste, } configMu.Unlock() @@ -655,6 +671,7 @@ func handleRecording(capture audio.CaptureDevice, sess recSession) (<-chan struc Stream: cfg.stream, Format: cfg.format, Language: cfg.lang, + Hints: cfg.hints, }) if err != nil { return nil, err @@ -813,7 +830,7 @@ func saveLastRecording() { } ts := rec.Timestamp.Format("2006-01-02T15-04-05") - dir := filepath.Join(settingsDir(), "samples", ts) + dir := filepath.Join(config.Dir(), "samples", ts) if err := os.MkdirAll(dir, 0755); err != nil { alert.Error("Save failed: " + err.Error()) return @@ -837,6 +854,42 @@ func saveLastRecording() { alert.Info("Saved to " + dir) } +func runTranscribeFile(audioFile string) { + data, err := os.ReadFile(audioFile) + if err != nil { + fatal("Error reading file: %v", err) + } + + ext := filepath.Ext(audioFile) + format := "mp3" + switch ext { + case ".flac": + format = "flac" + case ".wav": + format = "wav" + case ".mp3": + format = "mp3" + default: + fatal("Unsupported audio format: %s", ext) + } + + type directTranscriber interface { + Transcribe(audio []byte, format, lang, hints string) (*transcriber.Result, error) + } + + dt, ok := activeTranscriber.(directTranscriber) + if !ok { + fatal("Provider %q does not support direct file transcription", activeTranscriber.Name()) + } + + result, err := dt.Transcribe(data, format, activeTranscriber.GetLanguage(), config.GetHints()) + if err != nil { + fatal("Transcription error: %v", err) + } + + fmt.Println(result.Text) +} + func runBenchmark(wavFile string, runs int) { fmt.Printf("Benchmark: %s (%d runs)\n", wavFile, runs) diff --git a/transcriber/batch_session.go b/transcriber/batch_session.go index b98cba1..325c5a8 100644 --- a/transcriber/batch_session.go +++ b/transcriber/batch_session.go @@ -9,7 +9,7 @@ import ( "zee/encoder" ) -type transcribeFunc func(audio []byte, format, lang string) (*Result, error) +type transcribeFunc func(audio []byte, format, lang, hints string) (*Result, error) type batchSession struct { cfg SessionConfig @@ -94,7 +94,7 @@ func (bs *batchSession) Close() (SessionResult, error) { audioData := bs.encoder.Bytes() apiFormat := apiFormatFromConfig(bs.cfg.Format) - result, err := bs.transcribe(audioData, apiFormat, bs.cfg.Language) + result, err := bs.transcribe(audioData, apiFormat, bs.cfg.Language, bs.cfg.Hints) if err != nil { return SessionResult{}, err } diff --git a/transcriber/deepgram.go b/transcriber/deepgram.go index c8b76ef..9a5a90c 100644 --- a/transcriber/deepgram.go +++ b/transcriber/deepgram.go @@ -6,6 +6,8 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" + "strings" "zee/encoder" ) @@ -45,18 +47,19 @@ func (d *Deepgram) Models() []ModelInfo { return DeepgramModels } func (d *Deepgram) NewSession(ctx context.Context, cfg SessionConfig) (Session, error) { go d.client.Warm() if cfg.Stream { - return d.newStreamSession(ctx, cfg.Language) + return d.newStreamSession(ctx, cfg.Language, cfg.Hints) } - return newBatchSession(cfg, d.transcribe) + return newBatchSession(cfg, d.Transcribe) } -func (d *Deepgram) newStreamSession(ctx context.Context, lang string) (Session, error) { +func (d *Deepgram) newStreamSession(ctx context.Context, lang, hints string) (Session, error) { dial := func() (rawStreamSession, error) { return d.startStream(ctx, streamSessionConfig{ SampleRate: encoder.SampleRate, Channels: encoder.Channels, Language: lang, Model: "nova-3", + Hints: hints, }) } return newStreamSession(dial), nil @@ -77,13 +80,27 @@ type deepgramResponse struct { } `json:"results"` } -func (d *Deepgram) transcribe(audioData []byte, format, lang string) (*Result, error) { +func (d *Deepgram) Transcribe(audioData []byte, format, lang, hints string) (*Result, error) { contentType := "audio/flac" if format == "mp3" { contentType = "audio/mpeg" } - req, err := http.NewRequest("POST", d.apiURL, bytes.NewReader(audioData)) + apiURL := d.apiURL + if hints != "" { + u, err := url.Parse(apiURL) + if err != nil { + return nil, err + } + q := u.Query() + for _, term := range strings.Split(hints, ",") { + q.Add("keyterm", strings.TrimSpace(term)) + } + u.RawQuery = q.Encode() + apiURL = u.String() + } + + req, err := http.NewRequest("POST", apiURL, bytes.NewReader(audioData)) if err != nil { return nil, err } diff --git a/transcriber/deepgram_stream.go b/transcriber/deepgram_stream.go index 5fd0da8..89baa40 100644 --- a/transcriber/deepgram_stream.go +++ b/transcriber/deepgram_stream.go @@ -16,6 +16,7 @@ type streamSessionConfig struct { Channels int Language string Model string + Hints string } type deepgramStreamResponse struct { @@ -58,6 +59,11 @@ func (d *Deepgram) startStream(ctx context.Context, cfg streamSessionConfig) (ra if cfg.Language != "" { q.Set("language", cfg.Language) } + if cfg.Hints != "" { + for _, term := range strings.Split(cfg.Hints, ",") { + q.Add("keyterm", strings.TrimSpace(term)) + } + } endpoint.RawQuery = q.Encode() headers := http.Header{} diff --git a/transcriber/elevenlabs.go b/transcriber/elevenlabs.go index b838ba4..19bda6e 100644 --- a/transcriber/elevenlabs.go +++ b/transcriber/elevenlabs.go @@ -7,6 +7,7 @@ import ( "fmt" "mime/multipart" "net/http" + "strings" ) const ModelScribeV2 = "scribe_v2" @@ -55,7 +56,7 @@ func (e *ElevenLabs) NewSession(_ context.Context, cfg SessionConfig) (Session, if cfg.Stream { return nil, fmt.Errorf("elevenlabs does not support streaming transcription") } - return newBatchSession(cfg, e.transcribe) + return newBatchSession(cfg, e.Transcribe) } type elevenLabsResponse struct { @@ -71,7 +72,7 @@ type elevenLabsResponse struct { } `json:"words"` } -func (e *ElevenLabs) transcribe(audioData []byte, format, lang string) (*Result, error) { +func (e *ElevenLabs) Transcribe(audioData []byte, format, lang, hints string) (*Result, error) { var body bytes.Buffer writer := multipart.NewWriter(&body) @@ -88,6 +89,11 @@ func (e *ElevenLabs) transcribe(audioData []byte, format, lang string) (*Result, writer.WriteField("language_code", lang) } writer.WriteField("tag_audio_events", "false") + if hints != "" { + for _, word := range strings.Split(hints, ",") { + writer.WriteField("keyterms[]", strings.TrimSpace(word)) + } + } writer.Close() req, err := http.NewRequest("POST", e.apiURL, &body) diff --git a/transcriber/groq.go b/transcriber/groq.go index 25bec4c..17f0c12 100644 --- a/transcriber/groq.go +++ b/transcriber/groq.go @@ -54,7 +54,7 @@ func (g *Groq) NewSession(_ context.Context, cfg SessionConfig) (Session, error) if cfg.Stream { return nil, fmt.Errorf("groq does not support streaming transcription") } - return newBatchSession(cfg, g.transcribe) + return newBatchSession(cfg, g.Transcribe) } type groqResponse struct { @@ -71,7 +71,7 @@ type groqResponse struct { } `json:"segments"` } -func (g *Groq) transcribe(audioData []byte, format, lang string) (*Result, error) { +func (g *Groq) Transcribe(audioData []byte, format, lang, hints string) (*Result, error) { var body bytes.Buffer writer := multipart.NewWriter(&body) @@ -88,6 +88,9 @@ func (g *Groq) transcribe(audioData []byte, format, lang string) (*Result, error if lang != "" { writer.WriteField("language", lang) } + if hints != "" { + writer.WriteField("prompt", hints) + } writer.Close() req, err := http.NewRequest("POST", g.apiURL, &body) diff --git a/transcriber/mistral.go b/transcriber/mistral.go index d54ad08..e7372db 100644 --- a/transcriber/mistral.go +++ b/transcriber/mistral.go @@ -8,6 +8,7 @@ import ( "mime/multipart" "net/http" "strconv" + "strings" ) var voxtralLangs = langsFromCodes([]string{ @@ -45,10 +46,10 @@ func (m *Mistral) NewSession(_ context.Context, cfg SessionConfig) (Session, err if cfg.Stream { return nil, fmt.Errorf("mistral does not support streaming transcription") } - return newBatchSession(cfg, m.transcribe) + return newBatchSession(cfg, m.Transcribe) } -func (m *Mistral) transcribe(audioData []byte, format, lang string) (*Result, error) { +func (m *Mistral) Transcribe(audioData []byte, format, lang, hints string) (*Result, error) { var body bytes.Buffer writer := multipart.NewWriter(&body) @@ -64,6 +65,11 @@ func (m *Mistral) transcribe(audioData []byte, format, lang string) (*Result, er if lang != "" { writer.WriteField("language", lang) } + if hints != "" { + for _, word := range strings.Split(hints, ",") { + writer.WriteField("context_bias[]", strings.TrimSpace(word)) + } + } writer.Close() req, err := http.NewRequest("POST", m.apiURL, &body) diff --git a/transcriber/openai.go b/transcriber/openai.go index 639559c..27fc6ca 100644 --- a/transcriber/openai.go +++ b/transcriber/openai.go @@ -49,10 +49,10 @@ func (o *OpenAI) NewSession(_ context.Context, cfg SessionConfig) (Session, erro if cfg.Stream { return nil, fmt.Errorf("openai does not support streaming transcription") } - return newBatchSession(cfg, o.transcribe) + return newBatchSession(cfg, o.Transcribe) } -func (o *OpenAI) transcribe(audioData []byte, format, lang string) (*Result, error) { +func (o *OpenAI) Transcribe(audioData []byte, format, lang, hints string) (*Result, error) { var body bytes.Buffer writer := multipart.NewWriter(&body) @@ -69,6 +69,9 @@ func (o *OpenAI) transcribe(audioData []byte, format, lang string) (*Result, err if lang != "" { writer.WriteField("language", lang) } + if hints != "" { + writer.WriteField("prompt", hints) + } if err := writer.Close(); err != nil { return nil, err } diff --git a/transcriber/session.go b/transcriber/session.go index c1a1eca..4edac56 100644 --- a/transcriber/session.go +++ b/transcriber/session.go @@ -13,6 +13,7 @@ type SessionConfig struct { Stream bool Format string // "mp3@16"|"mp3@64"|"flac" (batch only; ignored for streaming) Language string + Hints string // optional vocabulary hints for the model } type BatchStats struct { diff --git a/transcriber/stream_session.go b/transcriber/stream_session.go index 141693e..9acb4a1 100644 --- a/transcriber/stream_session.go +++ b/transcriber/stream_session.go @@ -37,6 +37,7 @@ type streamSession struct { updates chan string startedAt time.Time connected chan struct{} // closed when WebSocket is ready (or failed) + // hints string // TODO: Deepgram streaming supports keywords param sendDone chan struct{} recvDone chan struct{} diff --git a/transcriber/transcriber_test.go b/transcriber/transcriber_test.go index ec5c7e8..1c96325 100644 --- a/transcriber/transcriber_test.go +++ b/transcriber/transcriber_test.go @@ -72,7 +72,7 @@ func TestNewEncoder(t *testing.T) { } func TestBatchSessionFeedAndClose(t *testing.T) { - fakeFn := func(audio []byte, format, lang string) (*Result, error) { + fakeFn := func(audio []byte, format, lang, hints string) (*Result, error) { return &Result{ Text: "hello world", Metrics: &NetworkMetrics{TTFB: 10 * time.Millisecond}, diff --git a/tray/tray.go b/tray/tray.go index 7030fae..cacf79f 100644 --- a/tray/tray.go +++ b/tray/tray.go @@ -50,6 +50,7 @@ var ( appVersion string checkUpdateCb func() saveAudioCb func() + editHintsCb func() ) var languages []transcriber.Language // set via SetLanguages @@ -118,6 +119,7 @@ func SetLastRecording(dur time.Duration, totalMs float64) { func SetVersion(v string) { appVersion = v } func OnCheckUpdate(fn func()) { checkUpdateCb = fn } func OnSaveAudio(fn func()) { saveAudioCb = fn } +func OnEditHints(fn func()) { editHintsCb = fn } func SetLanguage(code string, onSwitch func(string)) { langCode = code diff --git a/tray/tray_darwin.go b/tray/tray_darwin.go index 7030692..4740a9a 100644 --- a/tray/tray_darwin.go +++ b/tray/tray_darwin.go @@ -222,6 +222,13 @@ func onReady() { } }) + mEditHints := mSettings.AddSubMenuItem("Edit Hints…", "Edit vocabulary hints file") + mEditHints.Click(func() { + if editHintsCb != nil { + go editHintsCb() + } + }) + sep := mSettings.AddSubMenuItem("─────────", "") sep.Disable()