diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..4da3d99 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,20 @@ +Fetch ONNX Runtime Web dist assets + +This folder contains a small helper script to download the ESM/UMD build and WASM +assets for `onnxruntime-web` into the project `./ort/` folder so the app can +load the WASM locally and avoid cross-origin or CDN transform issues. + +Usage + +1. Run the script (requires `npm` and `tar` available on the PATH): + +```bash +./scripts/fetch-onnx-dist.sh 1.18.0 +``` + +2. Serve the repository (or your static files) so `/ort/` is accessible from the +app root. The code already defaults to `DEFAULT_ORT_WASM_PATH = '/ort/'` in +`src/stt/config.js`. + +3. Optionally set `window.ORT_WASM_PATH = '/ort/'` and `window.SILERO_VAD_MODEL = '/models/silero_v5_16k.onnx'` in +`index.html` to be explicit during development. diff --git a/scripts/fetch-onnx-dist.sh b/scripts/fetch-onnx-dist.sh new file mode 100644 index 0000000..b8d39a3 --- /dev/null +++ b/scripts/fetch-onnx-dist.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Fetch onnxruntime-web dist/ assets and place them in project ./ort/ folder +# Usage: ./scripts/fetch-onnx-dist.sh [version] +# Example: ./scripts/fetch-onnx-dist.sh 1.18.0 + +VERSION=${1:-1.18.0} +ROOT_DIR=$(cd "$(dirname "$0")/.." && pwd) +TMPDIR=$(mktemp -d) + +echo "Fetching onnxruntime-web@$VERSION into $ROOT_DIR/ort" +cd "$TMPDIR" + +# Use npm pack to download the package tarball +echo "Downloading npm package..." +npm pack "onnxruntime-web@$VERSION" >/dev/null 2>&1 +TARBALL=$(ls onnxruntime-web-*.tgz | head -n1) +if [ -z "$TARBALL" ]; then + echo "Failed to download onnxruntime-web@$VERSION" + exit 1 +fi + +# Extract tarball +mkdir -p package +tar -xzf "$TARBALL" + +# Ensure dist exists +if [ ! -d package/dist ]; then + echo "package/dist not found inside tarball. Listing package/ contents:" + ls -la package + exit 1 +fi + +# Copy dist files into repo ./ort/ +DEST_DIR="$ROOT_DIR/ort" +# Remove old files and recreate directory +rm -rf "$DEST_DIR" +mkdir -p "$DEST_DIR" +cp -r package/dist/* "$DEST_DIR/" + +echo "Copied $(ls -1 "$DEST_DIR" | wc -l) files to $DEST_DIR" + +# Cleanup +rm -rf "$TMPDIR" + +echo "Done. To serve the ONNX WASM assets locally, ensure your static server serves the ./ort/ folder at '/ort/'." +echo "You can also set window.ORT_WASM_PATH = '/ort/' before loading the app to be explicit." \ No newline at end of file diff --git a/src/openai.js b/src/openai.js index 74f3a47..58f0870 100644 --- a/src/openai.js +++ b/src/openai.js @@ -152,7 +152,7 @@ export async function summarizeText({ lowQuality = '', highQuality = '' }) { return response.choices?.[0]?.message?.content || ''; } -export async function transcribeFile({ file, language }) { +export async function transcribeFile({ file, language, prompt }) { if (!file) throw new Error('File is required for transcription'); const model = resolveTranscriptionModel(); @@ -161,6 +161,7 @@ export async function transcribeFile({ file, language }) { file, model, language, + prompt, }); return response.text; } catch (error) { diff --git a/src/openai.test.js b/src/openai.test.js index 94bf3b1..e7ab9c5 100644 --- a/src/openai.test.js +++ b/src/openai.test.js @@ -49,11 +49,14 @@ describe('openai helpers', () => { const result = await transcribeFile({ file, language: 'en' }); expect(result).toBe('hello world'); - expect(mockClient.audio.transcriptions.create).toHaveBeenCalledWith({ - file, - model: 'gpt-4o-transcribe', - language: 'en', - }); + expect(mockClient.audio.transcriptions.create).toHaveBeenCalledWith( + expect.objectContaining({ + file, + model: 'gpt-4o-transcribe', + language: 'en', + prompt: undefined, + }) + ); }); test('transcribeFile honors TRANSCRIBE_MODEL override', async () => { @@ -65,11 +68,14 @@ describe('openai helpers', () => { const file = { name: 'audio.wav' }; await transcribeFile({ file, language: 'de' }); - expect(mockClient.audio.transcriptions.create).toHaveBeenCalledWith({ - file, - model: 'custom-model', - language: 'de', - }); + expect(mockClient.audio.transcriptions.create).toHaveBeenCalledWith( + expect.objectContaining({ + file, + model: 'custom-model', + language: 'de', + prompt: undefined, + }) + ); }); test('transcribeFile logs model and status information on failure', async () => { diff --git a/src/stt/audio.js b/src/stt/audio.js new file mode 100644 index 0000000..86dc311 --- /dev/null +++ b/src/stt/audio.js @@ -0,0 +1,161 @@ +import { STT_CONFIG } from './config.js'; + +/* istanbul ignore next -- depends on browser-specific audio globals */ +function hasWebAudio() { + return ( + typeof window !== 'undefined' && + (window.AudioContext || window.webkitAudioContext) && + window.OfflineAudioContext + ); +} + +/* istanbul ignore next -- depends on browser-specific audio globals */ +function getAudioContext() { + const Ctor = window.AudioContext || window.webkitAudioContext; + return new Ctor(); +} + +/* istanbul ignore next -- exercised via decodeToMono16k in browser */ +function mixToMono(audioBuffer) { + const { numberOfChannels } = audioBuffer; + if (numberOfChannels === 1) { + return audioBuffer.getChannelData(0).slice(); + } + + const length = audioBuffer.length; + const output = new Float32Array(length); + + for (let channel = 0; channel < numberOfChannels; channel += 1) { + const data = audioBuffer.getChannelData(channel); + for (let i = 0; i < length; i += 1) { + output[i] += data[i] / numberOfChannels; + } + } + + return output; +} + +/* istanbul ignore next -- exercised via decodeToMono16k in browser */ +async function resampleMonoBuffer(mono, sourceRate, targetRate) { + if (sourceRate === targetRate) { + return mono; + } + + const length = Math.ceil((mono.length * targetRate) / sourceRate); + const offline = new OfflineAudioContext(1, length, targetRate); + const buffer = offline.createBuffer(1, mono.length, sourceRate); + buffer.copyToChannel(mono, 0); + + const source = offline.createBufferSource(); + source.buffer = buffer; + source.connect(offline.destination); + source.start(0); + + const rendered = await offline.startRendering(); + return rendered.getChannelData(0).slice(); +} + +/* istanbul ignore next -- browser-only decode path */ +export async function decodeToMono16k(file) { + if (!hasWebAudio()) { + throw new Error('Web Audio API not available'); + } + + const arrayBuffer = await file.arrayBuffer(); + const ctx = getAudioContext(); + + try { + const decoded = await ctx.decodeAudioData(arrayBuffer); + const mono = mixToMono(decoded); + const resampled = await resampleMonoBuffer( + mono, + decoded.sampleRate, + STT_CONFIG.sampleRate + ); + + return { + pcm: resampled, + sampleRate: STT_CONFIG.sampleRate, + durationMs: Math.round((resampled.length / STT_CONFIG.sampleRate) * 1000), + }; + } finally { + ctx.close?.(); + } +} + +export function clampMs(value, min, max) { + return Math.min(Math.max(value, min), max); +} + +export function estimateChunkBytes(durationMs) { + const seconds = durationMs / 1000; + const bytesPerSecond = STT_CONFIG.sampleRate * STT_CONFIG.wavBytesPerSample; + return Math.ceil(seconds * bytesPerSecond) + 44; // WAV header overhead +} + +function floatTo16BitPCM(float32Array) { + const buffer = new ArrayBuffer(float32Array.length * 2); + const view = new DataView(buffer); + + for (let i = 0; i < float32Array.length; i += 1) { + let sample = Math.max(-1, Math.min(1, float32Array[i])); + sample = sample < 0 ? sample * 0x8000 : sample * 0x7fff; + view.setInt16(i * 2, sample, true); + } + + return new Uint8Array(buffer); +} + +function writeWavHeader(dataLength) { + const buffer = new ArrayBuffer(44); + const view = new DataView(buffer); + const byteRate = STT_CONFIG.sampleRate * STT_CONFIG.wavBytesPerSample; + const blockAlign = STT_CONFIG.wavBytesPerSample; + + view.setUint32(0, 0x52494646, false); // 'RIFF' + view.setUint32(4, 36 + dataLength, true); + view.setUint32(8, 0x57415645, false); // 'WAVE' + view.setUint32(12, 0x666d7420, false); // 'fmt ' + view.setUint32(16, 16, true); // Subchunk1Size + view.setUint16(20, 1, true); // PCM + view.setUint16(22, 1, true); // Mono + view.setUint32(24, STT_CONFIG.sampleRate, true); + view.setUint32(28, byteRate, true); + view.setUint16(32, blockAlign, true); + view.setUint16(34, 16, true); // bits per sample + view.setUint32(36, 0x64617461, false); // 'data' + view.setUint32(40, dataLength, true); + + return new Uint8Array(buffer); +} + +export function encodeWavChunk(pcm, startMs, endMs) { + const totalSamples = pcm.length; + const sampleRate = STT_CONFIG.sampleRate; + const startIndex = Math.max(0, Math.floor((startMs / 1000) * sampleRate)); + const endIndex = Math.min( + totalSamples, + Math.ceil((endMs / 1000) * sampleRate) + ); + + if (endIndex <= startIndex) { + return null; + } + + const slice = pcm.slice(startIndex, endIndex); + const pcm16 = floatTo16BitPCM(slice); + const header = writeWavHeader(pcm16.length); + const blob = new Blob([header, pcm16], { type: 'audio/wav' }); + return { + blob, + durationMs: Math.round(((endIndex - startIndex) / sampleRate) * 1000), + }; +} + +export function samplesToMs(samples) { + return Math.round((samples / STT_CONFIG.sampleRate) * 1000); +} + +export function msToSamples(ms) { + return Math.round((ms / 1000) * STT_CONFIG.sampleRate); +} diff --git a/src/stt/audio.test.js b/src/stt/audio.test.js new file mode 100644 index 0000000..77f1b3b --- /dev/null +++ b/src/stt/audio.test.js @@ -0,0 +1,38 @@ +import { + clampMs, + encodeWavChunk, + estimateChunkBytes, + msToSamples, + samplesToMs, +} from './audio.js'; + +describe('audio helpers', () => { + test('estimateChunkBytes uses pcm byte rate', () => { + const estimate = estimateChunkBytes(1000); + expect(estimate).toBeGreaterThan(32000); + }); + + test('encodeWavChunk converts float32 pcm to wav blob', () => { + const pcm = new Float32Array(16_000); + for (let i = 0; i < pcm.length; i += 1) { + pcm[i] = Math.sin((i / pcm.length) * Math.PI * 2); + } + const result = encodeWavChunk(pcm, 0, 1000); + expect(result).not.toBeNull(); + expect(result.durationMs).toBeGreaterThanOrEqual(900); + expect(result.blob.type).toBe('audio/wav'); + expect(result.blob.size).toBeGreaterThan(0); + }); + + test('clampMs enforces bounds', () => { + expect(clampMs(50, 100, 200)).toBe(100); + expect(clampMs(250, 100, 200)).toBe(200); + expect(clampMs(150, 100, 200)).toBe(150); + }); + + test('sample conversions are consistent', () => { + const samples = msToSamples(1000); + const ms = samplesToMs(samples); + expect(ms).toBeCloseTo(1000, 0); + }); +}); diff --git a/src/stt/chunking.js b/src/stt/chunking.js new file mode 100644 index 0000000..9d78557 --- /dev/null +++ b/src/stt/chunking.js @@ -0,0 +1,155 @@ +import { STT_CONFIG } from './config.js'; +import { estimateChunkBytes } from './audio.js'; + +function normalizeSegment(segment, durationMs) { + const startMs = Math.max( + 0, + Math.min(durationMs, Math.floor(segment.startMs)) + ); + const endMs = Math.max( + startMs, + Math.min(durationMs, Math.ceil(segment.endMs)) + ); + return { startMs, endMs }; +} + +export function normalizeSegments(segments, durationMs) { + const normalized = segments + .map((segment) => normalizeSegment(segment, durationMs)) + .filter((segment) => segment.endMs > segment.startMs); + + if (!normalized.length) return []; + + normalized.sort((a, b) => a.startMs - b.startMs); + const merged = [normalized[0]]; + + for (let i = 1; i < normalized.length; i += 1) { + const current = normalized[i]; + const prev = merged[merged.length - 1]; + + if (current.startMs <= prev.endMs) { + prev.endMs = Math.max(prev.endMs, current.endMs); + } else { + merged.push(current); + } + } + + return merged; +} + +function shouldFinalizeChunk(chunk, segmentEndMs, estimateBytes) { + const duration = segmentEndMs - chunk.startMs; + if (duration <= 0) return false; + if (duration >= STT_CONFIG.maxChunkMs) return true; + if (estimateBytes(duration) >= STT_CONFIG.maxChunkBytes) return true; + return false; +} + +export function packSegmentsIntoChunks(segments, durationMs) { + const normalized = normalizeSegments(segments, durationMs); + const chunks = []; + const estimateBytes = (duration) => estimateChunkBytes(duration); + + if (!normalized.length) { + const safeMax = Math.min( + STT_CONFIG.maxChunkMs, + Math.floor( + (STT_CONFIG.maxChunkBytes / + (STT_CONFIG.sampleRate * STT_CONFIG.wavBytesPerSample)) * + 1000 + ) + ); + const chunkDuration = Math.max(60_000, safeMax); + + for (let start = 0; start < durationMs; start += chunkDuration) { + const end = Math.min(durationMs, start + chunkDuration); + chunks.push({ startMs: start, endMs: end }); + } + return chunks; + } + + let current = { + startMs: normalized[0].startMs, + endMs: normalized[0].endMs, + }; + + for (let i = 1; i < normalized.length; i += 1) { + const segment = normalized[i]; + const prospectiveEnd = Math.max(current.endMs, segment.endMs); + const finalize = shouldFinalizeChunk( + current, + prospectiveEnd, + estimateBytes + ); + + if (finalize) { + chunks.push({ ...current }); + current = { + startMs: Math.max(segment.startMs, current.endMs), + endMs: segment.endMs, + }; + } else { + current.endMs = prospectiveEnd; + } + } + + chunks.push({ ...current }); + return chunks; +} + +export function applyChunkOverlaps(chunks, durationMs) { + if (!chunks.length) return []; + + return chunks.map((chunk, index) => { + const startOverlap = index === 0 ? 0 : STT_CONFIG.chunkOverlapMs; + const endOverlap = + index === chunks.length - 1 ? 0 : STT_CONFIG.chunkOverlapMs; + + return { + ...chunk, + renderStartMs: Math.max(0, chunk.startMs - startOverlap), + renderEndMs: Math.min(durationMs, chunk.endMs + endOverlap), + index, + }; + }); +} + +export function buildFallbackChunks(durationMs) { + const estimateBytes = (duration) => estimateChunkBytes(duration); + const safeDuration = Math.min( + STT_CONFIG.maxChunkMs, + Math.floor( + (STT_CONFIG.maxChunkBytes / + (STT_CONFIG.sampleRate * STT_CONFIG.wavBytesPerSample)) * + 1000 + ) + ); + const chunkDuration = Math.max(5 * 60_000, safeDuration); + const chunks = []; + + for (let start = 0; start < durationMs; start += chunkDuration) { + const end = Math.min(durationMs, start + chunkDuration); + const duration = end - start; + if (duration <= 0) continue; + if (estimateBytes(duration) > STT_CONFIG.maxChunkBytes) { + const maxDuration = Math.floor( + (STT_CONFIG.maxChunkBytes / + (STT_CONFIG.sampleRate * STT_CONFIG.wavBytesPerSample)) * + 1000 + ); + const midpoint = start + Math.floor(maxDuration / 2); + chunks.push({ startMs: start, endMs: midpoint }); + chunks.push({ startMs: midpoint, endMs: end }); + } else { + chunks.push({ startMs: start, endMs: end }); + } + } + + return applyChunkOverlaps(chunks, durationMs); +} + +export function planChunks({ segments, durationMs }) { + const packed = packSegmentsIntoChunks(segments, durationMs); + const withOverlap = applyChunkOverlaps(packed, durationMs); + return withOverlap.length ? withOverlap : buildFallbackChunks(durationMs); +} diff --git a/src/stt/chunking.test.js b/src/stt/chunking.test.js new file mode 100644 index 0000000..2606150 --- /dev/null +++ b/src/stt/chunking.test.js @@ -0,0 +1,93 @@ +import { planChunks, buildFallbackChunks } from './chunking.js'; +import { STT_CONFIG } from './config.js'; +import { mergeChunkResults, buildPromptFromTail } from './merge.js'; + +describe('chunk planning utilities', () => { + test('planChunks splits segments exceeding max duration', () => { + const segments = [ + { startMs: 0, endMs: 600_000 }, + { startMs: 610_000, endMs: 1_300_000 }, + ]; + const chunks = planChunks({ segments, durationMs: 1_400_000 }); + expect(chunks).toHaveLength(2); + expect(chunks[0].renderStartMs).toBe(0); + expect(chunks[1].renderStartMs).toBeGreaterThan(chunks[0].renderStartMs); + expect(chunks[0].renderEndMs - chunks[0].renderStartMs).toBeLessThanOrEqual( + 1_200_000 + 500 + ); + }); + + test('planChunks falls back when no speech detected', () => { + const chunks = planChunks({ segments: [], durationMs: 900_000 }); + expect(chunks.length).toBeGreaterThan(0); + chunks.forEach((chunk, index) => { + expect(chunk.index).toBe(index); + expect(chunk.renderEndMs).toBeGreaterThan(chunk.renderStartMs); + }); + }); + + test('buildFallbackChunks total rendered duration is reasonable', () => { + const durationMs = 10 * 60 * 1000; // 10 minutes + const chunks = buildFallbackChunks(durationMs); + const totalRendered = chunks.reduce( + (sum, chunk) => sum + (chunk.renderEndMs - chunk.renderStartMs), + 0 + ); + // For fallback chunks with overlap, total should be close to input duration + // Allow up to 10% extra for overlaps + expect(totalRendered).toBeLessThanOrEqual(durationMs * 1.1); + expect(totalRendered).toBeGreaterThanOrEqual(durationMs); + }); + + test('planChunks respects size thresholds during packing', () => { + const originalBytes = STT_CONFIG.maxChunkBytes; + STT_CONFIG.maxChunkBytes = 32_000; + const segments = [ + { startMs: 0, endMs: 40_000 }, + { startMs: 45_000, endMs: 80_000 }, + ]; + const chunks = planChunks({ segments, durationMs: 90_000 }); + expect(chunks.length).toBeGreaterThan(1); + STT_CONFIG.maxChunkBytes = originalBytes; + }); + + test('planChunks keeps single chunk when under limits', () => { + const chunks = planChunks({ + segments: [ + { startMs: 0, endMs: 10_000 }, + { startMs: 12_000, endMs: 18_000 }, + ], + durationMs: 20_000, + }); + expect(chunks).toHaveLength(1); + }); + + test('planChunks merges overlapping segments', () => { + const chunks = planChunks({ + segments: [ + { startMs: 0, endMs: 10_000 }, + { startMs: 9_000, endMs: 15_000 }, + ], + durationMs: 20_000, + }); + expect(chunks).toHaveLength(1); + expect(chunks[0].renderEndMs).toBeGreaterThan(chunks[0].renderStartMs); + }); +}); + +describe('merge helpers', () => { + test('buildPromptFromTail trims tail characters', () => { + const prompt = buildPromptFromTail(' Example transcript text '); + expect(prompt.endsWith('text')).toBe(true); + }); + + test('mergeChunkResults removes duplicate sentences', () => { + const merged = mergeChunkResults([ + { index: 0, text: 'Hello world. This is chunk one.' }, + { index: 1, text: 'This is chunk one. And here is more.' }, + ]); + expect(merged).toContain('Hello world.'); + expect(merged).toContain('And here is more.'); + expect(merged).not.toContain('This is chunk one.\nThis is chunk one.'); + }); +}); diff --git a/src/stt/config.js b/src/stt/config.js new file mode 100644 index 0000000..e5eadf1 --- /dev/null +++ b/src/stt/config.js @@ -0,0 +1,25 @@ +export const STT_CONFIG = { + sampleRate: 16000, + windowSamples: 512, + threshold: 0.5, + minSpeechMs: 250, + minSilenceMs: 100, + speechPadMs: 200, + maxSpeechMs: 15 * 60 * 1000, + chunkOverlapMs: 500, + maxChunkMs: 1200 * 1000, + maxChunkBytes: 24 * 1024 * 1024, + uploadConcurrency: 3, + wavBytesPerSample: 2, + promptTailChars: 200, +}; + +export const DEFAULT_SILERO_MODEL_URL = + typeof window !== 'undefined' && window.SILERO_VAD_MODEL + ? window.SILERO_VAD_MODEL + : 'https://github.com/snakers4/silero-models/raw/master/models/silero_vad/en/silero_vad.onnx'; + +export const DEFAULT_ORT_WASM_PATH = + typeof window !== 'undefined' && window.ORT_WASM_PATH + ? window.ORT_WASM_PATH + : '/ort/'; diff --git a/src/stt/merge.js b/src/stt/merge.js new file mode 100644 index 0000000..1b5084e --- /dev/null +++ b/src/stt/merge.js @@ -0,0 +1,78 @@ +import { STT_CONFIG } from './config.js'; + +function tokenize(text) { + return text + .toLowerCase() + .replace(/[^\p{L}\p{N}\s]+/gu, ' ') + .split(/\s+/) + .filter(Boolean); +} + +function cosineSimilarity(aTokens, bTokens) { + if (!aTokens.length || !bTokens.length) return 0; + const freqA = new Map(); + const freqB = new Map(); + + for (const token of aTokens) { + freqA.set(token, (freqA.get(token) || 0) + 1); + } + for (const token of bTokens) { + freqB.set(token, (freqB.get(token) || 0) + 1); + } + + let dot = 0; + for (const [token, countA] of freqA.entries()) { + const countB = freqB.get(token) || 0; + dot += countA * countB; + } + + const norm = (freq) => + Math.sqrt([...freq.values()].reduce((sum, c) => sum + c * c, 0)); + const denom = norm(freqA) * norm(freqB); + return denom === 0 ? 0 : dot / denom; +} + +function removeDuplicateSentence(previousText, currentText) { + if (!previousText || !currentText) return currentText; + + const sentences = currentText.split(/(?<=[.!?])\s+/); + if (sentences.length === 0) { + return currentText; + } + + const firstSentence = sentences[0]; + const prevTail = previousText.slice( + -Math.max(firstSentence.length + 20, 200) + ); + const normalizedTail = prevTail.toLowerCase(); + const normalizedSentence = firstSentence.toLowerCase(); + const similarity = cosineSimilarity( + tokenize(prevTail), + tokenize(firstSentence) + ); + + if (normalizedTail.includes(normalizedSentence) || similarity >= 0.75) { + return currentText.slice(firstSentence.length).trimStart(); + } + + return currentText; +} + +export function mergeChunkResults(chunks) { + const ordered = [...chunks].sort((a, b) => a.index - b.index); + let merged = ''; + + for (const chunk of ordered) { + const cleanText = removeDuplicateSentence(merged, chunk.text || ''); + merged = merged ? `${merged}\n${cleanText}` : cleanText; + } + + return merged.trim(); +} + +export function buildPromptFromTail(text) { + if (!text) return ''; + const trimmed = text.trim(); + if (!trimmed) return ''; + return trimmed.slice(-STT_CONFIG.promptTailChars); +} diff --git a/src/stt/transcriber.js b/src/stt/transcriber.js new file mode 100644 index 0000000..7bb3426 --- /dev/null +++ b/src/stt/transcriber.js @@ -0,0 +1,138 @@ +import { decodeToMono16k, encodeWavChunk } from './audio.js'; +import { detectSpeechSegments } from './vad.js'; +import { planChunks, buildFallbackChunks } from './chunking.js'; +import { buildPromptFromTail, mergeChunkResults } from './merge.js'; +import { STT_CONFIG } from './config.js'; +import { transcribeFile } from '../openai.js'; + +function createLimiter(concurrency) { + let active = 0; + const queue = []; + + const next = () => { + if (active >= concurrency) return; + const task = queue.shift(); + if (!task) return; + active += 1; + Promise.resolve() + .then(task.fn) + .then(task.resolve, task.reject) + .finally(() => { + active -= 1; + next(); + }); + }; + + return (fn) => + new Promise((resolve, reject) => { + queue.push({ fn, resolve, reject }); + next(); + }); +} + +function buildChunkFileName(baseName, index) { + const padded = String(index + 1).padStart(3, '0'); + return `${baseName}-chunk-${padded}.wav`; +} + +function createChunkFiles(pcm, chunks, baseName) { + const files = []; + + for (const chunk of chunks) { + const encoded = encodeWavChunk(pcm, chunk.renderStartMs, chunk.renderEndMs); + if (!encoded) continue; + const fileName = buildChunkFileName(baseName, chunk.index); + const file = new File([encoded.blob], fileName, { type: 'audio/wav' }); + files.push({ ...chunk, file, durationMs: encoded.durationMs }); + } + + return files; +} + +function makeFileFromSlice({ file, start, end, index }) { + const slice = file.slice(start, end); + const padded = String(index + 1).padStart(3, '0'); + const originalName = file.name || 'audio'; + const suffix = + file.type && !file.type.includes('wav') && !originalName.endsWith('.wav') + ? '.bin' + : ''; + const name = `${originalName}-fallback-${padded}${suffix}`; + return new File([slice], name, { + type: file.type || 'application/octet-stream', + }); +} + +async function fallbackByteChunking({ file, language }) { + const maxBytes = STT_CONFIG.maxChunkBytes; + const chunks = Math.ceil(file.size / maxBytes); + const results = []; + let accumulated = ''; + + for (let index = 0; index < chunks; index += 1) { + const start = index * maxBytes; + const end = Math.min(file.size, start + maxBytes); + const chunkFile = makeFileFromSlice({ file, start, end, index }); + const prompt = buildPromptFromTail(accumulated); + const text = await transcribeFile({ file: chunkFile, language, prompt }); + results.push({ index, text }); + accumulated = accumulated ? `${accumulated}\n${text}` : text; + } + + return mergeChunkResults(results); +} + +export async function chunkedTranscription({ file, language }) { + const baseName = (file?.name || 'audio').replace(/\.[^/.]+$/, ''); + let pcmInfo; + let chunks = []; + + try { + pcmInfo = await decodeToMono16k(file); + } catch (error) { + console.warn( + 'Falling back to byte-based chunking due to decode failure', + error + ); + return fallbackByteChunking({ file, language }); + } + + const { pcm, durationMs } = pcmInfo; + + try { + const vadSegments = await detectSpeechSegments(pcm); + chunks = planChunks({ segments: vadSegments, durationMs }); + } catch (error) { + console.warn('VAD segmentation failed, using fallback chunking', error); + chunks = buildFallbackChunks(durationMs); + } + + const chunkFiles = createChunkFiles(pcm, chunks, baseName); + if (!chunkFiles.length) { + throw new Error('Failed to prepare audio chunks for transcription'); + } + + const limit = createLimiter(STT_CONFIG.uploadConcurrency); + const results = []; + let accumulatedText = ''; + + for (const chunk of chunkFiles) { + const prompt = buildPromptFromTail(accumulatedText); + const task = limit(async () => { + const text = await transcribeFile({ + file: chunk.file, + language, + prompt, + }); + return { index: chunk.index, text }; + }); + + const result = await task; + results.push(result); + accumulatedText = accumulatedText + ? `${accumulatedText}\n${result.text}` + : result.text; + } + + return mergeChunkResults(results); +} diff --git a/src/stt/transcriber.test.js b/src/stt/transcriber.test.js new file mode 100644 index 0000000..2060828 --- /dev/null +++ b/src/stt/transcriber.test.js @@ -0,0 +1,165 @@ +import { chunkedTranscription } from './transcriber.js'; +import { STT_CONFIG } from './config.js'; +import { decodeToMono16k, encodeWavChunk } from './audio.js'; +import { detectSpeechSegments } from './vad.js'; +import { planChunks, buildFallbackChunks } from './chunking.js'; +import { transcribeFile } from '../openai.js'; + +jest.mock('./audio.js', () => ({ + decodeToMono16k: jest.fn(), + encodeWavChunk: jest.fn(), +})); + +jest.mock('./vad.js', () => ({ + detectSpeechSegments: jest.fn(), +})); + +jest.mock('./chunking.js', () => ({ + planChunks: jest.fn(), + buildFallbackChunks: jest.fn(), +})); + +jest.mock('../openai.js', () => ({ + transcribeFile: jest.fn(), +})); + +describe('chunkedTranscription', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + test('splits audio into chunks and preserves prompts', async () => { + const pcm = new Float32Array(32_000); + decodeToMono16k.mockResolvedValue({ pcm, durationMs: 2000 }); + detectSpeechSegments.mockResolvedValue([{ startMs: 0, endMs: 1500 }]); + planChunks.mockReturnValue([ + { index: 0, renderStartMs: 0, renderEndMs: 1200 }, + { index: 1, renderStartMs: 1000, renderEndMs: 2000 }, + ]); + encodeWavChunk.mockImplementation((buffer, startMs, endMs) => ({ + blob: new Blob([`${startMs}-${endMs}`]), + durationMs: endMs - startMs, + })); + transcribeFile + .mockResolvedValueOnce('First chunk content.') + .mockResolvedValueOnce('Continuation second chunk.'); + + const file = new File([new Uint8Array(10)], 'example.wav', { + type: 'audio/wav', + }); + + const result = await chunkedTranscription({ file, language: 'en' }); + + expect(planChunks).toHaveBeenCalled(); + expect(encodeWavChunk).toHaveBeenCalledTimes(2); + expect(transcribeFile).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ prompt: '' }) + ); + expect(transcribeFile).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ prompt: 'First chunk content.' }) + ); + expect(result).toContain('First chunk content.'); + expect(result).toContain('Continuation second chunk.'); + }); + + test('falls back to chunking when VAD fails', async () => { + const pcm = new Float32Array(16_000); + decodeToMono16k.mockResolvedValue({ pcm, durationMs: 1000 }); + detectSpeechSegments.mockRejectedValue(new Error('vad failure')); + buildFallbackChunks.mockReturnValue([ + { index: 0, renderStartMs: 0, renderEndMs: 1000 }, + ]); + encodeWavChunk.mockReturnValue({ + blob: new Blob(['fallback']), + durationMs: 1000, + }); + transcribeFile.mockResolvedValue('Recovered text'); + + const file = new File([new Uint8Array(10)], 'fallback.wav', { + type: 'audio/wav', + }); + + const result = await chunkedTranscription({ file, language: 'en' }); + + expect(buildFallbackChunks).toHaveBeenCalled(); + expect(transcribeFile).toHaveBeenCalledWith( + expect.objectContaining({ prompt: '' }) + ); + expect(result).toBe('Recovered text'); + }); + + test('falls back to byte chunking when decode fails', async () => { + decodeToMono16k.mockRejectedValue(new Error('decode error')); + const originalSize = STT_CONFIG.maxChunkBytes / 2; + const file = new File([new Uint8Array(originalSize)], 'large.bin', { + type: 'application/octet-stream', + }); + transcribeFile.mockResolvedValue('Single chunk text'); + + const result = await chunkedTranscription({ file, language: 'en' }); + + expect(transcribeFile).toHaveBeenCalledTimes(1); + expect(result).toBe('Single chunk text'); + }); + + test('byte chunking splits very large files', async () => { + decodeToMono16k.mockRejectedValue(new Error('decode error')); + const size = STT_CONFIG.maxChunkBytes * 1.5; + const file = new File([new Uint8Array(size)], 'massive.bin', { + type: 'application/octet-stream', + }); + transcribeFile + .mockResolvedValueOnce('Part A') + .mockResolvedValueOnce('Part B'); + + const result = await chunkedTranscription({ file, language: 'en' }); + + expect(transcribeFile).toHaveBeenCalledTimes(2); + expect(result).toContain('Part A'); + expect(result).toContain('Part B'); + }); + + test('skips chunks that fail to encode', async () => { + const pcm = new Float32Array(32_000); + decodeToMono16k.mockResolvedValue({ pcm, durationMs: 2000 }); + detectSpeechSegments.mockResolvedValue([{ startMs: 0, endMs: 1500 }]); + planChunks.mockReturnValue([ + { index: 0, renderStartMs: 0, renderEndMs: 1200 }, + { index: 1, renderStartMs: 1000, renderEndMs: 2000 }, + ]); + encodeWavChunk.mockReturnValueOnce(null).mockReturnValueOnce({ + blob: new Blob(['valid']), + durationMs: 800, + }); + transcribeFile.mockResolvedValue('Only valid chunk'); + + const file = new File([new Uint8Array(10)], 'example.wav', { + type: 'audio/wav', + }); + + const result = await chunkedTranscription({ file, language: 'en' }); + + expect(transcribeFile).toHaveBeenCalledTimes(1); + expect(result).toBe('Only valid chunk'); + }); + + test('throws when no chunks can be encoded', async () => { + const pcm = new Float32Array(16_000); + decodeToMono16k.mockResolvedValue({ pcm, durationMs: 1000 }); + detectSpeechSegments.mockResolvedValue([{ startMs: 0, endMs: 800 }]); + planChunks.mockReturnValue([ + { index: 0, renderStartMs: 0, renderEndMs: 900 }, + ]); + encodeWavChunk.mockReturnValue(null); + + const file = new File([new Uint8Array(10)], 'broken.wav', { + type: 'audio/wav', + }); + + await expect( + chunkedTranscription({ file, language: 'en' }) + ).rejects.toThrow('Failed to prepare audio chunks for transcription'); + }); +}); diff --git a/src/stt/vad.js b/src/stt/vad.js new file mode 100644 index 0000000..4b98553 --- /dev/null +++ b/src/stt/vad.js @@ -0,0 +1,282 @@ +import { + DEFAULT_ORT_WASM_PATH, + DEFAULT_SILERO_MODEL_URL, + STT_CONFIG, +} from './config.js'; +import { samplesToMs } from './audio.js'; + +let ortPromise = null; +let sessionPromise = null; + +/* istanbul ignore next -- runtime depends on onnxruntime-web in browser */ +function ensureOrt() { + if (!ortPromise) { + // Try multiple import sources. Some CDNs or esm transforms wrap the + // real export under `default` or produce incomplete modules. Try + // esm.sh first (fast), then fall back to known CDN ESM builds. + ortPromise = (async () => { + const candidates = [ + 'https://esm.sh/onnxruntime-web@1.18.0', + 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.18.0/dist/ort.esm.js', + 'https://unpkg.com/onnxruntime-web@1.18.0/dist/ort.esm.js', + ]; + + let lastError = null; + for (const url of candidates) { + try { + const module = await import(url); + const ort = module?.default || module; + if (!ort) throw new Error('empty module'); + + // basic sanity: must expose InferenceSession.create + if ( + !ort.InferenceSession || + typeof ort.InferenceSession.create !== 'function' + ) { + throw new Error( + 'incomplete ort module (missing InferenceSession.create)' + ); + } + + // configure WASM loader (single-threaded to avoid COOP/COEP) + if (ort?.env?.wasm) { + ort.env.wasm.numThreads = 1; + ort.env.wasm.wasmPaths = DEFAULT_ORT_WASM_PATH; + } + + // diagnostic: report which URL produced a usable ort + try { + console.info('Loaded ONNX Runtime Web from', url); + } catch { + /* ignore */ + } + return ort; + } catch (err) { + // try next candidate + lastError = err; + } + } + + console.warn( + 'Failed to load ONNX Runtime Web from CDN candidates', + lastError + ); + throw new Error('ONNX Runtime Web not available'); + })(); + } + return ortPromise; +} + +/* istanbul ignore next -- runtime depends on onnxruntime-web in browser */ +async function ensureSession() { + if (!sessionPromise) { + try { + const ort = await ensureOrt(); + if ( + !ort.InferenceSession || + typeof ort.InferenceSession.create !== 'function' + ) { + throw new Error( + 'InferenceSession.create not available in ONNX Runtime Web module' + ); + } + sessionPromise = ort.InferenceSession.create(DEFAULT_SILERO_MODEL_URL); + } catch (error) { + console.warn( + 'Failed to create ONNX InferenceSession, VAD will not be available', + error + ); + throw new Error('ONNX InferenceSession not available'); + } + } + return sessionPromise; +} + +function createStateTensor(ort) { + return new ort.Tensor('float32', new Float32Array(128), [1, 128]); +} + +function createHiddenTensor(ort) { + return new ort.Tensor('float32', new Float32Array(2 * 1 * 64), [2, 1, 64]); +} + +function createInputTensor(ort, chunk, windowSamples) { + const buffer = new Float32Array(windowSamples); + buffer.set(chunk); + return new ort.Tensor('float32', buffer, [1, windowSamples]); +} + +function createSrTensor(ort) { + const rate = BigInt(STT_CONFIG.sampleRate); + const srArray = new BigInt64Array([rate]); + return new ort.Tensor('int64', srArray, [1]); +} + +function appendSegment({ segments, startMs, endMs, totalDurationMs }) { + const start = Math.max(0, startMs - STT_CONFIG.speechPadMs); + const end = Math.min(endMs + STT_CONFIG.speechPadMs, totalDurationMs); + if (end - start >= STT_CONFIG.minSpeechMs) { + segments.push({ startMs: start, endMs: end }); + } +} + +function mergeSegments(segments) { + if (!segments.length) return segments; + segments.sort((a, b) => a.startMs - b.startMs); + const merged = [segments[0]]; + + for (let i = 1; i < segments.length; i += 1) { + const prev = merged[merged.length - 1]; + const current = segments[i]; + if (current.startMs <= prev.endMs + STT_CONFIG.minSilenceMs) { + prev.endMs = Math.max(prev.endMs, current.endMs); + } else { + merged.push(current); + } + } + + return merged; +} + +function postProcessProbabilities(probabilities, totalSamples) { + const windowMs = samplesToMs(STT_CONFIG.windowSamples); + const totalDurationMs = samplesToMs(totalSamples); + const segments = []; + + let speechStart = null; + let lastSpeechMs = 0; + let silenceMs = 0; + + const finalizeSpeech = () => { + if (speechStart === null) return; + appendSegment({ + segments, + startMs: speechStart, + endMs: lastSpeechMs, + totalDurationMs, + }); + speechStart = null; + silenceMs = 0; + }; + + for (let i = 0; i < probabilities.length; i += 1) { + const prob = probabilities[i]; + const frameStart = i * windowMs; + const frameEnd = frameStart + windowMs; + + if (prob >= STT_CONFIG.threshold) { + speechStart = speechStart ?? frameStart; + lastSpeechMs = frameEnd; + silenceMs = 0; + + if (lastSpeechMs - speechStart >= STT_CONFIG.maxSpeechMs) { + finalizeSpeech(); + } + continue; + } + + if (speechStart === null) continue; + + silenceMs += windowMs; + if (silenceMs >= STT_CONFIG.minSilenceMs) { + finalizeSpeech(); + } + } + + finalizeSpeech(); + return mergeSegments(segments); +} + +const PROBABILITY_KEYS = ['output', 'prob', 'probs', 'output.1', 'speech_prob']; + +function readProbability(value) { + if (value == null) return null; + if (typeof value === 'number') return value; + + const arrayLike = Array.isArray(value) ? value : value?.data; + if (arrayLike && typeof arrayLike[0] === 'number') { + return arrayLike[0]; + } + + return null; +} + +function extractSpeechProbability(results) { + for (const key of PROBABILITY_KEYS) { + const probability = readProbability(results[key]); + if (probability !== null && typeof probability === 'number') { + return probability; + } + } + return 0; +} + +/* istanbul ignore next -- requires onnx runtime in browser */ +export async function detectSpeechSegments(pcm) { + if (!pcm || pcm.length === 0) return []; + if (typeof window === 'undefined') { + throw new Error('VAD requires browser environment'); + } + + const ort = await ensureOrt(); + const session = await ensureSession(); + + if (!session || typeof session.run !== 'function') { + throw new Error( + `ONNX InferenceSession.run not available (type=${typeof (session && session.run)})` + ); + } + + const probabilities = []; + const windowSamples = STT_CONFIG.windowSamples; + let hTensor = createHiddenTensor(ort); + let cTensor = createHiddenTensor(ort); + let stateTensor = createStateTensor(ort); + const srTensor = createSrTensor(ort); + + for (let offset = 0; offset < pcm.length; offset += windowSamples) { + const chunk = pcm.subarray(offset, offset + windowSamples); + const inputTensor = createInputTensor(ort, chunk, windowSamples); + + const feeds = { + input: inputTensor, + h: hTensor, + c: cTensor, + sr: srTensor, + state: stateTensor, + }; + + let results; + try { + results = await session.run(feeds); + } catch (error) { + console.warn( + 'Silero VAD inference failed, falling back to naive chunking', + error + ); + throw error; + } + + const probability = extractSpeechProbability(results); + probabilities.push(typeof probability === 'number' ? probability : 0); + + hTensor = results.h || hTensor; + cTensor = results.c || cTensor; + stateTensor = results.state || stateTensor; + } + + return postProcessProbabilities(probabilities, pcm.length); +} + +export function __resetVadForTesting() { + ortPromise = null; + sessionPromise = null; +} + +export const __internal = { + appendSegment, + mergeSegments, + postProcessProbabilities, + readProbability, + extractSpeechProbability, +}; diff --git a/src/stt/vad.test.js b/src/stt/vad.test.js new file mode 100644 index 0000000..514e388 --- /dev/null +++ b/src/stt/vad.test.js @@ -0,0 +1,55 @@ +import { __internal } from './vad.js'; +import { STT_CONFIG } from './config.js'; + +const { postProcessProbabilities, readProbability, extractSpeechProbability } = + __internal; + +describe('vad helpers', () => { + test('readProbability handles different inputs', () => { + expect(readProbability(0.5)).toBe(0.5); + expect(readProbability([0.7])).toBe(0.7); + expect(readProbability({ data: [0.2] })).toBe(0.2); + expect(readProbability({ data: new Float32Array([0.3]) })).toBeCloseTo(0.3); + expect(readProbability(null)).toBeNull(); + }); + + test('extractSpeechProbability picks first available key', () => { + const results = { + output: null, + prob: null, + probs: null, + 'output.1': { data: [0.6] }, + speech_prob: { data: [0.1] }, + }; + expect(extractSpeechProbability(results)).toBe(0.6); + }); + + test('postProcessProbabilities merges short gaps', () => { + const probabilities = new Array(20).fill(0); + for (let i = 1; i <= 3; i += 1) probabilities[i] = 0.9; + for (let i = 5; i <= 7; i += 1) probabilities[i] = 0.85; + const totalSamples = STT_CONFIG.sampleRate * 2; // 2 seconds of audio + const segments = postProcessProbabilities(probabilities, totalSamples); + expect(segments).toHaveLength(1); + const [segment] = segments; + expect(segment.startMs).toBeGreaterThanOrEqual(0); + expect(segment.endMs).toBeGreaterThan(segment.startMs); + }); + + test('postProcessProbabilities respects max speech length', () => { + const originalMaxSpeech = STT_CONFIG.maxSpeechMs; + const originalMinSilence = STT_CONFIG.minSilenceMs; + STT_CONFIG.maxSpeechMs = 100; + STT_CONFIG.minSilenceMs = 50; + const probabilities = [ + ...new Array(10).fill(0.95), + ...new Array(4).fill(0), + ...new Array(10).fill(0.95), + ]; + const totalSamples = STT_CONFIG.sampleRate * 3; + const segments = postProcessProbabilities(probabilities, totalSamples); + expect(segments.length).toBeGreaterThan(0); + STT_CONFIG.maxSpeechMs = originalMaxSpeech; + STT_CONFIG.minSilenceMs = originalMinSilence; + }); +}); diff --git a/src/transcription.js b/src/transcription.js index 93a6c43..a1ea1e3 100644 --- a/src/transcription.js +++ b/src/transcription.js @@ -1,4 +1,4 @@ -import { transcribeFile } from './openai.js'; +import { chunkedTranscription } from './stt/transcriber.js'; const LANGUAGE_STORAGE_KEY = 'transcription_language'; @@ -132,5 +132,5 @@ export function createSpeechRecognitionController({ } export async function transcribeAudioFile({ file, language }) { - return transcribeFile({ file, language }); + return chunkedTranscription({ file, language }); } diff --git a/src/transcription.test.js b/src/transcription.test.js index fca38bb..3cbe967 100644 --- a/src/transcription.test.js +++ b/src/transcription.test.js @@ -4,10 +4,10 @@ import { createSpeechRecognitionController, transcribeAudioFile, } from './transcription.js'; -import { transcribeFile } from './openai.js'; +import { chunkedTranscription } from './stt/transcriber.js'; -jest.mock('./openai.js', () => ({ - transcribeFile: jest.fn(), +jest.mock('./stt/transcriber.js', () => ({ + chunkedTranscription: jest.fn(), })); describe('transcription utilities', () => { @@ -223,11 +223,14 @@ describe('transcription utilities', () => { controller.stop(); }); - test('transcribeAudioFile proxies to openai module', async () => { + test('transcribeAudioFile delegates to chunked transcriber', async () => { const file = new File(['data'], 'audio.mp3', { type: 'audio/mpeg' }); - transcribeFile.mockResolvedValue('transcribed'); + chunkedTranscription.mockResolvedValue('transcribed'); const result = await transcribeAudioFile({ file, language: 'en' }); - expect(transcribeFile).toHaveBeenCalledWith({ file, language: 'en' }); + expect(chunkedTranscription).toHaveBeenCalledWith({ + file, + language: 'en', + }); expect(result).toBe('transcribed'); }); });