Skip to content
Merged

V7 #10

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Fetch ONNX Runtime Web dist assets

This folder contains a small helper script to download the ESM/UMD build and WASM
assets for `onnxruntime-web` into the project `./ort/` folder so the app can
load the WASM locally and avoid cross-origin or CDN transform issues.

Usage

1. Run the script (requires `npm` and `tar` available on the PATH):

```bash
./scripts/fetch-onnx-dist.sh 1.18.0
```

2. Serve the repository (or your static files) so `/ort/` is accessible from the
app root. The code already defaults to `DEFAULT_ORT_WASM_PATH = '/ort/'` in
`src/stt/config.js`.

3. Optionally set `window.ORT_WASM_PATH = '/ort/'` and `window.SILERO_VAD_MODEL = '/models/silero_v5_16k.onnx'` in
`index.html` to be explicit during development.
48 changes: 48 additions & 0 deletions scripts/fetch-onnx-dist.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env bash
set -euo pipefail

# Fetch onnxruntime-web dist/ assets and place them in project ./ort/ folder
# Usage: ./scripts/fetch-onnx-dist.sh [version]
# Example: ./scripts/fetch-onnx-dist.sh 1.18.0

VERSION=${1:-1.18.0}
ROOT_DIR=$(cd "$(dirname "$0")/.." && pwd)
TMPDIR=$(mktemp -d)

echo "Fetching onnxruntime-web@$VERSION into $ROOT_DIR/ort"
cd "$TMPDIR"

# Use npm pack to download the package tarball
echo "Downloading npm package..."
npm pack "onnxruntime-web@$VERSION" >/dev/null 2>&1
TARBALL=$(ls onnxruntime-web-*.tgz | head -n1)
if [ -z "$TARBALL" ]; then
echo "Failed to download onnxruntime-web@$VERSION"
exit 1
fi

# Extract tarball
mkdir -p package
tar -xzf "$TARBALL"

# Ensure dist exists
if [ ! -d package/dist ]; then
echo "package/dist not found inside tarball. Listing package/ contents:"
ls -la package
exit 1
fi

# Copy dist files into repo ./ort/
DEST_DIR="$ROOT_DIR/ort"
# Remove old files and recreate directory
rm -rf "$DEST_DIR"
mkdir -p "$DEST_DIR"
cp -r package/dist/* "$DEST_DIR/"

echo "Copied $(ls -1 "$DEST_DIR" | wc -l) files to $DEST_DIR"

# Cleanup
rm -rf "$TMPDIR"

echo "Done. To serve the ONNX WASM assets locally, ensure your static server serves the ./ort/ folder at '/ort/'."
echo "You can also set window.ORT_WASM_PATH = '/ort/' before loading the app to be explicit."
3 changes: 2 additions & 1 deletion src/openai.js
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ export async function summarizeText({ lowQuality = '', highQuality = '' }) {
return response.choices?.[0]?.message?.content || '';
}

export async function transcribeFile({ file, language }) {
export async function transcribeFile({ file, language, prompt }) {
if (!file) throw new Error('File is required for transcription');
const model = resolveTranscriptionModel();

Expand All @@ -161,6 +161,7 @@ export async function transcribeFile({ file, language }) {
file,
model,
language,
prompt,
});
return response.text;
} catch (error) {
Expand Down
26 changes: 16 additions & 10 deletions src/openai.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ describe('openai helpers', () => {
const result = await transcribeFile({ file, language: 'en' });

expect(result).toBe('hello world');
expect(mockClient.audio.transcriptions.create).toHaveBeenCalledWith({
file,
model: 'gpt-4o-transcribe',
language: 'en',
});
expect(mockClient.audio.transcriptions.create).toHaveBeenCalledWith(
expect.objectContaining({
file,
model: 'gpt-4o-transcribe',
language: 'en',
prompt: undefined,
})
);
});

test('transcribeFile honors TRANSCRIBE_MODEL override', async () => {
Expand All @@ -65,11 +68,14 @@ describe('openai helpers', () => {
const file = { name: 'audio.wav' };
await transcribeFile({ file, language: 'de' });

expect(mockClient.audio.transcriptions.create).toHaveBeenCalledWith({
file,
model: 'custom-model',
language: 'de',
});
expect(mockClient.audio.transcriptions.create).toHaveBeenCalledWith(
expect.objectContaining({
file,
model: 'custom-model',
language: 'de',
prompt: undefined,
})
);
});

test('transcribeFile logs model and status information on failure', async () => {
Expand Down
161 changes: 161 additions & 0 deletions src/stt/audio.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import { STT_CONFIG } from './config.js';

/* istanbul ignore next -- depends on browser-specific audio globals */
function hasWebAudio() {
return (
typeof window !== 'undefined' &&
(window.AudioContext || window.webkitAudioContext) &&
window.OfflineAudioContext
);
}

/* istanbul ignore next -- depends on browser-specific audio globals */
function getAudioContext() {
const Ctor = window.AudioContext || window.webkitAudioContext;
return new Ctor();
}

/* istanbul ignore next -- exercised via decodeToMono16k in browser */
function mixToMono(audioBuffer) {
const { numberOfChannels } = audioBuffer;
if (numberOfChannels === 1) {
return audioBuffer.getChannelData(0).slice();
}

const length = audioBuffer.length;
const output = new Float32Array(length);

for (let channel = 0; channel < numberOfChannels; channel += 1) {
const data = audioBuffer.getChannelData(channel);
for (let i = 0; i < length; i += 1) {
output[i] += data[i] / numberOfChannels;
}
}

return output;
}

