Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions src/whispercpp-transcribe.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import { test, expect } from 'bun:test'
import { mkdtemp, rm } from 'node:fs/promises'
import { tmpdir } from 'node:os'
import path from 'node:path'
import { downloadWhisperModelFile } from './whispercpp-transcribe'

async function createTempDir() {
const dir = await mkdtemp(path.join(tmpdir(), 'whisper-model-'))
return {
path: dir,
async [Symbol.asyncDispose]() {
await rm(dir, { recursive: true, force: true })
},
}
}

test('downloadWhisperModelFile retries rate-limited model downloads', async () => {
await using tempDir = await createTempDir()
const modelPath = path.join(tempDir.path, 'model.bin')
const delays: number[] = []
let calls = 0
const fetchModel = async () => {
calls++
if (calls === 1) {
return new Response('rate limited', {
status: 429,
statusText: 'Too Many Requests',
headers: { 'retry-after': '1' },
})
}
return new Response('model bytes')
}

await downloadWhisperModelFile(modelPath, {
attemptCount: 2,
fetch: fetchModel,
sleep: async (delayMs) => {
delays.push(delayMs)
},
})

expect(calls).toBe(2)
expect(delays).toEqual([1000])
expect(await Bun.file(modelPath).text()).toBe('model bytes')
})

test('downloadWhisperModelFile retries interrupted model downloads', async () => {
await using tempDir = await createTempDir()
const modelPath = path.join(tempDir.path, 'model.bin')
const delays: number[] = []
let calls = 0
const fetchModel = async () => {
calls++
if (calls === 1) {
const response = new Response('partial model bytes')
response.arrayBuffer = async () => {
throw new TypeError('body interrupted')
}
return response
}
return new Response('model bytes')
}

await downloadWhisperModelFile(modelPath, {
attemptCount: 2,
fetch: fetchModel,
sleep: async (delayMs) => {
delays.push(delayMs)
},
})

expect(calls).toBe(2)
expect(delays).toEqual([5000])
expect(await Bun.file(modelPath).text()).toBe('model bytes')
})

test('downloadWhisperModelFile does not retry permanent download failures', async () => {
await using tempDir = await createTempDir()
const modelPath = path.join(tempDir.path, 'model.bin')
let calls = 0
const fetchModel = async () => {
calls++
return new Response('missing', {
status: 404,
statusText: 'Not Found',
})
}

await expect(
downloadWhisperModelFile(modelPath, {
attemptCount: 3,
fetch: fetchModel,
sleep: async () => {},
}),
).rejects.toThrow('404 Not Found')

expect(calls).toBe(1)
expect(await Bun.file(modelPath).exists()).toBe(false)
})
121 changes: 113 additions & 8 deletions src/whispercpp-transcribe.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ const DEFAULT_MODEL_URL =
'https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin'
const DEFAULT_LANGUAGE = 'en'
const DEFAULT_BINARY = 'whisper-cli'
const MODEL_DOWNLOAD_ATTEMPTS = 8
const MODEL_DOWNLOAD_TIMEOUT_MS = 120_000
const MODEL_DOWNLOAD_INITIAL_RETRY_DELAY_MS = 5_000
const MODEL_DOWNLOAD_MAX_RETRY_DELAY_MS = 60_000

type TranscribeOptions = {
modelPath?: string
Expand All @@ -18,6 +22,24 @@ type TranscribeOptions = {
progress?: StepProgressReporter
}

type ModelDownloadOptions = {
progress?: StepProgressReporter
fetch?: (url: string, init?: RequestInit) => Promise<Response>
sleep?: (delayMs: number) => Promise<void>
attemptCount?: number
}

class ModelDownloadHttpError extends Error {
constructor(
readonly status: number,
readonly statusText: string,
readonly retryAfter: string | null,
) {
super(`Failed to download whisper.cpp model (${status} ${statusText}).`)
this.name = 'ModelDownloadHttpError'
}
}

export type TranscriptSegment = {
start: number
end: number
Expand Down Expand Up @@ -101,19 +123,102 @@ async function ensureModelFile(
throw new Error(`Whisper model not found at ${modelPath}.`)
}

progress?.setLabel('Downloading model')
await downloadWhisperModelFile(modelPath, { progress })
}

export async function downloadWhisperModelFile(
modelPath: string,
options: ModelDownloadOptions = {},
) {
options.progress?.setLabel('Downloading model')
await mkdir(path.dirname(modelPath), { recursive: true })
const response = await fetch(DEFAULT_MODEL_URL)
if (!response.ok) {
throw new Error(
`Failed to download whisper.cpp model (${response.status} ${response.statusText}).`,
)
}

const bytes = await response.arrayBuffer()
const bytes = await fetchModelBytes(options)
await Bun.write(modelPath, bytes)
}

async function fetchModelBytes(options: ModelDownloadOptions) {
const fetcher = options.fetch ?? fetch
const sleep = options.sleep ?? Bun.sleep
const attemptCount = Math.max(
1,
Math.floor(options.attemptCount ?? MODEL_DOWNLOAD_ATTEMPTS),
)
let lastError: unknown

for (let attempt = 1; attempt <= attemptCount; attempt++) {
try {
const response = await fetcher(DEFAULT_MODEL_URL, {
signal: AbortSignal.timeout(MODEL_DOWNLOAD_TIMEOUT_MS),
})
if (!response.ok) {
throw new ModelDownloadHttpError(
response.status,
response.statusText,
response.headers.get('retry-after'),
)
}
return await response.arrayBuffer()
} catch (error) {
lastError = error
if (!shouldRetryModelDownload(error) || attempt === attemptCount) {
break
}
await sleep(getModelDownloadRetryDelayMs(error, attempt))
}
}

if (lastError instanceof Error) {
throw lastError
}
throw new Error('Failed to download whisper.cpp model.')
}

function shouldRetryModelDownload(error: unknown) {
if (!(error instanceof ModelDownloadHttpError)) {
return true
}
return (
error.status === 408 ||
error.status === 425 ||
error.status === 429 ||
error.status >= 500
)
}

function getModelDownloadRetryDelayMs(error: unknown, attempt: number) {
if (error instanceof ModelDownloadHttpError) {
const retryAfterDelayMs = parseRetryAfterDelayMs(error.retryAfter)
if (retryAfterDelayMs !== null) {
return retryAfterDelayMs
}
}
return Math.min(
MODEL_DOWNLOAD_INITIAL_RETRY_DELAY_MS * 2 ** (attempt - 1),
MODEL_DOWNLOAD_MAX_RETRY_DELAY_MS,
)
}

function parseRetryAfterDelayMs(retryAfter: string | null) {
if (!retryAfter) {
return null
}

const seconds = Number(retryAfter)
if (Number.isFinite(seconds) && seconds >= 0) {
return Math.min(seconds * 1000, MODEL_DOWNLOAD_MAX_RETRY_DELAY_MS)
}

const retryAt = Date.parse(retryAfter)
if (Number.isNaN(retryAt)) {
return null
}
return Math.min(
Math.max(0, retryAt - Date.now()),
MODEL_DOWNLOAD_MAX_RETRY_DELAY_MS,
)
}

async function readTranscriptText(transcriptPath: string, fallback: string) {
const transcriptFile = Bun.file(transcriptPath)
if (await transcriptFile.exists()) {
Expand Down
Loading