diff --git a/src/whispercpp-transcribe.test.ts b/src/whispercpp-transcribe.test.ts new file mode 100644 index 0000000..31a3132 --- /dev/null +++ b/src/whispercpp-transcribe.test.ts @@ -0,0 +1,99 @@ +import { test, expect } from 'bun:test' +import { mkdtemp, rm } from 'node:fs/promises' +import { tmpdir } from 'node:os' +import path from 'node:path' +import { downloadWhisperModelFile } from './whispercpp-transcribe' + +async function createTempDir() { + const dir = await mkdtemp(path.join(tmpdir(), 'whisper-model-')) + return { + path: dir, + async [Symbol.asyncDispose]() { + await rm(dir, { recursive: true, force: true }) + }, + } +} + +test('downloadWhisperModelFile retries rate-limited model downloads', async () => { + await using tempDir = await createTempDir() + const modelPath = path.join(tempDir.path, 'model.bin') + const delays: number[] = [] + let calls = 0 + const fetchModel = async () => { + calls++ + if (calls === 1) { + return new Response('rate limited', { + status: 429, + statusText: 'Too Many Requests', + headers: { 'retry-after': '1' }, + }) + } + return new Response('model bytes') + } + + await downloadWhisperModelFile(modelPath, { + attemptCount: 2, + fetch: fetchModel, + sleep: async (delayMs) => { + delays.push(delayMs) + }, + }) + + expect(calls).toBe(2) + expect(delays).toEqual([1000]) + expect(await Bun.file(modelPath).text()).toBe('model bytes') +}) + +test('downloadWhisperModelFile retries interrupted model downloads', async () => { + await using tempDir = await createTempDir() + const modelPath = path.join(tempDir.path, 'model.bin') + const delays: number[] = [] + let calls = 0 + const fetchModel = async () => { + calls++ + if (calls === 1) { + const response = new Response('partial model bytes') + response.arrayBuffer = async () => { + throw new TypeError('body interrupted') + } + return response + } + return new Response('model bytes') + } + + await downloadWhisperModelFile(modelPath, { + attemptCount: 2, + fetch: fetchModel, + sleep: async (delayMs) => { + delays.push(delayMs) + }, + }) + + expect(calls).toBe(2) + expect(delays).toEqual([5000]) + expect(await Bun.file(modelPath).text()).toBe('model bytes') +}) + +test('downloadWhisperModelFile does not retry permanent download failures', async () => { + await using tempDir = await createTempDir() + const modelPath = path.join(tempDir.path, 'model.bin') + let calls = 0 + const fetchModel = async () => { + calls++ + return new Response('missing', { + status: 404, + statusText: 'Not Found', + }) + } + + await expect( + downloadWhisperModelFile(modelPath, { + attemptCount: 3, + fetch: fetchModel, + sleep: async () => {}, + }), + ).rejects.toThrow('404 Not Found') + + expect(calls).toBe(1) + expect(await Bun.file(modelPath).exists()).toBe(false) +}) diff --git a/src/whispercpp-transcribe.ts b/src/whispercpp-transcribe.ts index 70155d7..bdf9ddb 100644 --- a/src/whispercpp-transcribe.ts +++ b/src/whispercpp-transcribe.ts @@ -8,6 +8,10 @@ const DEFAULT_MODEL_URL = 'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin' const DEFAULT_LANGUAGE = 'en' const DEFAULT_BINARY = 'whisper-cli' +const MODEL_DOWNLOAD_ATTEMPTS = 8 +const MODEL_DOWNLOAD_TIMEOUT_MS = 120_000 +const MODEL_DOWNLOAD_INITIAL_RETRY_DELAY_MS = 5_000 +const MODEL_DOWNLOAD_MAX_RETRY_DELAY_MS = 60_000 type TranscribeOptions = { modelPath?: string @@ -18,6 +22,24 @@ type TranscribeOptions = { progress?: StepProgressReporter } +type ModelDownloadOptions = { + progress?: StepProgressReporter + fetch?: (url: string, init?: RequestInit) => Promise + sleep?: (delayMs: number) => Promise + attemptCount?: number +} + +class ModelDownloadHttpError extends Error { + constructor( + readonly status: number, + readonly statusText: string, + readonly retryAfter: string | null, + ) { + super(`Failed to download whisper.cpp model (${status} ${statusText}).`) + this.name = 'ModelDownloadHttpError' + } +} + export type TranscriptSegment = { start: number end: number @@ -101,19 +123,102 @@ async function ensureModelFile( throw new Error(`Whisper model not found at ${modelPath}.`) } - progress?.setLabel('Downloading model') + await downloadWhisperModelFile(modelPath, { progress }) +} + +export async function downloadWhisperModelFile( + modelPath: string, + options: ModelDownloadOptions = {}, +) { + options.progress?.setLabel('Downloading model') await mkdir(path.dirname(modelPath), { recursive: true }) - const response = await fetch(DEFAULT_MODEL_URL) - if (!response.ok) { - throw new Error( - `Failed to download whisper.cpp model (${response.status} ${response.statusText}).`, - ) - } - const bytes = await response.arrayBuffer() + const bytes = await fetchModelBytes(options) await Bun.write(modelPath, bytes) } +async function fetchModelBytes(options: ModelDownloadOptions) { + const fetcher = options.fetch ?? fetch + const sleep = options.sleep ?? Bun.sleep + const attemptCount = Math.max( + 1, + Math.floor(options.attemptCount ?? MODEL_DOWNLOAD_ATTEMPTS), + ) + let lastError: unknown + + for (let attempt = 1; attempt <= attemptCount; attempt++) { + try { + const response = await fetcher(DEFAULT_MODEL_URL, { + signal: AbortSignal.timeout(MODEL_DOWNLOAD_TIMEOUT_MS), + }) + if (!response.ok) { + throw new ModelDownloadHttpError( + response.status, + response.statusText, + response.headers.get('retry-after'), + ) + } + return await response.arrayBuffer() + } catch (error) { + lastError = error + if (!shouldRetryModelDownload(error) || attempt === attemptCount) { + break + } + await sleep(getModelDownloadRetryDelayMs(error, attempt)) + } + } + + if (lastError instanceof Error) { + throw lastError + } + throw new Error('Failed to download whisper.cpp model.') +} + +function shouldRetryModelDownload(error: unknown) { + if (!(error instanceof ModelDownloadHttpError)) { + return true + } + return ( + error.status === 408 || + error.status === 425 || + error.status === 429 || + error.status >= 500 + ) +} + +function getModelDownloadRetryDelayMs(error: unknown, attempt: number) { + if (error instanceof ModelDownloadHttpError) { + const retryAfterDelayMs = parseRetryAfterDelayMs(error.retryAfter) + if (retryAfterDelayMs !== null) { + return retryAfterDelayMs + } + } + return Math.min( + MODEL_DOWNLOAD_INITIAL_RETRY_DELAY_MS * 2 ** (attempt - 1), + MODEL_DOWNLOAD_MAX_RETRY_DELAY_MS, + ) +} + +function parseRetryAfterDelayMs(retryAfter: string | null) { + if (!retryAfter) { + return null + } + + const seconds = Number(retryAfter) + if (Number.isFinite(seconds) && seconds >= 0) { + return Math.min(seconds * 1000, MODEL_DOWNLOAD_MAX_RETRY_DELAY_MS) + } + + const retryAt = Date.parse(retryAfter) + if (Number.isNaN(retryAt)) { + return null + } + return Math.min( + Math.max(0, retryAt - Date.now()), + MODEL_DOWNLOAD_MAX_RETRY_DELAY_MS, + ) +} + async function readTranscriptText(transcriptPath: string, fallback: string) { const transcriptFile = Bun.file(transcriptPath) if (await transcriptFile.exists()) {