/* istanbul ignore next -- exercised via decodeToMono16k in browser */
async function resampleMonoBuffer(mono, sourceRate, targetRate) {
if (sourceRate === targetRate) {
return mono;
}

const length = Math.ceil((mono.length * targetRate) / sourceRate);
const offline = new OfflineAudioContext(1, length, targetRate);
const buffer = offline.createBuffer(1, mono.length, sourceRate);
buffer.copyToChannel(mono, 0);

const source = offline.createBufferSource();
source.buffer = buffer;
source.connect(offline.destination);
source.start(0);

const rendered = await offline.startRendering();
return rendered.getChannelData(0).slice();
}

/* istanbul ignore next -- browser-only decode path */
export async function decodeToMono16k(file) {
if (!hasWebAudio()) {
throw new Error('Web Audio API not available');
}

const arrayBuffer = await file.arrayBuffer();
const ctx = getAudioContext();

try {
const decoded = await ctx.decodeAudioData(arrayBuffer);
const mono = mixToMono(decoded);
const resampled = await resampleMonoBuffer(
mono,
decoded.sampleRate,
STT_CONFIG.sampleRate
);

return {
pcm: resampled,
sampleRate: STT_CONFIG.sampleRate,
durationMs: Math.round((resampled.length / STT_CONFIG.sampleRate) * 1000),
};
} finally {
ctx.close?.();
}
}

export function clampMs(value, min, max) {
return Math.min(Math.max(value, min), max);
}

export function estimateChunkBytes(durationMs) {
const seconds = durationMs / 1000;
const bytesPerSecond = STT_CONFIG.sampleRate * STT_CONFIG.wavBytesPerSample;
return Math.ceil(seconds * bytesPerSecond) + 44; // WAV header overhead
}

function floatTo16BitPCM(float32Array) {
const buffer = new ArrayBuffer(float32Array.length * 2);
const view = new DataView(buffer);

for (let i = 0; i < float32Array.length; i += 1) {
let sample = Math.max(-1, Math.min(1, float32Array[i]));
sample = sample < 0 ? sample * 0x8000 : sample * 0x7fff;
view.setInt16(i * 2, sample, true);
}

return new Uint8Array(buffer);
}

function writeWavHeader(dataLength) {
const buffer = new ArrayBuffer(44);
const view = new DataView(buffer);
const byteRate = STT_CONFIG.sampleRate * STT_CONFIG.wavBytesPerSample;
const blockAlign = STT_CONFIG.wavBytesPerSample;

view.setUint32(0, 0x52494646, false); // 'RIFF'
view.setUint32(4, 36 + dataLength, true);
view.setUint32(8, 0x57415645, false); // 'WAVE'
view.setUint32(12, 0x666d7420, false); // 'fmt '
view.setUint32(16, 16, true); // Subchunk1Size
view.setUint16(20, 1, true); // PCM
view.setUint16(22, 1, true); // Mono
view.setUint32(24, STT_CONFIG.sampleRate, true);
view.setUint32(28, byteRate, true);
view.setUint16(32, blockAlign, true);
view.setUint16(34, 16, true); // bits per sample
view.setUint32(36, 0x64617461, false); // 'data'
view.setUint32(40, dataLength, true);

return new Uint8Array(buffer);
}

export function encodeWavChunk(pcm, startMs, endMs) {
const totalSamples = pcm.length;
const sampleRate = STT_CONFIG.sampleRate;
const startIndex = Math.max(0, Math.floor((startMs / 1000) * sampleRate));
const endIndex = Math.min(
totalSamples,
Math.ceil((endMs / 1000) * sampleRate)
);

if (endIndex <= startIndex) {
return null;
}

const slice = pcm.slice(startIndex, endIndex);
const pcm16 = floatTo16BitPCM(slice);
const header = writeWavHeader(pcm16.length);
const blob = new Blob([header, pcm16], { type: 'audio/wav' });
return {
blob,
durationMs: Math.round(((endIndex - startIndex) / sampleRate) * 1000),
};
}

export function samplesToMs(samples) {
return Math.round((samples / STT_CONFIG.sampleRate) * 1000);
}

export function msToSamples(ms) {
return Math.round((ms / 1000) * STT_CONFIG.sampleRate);
}
38 changes: 38 additions & 0 deletions src/stt/audio.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import {
clampMs,
encodeWavChunk,
estimateChunkBytes,
msToSamples,
samplesToMs,
} from './audio.js';

describe('audio helpers', () => {
test('estimateChunkBytes uses pcm byte rate', () => {
const estimate = estimateChunkBytes(1000);
expect(estimate).toBeGreaterThan(32000);
});

test('encodeWavChunk converts float32 pcm to wav blob', () => {
const pcm = new Float32Array(16_000);
for (let i = 0; i < pcm.length; i += 1) {
pcm[i] = Math.sin((i / pcm.length) * Math.PI * 2);
}
const result = encodeWavChunk(pcm, 0, 1000);
expect(result).not.toBeNull();
expect(result.durationMs).toBeGreaterThanOrEqual(900);
expect(result.blob.type).toBe('audio/wav');
expect(result.blob.size).toBeGreaterThan(0);
});

test('clampMs enforces bounds', () => {
expect(clampMs(50, 100, 200)).toBe(100);
expect(clampMs(250, 100, 200)).toBe(200);
expect(clampMs(150, 100, 200)).toBe(150);
});

test('sample conversions are consistent', () => {
const samples = msToSamples(1000);
const ms = samplesToMs(samples);
expect(ms).toBeCloseTo(1000, 0);
});
});
Loading