From 49b1ff1482d5e0f9d8f4b650f9868c431919017d Mon Sep 17 00:00:00 2001 From: Richard Palethorpe Date: Tue, 3 Mar 2026 10:39:25 +0000 Subject: [PATCH 1/2] feat(realtime): WebRTC support Signed-off-by: Richard Palethorpe --- .github/workflows/test.yml | 4 +- .github/workflows/tests-e2e.yml | 2 +- .gitignore | 4 + Dockerfile | 6 +- Makefile | 86 +- core/backend/transcript.go | 46 +- core/backend/tts.go | 58 +- core/http/endpoints/openai/opus.go | 100 ++ core/http/endpoints/openai/opus_test.go | 1267 +++++++++++++++++ core/http/endpoints/openai/realtime.go | 947 +++++++----- .../endpoints/openai/realtime_transport.go | 23 + .../openai/realtime_transport_webrtc.go | 250 ++++ .../endpoints/openai/realtime_transport_ws.go | 47 + core/http/endpoints/openai/realtime_webrtc.go | 250 ++++ core/http/endpoints/openai/types/types.go | 43 +- core/http/middleware/trace.go | 2 +- core/http/react-ui/src/pages/Settings.jsx | 14 +- core/http/react-ui/src/pages/Talk.jsx | 798 +++++++++-- core/http/react-ui/src/pages/Traces.jsx | 347 ++++- core/http/react-ui/src/utils/api.js | 6 + core/http/react-ui/src/utils/config.js | 4 + core/http/routes/openai.go | 1 + core/http/routes/ui.go | 38 + core/http/static/talk.js | 689 +++++++-- core/http/views/talk.html | 251 +++- core/http/views/traces.html | 53 +- core/trace/audio_snippet.go | 102 ++ core/trace/backend_trace.go | 2 +- docs/content/advanced/advanced-usage.md | 2 + go.mod | 30 +- go.sum | 64 +- pkg/audio/audio.go | 43 + pkg/audio/audio_test.go | 155 ++ pkg/opus/opus.go | 261 ++++ pkg/opus/shim/Makefile | 10 + pkg/opus/shim/libopusshim.so | Bin 0 -> 15240 bytes pkg/opus/shim/opus_shim.c | 9 + pkg/sound/int16.go | 12 + pkg/sound/int16_test.go | 162 +++ pkg/sound/testutil_test.go | 72 + tests/e2e/e2e_suite_test.go | 95 +- tests/e2e/mock-backend/main.go | 93 +- tests/e2e/realtime_webrtc_test.go | 459 ++++++ tests/e2e/realtime_ws_test.go | 269 ++++ 44 files changed, 6293 insertions(+), 883 deletions(-) create mode 100644 core/http/endpoints/openai/opus.go create mode 100644 core/http/endpoints/openai/opus_test.go create mode 100644 core/http/endpoints/openai/realtime_transport.go create mode 100644 core/http/endpoints/openai/realtime_transport_webrtc.go create mode 100644 core/http/endpoints/openai/realtime_transport_ws.go create mode 100644 core/http/endpoints/openai/realtime_webrtc.go create mode 100644 core/trace/audio_snippet.go create mode 100644 pkg/audio/audio_test.go create mode 100644 pkg/opus/opus.go create mode 100644 pkg/opus/shim/Makefile create mode 100755 pkg/opus/shim/libopusshim.so create mode 100644 pkg/opus/shim/opus_shim.c create mode 100644 pkg/sound/int16_test.go create mode 100644 pkg/sound/testutil_test.go create mode 100644 tests/e2e/realtime_webrtc_test.go create mode 100644 tests/e2e/realtime_ws_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fd6eece4c7af..b9b71f5e8efd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -93,7 +93,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install curl ffmpeg + sudo apt-get install curl ffmpeg libopus-dev - name: Setup Node.js uses: actions/setup-node@v4 with: @@ -195,7 +195,7 @@ jobs: run: go version - name: Dependencies run: | - brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm + brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm opus pip install --user --no-cache-dir grpcio-tools grpcio - name: Setup Node.js uses: actions/setup-node@v4 diff --git a/.github/workflows/tests-e2e.yml b/.github/workflows/tests-e2e.yml index 490eb296ab43..147ea44dab23 100644 --- a/.github/workflows/tests-e2e.yml +++ b/.github/workflows/tests-e2e.yml @@ -43,7 +43,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install -y build-essential + sudo apt-get install -y build-essential libopus-dev - name: Setup Node.js uses: actions/setup-node@v4 with: diff --git a/.gitignore b/.gitignore index 3d7e27f7a96d..3dcb309ca40d 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,7 @@ test-models/ test-dir/ tests/e2e-aio/backends tests/e2e-aio/models +mock-backend release/ @@ -69,3 +70,6 @@ docs/static/gallery.html # React UI build artifacts (keep placeholder dist/index.html) core/http/react-ui/node_modules/ core/http/react-ui/dist + +# Extracted backend binaries for container-based testing +local-backends/ diff --git a/Dockerfile b/Dockerfile index 666f19ab1895..f319ce5b74c1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && \ apt-get install -y --no-install-recommends \ ca-certificates curl wget espeak-ng libgomp1 \ - ffmpeg libopenblas0 libopenblas-dev sox && \ + ffmpeg libopenblas0 libopenblas-dev libopus0 sox && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -190,6 +190,7 @@ RUN apt-get update && \ curl libssl-dev \ git \ git-lfs \ + libopus-dev pkg-config \ unzip upx-ucl python3 python-is-python3 && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* @@ -378,6 +379,9 @@ COPY ./entrypoint.sh . # Copy the binary COPY --from=builder /build/local-ai ./ +# Copy the opus shim if it was built +RUN --mount=from=builder,src=/build/,dst=/mnt/build \ + if [ -f /mnt/build/libopusshim.so ]; then cp /mnt/build/libopusshim.so ./; fi # Make sure the models directory exists RUN mkdir -p /models /backends diff --git a/Makefile b/Makefile index 54c21088afa0..1e0175357a1e 100644 --- a/Makefile +++ b/Makefile @@ -106,7 +106,17 @@ react-ui-docker: core/http/react-ui/dist: react-ui ## Build: -build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project + +# Build the opus shim if libopus is available +build-opus-shim: + @if command -v pkg-config >/dev/null 2>&1 && pkg-config --exists opus; then \ + echo "$(GREEN)I Building opus shim (libopus found)$(RESET)"; \ + $(MAKE) -C pkg/opus/shim; \ + else \ + echo "$(YELLOW)W libopus-dev not found, skipping opus shim build (WebRTC audio will not work)$(RESET)"; \ + fi + +build: protogen-go install-go-tools build-opus-shim core/http/react-ui/dist ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) @@ -114,6 +124,7 @@ build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project $(info ${GREEN}I UPX: ${YELLOW}$(UPX)${RESET}) rm -rf $(BINARY_NAME) || true CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./cmd/local-ai + @if [ -f pkg/opus/shim/libopusshim.so ]; then cp pkg/opus/shim/libopusshim.so .; fi build-launcher: ## Build the launcher application $(info ${GREEN}I local-ai launcher build info:${RESET}) @@ -151,7 +162,7 @@ test-models/testmodel.ggml: wget -q https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav cp tests/models_fixtures/* test-models -prepare-test: protogen-go +prepare-test: protogen-go build-opus-shim cp tests/models_fixtures/* test-models ######################################################## @@ -163,6 +174,7 @@ test: test-models/testmodel.ggml protogen-go @echo 'Running tests' export GO_TAGS="debug" $(MAKE) prepare-test + OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \ HUGGINGFACE_GRPC=$(abspath ./)/backend/python/transformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS) $(MAKE) test-llama-gguf @@ -218,9 +230,10 @@ prepare-e2e: run-e2e-image: docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests -test-e2e: build-mock-backend prepare-e2e run-e2e-image +test-e2e: build-mock-backend build-opus-shim prepare-e2e run-e2e-image @echo 'Running e2e tests' BUILD_TYPE=$(BUILD_TYPE) \ + OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \ LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e $(MAKE) clean-mock-backend @@ -250,6 +263,73 @@ test-stablediffusion: prepare-test test-stores: $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts $(TEST_FLAKES) -v -r tests/integration +test-realtime: build-mock-backend + @echo 'Running realtime e2e tests (mock backend)' + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime && !real-models" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e + +# Real-model realtime tests. Set REALTIME_TEST_MODEL to use your own pipeline, +# or leave unset to auto-build one from the component env vars below. +REALTIME_VAD?=silero-vad-ggml +REALTIME_STT?=whisper-1 +REALTIME_LLM?=qwen3-0.6b +REALTIME_TTS?=tts-1 +REALTIME_BACKENDS_PATH?=$(abspath ./)/backends + +test-realtime-models: build-mock-backend + @echo 'Running realtime e2e tests (real models)' + REALTIME_TEST_MODEL=$${REALTIME_TEST_MODEL:-realtime-test-pipeline} \ + REALTIME_VAD=$(REALTIME_VAD) \ + REALTIME_STT=$(REALTIME_STT) \ + REALTIME_LLM=$(REALTIME_LLM) \ + REALTIME_TTS=$(REALTIME_TTS) \ + REALTIME_BACKENDS_PATH=$(REALTIME_BACKENDS_PATH) \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e + +# --- Container-based real-model testing --- + +REALTIME_BACKEND_NAMES ?= silero-vad whisper llama-cpp kokoro +REALTIME_MODELS_DIR ?= $(abspath ./models) +REALTIME_BACKENDS_DIR ?= $(abspath ./local-backends) +REALTIME_DOCKER_FLAGS ?= --gpus all + +local-backends: + mkdir -p local-backends + +extract-backend-%: docker-build-% local-backends + @echo "Extracting backend $*..." + @CID=$$(docker create local-ai-backend:$*) && \ + rm -rf local-backends/$* && mkdir -p local-backends/$* && \ + docker cp $$CID:/ - | tar -xf - -C local-backends/$* && \ + docker rm $$CID > /dev/null + +extract-realtime-backends: $(addprefix extract-backend-,$(REALTIME_BACKEND_NAMES)) + +test-realtime-models-docker: build-mock-backend + docker build --target build-requirements \ + --build-arg BUILD_TYPE=$(or $(BUILD_TYPE),cublas) \ + --build-arg CUDA_MAJOR_VERSION=$(or $(CUDA_MAJOR_VERSION),13) \ + --build-arg CUDA_MINOR_VERSION=$(or $(CUDA_MINOR_VERSION),0) \ + -t localai-test-runner . + docker run --rm \ + $(REALTIME_DOCKER_FLAGS) \ + -v $(abspath ./):/build \ + -v $(REALTIME_MODELS_DIR):/models:ro \ + -v $(REALTIME_BACKENDS_DIR):/backends \ + -v localai-go-cache:/root/go/pkg/mod \ + -v localai-go-build-cache:/root/.cache/go-build \ + -e REALTIME_TEST_MODEL=$${REALTIME_TEST_MODEL:-realtime-test-pipeline} \ + -e REALTIME_VAD=$(REALTIME_VAD) \ + -e REALTIME_STT=$(REALTIME_STT) \ + -e REALTIME_LLM=$(REALTIME_LLM) \ + -e REALTIME_TTS=$(REALTIME_TTS) \ + -e REALTIME_BACKENDS_PATH=/backends \ + -e REALTIME_MODELS_PATH=/models \ + -w /build \ + localai-test-runner \ + bash -c 'git config --global --add safe.directory /build && \ + make protogen-go && make build-mock-backend && \ + go run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e' + test-container: docker build --target requirements -t local-ai-test-container . docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container diff --git a/core/backend/transcript.go b/core/backend/transcript.go index dbbf718a3a48..7568e4e40706 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -3,11 +3,12 @@ package backend import ( "context" "fmt" + "maps" "time" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" @@ -30,9 +31,12 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt } var startTime time.Time + var audioSnippet map[string]any if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() + // Capture audio before the backend call — the backend may delete the file. + audioSnippet = trace.AudioSnippet(audio) } r, err := transcriptionModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ @@ -45,6 +49,16 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt }) if err != nil { if appConfig.EnableTracing { + errData := map[string]any{ + "audio_file": audio, + "language": language, + "translate": translate, + "diarize": diarize, + "prompt": prompt, + } + if audioSnippet != nil { + maps.Copy(errData, audioSnippet) + } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), @@ -53,13 +67,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt Backend: modelConfig.Backend, Summary: trace.TruncateString(audio, 200), Error: err.Error(), - Data: map[string]any{ - "audio_file": audio, - "language": language, - "translate": translate, - "diarize": diarize, - "prompt": prompt, - }, + Data: errData, }) } return nil, err @@ -84,6 +92,18 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt } if appConfig.EnableTracing { + data := map[string]any{ + "audio_file": audio, + "language": language, + "translate": translate, + "diarize": diarize, + "prompt": prompt, + "result_text": tr.Text, + "segments_count": len(tr.Segments), + } + if audioSnippet != nil { + maps.Copy(data, audioSnippet) + } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), @@ -91,15 +111,7 @@ func ModelTranscription(audio, language string, translate, diarize bool, prompt ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(audio+" -> "+tr.Text, 200), - Data: map[string]any{ - "audio_file": audio, - "language": language, - "translate": translate, - "diarize": diarize, - "prompt": prompt, - "result_text": tr.Text, - "segments_count": len(tr.Segments), - }, + Data: data, }) } diff --git a/core/backend/tts.go b/core/backend/tts.go index 7859cd67cb71..69193db12a5d 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "maps" "os" "path/filepath" "time" @@ -84,6 +85,16 @@ func ModelTTS( errStr = fmt.Sprintf("TTS error: %s", res.Message) } + data := map[string]any{ + "text": text, + "voice": voice, + "language": language, + } + if err == nil && res.Success { + if snippet := trace.AudioSnippet(filePath); snippet != nil { + maps.Copy(data, snippet) + } + } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), @@ -92,11 +103,7 @@ func ModelTTS( Backend: modelConfig.Backend, Summary: trace.TruncateString(text, 200), Error: errStr, - Data: map[string]any{ - "text": text, - "voice": voice, - "language": language, - }, + Data: data, }) } @@ -158,6 +165,11 @@ func ModelTTSStream( headerSent := false var callbackErr error + // Collect up to 30s of audio for tracing + var snippetPCM []byte + var totalPCMBytes int + snippetCapped := false + err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{ Text: text, Model: modelPath, @@ -166,7 +178,7 @@ func ModelTTSStream( }, func(reply *proto.Reply) { // First message contains sample rate info if !headerSent && len(reply.Message) > 0 { - var info map[string]interface{} + var info map[string]any if json.Unmarshal(reply.Message, &info) == nil { if sr, ok := info["sample_rate"].(float64); ok { sampleRate = uint32(sr) @@ -207,6 +219,22 @@ func ModelTTSStream( if writeErr := audioCallback(reply.Audio); writeErr != nil { callbackErr = writeErr } + // Accumulate PCM for tracing snippet + totalPCMBytes += len(reply.Audio) + if appConfig.EnableTracing && !snippetCapped { + maxBytes := int(sampleRate) * 2 * trace.MaxSnippetSeconds // 16-bit mono + if len(snippetPCM)+len(reply.Audio) <= maxBytes { + snippetPCM = append(snippetPCM, reply.Audio...) + } else { + remaining := maxBytes - len(snippetPCM) + if remaining > 0 { + // Align to sample boundary (2 bytes per sample) + remaining = remaining &^ 1 + snippetPCM = append(snippetPCM, reply.Audio[:remaining]...) + } + snippetCapped = true + } + } } }) @@ -221,6 +249,17 @@ func ModelTTSStream( errStr = resultErr.Error() } + data := map[string]any{ + "text": text, + "voice": voice, + "language": language, + "streaming": true, + } + if resultErr == nil && len(snippetPCM) > 0 { + if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes); snippet != nil { + maps.Copy(data, snippet) + } + } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), @@ -229,12 +268,7 @@ func ModelTTSStream( Backend: modelConfig.Backend, Summary: trace.TruncateString(text, 200), Error: errStr, - Data: map[string]any{ - "text": text, - "voice": voice, - "language": language, - "streaming": true, - }, + Data: data, }) } diff --git a/core/http/endpoints/openai/opus.go b/core/http/endpoints/openai/opus.go new file mode 100644 index 000000000000..86ef7b5236d6 --- /dev/null +++ b/core/http/endpoints/openai/opus.go @@ -0,0 +1,100 @@ +package openai + +import ( + "fmt" + + "github.com/mudler/LocalAI/pkg/opus" + "github.com/mudler/LocalAI/pkg/sound" +) + +const ( + opusSampleRate = 48000 + opusChannels = 1 + // 20ms frame at 48kHz mono = 960 samples + opusFrameSize = 960 + // Maximum Opus packet size + opusMaxPacketSize = 4000 + // Maximum decoded frame size (120ms at 48kHz) + opusMaxFrameSize = 5760 +) + +// OpusEncoder wraps libopus (via purego shim) for encoding PCM int16 LE to Opus frames. +type OpusEncoder struct { + enc *opus.Encoder +} + +func NewOpusEncoder() (*OpusEncoder, error) { + enc, err := opus.NewEncoder(opusSampleRate, opusChannels, opus.ApplicationAudio) + if err != nil { + return nil, fmt.Errorf("opus encoder: %w", err) + } + if err := enc.SetBitrate(64000); err != nil { + enc.Close() + return nil, fmt.Errorf("opus set bitrate: %w", err) + } + if err := enc.SetComplexity(10); err != nil { + enc.Close() + return nil, fmt.Errorf("opus set complexity: %w", err) + } + return &OpusEncoder{enc: enc}, nil +} + +// Encode takes PCM int16 LE bytes at the given sampleRate and returns Opus frames. +// It resamples to 48kHz if needed, then encodes in 20ms frames. +func (e *OpusEncoder) Encode(pcmInt16LE []byte, sampleRate int) ([][]byte, error) { + samples := sound.BytesToInt16sLE(pcmInt16LE) + if len(samples) == 0 { + return nil, nil + } + + if sampleRate != opusSampleRate { + samples = sound.ResampleInt16(samples, sampleRate, opusSampleRate) + } + + var frames [][]byte + packet := make([]byte, opusMaxPacketSize) + + for offset := 0; offset+opusFrameSize <= len(samples); offset += opusFrameSize { + frame := samples[offset : offset+opusFrameSize] + n, err := e.enc.Encode(frame, opusFrameSize, packet) + if err != nil { + return frames, fmt.Errorf("opus encode: %w", err) + } + out := make([]byte, n) + copy(out, packet[:n]) + frames = append(frames, out) + } + + return frames, nil +} + +func (e *OpusEncoder) Close() { + e.enc.Close() +} + +// OpusDecoder wraps libopus (via purego shim) for decoding Opus frames to PCM int16 LE. +type OpusDecoder struct { + dec *opus.Decoder +} + +func NewOpusDecoder() (*OpusDecoder, error) { + dec, err := opus.NewDecoder(opusSampleRate, opusChannels) + if err != nil { + return nil, fmt.Errorf("opus decoder: %w", err) + } + return &OpusDecoder{dec: dec}, nil +} + +// Decode takes a single Opus frame and returns PCM int16 LE bytes at 48kHz. +func (d *OpusDecoder) Decode(opusFrame []byte) ([]int16, error) { + pcm := make([]int16, opusMaxFrameSize) + n, err := d.dec.Decode(opusFrame, pcm, opusMaxFrameSize, false) + if err != nil { + return nil, fmt.Errorf("opus decode: %w", err) + } + return pcm[:n], nil +} + +func (d *OpusDecoder) Close() { + d.dec.Close() +} diff --git a/core/http/endpoints/openai/opus_test.go b/core/http/endpoints/openai/opus_test.go new file mode 100644 index 000000000000..77314c9ab5ab --- /dev/null +++ b/core/http/endpoints/openai/opus_test.go @@ -0,0 +1,1267 @@ +package openai + +import ( + "encoding/binary" + "fmt" + "io" + "math" + "math/rand/v2" + "os" + "os/exec" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/mudler/LocalAI/pkg/opus" + "github.com/mudler/LocalAI/pkg/sound" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" +) + +// --- helpers (mirror pkg/sound/testutil_test.go but in this package) --- + +func generateSineWave(freq float64, sampleRate, numSamples int) []int16 { + out := make([]int16, numSamples) + for i := range out { + t := float64(i) / float64(sampleRate) + out[i] = int16(math.MaxInt16 / 2 * math.Sin(2*math.Pi*freq*t)) + } + return out +} + +func computeRMS(samples []int16) float64 { + if len(samples) == 0 { + return 0 + } + var sum float64 + for _, s := range samples { + v := float64(s) + sum += v * v + } + return math.Sqrt(sum / float64(len(samples))) +} + +// estimateFrequency uses zero-crossing count to estimate the dominant frequency. +func estimateFrequency(samples []int16, sampleRate int) float64 { + if len(samples) < 2 { + return 0 + } + crossings := 0 + for i := 1; i < len(samples); i++ { + if (samples[i-1] >= 0 && samples[i] < 0) || (samples[i-1] < 0 && samples[i] >= 0) { + crossings++ + } + } + duration := float64(len(samples)) / float64(sampleRate) + return float64(crossings) / (2 * duration) +} + +// encodeDecodeRoundtrip encodes PCM at the given sample rate and decodes +// all resulting frames, returning the concatenated decoded samples. +func encodeDecodeRoundtrip(t *testing.T, pcmBytes []byte, sampleRate int) []int16 { + t.Helper() + enc, err := NewOpusEncoder() + if err != nil { + t.Fatalf("NewOpusEncoder: %v", err) + } + defer enc.Close() + + dec, err := NewOpusDecoder() + if err != nil { + t.Fatalf("NewOpusDecoder: %v", err) + } + defer dec.Close() + + frames, err := enc.Encode(pcmBytes, sampleRate) + if err != nil { + t.Fatalf("Encode: %v", err) + } + + var all []int16 + for _, frame := range frames { + d, err := dec.Decode(frame) + if err != nil { + t.Fatalf("Decode: %v", err) + } + all = append(all, d...) + } + return all +} + +// --- Opus encoder tests --- + +// TestOpus_ChromeLikeVoIPDecode tests decoding Opus frames encoded with +// VoIP mode at 32kbps (similar to Chrome's WebRTC encoder settings). +// Chrome uses SILK mode for voice, which exercises different code paths +// in the decoder compared to ApplicationAudio (CELT-preferring). +func TestOpus_ChromeLikeVoIPDecode(t *testing.T) { + // Chrome typically encodes voice at 32kbps in VoIP mode + enc, err := opus.NewEncoder(48000, 1, opus.ApplicationVoIP) + if err != nil { + t.Fatalf("NewEncoder(VoIP): %v", err) + } + defer enc.Close() + if err := enc.SetBitrate(32000); err != nil { + t.Fatalf("SetBitrate: %v", err) + } + if err := enc.SetComplexity(5); err != nil { + t.Fatalf("SetComplexity: %v", err) + } + + dec, err := NewOpusDecoder() + if err != nil { + t.Fatalf("NewOpusDecoder: %v", err) + } + defer dec.Close() + + // Encode 1 second of 440Hz sine at 48kHz + sine := generateSineWave(440, 48000, 48000) + packet := make([]byte, 4000) + + var allDecoded []int16 + for offset := 0; offset+opusFrameSize <= len(sine); offset += opusFrameSize { + frame := sine[offset : offset+opusFrameSize] + n, err := enc.Encode(frame, opusFrameSize, packet) + if err != nil { + t.Fatalf("VoIP encode: %v", err) + } + + decoded, err := dec.Decode(packet[:n]) + if err != nil { + t.Fatalf("Decode VoIP frame: %v (packet size=%d)", err, n) + } + allDecoded = append(allDecoded, decoded...) + } + + if len(allDecoded) == 0 { + t.Fatal("no decoded samples from VoIP encoder") + } + + // Skip warmup + skip := min(len(allDecoded)/4, 48000*100/1000) + tail := allDecoded[skip:] + rms := computeRMS(tail) + + t.Logf("VoIP/SILK roundtrip: %d decoded samples, RMS=%.1f", len(allDecoded), rms) + if rms < 50 { + t.Errorf("VoIP decoded RMS=%.1f is too low; SILK decoder may be broken", rms) + } +} + +// TestOpus_StereoEncoderMonoDecoder tests decoding stereo-encoded Opus +// with a mono decoder. Chrome signals opus/48000/2 in SDP and may send +// stereo Opus. The mono decoder should downmix correctly. +func TestOpus_StereoEncoderMonoDecoder(t *testing.T) { + // Encode as stereo (2 channels) — similar to what Chrome might send + enc, err := opus.NewEncoder(48000, 2, opus.ApplicationVoIP) + if err != nil { + t.Fatalf("NewEncoder(stereo): %v", err) + } + defer enc.Close() + if err := enc.SetBitrate(32000); err != nil { + t.Fatalf("SetBitrate: %v", err) + } + + // Decode with our standard mono decoder + dec, err := NewOpusDecoder() + if err != nil { + t.Fatalf("NewOpusDecoder: %v", err) + } + defer dec.Close() + + // Create stereo signal: same sine in both channels (interleaved L,R,L,R...) + mono := generateSineWave(440, 48000, 48000) + stereo := make([]int16, len(mono)*2) + for i, s := range mono { + stereo[i*2] = s // L + stereo[i*2+1] = s // R + } + + packet := make([]byte, 4000) + var allDecoded []int16 + for offset := 0; offset+opusFrameSize*2 <= len(stereo); offset += opusFrameSize * 2 { + frame := stereo[offset : offset+opusFrameSize*2] + n, err := enc.Encode(frame, opusFrameSize, packet) + if err != nil { + t.Fatalf("Stereo encode: %v", err) + } + + decoded, err := dec.Decode(packet[:n]) + if err != nil { + t.Fatalf("Decode stereo->mono: %v (packet size=%d)", err, n) + } + allDecoded = append(allDecoded, decoded...) + } + + if len(allDecoded) == 0 { + t.Fatal("no decoded samples from stereo encoder") + } + + skip := min(len(allDecoded)/4, 48000*100/1000) + tail := allDecoded[skip:] + rms := computeRMS(tail) + + t.Logf("Stereo->Mono: %d decoded samples, RMS=%.1f", len(allDecoded), rms) + if rms < 50 { + t.Errorf("Stereo->Mono decoded RMS=%.1f is too low; cross-channel decoding may be broken", rms) + } +} + +// TestOpus_DecodeLibopusEncoded uses ffmpeg (real libopus) to encode audio, +// then decodes with our opus-go decoder. This simulates Chrome sending Opus +// frames to the server. Skipped if ffmpeg is not available. +func TestOpus_DecodeLibopusEncoded(t *testing.T) { + ffmpegPath, err := exec.LookPath("ffmpeg") + if err != nil { + t.Skip("ffmpeg not found") + } + + tmpDir := t.TempDir() + + // Generate 1 second of 440Hz tone as raw PCM (16-bit LE mono 48kHz) + sine := generateSineWave(440, 48000, 48000) + pcmPath := filepath.Join(tmpDir, "input.raw") + pcmBytes := sound.Int16toBytesLE(sine) + if err := os.WriteFile(pcmPath, pcmBytes, 0644); err != nil { + t.Fatalf("write PCM: %v", err) + } + + for _, tc := range []struct { + name string + bitrate string + app string + }{ + {"voip_32k", "32000", "voip"}, + {"voip_64k", "64000", "voip"}, + {"audio_64k", "64000", "audio"}, + {"audio_128k", "128000", "audio"}, + } { + t.Run(tc.name, func(t *testing.T) { + testDecodeLibopus(t, ffmpegPath, tmpDir, pcmPath, sine, tc.bitrate, tc.app) + }) + } +} + +func testDecodeLibopus(t *testing.T, ffmpegPath, tmpDir, pcmPath string, _ []int16, bitrate, app string) { + t.Helper() + + oggPath := filepath.Join(tmpDir, fmt.Sprintf("libopus_%s_%s.ogg", app, bitrate)) + cmd := exec.Command(ffmpegPath, + "-y", + "-f", "s16le", "-ar", "48000", "-ac", "1", "-i", pcmPath, + "-c:a", "libopus", + "-b:a", bitrate, + "-application", app, + "-frame_duration", "20", + "-vbr", "on", + oggPath, + ) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("ffmpeg encode: %v\n%s", err, out) + } + + // Read the Ogg/Opus file and extract raw Opus frames + oggData, err := os.ReadFile(oggPath) + if err != nil { + t.Fatalf("read ogg: %v", err) + } + + frames := extractOpusFramesFromOgg(t, oggData) + if len(frames) == 0 { + t.Fatal("no Opus frames extracted from Ogg container") + } + t.Logf("Extracted %d Opus frames from libopus encoder (first frame %d bytes)", len(frames), len(frames[0])) + + // Decode with our opus-go decoder + dec, err := NewOpusDecoder() + if err != nil { + t.Fatalf("NewOpusDecoder: %v", err) + } + defer dec.Close() + + var allDecoded []int16 + decodeErrors := 0 + for i, frame := range frames { + decoded, err := dec.Decode(frame) + if err != nil { + decodeErrors++ + if decodeErrors <= 5 { + t.Logf("frame %d: decode error: %v (size=%d)", i, err, len(frame)) + } + continue + } + if i < 5 { + t.Logf("frame %d: payload=%d bytes, decoded=%d samples (%.1fms @ 48kHz)", + i, len(frame), len(decoded), float64(len(decoded))/48.0) + } + allDecoded = append(allDecoded, decoded...) + } + + if decodeErrors > 0 { + t.Logf("Total decode errors: %d/%d frames", decodeErrors, len(frames)) + } + + if len(allDecoded) == 0 { + t.Fatal("no decoded samples from libopus-encoded Opus") + } + + // Skip warmup and check quality + skip := min(len(allDecoded)/4, 48000*100/1000) + tail := allDecoded[skip:] + rms := computeRMS(tail) + freq := estimateFrequency(tail, 48000) + + t.Logf("libopus->opus-go: %d decoded samples, RMS=%.1f, freq≈%.0f Hz", len(allDecoded), rms, freq) + + if rms < 50 { + t.Errorf("RMS=%.1f is too low — opus-go cannot decode libopus output", rms) + } + if math.Abs(freq-440) > 30 { + t.Errorf("frequency %.0f Hz deviates from expected 440 Hz (ratio=%.3f)", freq, freq/440.0) + } +} + +// extractOpusFramesFromOgg parses an Ogg container and extracts raw Opus audio frames. +func extractOpusFramesFromOgg(t *testing.T, data []byte) [][]byte { + t.Helper() + var frames [][]byte + pos := 0 + pageNum := 0 + + for pos+27 <= len(data) { + // Check for OggS sync + if string(data[pos:pos+4]) != "OggS" { + t.Fatalf("invalid Ogg page at offset %d", pos) + } + + nSegments := int(data[pos+26]) + if pos+27+nSegments > len(data) { + break + } + + segTable := data[pos+27 : pos+27+nSegments] + dataStart := pos + 27 + nSegments + + // Calculate total page data size + var totalDataSize int + for _, s := range segTable { + totalDataSize += int(s) + } + + if dataStart+totalDataSize > len(data) { + break + } + + // Skip first two pages (OpusHead + OpusTags) + if pageNum >= 2 { + // Extract packets from segment table + pageData := data[dataStart : dataStart+totalDataSize] + offset := 0 + var packet []byte + for _, segSize := range segTable { + packet = append(packet, pageData[offset:offset+int(segSize)]...) + offset += int(segSize) + if segSize < 255 { + // End of packet + if len(packet) > 0 { + frameCopy := make([]byte, len(packet)) + copy(frameCopy, packet) + frames = append(frames, frameCopy) + } + packet = nil + } + } + // If last segment was 255, packet continues on next page + if len(packet) > 0 { + frameCopy := make([]byte, len(packet)) + copy(frameCopy, packet) + frames = append(frames, frameCopy) + } + } + + pos = dataStart + totalDataSize + pageNum++ + } + + return frames +} + +func TestOpusEncodeDecode_Roundtrip_48kHz(t *testing.T) { + // Use a longer signal (1 second) so the codec can stabilise past its + // lookahead period and produce meaningful output. + sine := generateSineWave(440, 48000, 48000) + pcmBytes := sound.Int16toBytesLE(sine) + + decoded := encodeDecodeRoundtrip(t, pcmBytes, 48000) + if len(decoded) == 0 { + t.Fatal("no decoded samples") + } + + // Skip initial codec warmup (first 50ms) for frequency estimation. + skip := 48000 * 50 / 1000 // 2400 samples at 48kHz + // The decoder may return fewer samples per frame (e.g. 480 instead of 960), + // so the total decoded length may differ. Adjust skip proportionally. + decodedSR := 48000 // decoder is initialised at 48kHz + skipDecoded := decodedSR * 50 / 1000 + if skipDecoded > len(decoded)/2 { + skipDecoded = len(decoded) / 4 + } + tail := decoded[skipDecoded:] + + rms := computeRMS(tail) + t.Logf("48kHz roundtrip: %d decoded samples, RMS=%.1f (skip=%d, analysed=%d)", + len(decoded), rms, skip, len(tail)) + + if rms < 50 { + t.Errorf("decoded audio RMS=%.1f is too low; signal appears silent", rms) + } +} + +func TestOpusEncodeDecode_Roundtrip_16kHz(t *testing.T) { + // 1 second of 440Hz at 16kHz. Encoder resamples 16k->48k internally. + sine16k := generateSineWave(440, 16000, 16000) + pcmBytes := sound.Int16toBytesLE(sine16k) + + decoded := encodeDecodeRoundtrip(t, pcmBytes, 16000) + if len(decoded) == 0 { + t.Fatal("no decoded samples") + } + + // Resample back to 16kHz + decoded16k := sound.ResampleInt16(decoded, 48000, 16000) + + // Skip warmup + skip := min(len(decoded16k)/4, 16000*50/1000) + tail := decoded16k[skip:] + + rms := computeRMS(tail) + t.Logf("16kHz roundtrip: %d decoded@48k -> %d resampled@16k, RMS=%.1f", + len(decoded), len(decoded16k), rms) + + if rms < 50 { + t.Errorf("decoded audio RMS=%.1f is too low; signal appears silent", rms) + } +} + +func TestOpusEncode_EmptyInput(t *testing.T) { + enc, err := NewOpusEncoder() + if err != nil { + t.Fatalf("NewOpusEncoder: %v", err) + } + defer enc.Close() + + frames, err := enc.Encode([]byte{}, 48000) + if err != nil { + t.Fatalf("Encode empty: %v", err) + } + if frames != nil { + t.Errorf("expected nil frames for empty input, got %d frames", len(frames)) + } +} + +func TestOpusEncode_SubFrameInput_SilentDrop(t *testing.T) { + // Less than 960 samples at 48kHz = not enough for a single frame. + // The encoder silently drops these trailing samples. + enc, err := NewOpusEncoder() + if err != nil { + t.Fatalf("NewOpusEncoder: %v", err) + } + defer enc.Close() + + sine := generateSineWave(440, 48000, 500) // < 960 + pcmBytes := sound.Int16toBytesLE(sine) + + frames, err := enc.Encode(pcmBytes, 48000) + if err != nil { + t.Fatalf("Encode: %v", err) + } + if len(frames) != 0 { + t.Errorf("expected 0 frames for %d samples (< 960), got %d", len(sine), len(frames)) + } +} + +func TestOpusEncode_MultiFrame(t *testing.T) { + enc, err := NewOpusEncoder() + if err != nil { + t.Fatalf("NewOpusEncoder: %v", err) + } + defer enc.Close() + + // 2880 samples at 48kHz = exactly 3 frames of 960 + sine := generateSineWave(440, 48000, 2880) + pcmBytes := sound.Int16toBytesLE(sine) + + frames, err := enc.Encode(pcmBytes, 48000) + if err != nil { + t.Fatalf("Encode: %v", err) + } + if len(frames) != 3 { + t.Errorf("expected 3 frames for 2880 samples, got %d", len(frames)) + } +} + +func TestOpusDecode_FrameSize(t *testing.T) { + // Document the actual decoded frame size from the pure Go opus-go library. + enc, err := NewOpusEncoder() + if err != nil { + t.Fatalf("NewOpusEncoder: %v", err) + } + defer enc.Close() + + dec, err := NewOpusDecoder() + if err != nil { + t.Fatalf("NewOpusDecoder: %v", err) + } + defer dec.Close() + + sine := generateSineWave(440, 48000, 960) + pcmBytes := sound.Int16toBytesLE(sine) + + frames, err := enc.Encode(pcmBytes, 48000) + if err != nil { + t.Fatalf("Encode: %v", err) + } + if len(frames) != 1 { + t.Fatalf("expected 1 frame, got %d", len(frames)) + } + + decoded, err := dec.Decode(frames[0]) + if err != nil { + t.Fatalf("Decode: %v", err) + } + + t.Logf("Encoder input: 960 samples (20ms @ 48kHz)") + t.Logf("Decoder output: %d samples (%.1fms @ 48kHz)", + len(decoded), float64(len(decoded))/48.0) + + // The decoder may return a different frame size due to internal + // bandwidth decisions in VoIP mode. Document the actual value. + if len(decoded) != 960 && len(decoded) != 480 { + t.Errorf("unexpected decoded frame size %d (expected 960 or 480)", len(decoded)) + } +} + +func TestOpus_FullWebRTCOutputPath(t *testing.T) { + // Simulates the TTS -> SendAudio path: + // PCM at 16kHz -> Encode(pcm, 16000) -> Opus frames -> Decode -> 48kHz samples + // Use 1 second of audio to let codec stabilise. + sine16k := generateSineWave(440, 16000, 16000) + pcmBytes := sound.Int16toBytesLE(sine16k) + + decoded := encodeDecodeRoundtrip(t, pcmBytes, 16000) + if len(decoded) == 0 { + t.Fatal("no frames produced") + } + + rms := computeRMS(decoded) + t.Logf("WebRTC output path: %d decoded samples at 48kHz, RMS=%.1f", len(decoded), rms) + + if rms < 50 { + t.Errorf("decoded audio RMS=%.1f is too low; expected recognisable signal", rms) + } +} + +func TestOpus_FullWebRTCInputPath(t *testing.T) { + // Simulates the client -> server path: + // PCM@48k -> Encode -> Decode -> Resample 48k->24k->16k + // Verify that the pipeline produces non-silent audio. + sine48k := generateSineWave(440, 48000, 48000) // 1 second + pcmBytes := sound.Int16toBytesLE(sine48k) + + decoded48k := encodeDecodeRoundtrip(t, pcmBytes, 48000) + if len(decoded48k) == 0 { + t.Fatal("no decoded samples") + } + + // WebRTC path: 48k -> 24k -> (VAD) -> 16k + step24k := sound.ResampleInt16(decoded48k, 48000, 24000) + webrtcPath := sound.ResampleInt16(step24k, 24000, 16000) + + rms := computeRMS(webrtcPath) + t.Logf("WebRTC input path: %d decoded@48k -> %d@24k -> %d@16k, RMS=%.1f", + len(decoded48k), len(step24k), len(webrtcPath), rms) + + if rms < 50 { + t.Errorf("WebRTC input path RMS=%.1f is too low; signal lost in pipeline", rms) + } +} + +// --- Bug documentation tests --- + +func TestOpusBug_TrailingSampleLoss(t *testing.T) { + // Encode 1000 samples at 48kHz -> only 1 frame (960 samples) returned. + // 40 trailing samples are silently lost. + enc, err := NewOpusEncoder() + if err != nil { + t.Fatalf("NewOpusEncoder: %v", err) + } + defer enc.Close() + + sine := generateSineWave(440, 48000, 1000) + pcmBytes := sound.Int16toBytesLE(sine) + + frames, err := enc.Encode(pcmBytes, 48000) + if err != nil { + t.Fatalf("Encode: %v", err) + } + if len(frames) != 1 { + t.Fatalf("expected 1 frame, got %d", len(frames)) + } + + dec, err := NewOpusDecoder() + if err != nil { + t.Fatalf("NewOpusDecoder: %v", err) + } + defer dec.Close() + + decoded, err := dec.Decode(frames[0]) + if err != nil { + t.Fatalf("Decode: %v", err) + } + + // The encoder only encoded 960 of 1000 input samples. + // Decoded frame size may be 960 or 480 depending on codec mode. + // Either way, 40 input samples are permanently lost. + t.Logf("Input: 1000 samples, Encoded: 1 frame, Decoded: %d samples (40 samples lost)", len(decoded)) + if len(decoded) > 960 { + t.Errorf("decoded more samples (%d) than the encoder consumed (960)", len(decoded)) + } +} + +func TestOpusBug_TTSSampleRateMismatch(t *testing.T) { + // If TTS produces 24kHz audio but the pipeline assumes 16kHz, + // the Opus encoder resamples from 16kHz to 48kHz (3x) instead of + // 24kHz to 48kHz (2x). The result is pitched up by 50%. + // + // This test uses a longer signal and compares the two paths to + // demonstrate the frequency distortion. + + // Generate 440Hz at 24kHz (what TTS actually produces) + sine24k := generateSineWave(440, 24000, 24000) // 1 second + pcmBytes := sound.Int16toBytesLE(sine24k) + + // BUG path: Pipeline passes sampleRate=16000 (assumed) instead of 24000 (actual) + decodedBug := encodeDecodeRoundtrip(t, pcmBytes, 16000) + // CORRECT path: Pipeline should pass sampleRate=24000 + decodedCorrect := encodeDecodeRoundtrip(t, pcmBytes, 24000) + + // Skip warmup for frequency estimation + skipBug := min(len(decodedBug)/4, 48000*100/1000) + skipCorrect := min(len(decodedCorrect)/4, 48000*100/1000) + + bugTail := decodedBug[skipBug:] + correctTail := decodedCorrect[skipCorrect:] + + bugFreq := estimateFrequency(bugTail, 48000) + correctFreq := estimateFrequency(correctTail, 48000) + + t.Logf("Bug path: %d decoded samples, freq≈%.0f Hz (expected ~660 Hz = 440*1.5)", len(decodedBug), bugFreq) + t.Logf("Correct path: %d decoded samples, freq≈%.0f Hz (expected ~440 Hz)", len(decodedCorrect), correctFreq) + + // The bug path produces significantly more decoded samples because + // the encoder thinks the input is 16kHz and upsamples by 3x instead of 2x. + // This also means the perceived playback speed and pitch are wrong. + if len(decodedBug) > 0 && len(decodedCorrect) > 0 { + ratio := float64(len(decodedBug)) / float64(len(decodedCorrect)) + t.Logf("Sample count ratio (bug/correct): %.2f (expected ~1.5)", ratio) + if ratio < 1.1 { + t.Error("expected bug path to produce significantly more samples due to wrong resample ratio") + } + } +} + +// TestOpus_CrossLibraryCompat encodes a sine wave with opus-go, wraps the +// output in a minimal Ogg/Opus container, and decodes it with ffmpeg. This +// catches issues where the pure-Go encoder produces Opus frames that only +// its own decoder can parse (but not a browser or standard decoder). +// Skipped if ffmpeg is not available. +func TestOpus_CrossLibraryCompat(t *testing.T) { + ffmpegPath, err := exec.LookPath("ffmpeg") + if err != nil { + t.Skip("ffmpeg not found, skipping cross-library compatibility test") + } + + // Encode 1 second of 440Hz sine at 48kHz with opus-go + sine := generateSineWave(440, 48000, 48000) + pcmBytes := sound.Int16toBytesLE(sine) + + enc, err := NewOpusEncoder() + if err != nil { + t.Fatalf("NewOpusEncoder: %v", err) + } + defer enc.Close() + + frames, err := enc.Encode(pcmBytes, 48000) + if err != nil { + t.Fatalf("Encode: %v", err) + } + if len(frames) == 0 { + t.Fatal("no frames produced") + } + t.Logf("opus-go produced %d frames (first frame %d bytes)", len(frames), len(frames[0])) + + // Wrap the Opus frames in an Ogg/Opus container so ffmpeg can decode them. + tmpDir := t.TempDir() + oggPath := filepath.Join(tmpDir, "opus_go_output.ogg") + if err := writeOggOpus(oggPath, frames, 48000, 1); err != nil { + t.Fatalf("writeOggOpus: %v", err) + } + + // Decode with ffmpeg + decodedWavPath := filepath.Join(tmpDir, "ffmpeg_decoded.wav") + cmd := exec.Command(ffmpegPath, "-y", "-i", oggPath, "-ar", "48000", "-ac", "1", "-c:a", "pcm_s16le", decodedWavPath) + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("ffmpeg failed to decode opus-go output: %v\n%s", err, out) + } + + // Read the decoded WAV and check audio quality + decodedData, err := os.ReadFile(decodedWavPath) + if err != nil { + t.Fatalf("read decoded WAV: %v", err) + } + + // Use our robust ParseWAV to handle ffmpeg's WAV output + decodedPCM, sr := parseTestWAV(decodedData) + if sr == 0 { + t.Fatal("ffmpeg output has no WAV header") + } + decodedSamples := sound.BytesToInt16sLE(decodedPCM) + + // Skip codec warmup (first 100ms), check RMS of the rest + skip := min(len(decodedSamples)/4, sr*100/1000) + if skip >= len(decodedSamples) { + skip = 0 + } + tail := decodedSamples[skip:] + rms := computeRMS(tail) + + t.Logf("ffmpeg decoded opus-go output: %d samples at %dHz, RMS=%.1f", len(decodedSamples), sr, rms) + + if rms < 50 { + t.Errorf("ffmpeg decoded RMS=%.1f is too low — opus-go frames are likely incompatible with standard decoders", rms) + } else { + t.Logf("PASS: opus-go Opus frames are decodable by ffmpeg (libopus) with good signal quality") + } +} + +// parseTestWAV is a simple WAV parser for test output (ffmpeg always writes standard headers). +func parseTestWAV(data []byte) (pcm []byte, sampleRate int) { + if len(data) < 44 || string(data[0:4]) != "RIFF" { + return data, 0 + } + // Walk chunks to find "data" + pos := 12 + sr := int(binary.LittleEndian.Uint32(data[24:28])) + for pos+8 <= len(data) { + id := string(data[pos : pos+4]) + sz := int(binary.LittleEndian.Uint32(data[pos+4 : pos+8])) + if id == "data" { + end := pos + 8 + sz + if end > len(data) { + end = len(data) + } + return data[pos+8 : end], sr + } + pos += 8 + sz + if sz%2 != 0 { + pos++ + } + } + return data[44:], sr +} + +// writeOggOpus writes Opus frames into a minimal Ogg/Opus container file. +func writeOggOpus(path string, frames [][]byte, sampleRate, channels int) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + serial := uint32(0x4C6F6341) // "LocA" + var pageSeq uint32 + const preSkip = 312 // standard Opus pre-skip for 48kHz + + // Page 1: OpusHead (BOS page) + opusHead := make([]byte, 19) + copy(opusHead[0:8], "OpusHead") + opusHead[8] = 1 // version + opusHead[9] = byte(channels) // channel count + binary.LittleEndian.PutUint16(opusHead[10:12], uint16(preSkip)) // pre-skip + binary.LittleEndian.PutUint32(opusHead[12:16], uint32(sampleRate)) // input sample rate + binary.LittleEndian.PutUint16(opusHead[16:18], 0) // output gain + opusHead[18] = 0 // channel mapping family + if err := writeOggPage(f, serial, pageSeq, 0, 0x02, [][]byte{opusHead}); err != nil { + return err + } + pageSeq++ + + // Page 2: OpusTags + opusTags := make([]byte, 16) + copy(opusTags[0:8], "OpusTags") + binary.LittleEndian.PutUint32(opusTags[8:12], 0) // vendor string length + binary.LittleEndian.PutUint32(opusTags[12:16], 0) // comment list length + if err := writeOggPage(f, serial, pageSeq, 0, 0x00, [][]byte{opusTags}); err != nil { + return err + } + pageSeq++ + + // Audio pages: one Opus frame per page for simplicity + var granulePos uint64 + for i, frame := range frames { + granulePos += 960 // 20ms at 48kHz + headerType := byte(0x00) + if i == len(frames)-1 { + headerType = 0x04 // EOS + } + if err := writeOggPage(f, serial, pageSeq, granulePos, headerType, [][]byte{frame}); err != nil { + return err + } + pageSeq++ + } + + return nil +} + +// writeOggPage writes a single Ogg page containing the given packets. +func writeOggPage(w io.Writer, serial, pageSeq uint32, granulePos uint64, headerType byte, packets [][]byte) error { + // Build segment table + var segments []byte + var pageData []byte + for _, pkt := range packets { + remaining := len(pkt) + for remaining >= 255 { + segments = append(segments, 255) + remaining -= 255 + } + segments = append(segments, byte(remaining)) + pageData = append(pageData, pkt...) + } + + // Build page header (27 bytes + segment table) + hdr := make([]byte, 27+len(segments)) + copy(hdr[0:4], "OggS") + hdr[4] = 0 // version + hdr[5] = headerType + binary.LittleEndian.PutUint64(hdr[6:14], granulePos) + binary.LittleEndian.PutUint32(hdr[14:18], serial) + binary.LittleEndian.PutUint32(hdr[18:22], pageSeq) + // CRC at [22:26] — filled after computing + hdr[26] = byte(len(segments)) + copy(hdr[27:], segments) + + // Compute CRC-32 over header + page data + crc := oggCRC32(hdr, pageData) + binary.LittleEndian.PutUint32(hdr[22:26], crc) + + if _, err := w.Write(hdr); err != nil { + return err + } + _, err := w.Write(pageData) + return err +} + +// oggCRC32 computes the Ogg CRC-32 checksum (polynomial 0x04C11DB7). +func oggCRC32(header, data []byte) uint32 { + var crc uint32 + for _, b := range header { + crc = (crc << 8) ^ oggCRCTable[byte(crc>>24)^b] + } + for _, b := range data { + crc = (crc << 8) ^ oggCRCTable[byte(crc>>24)^b] + } + return crc +} + +var oggCRCTable = func() [256]uint32 { + var t [256]uint32 + for i := range 256 { + r := uint32(i) << 24 + for range 8 { + if r&0x80000000 != 0 { + r = (r << 1) ^ 0x04C11DB7 + } else { + r <<= 1 + } + } + t[i] = r + } + return t +}() + +// goertzel computes the power at a specific frequency using the Goertzel algorithm. +// Returns power in linear scale (not dB). +func goertzel(samples []int16, targetFreq float64, sampleRate int) float64 { + N := len(samples) + if N == 0 { + return 0 + } + k := 0.5 + float64(N)*targetFreq/float64(sampleRate) + w := 2 * math.Pi * k / float64(N) + coeff := 2 * math.Cos(w) + var s1, s2 float64 + for _, sample := range samples { + s0 := float64(sample) + coeff*s1 - s2 + s2 = s1 + s1 = s0 + } + return s1*s1 + s2*s2 - coeff*s1*s2 +} + +// computeTHD computes Total Harmonic Distortion for a signal with known fundamental. +// THD = sqrt(sum of harmonic powers) / fundamental power, returned as percentage. +func computeTHD(samples []int16, fundamentalHz float64, sampleRate, numHarmonics int) float64 { + fundPower := goertzel(samples, fundamentalHz, sampleRate) + if fundPower <= 0 { + return 0 + } + var harmonicSum float64 + for h := 2; h <= numHarmonics; h++ { + harmonicSum += goertzel(samples, fundamentalHz*float64(h), sampleRate) + } + return math.Sqrt(harmonicSum/fundPower) * 100 +} + +// TestWebRTCPipeline_TestToneQuality exercises the full audio pipeline: +// +// PCM (24kHz) → resample to 48kHz → Opus encode → RTP packetize → +// WebRTC transport (local loopback) → RTP depacketize → Opus decode → PCM (48kHz) +// +// Two local PeerConnections are connected via SDP exchange (no network). +// The sender uses the same RTP construction as WebRTCTransport.SendAudio. +// Quality metrics are computed on the received/decoded audio and logged. +// +// This test catches regressions in: +// - Opus encoder output quality +// - RTP packetization (sequence numbers, timestamps, marker bit) +// - Sample rate handling in the encode path +// - Packet delivery through pion's internal transport +func TestWebRTCPipeline_TestToneQuality(t *testing.T) { + const ( + toneFreq = 440.0 + toneSampleRate = 24000 // matches sendTestTone + toneDuration = 1 // seconds + toneAmplitude = 16000 + toneNumSamples = toneSampleRate * toneDuration + ) + + // Generate test tone (same as sendTestTone in realtime.go) + pcm := make([]byte, toneNumSamples*2) + for i := 0; i < toneNumSamples; i++ { + sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate))) + binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) + } + + // Encode to Opus frames (same path as SendAudio) + enc, err := NewOpusEncoder() + if err != nil { + t.Fatalf("NewOpusEncoder: %v", err) + } + defer enc.Close() + + opusFrames, err := enc.Encode(pcm, toneSampleRate) + if err != nil { + t.Fatalf("Encode: %v", err) + } + if len(opusFrames) == 0 { + t.Fatal("no Opus frames produced") + } + t.Logf("Encoded %d Opus frames from %d PCM samples at %dHz", len(opusFrames), toneNumSamples, toneSampleRate) + + // --- Create sender PeerConnection --- + senderME := &webrtc.MediaEngine{} + if err := senderME.RegisterDefaultCodecs(); err != nil { + t.Fatalf("sender RegisterDefaultCodecs: %v", err) + } + senderAPI := webrtc.NewAPI(webrtc.WithMediaEngine(senderME)) + senderPC, err := senderAPI.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + t.Fatalf("sender NewPeerConnection: %v", err) + } + defer senderPC.Close() + + audioTrack, err := webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, + }, + "audio", "test", + ) + if err != nil { + t.Fatalf("NewTrackLocalStaticRTP: %v", err) + } + + rtpSender, err := senderPC.AddTrack(audioTrack) + if err != nil { + t.Fatalf("AddTrack: %v", err) + } + // Drain RTCP + go func() { + buf := make([]byte, 1500) + for { + if _, _, err := rtpSender.Read(buf); err != nil { + return + } + } + }() + + // --- Create receiver PeerConnection --- + receiverME := &webrtc.MediaEngine{} + if err := receiverME.RegisterDefaultCodecs(); err != nil { + t.Fatalf("receiver RegisterDefaultCodecs: %v", err) + } + receiverAPI := webrtc.NewAPI(webrtc.WithMediaEngine(receiverME)) + receiverPC, err := receiverAPI.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + t.Fatalf("receiver NewPeerConnection: %v", err) + } + defer receiverPC.Close() + + // Collect received RTP payloads (Opus frames) + type receivedPacket struct { + seqNum uint16 + timestamp uint32 + marker bool + payload []byte + } + var ( + receivedMu sync.Mutex + receivedPackets []receivedPacket + trackDone = make(chan struct{}) + ) + + receiverPC.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + defer close(trackDone) + for { + pkt, _, err := track.ReadRTP() + if err != nil { + return + } + payload := make([]byte, len(pkt.Payload)) + copy(payload, pkt.Payload) + receivedMu.Lock() + receivedPackets = append(receivedPackets, receivedPacket{ + seqNum: pkt.Header.SequenceNumber, + timestamp: pkt.Header.Timestamp, + marker: pkt.Header.Marker, + payload: payload, + }) + receivedMu.Unlock() + } + }) + + // --- Exchange SDP --- + offer, err := senderPC.CreateOffer(nil) + if err != nil { + t.Fatalf("CreateOffer: %v", err) + } + if err := senderPC.SetLocalDescription(offer); err != nil { + t.Fatalf("sender SetLocalDescription: %v", err) + } + senderGatherDone := webrtc.GatheringCompletePromise(senderPC) + select { + case <-senderGatherDone: + case <-time.After(5 * time.Second): + t.Fatal("sender ICE gathering timeout") + } + + if err := receiverPC.SetRemoteDescription(*senderPC.LocalDescription()); err != nil { + t.Fatalf("receiver SetRemoteDescription: %v", err) + } + answer, err := receiverPC.CreateAnswer(nil) + if err != nil { + t.Fatalf("CreateAnswer: %v", err) + } + if err := receiverPC.SetLocalDescription(answer); err != nil { + t.Fatalf("receiver SetLocalDescription: %v", err) + } + receiverGatherDone := webrtc.GatheringCompletePromise(receiverPC) + select { + case <-receiverGatherDone: + case <-time.After(5 * time.Second): + t.Fatal("receiver ICE gathering timeout") + } + + if err := senderPC.SetRemoteDescription(*receiverPC.LocalDescription()); err != nil { + t.Fatalf("sender SetRemoteDescription: %v", err) + } + + // Wait for connection + connected := make(chan struct{}) + senderPC.OnConnectionStateChange(func(s webrtc.PeerConnectionState) { + if s == webrtc.PeerConnectionStateConnected { + select { + case <-connected: + default: + close(connected) + } + } + }) + select { + case <-connected: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for WebRTC connection") + } + + // --- Send test tone via RTP (same logic as SendAudio) --- + const samplesPerFrame = 960 + seqNum := uint16(rand.UintN(65536)) + timestamp := rand.Uint32() + marker := true + + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + + for i, frame := range opusFrames { + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: marker, + SequenceNumber: seqNum, + Timestamp: timestamp, + }, + Payload: frame, + } + seqNum++ + timestamp += samplesPerFrame + marker = false + + if err := audioTrack.WriteRTP(pkt); err != nil { + t.Fatalf("WriteRTP frame %d: %v", i, err) + } + if i < len(opusFrames)-1 { + <-ticker.C + } + } + + // Wait for packets to arrive (give extra time for jitter buffer) + time.Sleep(500 * time.Millisecond) + + // Close sender to trigger track end on receiver + senderPC.Close() + + // Wait for track reader to finish (with timeout) + select { + case <-trackDone: + case <-time.After(2 * time.Second): + // Track reader may not exit cleanly on all platforms + } + + // --- Decode received Opus frames --- + receivedMu.Lock() + pkts := make([]receivedPacket, len(receivedPackets)) + copy(pkts, receivedPackets) + receivedMu.Unlock() + + if len(pkts) == 0 { + t.Fatal("no RTP packets received") + } + + dec, err := NewOpusDecoder() + if err != nil { + t.Fatalf("NewOpusDecoder: %v", err) + } + defer dec.Close() + + var allDecoded []int16 + decodeErrors := 0 + for _, pkt := range pkts { + samples, err := dec.Decode(pkt.payload) + if err != nil { + decodeErrors++ + continue + } + allDecoded = append(allDecoded, samples...) + } + + if len(allDecoded) == 0 { + t.Fatal("no decoded samples") + } + + // --- Analyse RTP packet delivery --- + frameLoss := len(opusFrames) - len(pkts) + seqGaps := 0 + for i := 1; i < len(pkts); i++ { + expected := pkts[i-1].seqNum + 1 + if pkts[i].seqNum != expected { + seqGaps++ + } + } + markerCount := 0 + for _, pkt := range pkts { + if pkt.marker { + markerCount++ + } + } + + t.Log("── RTP Delivery ──") + t.Logf(" Frames sent: %d", len(opusFrames)) + t.Logf(" Packets recv: %d", len(pkts)) + t.Logf(" Frame loss: %d", frameLoss) + t.Logf(" Sequence gaps: %d", seqGaps) + t.Logf(" Marker packets: %d (expect 1)", markerCount) + t.Logf(" Decode errors: %d", decodeErrors) + + // --- Audio quality metrics --- + // Skip codec warmup (first 100ms at 48kHz = 4800 samples) + skip := 48000 * 100 / 1000 + if skip > len(allDecoded)/2 { + skip = len(allDecoded) / 4 + } + tail := allDecoded[skip:] + + rms := computeRMS(tail) + freq := estimateFrequency(tail, 48000) + thd := computeTHD(tail, toneFreq, 48000, 10) + + t.Log("── Audio Quality ──") + t.Logf(" Decoded samples: %d (%.1f ms at 48kHz)", len(allDecoded), float64(len(allDecoded))/48.0) + t.Logf(" RMS level: %.1f", rms) + t.Logf(" Peak frequency: %.0f Hz (expected %.0f Hz)", freq, toneFreq) + t.Logf(" THD (h2-h10): %.1f%%", thd) + + // --- Assertions --- + if frameLoss > 0 { + t.Errorf("lost %d frames in localhost transport", frameLoss) + } + if seqGaps > 0 { + t.Errorf("detected %d sequence number gaps", seqGaps) + } + if markerCount != 1 { + t.Errorf("expected exactly 1 marker packet (first packet), got %d", markerCount) + } + if rms < 50 { + t.Errorf("RMS=%.1f is too low; signal appears silent or severely attenuated", rms) + } + freqDelta := math.Abs(freq - toneFreq) + if freqDelta > 20 { + t.Errorf("peak frequency %.0f Hz deviates from expected %.0f Hz by %.0f Hz", freq, toneFreq, freqDelta) + } + if thd > 50 { + t.Errorf("THD=%.1f%% is too high; signal is severely distorted", thd) + } + + // Log a summary line for quick scanning + result := "PASS" + issues := []string{} + if frameLoss > 0 { + issues = append(issues, fmt.Sprintf("%d frames lost", frameLoss)) + } + if freqDelta > 20 { + issues = append(issues, fmt.Sprintf("freq off by %.0f Hz", freqDelta)) + } + if thd > 50 { + issues = append(issues, fmt.Sprintf("THD %.1f%%", thd)) + } + if rms < 50 { + issues = append(issues, "silent") + } + if len(issues) > 0 { + result = "FAIL: " + fmt.Sprintf("%v", issues) + } + t.Logf("── Summary: %s ──", result) +} diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 415e75b18f62..b51c57181d48 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -3,8 +3,10 @@ package openai import ( "context" "encoding/base64" + "encoding/binary" "encoding/json" "fmt" + "math" "os" "sync" "time" @@ -40,23 +42,17 @@ const ( maxAudioBufferSize = 100 * 1024 * 1024 // Maximum WebSocket message size in bytes (10MB) to prevent DoS attacks maxWebSocketMessageSize = 10 * 1024 * 1024 + + defaultInstructions = "You are a helpful voice assistant. " + + "Your responses will be spoken aloud using text-to-speech, so keep them concise and conversational. " + + "Do not use markdown formatting, bullet points, numbered lists, code blocks, or special characters. " + + "Speak naturally as you would in a phone conversation. " + + "Avoid parenthetical asides, URLs, and anything that cannot be clearly vocalized." ) // A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result // If the model support instead audio-to-audio, we will use the specific gRPC calls instead -// LockedWebsocket wraps a websocket connection with a mutex for safe concurrent writes -type LockedWebsocket struct { - *websocket.Conn - sync.Mutex -} - -func (l *LockedWebsocket) WriteMessage(messageType int, data []byte) error { - l.Lock() - defer l.Unlock() - return l.Conn.WriteMessage(messageType, data) -} - // Session represents a single WebSocket connection and its state type Session struct { ID string @@ -77,8 +73,48 @@ type Session struct { ModelInterface Model // The pipeline model config or the config for an any-to-any model ModelConfig *config.ModelConfig - InputSampleRate int - MaxOutputTokens types.IntOrInf + InputSampleRate int + OutputSampleRate int + MaxOutputTokens types.IntOrInf + + // Response cancellation: protects activeResponseCancel/activeResponseDone + responseMu sync.Mutex + activeResponseCancel context.CancelFunc + activeResponseDone chan struct{} +} + +// cancelActiveResponse cancels any in-flight response and waits for its +// goroutine to exit. This ensures we never have overlapping responses and +// that interrupted responses are fully cleaned up before starting a new one. +func (s *Session) cancelActiveResponse() { + s.responseMu.Lock() + cancel := s.activeResponseCancel + done := s.activeResponseDone + s.responseMu.Unlock() + + if cancel != nil { + cancel() + } + if done != nil { + <-done + } +} + +// startResponse cancels any active response and returns a new context for +// the replacement response. The caller MUST close the returned done channel +// when the response goroutine exits. +func (s *Session) startResponse(parent context.Context) (context.Context, chan struct{}) { + s.cancelActiveResponse() + + ctx, cancel := context.WithCancel(parent) + done := make(chan struct{}) + + s.responseMu.Lock() + s.activeResponseCancel = cancel + s.activeResponseDone = done + s.responseMu.Unlock() + + return ctx, done } func (s *Session) FromClient(session *types.SessionUnion) { @@ -187,378 +223,414 @@ func Realtime(application *application.Application) echo.HandlerFunc { func registerRealtime(application *application.Application, model string) func(c *websocket.Conn) { return func(conn *websocket.Conn) { - c := &LockedWebsocket{Conn: conn} - + t := NewWebSocketTransport(conn) evaluator := application.TemplatesEvaluator() - xlog.Debug("Realtime WebSocket connection established", "address", c.RemoteAddr().String(), "model", model) + xlog.Debug("Realtime WebSocket connection established", "address", conn.RemoteAddr().String(), "model", model) + runRealtimeSession(application, t, model, evaluator) + } +} - // TODO: Allow any-to-any model to be specified - cl := application.ModelConfigLoader() - cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(model, application.ApplicationConfig()) - if err != nil { - xlog.Error("failed to load model config", "error", err) - sendError(c, "model_load_error", "Failed to load model config", "", "") - return - } +// runRealtimeSession runs the main event loop for a realtime session. +// It is transport-agnostic and works with both WebSocket and WebRTC. +func runRealtimeSession(application *application.Application, t Transport, model string, evaluator *templates.Evaluator) { + // TODO: Allow any-to-any model to be specified + cl := application.ModelConfigLoader() + cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(model, application.ApplicationConfig()) + if err != nil { + xlog.Error("failed to load model config", "error", err) + sendError(t, "model_load_error", "Failed to load model config", "", "") + return + } - if cfg == nil || (cfg.Pipeline.VAD == "" && cfg.Pipeline.Transcription == "" && cfg.Pipeline.TTS == "" && cfg.Pipeline.LLM == "") { - xlog.Error("model is not a pipeline", "model", model) - sendError(c, "invalid_model", "Model is not a pipeline model", "", "") - return - } + if cfg == nil || (cfg.Pipeline.VAD == "" && cfg.Pipeline.Transcription == "" && cfg.Pipeline.TTS == "" && cfg.Pipeline.LLM == "") { + xlog.Error("model is not a pipeline", "model", model) + sendError(t, "invalid_model", "Model is not a pipeline model", "", "") + return + } - sttModel := cfg.Pipeline.Transcription - - sessionID := generateSessionID() - session := &Session{ - ID: sessionID, - TranscriptionOnly: false, - Model: model, - Voice: cfg.TTSConfig.Voice, - ModelConfig: cfg, - TurnDetection: &types.TurnDetectionUnion{ - ServerVad: &types.ServerVad{ - Threshold: 0.5, - PrefixPaddingMs: 300, - SilenceDurationMs: 500, - CreateResponse: true, - }, + sttModel := cfg.Pipeline.Transcription + + sessionID := generateSessionID() + session := &Session{ + ID: sessionID, + TranscriptionOnly: false, + Model: model, + Voice: cfg.TTSConfig.Voice, + Instructions: defaultInstructions, + ModelConfig: cfg, + TurnDetection: &types.TurnDetectionUnion{ + ServerVad: &types.ServerVad{ + Threshold: 0.5, + PrefixPaddingMs: 300, + SilenceDurationMs: 500, + CreateResponse: true, }, - InputAudioTranscription: &types.AudioTranscription{ - Model: sttModel, - }, - Conversations: make(map[string]*Conversation), - InputSampleRate: defaultRemoteSampleRate, - } + }, + InputAudioTranscription: &types.AudioTranscription{ + Model: sttModel, + }, + Conversations: make(map[string]*Conversation), + InputSampleRate: defaultRemoteSampleRate, + OutputSampleRate: defaultRemoteSampleRate, + } + + // Create a default conversation + conversationID := generateConversationID() + conversation := &Conversation{ + ID: conversationID, + // TODO: We need to truncate the conversation items when a new item is added and we have run out of space. There are multiple places where items + // can be added so we could use a datastructure here that enforces truncation upon addition + Items: []*types.MessageItemUnion{}, + } + session.Conversations[conversationID] = conversation + session.DefaultConversationID = conversationID + + m, err := newModel( + &cfg.Pipeline, + application.ModelConfigLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + evaluator, + ) + if err != nil { + xlog.Error("failed to load model", "error", err) + sendError(t, "model_load_error", "Failed to load model", "", "") + return + } + session.ModelInterface = m + + // Store the session and notify the transport (for WebRTC audio track handling) + sessionLock.Lock() + sessions[sessionID] = session + sessionLock.Unlock() + + // For WebRTC, inbound audio arrives as Opus (48kHz) and is decoded+resampled + // to localSampleRate in handleIncomingAudioTrack. Set InputSampleRate to + // match so handleVAD doesn't needlessly double-resample. + if _, ok := t.(*WebRTCTransport); ok { + session.InputSampleRate = localSampleRate + } - // Create a default conversation - conversationID := generateConversationID() - conversation := &Conversation{ - ID: conversationID, - // TODO: We need to truncate the conversation items when a new item is added and we have run out of space. There are multiple places where items - // can be added so we could use a datastructure here that enforces truncation upon addition - Items: []*types.MessageItemUnion{}, + if sn, ok := t.(interface{ SetSession(*Session) }); ok { + sn.SetSession(session) + } + + sendEvent(t, types.SessionCreatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + }, + Session: session.ToServer(), + }) + + var ( + msg []byte + wg sync.WaitGroup + done = make(chan struct{}) + ) + + vadServerStarted := false + toggleVAD := func() { + if session.TurnDetection != nil && session.TurnDetection.ServerVad != nil && !vadServerStarted { + xlog.Debug("Starting VAD goroutine...") + done = make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + conversation := session.Conversations[session.DefaultConversationID] + handleVAD(session, conversation, t, done) + }() + vadServerStarted = true + } else if (session.TurnDetection == nil || session.TurnDetection.ServerVad == nil) && vadServerStarted { + xlog.Debug("Stopping VAD goroutine...") + close(done) + vadServerStarted = false } - session.Conversations[conversationID] = conversation - session.DefaultConversationID = conversationID - - m, err := newModel( - &cfg.Pipeline, - application.ModelConfigLoader(), - application.ModelLoader(), - application.ApplicationConfig(), - evaluator, - ) + } + + toggleVAD() + + for { + msg, err = t.ReadEvent() if err != nil { - xlog.Error("failed to load model", "error", err) - sendError(c, "model_load_error", "Failed to load model", "", "") - return + xlog.Error("read error", "error", err) + break } - session.ModelInterface = m - // Store the session - sessionLock.Lock() - sessions[sessionID] = session - sessionLock.Unlock() + // Handle diagnostic events that aren't part of the OpenAI protocol + var rawType struct { + Type string `json:"type"` + } + if json.Unmarshal(msg, &rawType) == nil && rawType.Type == "test_tone" { + xlog.Debug("Generating test tone") + go sendTestTone(t) + continue + } - sendEvent(c, types.SessionCreatedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", - }, - Session: session.ToServer(), - }) + // Parse the incoming message + event, err := types.UnmarshalClientEvent(msg) + if err != nil { + xlog.Error("invalid json", "error", err) + sendError(t, "invalid_json", "Invalid JSON format", "", "") + continue + } - var ( - msg []byte - wg sync.WaitGroup - done = make(chan struct{}) - ) + switch e := event.(type) { + case types.SessionUpdateEvent: + xlog.Debug("recv", "message", string(msg)) + + // Handle transcription session update + if e.Session.Transcription != nil { + if err := updateTransSession( + session, + &e.Session, + application.ModelConfigLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + ); err != nil { + xlog.Error("failed to update session", "error", err) + sendError(t, "session_update_error", "Failed to update session", "", "") + continue + } - vadServerStarted := false - toggleVAD := func() { - if session.TurnDetection.ServerVad != nil && !vadServerStarted { - xlog.Debug("Starting VAD goroutine...") - wg.Add(1) - go func() { - defer wg.Done() - conversation := session.Conversations[session.DefaultConversationID] - handleVAD(session, conversation, c, done) - }() - vadServerStarted = true - } else if session.TurnDetection.ServerVad == nil && vadServerStarted { - xlog.Debug("Stopping VAD goroutine...") + toggleVAD() - go func() { - done <- struct{}{} - }() - vadServerStarted = false + sendEvent(t, types.SessionUpdatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + }, + Session: session.ToServer(), + }) } - } - toggleVAD() + // Handle realtime session update + if e.Session.Realtime != nil { + if err := updateSession( + session, + &e.Session, + application.ModelConfigLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + evaluator, + ); err != nil { + xlog.Error("failed to update session", "error", err) + sendError(t, "session_update_error", "Failed to update session", "", "") + continue + } + + toggleVAD() - for { - if _, msg, err = c.ReadMessage(); err != nil { - xlog.Error("read error", "error", err) - break + sendEvent(t, types.SessionUpdatedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + }, + Session: session.ToServer(), + }) } - // Parse the incoming message - event, err := types.UnmarshalClientEvent(msg) - if err != nil { - xlog.Error("invalid json", "error", err) - sendError(c, "invalid_json", "Invalid JSON format", "", "") + case types.InputAudioBufferAppendEvent: + // Handle 'input_audio_buffer.append' + if e.Audio == "" { + xlog.Error("Audio data is missing in 'input_audio_buffer.append'") + sendError(t, "missing_audio_data", "Audio data is missing", "", "") continue } - switch e := event.(type) { - case types.SessionUpdateEvent: - xlog.Debug("recv", "message", string(msg)) - - // Handle transcription session update - if e.Session.Transcription != nil { - if err := updateTransSession( - session, - &e.Session, - application.ModelConfigLoader(), - application.ModelLoader(), - application.ApplicationConfig(), - ); err != nil { - xlog.Error("failed to update session", "error", err) - sendError(c, "session_update_error", "Failed to update session", "", "") - continue - } - - toggleVAD() + // Decode base64 audio data + decodedAudio, err := base64.StdEncoding.DecodeString(e.Audio) + if err != nil { + xlog.Error("failed to decode audio data", "error", err) + sendError(t, "invalid_audio_data", "Failed to decode audio data", "", "") + continue + } - sendEvent(c, types.SessionUpdatedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", - }, - Session: session.ToServer(), - }) - } + // Check buffer size limits before appending + session.AudioBufferLock.Lock() + newSize := len(session.InputAudioBuffer) + len(decodedAudio) + if newSize > maxAudioBufferSize { + session.AudioBufferLock.Unlock() + xlog.Error("audio buffer size limit exceeded", "current_size", len(session.InputAudioBuffer), "incoming_size", len(decodedAudio), "limit", maxAudioBufferSize) + sendError(t, "buffer_size_exceeded", fmt.Sprintf("Audio buffer size limit exceeded (max %d bytes)", maxAudioBufferSize), "", "") + continue + } - // Handle realtime session update - if e.Session.Realtime != nil { - if err := updateSession( - session, - &e.Session, - application.ModelConfigLoader(), - application.ModelLoader(), - application.ApplicationConfig(), - evaluator, - ); err != nil { - xlog.Error("failed to update session", "error", err) - sendError(c, "session_update_error", "Failed to update session", "", "") - continue - } + // Append to InputAudioBuffer + session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...) + session.AudioBufferLock.Unlock() - toggleVAD() + case types.InputAudioBufferCommitEvent: + xlog.Debug("recv", "message", string(msg)) - sendEvent(c, types.SessionUpdatedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", - }, - Session: session.ToServer(), - }) - } + sessionLock.Lock() + isServerVAD := session.TurnDetection != nil && session.TurnDetection.ServerVad != nil + sessionLock.Unlock() - case types.InputAudioBufferAppendEvent: - // Handle 'input_audio_buffer.append' - if e.Audio == "" { - xlog.Error("Audio data is missing in 'input_audio_buffer.append'") - sendError(c, "missing_audio_data", "Audio data is missing", "", "") - continue - } + // TODO: At the least need to check locking and timer state in the VAD Go routine before allowing this + if isServerVAD { + sendNotImplemented(t, "input_audio_buffer.commit in conjunction with VAD") + continue + } - // Decode base64 audio data - decodedAudio, err := base64.StdEncoding.DecodeString(e.Audio) - if err != nil { - xlog.Error("failed to decode audio data", "error", err) - sendError(c, "invalid_audio_data", "Failed to decode audio data", "", "") - continue - } + session.AudioBufferLock.Lock() + allAudio := make([]byte, len(session.InputAudioBuffer)) + copy(allAudio, session.InputAudioBuffer) + session.InputAudioBuffer = nil + session.AudioBufferLock.Unlock() - // Check buffer size limits before appending - session.AudioBufferLock.Lock() - newSize := len(session.InputAudioBuffer) + len(decodedAudio) - if newSize > maxAudioBufferSize { - session.AudioBufferLock.Unlock() - xlog.Error("audio buffer size limit exceeded", "current_size", len(session.InputAudioBuffer), "incoming_size", len(decodedAudio), "limit", maxAudioBufferSize) - sendError(c, "buffer_size_exceeded", fmt.Sprintf("Audio buffer size limit exceeded (max %d bytes)", maxAudioBufferSize), "", "") - continue - } + sendEvent(t, types.InputAudioBufferCommittedEvent{ + ServerEventBase: types.ServerEventBase{}, + ItemID: generateItemID(), + }) - // Append to InputAudioBuffer - session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...) - session.AudioBufferLock.Unlock() + respCtx, respDone := session.startResponse(context.Background()) + go func() { + defer close(respDone) + commitUtterance(respCtx, allAudio, session, conversation, t) + }() + + case types.ConversationItemCreateEvent: + xlog.Debug("recv", "message", string(msg)) + // Add the item to the conversation + item := e.Item + // Ensure IDs are present + if item.User != nil && item.User.ID == "" { + item.User.ID = generateItemID() + } + if item.Assistant != nil && item.Assistant.ID == "" { + item.Assistant.ID = generateItemID() + } + if item.System != nil && item.System.ID == "" { + item.System.ID = generateItemID() + } + if item.FunctionCall != nil && item.FunctionCall.ID == "" { + item.FunctionCall.ID = generateItemID() + } + if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { + item.FunctionCallOutput.ID = generateItemID() + } - case types.InputAudioBufferCommitEvent: - xlog.Debug("recv", "message", string(msg)) + conversation.Lock.Lock() + conversation.Items = append(conversation.Items, &item) + conversation.Lock.Unlock() - sessionLock.Lock() - isServerVAD := session.TurnDetection.ServerVad != nil - sessionLock.Unlock() + sendEvent(t, types.ConversationItemAddedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: e.EventID, + }, + PreviousItemID: e.PreviousItemID, + Item: item, + }) - // TODO: At the least need to check locking and timer state in the VAD Go routine before allowing this - if isServerVAD { - sendNotImplemented(c, "input_audio_buffer.commit in conjunction with VAD") - continue - } + case types.ConversationItemDeleteEvent: + sendError(t, "not_implemented", "Deleting items not implemented", "", "event_TODO") - session.AudioBufferLock.Lock() - allAudio := make([]byte, len(session.InputAudioBuffer)) - copy(allAudio, session.InputAudioBuffer) - session.InputAudioBuffer = nil - session.AudioBufferLock.Unlock() + case types.ConversationItemRetrieveEvent: + xlog.Debug("recv", "message", string(msg)) - go commitUtterance(context.TODO(), allAudio, session, conversation, c) + if e.ItemID == "" { + sendError(t, "invalid_item_id", "Need item_id, but none specified", "", "event_TODO") + continue + } - case types.ConversationItemCreateEvent: - xlog.Debug("recv", "message", string(msg)) - // Add the item to the conversation - item := e.Item - // Ensure IDs are present - if item.User != nil && item.User.ID == "" { - item.User.ID = generateItemID() - } - if item.Assistant != nil && item.Assistant.ID == "" { - item.Assistant.ID = generateItemID() - } - if item.System != nil && item.System.ID == "" { - item.System.ID = generateItemID() - } - if item.FunctionCall != nil && item.FunctionCall.ID == "" { - item.FunctionCall.ID = generateItemID() + conversation.Lock.Lock() + var retrievedItem types.MessageItemUnion + for _, item := range conversation.Items { + // We need to check ID in the union + var id string + if item.System != nil { + id = item.System.ID + } else if item.User != nil { + id = item.User.ID + } else if item.Assistant != nil { + id = item.Assistant.ID + } else if item.FunctionCall != nil { + id = item.FunctionCall.ID + } else if item.FunctionCallOutput != nil { + id = item.FunctionCallOutput.ID } - if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { - item.FunctionCallOutput.ID = generateItemID() - } - - conversation.Lock.Lock() - conversation.Items = append(conversation.Items, &item) - conversation.Lock.Unlock() - - sendEvent(c, types.ConversationItemAddedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: e.EventID, - }, - PreviousItemID: e.PreviousItemID, - Item: item, - }) - case types.ConversationItemDeleteEvent: - sendError(c, "not_implemented", "Deleting items not implemented", "", "event_TODO") + if id == e.ItemID { + retrievedItem = *item + break + } + } + conversation.Lock.Unlock() - case types.ConversationItemRetrieveEvent: - xlog.Debug("recv", "message", string(msg)) + sendEvent(t, types.ConversationItemRetrievedEvent{ + ServerEventBase: types.ServerEventBase{ + EventID: "event_TODO", + }, + Item: retrievedItem, + }) - if e.ItemID == "" { - sendError(c, "invalid_item_id", "Need item_id, but none specified", "", "event_TODO") - continue - } + case types.ResponseCreateEvent: + xlog.Debug("recv", "message", string(msg)) + // Handle optional items to add to context + if len(e.Response.Input) > 0 { conversation.Lock.Lock() - var retrievedItem types.MessageItemUnion - for _, item := range conversation.Items { - // We need to check ID in the union - var id string - if item.System != nil { - id = item.System.ID - } else if item.User != nil { - id = item.User.ID - } else if item.Assistant != nil { - id = item.Assistant.ID - } else if item.FunctionCall != nil { - id = item.FunctionCall.ID - } else if item.FunctionCallOutput != nil { - id = item.FunctionCallOutput.ID + for _, item := range e.Response.Input { + // Ensure IDs are present + if item.User != nil && item.User.ID == "" { + item.User.ID = generateItemID() } - - if id == e.ItemID { - retrievedItem = *item - break + if item.Assistant != nil && item.Assistant.ID == "" { + item.Assistant.ID = generateItemID() } - } - conversation.Lock.Unlock() - - sendEvent(c, types.ConversationItemRetrievedEvent{ - ServerEventBase: types.ServerEventBase{ - EventID: "event_TODO", - }, - Item: retrievedItem, - }) - - case types.ResponseCreateEvent: - xlog.Debug("recv", "message", string(msg)) - - // Handle optional items to add to context - if len(e.Response.Input) > 0 { - conversation.Lock.Lock() - for _, item := range e.Response.Input { - // Ensure IDs are present - if item.User != nil && item.User.ID == "" { - item.User.ID = generateItemID() - } - if item.Assistant != nil && item.Assistant.ID == "" { - item.Assistant.ID = generateItemID() - } - if item.System != nil && item.System.ID == "" { - item.System.ID = generateItemID() - } - if item.FunctionCall != nil && item.FunctionCall.ID == "" { - item.FunctionCall.ID = generateItemID() - } - if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { - item.FunctionCallOutput.ID = generateItemID() - } - - conversation.Items = append(conversation.Items, &item) + if item.System != nil && item.System.ID == "" { + item.System.ID = generateItemID() + } + if item.FunctionCall != nil && item.FunctionCall.ID == "" { + item.FunctionCall.ID = generateItemID() + } + if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { + item.FunctionCallOutput.ID = generateItemID() } - conversation.Lock.Unlock() - } - go triggerResponse(session, conversation, c, &e.Response) + conversation.Items = append(conversation.Items, &item) + } + conversation.Lock.Unlock() + } - case types.ResponseCancelEvent: - xlog.Debug("recv", "message", string(msg)) + respCtx, respDone := session.startResponse(context.Background()) + go func() { + defer close(respDone) + triggerResponse(respCtx, session, conversation, t, &e.Response) + }() - // Handle cancellation of ongoing responses - // Implement cancellation logic as needed - sendNotImplemented(c, "response.cancel") + case types.ResponseCancelEvent: + xlog.Debug("recv", "message", string(msg)) + session.cancelActiveResponse() - default: - xlog.Error("unknown message type") - // sendError(c, "unknown_message_type", fmt.Sprintf("Unknown message type: %s", incomingMsg.Type), "", "") - } + default: + xlog.Error("unknown message type") + // sendError(t, "unknown_message_type", fmt.Sprintf("Unknown message type: %s", incomingMsg.Type), "", "") } + } - // Close the done channel to signal goroutines to exit - close(done) - wg.Wait() + // Cancel any in-flight response before tearing down + session.cancelActiveResponse() - // Remove the session from the sessions map - sessionLock.Lock() - delete(sessions, sessionID) - sessionLock.Unlock() + // Signal any running VAD goroutine to exit. + if vadServerStarted { + close(done) } + wg.Wait() + + // Remove the session from the sessions map + sessionLock.Lock() + delete(sessions, sessionID) + sessionLock.Unlock() } -// Helper function to send events to the client -func sendEvent(c *LockedWebsocket, event types.ServerEvent) { - eventBytes, err := json.Marshal(event) - if err != nil { - xlog.Error("failed to marshal event", "error", err) - return - } - if err = c.WriteMessage(websocket.TextMessage, eventBytes); err != nil { +// sendEvent sends a server event via the transport, logging any errors. +func sendEvent(t Transport, event types.ServerEvent) { + if err := t.SendEvent(event); err != nil { xlog.Error("write error", "error", err) } } -// Helper function to send errors to the client -func sendError(c *LockedWebsocket, code, message, param, eventID string) { +// sendError sends an error event to the client. +func sendError(t Transport, code, message, param, eventID string) { errorEvent := types.ErrorEvent{ ServerEventBase: types.ServerEventBase{ EventID: eventID, @@ -572,11 +644,35 @@ func sendError(c *LockedWebsocket, code, message, param, eventID string) { }, } - sendEvent(c, errorEvent) + sendEvent(t, errorEvent) } -func sendNotImplemented(c *LockedWebsocket, message string) { - sendError(c, "not_implemented", message, "", "event_TODO") +func sendNotImplemented(t Transport, message string) { + sendError(t, "not_implemented", message, "", "event_TODO") +} + +// sendTestTone generates a 1-second 440 Hz sine wave and sends it through +// the transport's audio path. This exercises the full Opus encode → RTP → +// browser decode pipeline without involving TTS. +func sendTestTone(t Transport) { + const ( + freq = 440.0 + sampleRate = 24000 + duration = 1 // seconds + amplitude = 16000 + numSamples = sampleRate * duration + ) + + pcm := make([]byte, numSamples*2) // 16-bit samples = 2 bytes each + for i := 0; i < numSamples; i++ { + sample := int16(amplitude * math.Sin(2*math.Pi*freq*float64(i)/sampleRate)) + binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) + } + + xlog.Debug("Sending test tone", "samples", numSamples, "sample_rate", sampleRate, "freq", freq) + if err := t.SendAudio(context.Background(), pcm, sampleRate); err != nil { + xlog.Error("test tone send failed", "error", err) + } } func updateTransSession(session *Session, update *types.SessionUnion, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { @@ -616,7 +712,7 @@ func updateTransSession(session *Session, update *types.SessionUnion, cl *config trCur.Prompt = trUpd.Prompt } - if update.Transcription.Audio.Input.TurnDetection != nil { + if update.Transcription.Audio.Input.TurnDetectionSet { session.TurnDetection = update.Transcription.Audio.Input.TurnDetection } @@ -675,7 +771,7 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode session.ModelInterface = m } - if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.TurnDetection != nil { + if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.TurnDetectionSet { session.TurnDetection = rt.Audio.Input.TurnDetection } @@ -685,6 +781,12 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode } } + if rt.Audio != nil && rt.Audio.Output != nil && rt.Audio.Output.Format != nil && rt.Audio.Output.Format.PCM != nil { + if rt.Audio.Output.Format.PCM.Rate > 0 { + session.OutputSampleRate = rt.Audio.Output.Format.PCM.Rate + } + } + if rt.Instructions != "" { session.Instructions = rt.Instructions } @@ -705,7 +807,7 @@ func updateSession(session *Session, update *types.SessionUnion, cl *config.Mode // handleVAD is a goroutine that listens for audio data from the client, // runs VAD on the audio data, and commits utterances to the conversation -func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done chan struct{}) { +func handleVAD(session *Session, conv *Conversation, t Transport, done chan struct{}) { vadContext, cancel := context.WithCancel(context.Background()) go func() { <-done @@ -713,7 +815,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch }() silenceThreshold := 0.5 // Default 500ms - if session.TurnDetection.ServerVad != nil { + if session.TurnDetection != nil && session.TurnDetection.ServerVad != nil { silenceThreshold = float64(session.TurnDetection.ServerVad.SilenceDurationMs) / 1000 } @@ -734,7 +836,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch session.AudioBufferLock.Unlock() aints := sound.BytesToInt16sLE(allAudio) - if len(aints) == 0 || len(aints) < int(silenceThreshold)*session.InputSampleRate { + if len(aints) == 0 || len(aints) < int(silenceThreshold*float64(session.InputSampleRate)) { continue } @@ -748,7 +850,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch continue } xlog.Error("failed to process audio", "error", err) - sendError(c, "processing_error", "Failed to process audio: "+err.Error(), "", "") + sendError(t, "processing_error", "Failed to process audio: "+err.Error(), "", "") continue } @@ -760,21 +862,17 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch session.InputAudioBuffer = nil session.AudioBufferLock.Unlock() - // NOTE: OpenAI doesn't send this message unless the client requests it - // xlog.Debug("Detected silence for a while, clearing audio buffer") - // sendEvent(c, types.InputAudioBufferClearedEvent{ - // ServerEventBase: types.ServerEventBase{ - // EventID: "event_TODO", - // }, - // }) - continue } else if len(segments) == 0 { continue } if !speechStarted { - sendEvent(c, types.InputAudioBufferSpeechStartedEvent{ + // Barge-in: cancel any in-flight response so we stop + // sending audio and don't keep the interrupted reply in history. + session.cancelActiveResponse() + + sendEvent(t, types.InputAudioBufferSpeechStartedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, @@ -795,7 +893,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch session.InputAudioBuffer = nil session.AudioBufferLock.Unlock() - sendEvent(c, types.InputAudioBufferSpeechStoppedEvent{ + sendEvent(t, types.InputAudioBufferSpeechStoppedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, @@ -803,7 +901,7 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch }) speechStarted = false - sendEvent(c, types.InputAudioBufferCommittedEvent{ + sendEvent(t, types.InputAudioBufferCommittedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, @@ -813,13 +911,17 @@ func handleVAD(session *Session, conv *Conversation, c *LockedWebsocket, done ch abytes := sound.Int16toBytesLE(aints) // TODO: Remove prefix silence that is is over TurnDetectionParams.PrefixPaddingMs - go commitUtterance(vadContext, abytes, session, conv, c) + respCtx, respDone := session.startResponse(vadContext) + go func() { + defer close(respDone) + commitUtterance(respCtx, abytes, session, conv, t) + }() } } } } -func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, c *LockedWebsocket) { +func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, t Transport) { if len(utt) == 0 { return } @@ -851,15 +953,15 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co if session.InputAudioTranscription != nil { tr, err := session.ModelInterface.Transcribe(ctx, f.Name(), session.InputAudioTranscription.Language, false, false, session.InputAudioTranscription.Prompt) if err != nil { - sendError(c, "transcription_failed", err.Error(), "", "event_TODO") + sendError(t, "transcription_failed", err.Error(), "", "event_TODO") return } else if tr == nil { - sendError(c, "transcription_failed", "trancribe result is nil", "", "event_TODO") + sendError(t, "transcription_failed", "trancribe result is nil", "", "event_TODO") return } transcript = tr.Text - sendEvent(c, types.ConversationItemInputAudioTranscriptionCompletedEvent{ + sendEvent(t, types.ConversationItemInputAudioTranscriptionCompletedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, @@ -871,12 +973,12 @@ func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Co Transcript: transcript, }) } else { - sendNotImplemented(c, "any-to-any models") + sendNotImplemented(t, "any-to-any models") return } if !session.TranscriptionOnly { - generateResponse(session, utt, transcript, conv, c, websocket.TextMessage) + generateResponse(ctx, session, utt, transcript, conv, t) } } @@ -901,7 +1003,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADS } // Function to generate a response based on the conversation -func generateResponse(session *Session, utt []byte, transcript string, conv *Conversation, c *LockedWebsocket, mt int) { +func generateResponse(ctx context.Context, session *Session, utt []byte, transcript string, conv *Conversation, t Transport) { xlog.Debug("Generating realtime response...") // Create user message item @@ -922,14 +1024,14 @@ func generateResponse(session *Session, utt []byte, transcript string, conv *Con conv.Items = append(conv.Items, &item) conv.Lock.Unlock() - sendEvent(c, types.ConversationItemAddedEvent{ + sendEvent(t, types.ConversationItemAddedEvent{ Item: item, }) - triggerResponse(session, conv, c, nil) + triggerResponse(ctx, session, conv, t, nil) } -func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, overrides *types.ResponseCreateParams) { +func triggerResponse(ctx context.Context, session *Session, conv *Conversation, t Transport, overrides *types.ResponseCreateParams) { config := session.ModelInterface.PredictConfig() // Default values @@ -1077,7 +1179,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o } responseID := generateUniqueID() - sendEvent(c, types.ResponseCreatedEvent{ + sendEvent(t, types.ResponseCreatedEvent{ ServerEventBase: types.ServerEventBase{}, Response: types.Response{ ID: responseID, @@ -1086,15 +1188,29 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o }, }) - predFunc, err := session.ModelInterface.Predict(context.TODO(), conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil) + predFunc, err := session.ModelInterface.Predict(ctx, conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil) if err != nil { - sendError(c, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here + sendError(t, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here return } pred, err := predFunc() if err != nil { - sendError(c, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", "") + sendError(t, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", "") + return + } + + // Check for cancellation after LLM inference (barge-in may have fired) + if ctx.Err() != nil { + xlog.Debug("Response cancelled after LLM inference (barge-in)") + sendEvent(t, types.ResponseDoneEvent{ + ServerEventBase: types.ServerEventBase{}, + Response: types.Response{ + ID: responseID, + Object: "realtime.response", + Status: types.ResponseStatusCancelled, + }, + }) return } @@ -1194,14 +1310,14 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o conv.Items = append(conv.Items, &item) conv.Lock.Unlock() - sendEvent(c, types.ResponseOutputItemAddedEvent{ + sendEvent(t, types.ResponseOutputItemAddedEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: 0, Item: item, }) - sendEvent(c, types.ResponseContentPartAddedEvent{ + sendEvent(t, types.ResponseContentPartAddedEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1210,15 +1326,54 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o Part: item.Assistant.Content[0], }) - audioFilePath, res, err := session.ModelInterface.TTS(context.TODO(), finalSpeech, session.Voice, session.InputAudioTranscription.Language) + // removeItemFromConv removes the last occurrence of an item with + // the given assistant ID from conversation history. + removeItemFromConv := func(assistantID string) { + conv.Lock.Lock() + for i := len(conv.Items) - 1; i >= 0; i-- { + if conv.Items[i].Assistant != nil && conv.Items[i].Assistant.ID == assistantID { + conv.Items = append(conv.Items[:i], conv.Items[i+1:]...) + break + } + } + conv.Lock.Unlock() + } + + // sendCancelledResponse emits the cancelled status and cleans up the + // assistant item so the interrupted reply is not in chat history. + sendCancelledResponse := func() { + removeItemFromConv(item.Assistant.ID) + sendEvent(t, types.ResponseDoneEvent{ + ServerEventBase: types.ServerEventBase{}, + Response: types.Response{ + ID: responseID, + Object: "realtime.response", + Status: types.ResponseStatusCancelled, + }, + }) + } + + // Check for cancellation before TTS + if ctx.Err() != nil { + xlog.Debug("Response cancelled before TTS (barge-in)") + sendCancelledResponse() + return + } + + audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language) if err != nil { + if ctx.Err() != nil { + xlog.Debug("TTS cancelled (barge-in)") + sendCancelledResponse() + return + } xlog.Error("TTS failed", "error", err) - sendError(c, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID) + sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID) return } if !res.Success { xlog.Error("TTS failed", "message", res.Message) - sendError(c, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID) + sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID) return } defer os.Remove(audioFilePath) @@ -1226,21 +1381,41 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o audioBytes, err := os.ReadFile(audioFilePath) if err != nil { xlog.Error("failed to read TTS file", "error", err) - sendError(c, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID) + sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID) return } - // Strip WAV header (44 bytes) to get raw PCM data - // The OpenAI Realtime API expects raw PCM, not WAV files - const wavHeaderSize = 44 - pcmData := audioBytes - if len(audioBytes) > wavHeaderSize { - pcmData = audioBytes[wavHeaderSize:] + // Parse WAV header to get raw PCM and the actual sample rate from the TTS backend. + pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes) + if ttsSampleRate == 0 { + ttsSampleRate = localSampleRate + } + xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate) + + // SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the + // Opus encoder, which resamples to 48kHz internally. This avoids a + // lossy intermediate resample through 16kHz. + if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil { + if ctx.Err() != nil { + xlog.Debug("Audio playback cancelled (barge-in)") + sendCancelledResponse() + return + } + xlog.Error("failed to send audio via transport", "error", err) } - audioString := base64.StdEncoding.EncodeToString(pcmData) + // The base64 event (used by WebSocket clients) should be at the + // session's output sample rate. This is separate from InputSampleRate + // which tracks inbound audio (e.g. 16kHz for WebRTC). + wsPCM := pcmData + if ttsSampleRate != session.OutputSampleRate { + samples := sound.BytesToInt16sLE(pcmData) + resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate) + wsPCM = sound.Int16toBytesLE(resampled) + } + audioString := base64.StdEncoding.EncodeToString(wsPCM) - sendEvent(c, types.ResponseOutputAudioTranscriptDeltaEvent{ + sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1248,7 +1423,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o ContentIndex: 0, Delta: finalSpeech, }) - sendEvent(c, types.ResponseOutputAudioTranscriptDoneEvent{ + sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1257,7 +1432,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o Transcript: finalSpeech, }) - sendEvent(c, types.ResponseOutputAudioDeltaEvent{ + sendEvent(t, types.ResponseOutputAudioDeltaEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1265,7 +1440,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o ContentIndex: 0, Delta: audioString, }) - sendEvent(c, types.ResponseOutputAudioDoneEvent{ + sendEvent(t, types.ResponseOutputAudioDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1273,7 +1448,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o ContentIndex: 0, }) - sendEvent(c, types.ResponseContentPartDoneEvent{ + sendEvent(t, types.ResponseContentPartDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, @@ -1287,7 +1462,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o item.Assistant.Content[0].Audio = audioString conv.Lock.Unlock() - sendEvent(c, types.ResponseOutputItemDoneEvent{ + sendEvent(t, types.ResponseOutputItemDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: 0, @@ -1321,14 +1496,14 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o outputIndex++ } - sendEvent(c, types.ResponseOutputItemAddedEvent{ + sendEvent(t, types.ResponseOutputItemAddedEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: outputIndex, Item: fcItem, }) - sendEvent(c, types.ResponseFunctionCallArgumentsDeltaEvent{ + sendEvent(t, types.ResponseFunctionCallArgumentsDeltaEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: toolCallID, @@ -1337,7 +1512,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o Delta: tc.Arguments, }) - sendEvent(c, types.ResponseFunctionCallArgumentsDoneEvent{ + sendEvent(t, types.ResponseFunctionCallArgumentsDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: toolCallID, @@ -1347,7 +1522,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o Name: tc.Name, }) - sendEvent(c, types.ResponseOutputItemDoneEvent{ + sendEvent(t, types.ResponseOutputItemDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: outputIndex, @@ -1355,7 +1530,7 @@ func triggerResponse(session *Session, conv *Conversation, c *LockedWebsocket, o }) } - sendEvent(c, types.ResponseDoneEvent{ + sendEvent(t, types.ResponseDoneEvent{ ServerEventBase: types.ServerEventBase{}, Response: types.Response{ ID: responseID, diff --git a/core/http/endpoints/openai/realtime_transport.go b/core/http/endpoints/openai/realtime_transport.go new file mode 100644 index 000000000000..5ffcb0ba917e --- /dev/null +++ b/core/http/endpoints/openai/realtime_transport.go @@ -0,0 +1,23 @@ +package openai + +import ( + "context" + + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" +) + +// Transport abstracts event and audio I/O so the same session logic +// can serve both WebSocket and WebRTC connections. +type Transport interface { + // SendEvent marshals and sends a server event to the client. + SendEvent(event types.ServerEvent) error + // ReadEvent reads the next raw client event (JSON bytes). + ReadEvent() ([]byte, error) + // SendAudio sends raw PCM audio to the client at the given sample rate. + // For WebSocket this is a no-op (audio is sent via JSON events). + // For WebRTC this encodes to Opus and writes to the media track. + // The context allows cancellation for barge-in support. + SendAudio(ctx context.Context, pcmData []byte, sampleRate int) error + // Close tears down the underlying connection. + Close() error +} diff --git a/core/http/endpoints/openai/realtime_transport_webrtc.go b/core/http/endpoints/openai/realtime_transport_webrtc.go new file mode 100644 index 000000000000..f25d573db163 --- /dev/null +++ b/core/http/endpoints/openai/realtime_transport_webrtc.go @@ -0,0 +1,250 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "math/rand/v2" + "sync" + "time" + + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + "github.com/mudler/xlog" + "github.com/pion/rtp" + "github.com/pion/webrtc/v4" +) + +// WebRTCTransport implements Transport over a pion/webrtc PeerConnection. +// Events travel via the "oai-events" DataChannel; audio goes over an RTP track. +type WebRTCTransport struct { + pc *webrtc.PeerConnection + dc *webrtc.DataChannel + audioTrack *webrtc.TrackLocalStaticRTP + encoder *OpusEncoder + inEvents chan []byte + outEvents chan []byte // buffered outbound event queue + closed chan struct{} + closeOnce sync.Once + flushed chan struct{} // closed when sender goroutine has drained outEvents + dcReady chan struct{} // closed when data channel is open + dcReadyOnce sync.Once + sessionCh chan *Session // delivers session from runRealtimeSession to handleIncomingAudioTrack + + // RTP state for outbound audio — protected by rtpMu + rtpMu sync.Mutex + rtpSeqNum uint16 + rtpTimestamp uint32 + rtpMarker bool // true → next packet gets marker bit set +} + +func NewWebRTCTransport(pc *webrtc.PeerConnection, audioTrack *webrtc.TrackLocalStaticRTP) (*WebRTCTransport, error) { + enc, err := NewOpusEncoder() + if err != nil { + return nil, fmt.Errorf("webrtc transport: %w", err) + } + + t := &WebRTCTransport{ + pc: pc, + audioTrack: audioTrack, + encoder: enc, + inEvents: make(chan []byte, 256), + outEvents: make(chan []byte, 256), + closed: make(chan struct{}), + flushed: make(chan struct{}), + dcReady: make(chan struct{}), + sessionCh: make(chan *Session, 1), + rtpSeqNum: uint16(rand.UintN(65536)), + rtpTimestamp: rand.Uint32(), + rtpMarker: true, // first packet of the stream gets marker + } + + // The client creates the "oai-events" data channel (so m=application is + // included in the SDP offer). We receive it here via OnDataChannel. + pc.OnDataChannel(func(dc *webrtc.DataChannel) { + if dc.Label() != "oai-events" { + return + } + t.dc = dc + dc.OnOpen(func() { + t.dcReadyOnce.Do(func() { close(t.dcReady) }) + }) + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + select { + case t.inEvents <- msg.Data: + case <-t.closed: + } + }) + // The channel may already be open by the time OnDataChannel fires + if dc.ReadyState() == webrtc.DataChannelStateOpen { + t.dcReadyOnce.Do(func() { close(t.dcReady) }) + } + }) + + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + xlog.Debug("WebRTC connection state", "state", state.String()) + if state == webrtc.PeerConnectionStateFailed || + state == webrtc.PeerConnectionStateClosed || + state == webrtc.PeerConnectionStateDisconnected { + t.closeOnce.Do(func() { close(t.closed) }) + } + }) + + go t.sendLoop() + + return t, nil +} + +// sendLoop is a dedicated goroutine that drains outEvents and sends them +// over the data channel. It waits for the data channel to open before +// sending, and drains any remaining events when closed is signalled. +func (t *WebRTCTransport) sendLoop() { + defer close(t.flushed) + + // Wait for data channel to be ready + select { + case <-t.dcReady: + case <-t.closed: + return + } + + for { + select { + case data, ok := <-t.outEvents: + if !ok { + return + } + if err := t.dc.SendText(string(data)); err != nil { + xlog.Error("data channel send failed", "error", err) + return + } + case <-t.closed: + // Drain any remaining queued events before exiting + for { + select { + case data := <-t.outEvents: + if err := t.dc.SendText(string(data)); err != nil { + return + } + default: + return + } + } + } + } +} + +func (t *WebRTCTransport) SendEvent(event types.ServerEvent) error { + data, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("marshal event: %w", err) + } + + select { + case t.outEvents <- data: + return nil + case <-t.closed: + return fmt.Errorf("transport closed") + } +} + +func (t *WebRTCTransport) ReadEvent() ([]byte, error) { + select { + case msg := <-t.inEvents: + return msg, nil + case <-t.closed: + return nil, fmt.Errorf("transport closed") + } +} + +// SendAudio encodes raw PCM int16 LE to Opus and writes RTP packets to the +// audio track. The encoder resamples from the given sampleRate to 48kHz +// internally. Frames are paced at real-time intervals (20ms per frame) to +// avoid overwhelming the browser's jitter buffer with a burst of packets. +// +// The context allows callers to cancel mid-stream for barge-in support. +// When cancelled, the marker bit is set so the next audio segment starts +// cleanly in the browser's jitter buffer. +// +// RTP packets are constructed manually (rather than via WriteSample) so we +// can control the marker bit. pion's WriteSample sets the marker bit on +// every Opus packet, which causes Chrome's NetEq jitter buffer to reset +// its timing estimation for each frame, producing severe audio distortion. +func (t *WebRTCTransport) SendAudio(ctx context.Context, pcmData []byte, sampleRate int) error { + frames, err := t.encoder.Encode(pcmData, sampleRate) + if err != nil { + return err + } + + const frameDuration = 20 * time.Millisecond + const samplesPerFrame = 960 // 20ms at 48kHz + + ticker := time.NewTicker(frameDuration) + defer ticker.Stop() + + for i, frame := range frames { + t.rtpMu.Lock() + pkt := &rtp.Packet{ + Header: rtp.Header{ + Version: 2, + Marker: t.rtpMarker, + SequenceNumber: t.rtpSeqNum, + Timestamp: t.rtpTimestamp, + // SSRC and PayloadType are overridden by pion's writeRTP + }, + Payload: frame, + } + t.rtpSeqNum++ + t.rtpTimestamp += samplesPerFrame + t.rtpMarker = false // only the first packet gets marker + t.rtpMu.Unlock() + + if err := t.audioTrack.WriteRTP(pkt); err != nil { + return fmt.Errorf("write rtp: %w", err) + } + + // Pace output at ~real-time so the browser's jitter buffer + // receives packets at the expected rate. Skip wait after last frame. + if i < len(frames)-1 { + select { + case <-ticker.C: + case <-ctx.Done(): + // Barge-in: mark the next packet so the browser knows + // a new audio segment is starting after the interruption. + t.rtpMu.Lock() + t.rtpMarker = true + t.rtpMu.Unlock() + return ctx.Err() + case <-t.closed: + return fmt.Errorf("transport closed during audio send") + } + } + } + return nil +} + +// SetSession delivers the session to any goroutine waiting in WaitForSession. +func (t *WebRTCTransport) SetSession(s *Session) { + select { + case t.sessionCh <- s: + case <-t.closed: + } +} + +// WaitForSession blocks until the session is available or the transport closes. +func (t *WebRTCTransport) WaitForSession() *Session { + select { + case s := <-t.sessionCh: + return s + case <-t.closed: + return nil + } +} + +func (t *WebRTCTransport) Close() error { + // Signal no more events and unblock the sender if it's waiting + t.closeOnce.Do(func() { close(t.closed) }) + // Wait for the sender to drain any remaining queued events + <-t.flushed + t.encoder.Close() + return t.pc.Close() +} diff --git a/core/http/endpoints/openai/realtime_transport_ws.go b/core/http/endpoints/openai/realtime_transport_ws.go new file mode 100644 index 000000000000..6621f2ca6b82 --- /dev/null +++ b/core/http/endpoints/openai/realtime_transport_ws.go @@ -0,0 +1,47 @@ +package openai + +import ( + "context" + "encoding/json" + "sync" + + "github.com/gorilla/websocket" + "github.com/mudler/LocalAI/core/http/endpoints/openai/types" + "github.com/mudler/xlog" +) + +// WebSocketTransport implements Transport over a gorilla/websocket connection. +type WebSocketTransport struct { + conn *websocket.Conn + mu sync.Mutex +} + +func NewWebSocketTransport(conn *websocket.Conn) *WebSocketTransport { + return &WebSocketTransport{conn: conn} +} + +func (t *WebSocketTransport) SendEvent(event types.ServerEvent) error { + eventBytes, err := json.Marshal(event) + if err != nil { + xlog.Error("failed to marshal event", "error", err) + return err + } + t.mu.Lock() + defer t.mu.Unlock() + return t.conn.WriteMessage(websocket.TextMessage, eventBytes) +} + +func (t *WebSocketTransport) ReadEvent() ([]byte, error) { + _, msg, err := t.conn.ReadMessage() + return msg, err +} + +// SendAudio is a no-op for WebSocket — audio is delivered via JSON events +// (base64-encoded in response.audio.delta). +func (t *WebSocketTransport) SendAudio(_ context.Context, _ []byte, _ int) error { + return nil +} + +func (t *WebSocketTransport) Close() error { + return t.conn.Close() +} diff --git a/core/http/endpoints/openai/realtime_webrtc.go b/core/http/endpoints/openai/realtime_webrtc.go new file mode 100644 index 000000000000..6d3ead99b820 --- /dev/null +++ b/core/http/endpoints/openai/realtime_webrtc.go @@ -0,0 +1,250 @@ +package openai + +import ( + "math" + "net/http" + "time" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/pkg/sound" + "github.com/mudler/xlog" + "github.com/pion/webrtc/v4" +) + +// RealtimeCallRequest is the JSON body for POST /v1/realtime/calls. +type RealtimeCallRequest struct { + SDP string `json:"sdp"` + Model string `json:"model"` +} + +// RealtimeCallResponse is the JSON response for POST /v1/realtime/calls. +type RealtimeCallResponse struct { + SDP string `json:"sdp"` + SessionID string `json:"session_id"` +} + +// RealtimeCalls handles POST /v1/realtime/calls for WebRTC signaling. +func RealtimeCalls(application *application.Application) echo.HandlerFunc { + return func(c echo.Context) error { + var req RealtimeCallRequest + if err := c.Bind(&req); err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"}) + } + if req.SDP == "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "sdp is required"}) + } + if req.Model == "" { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "model is required"}) + } + + // Create a MediaEngine with Opus support + m := &webrtc.MediaEngine{} + if err := m.RegisterDefaultCodecs(); err != nil { + xlog.Error("failed to register codecs", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "codec registration failed"}) + } + + api := webrtc.NewAPI(webrtc.WithMediaEngine(m)) + + pc, err := api.NewPeerConnection(webrtc.Configuration{}) + if err != nil { + xlog.Error("failed to create peer connection", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create peer connection"}) + } + + // Create outbound audio track (Opus, 48kHz). + // We use TrackLocalStaticRTP (not TrackLocalStaticSample) so that + // SendAudio can construct RTP packets directly and control the marker + // bit. pion's WriteSample sets the marker bit on every Opus packet, + // which causes Chrome's NetEq jitter buffer to reset for each frame. + audioTrack, err := webrtc.NewTrackLocalStaticRTP( + webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, // Opus in WebRTC is always signaled as 2 channels per RFC 7587 + }, + "audio", + "localai", + ) + if err != nil { + pc.Close() + xlog.Error("failed to create audio track", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create audio track"}) + } + + rtpSender, err := pc.AddTrack(audioTrack) + if err != nil { + pc.Close() + xlog.Error("failed to add audio track", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to add audio track"}) + } + + // Drain RTCP (control protocol) packets we don't have anyting useful to do with + go func() { + buf := make([]byte, 1500) + for { + if _, _, err := rtpSender.Read(buf); err != nil { + return + } + } + }() + + // Create the transport (the data channel is created by the client and + // received via pc.OnDataChannel inside NewWebRTCTransport) + transport, err := NewWebRTCTransport(pc, audioTrack) + if err != nil { + pc.Close() + xlog.Error("failed to create webrtc transport", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create transport"}) + } + + // Handle incoming audio track from the client + pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + codec := track.Codec() + if codec.MimeType != webrtc.MimeTypeOpus { + xlog.Warn("unexpected track codec, ignoring", "mime", codec.MimeType) + return + } + xlog.Debug("Received audio track from client", + "codec", codec.MimeType, + "clock_rate", codec.ClockRate, + "channels", codec.Channels, + "sdp_fmtp", codec.SDPFmtpLine, + "payload_type", codec.PayloadType, + ) + + decoder, err := NewOpusDecoder() + if err != nil { + xlog.Error("failed to create opus decoder", "error", err) + return + } + defer decoder.Close() + + handleIncomingAudioTrack(track, decoder, transport) + }) + + // Set the remote SDP (client's offer) + if err := pc.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: req.SDP, + }); err != nil { + transport.Close() + xlog.Error("failed to set remote description", "error", err) + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid SDP offer"}) + } + + // Create answer + answer, err := pc.CreateAnswer(nil) + if err != nil { + transport.Close() + xlog.Error("failed to create answer", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create answer"}) + } + + if err := pc.SetLocalDescription(answer); err != nil { + transport.Close() + xlog.Error("failed to set local description", "error", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to set local description"}) + } + + // Wait for ICE gathering to complete (with timeout) + gatherDone := webrtc.GatheringCompletePromise(pc) + select { + case <-gatherDone: + case <-time.After(10 * time.Second): + xlog.Warn("ICE gathering timed out, using partial candidates") + } + + localDesc := pc.LocalDescription() + if localDesc == nil { + transport.Close() + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "no local description"}) + } + + sessionID := generateSessionID() + + // Start the realtime session in a goroutine + evaluator := application.TemplatesEvaluator() + go func() { + defer transport.Close() + runRealtimeSession(application, transport, req.Model, evaluator) + }() + + return c.JSON(http.StatusCreated, RealtimeCallResponse{ + SDP: localDesc.SDP, + SessionID: sessionID, + }) + } +} + +// handleIncomingAudioTrack reads Opus frames from a remote WebRTC track, +// decodes them to PCM, resamples to the session's input sample rate, +// and appends to the session's InputAudioBuffer. +func handleIncomingAudioTrack(track *webrtc.TrackRemote, decoder *OpusDecoder, transport *WebRTCTransport) { + session := transport.WaitForSession() + if session == nil { + xlog.Error("could not find session for incoming audio track (transport closed)") + sendError(transport, "session_error", "Session failed to start — check server logs", "", "") + return + } + + var frameCount int + var decodeErrors int + for { + pkt, _, err := track.ReadRTP() + if err != nil { + xlog.Debug("audio track read ended", "error", err) + return + } + + samples, err := decoder.Decode(pkt.Payload) + if err != nil { + decodeErrors++ + xlog.Warn("opus decode error", "error", err, "payload_bytes", len(pkt.Payload), "errors_so_far", decodeErrors) + continue + } + + // Log decode diagnostics for the first 50 frames to help debug audio issues + frameCount++ + if frameCount <= 50 { + var sumSq float64 + var peak int16 + for _, s := range samples { + sumSq += float64(s) * float64(s) + if s > peak { + peak = s + } else if -s > peak { + peak = -s + } + } + rms := math.Sqrt(sumSq / float64(len(samples))) + xlog.Debug("opus decode frame", + "frame", frameCount, + "payload_bytes", len(pkt.Payload), + "decoded_samples", len(samples), + "rms", int(rms), + "peak", peak, + "marker", pkt.Marker, + ) + } + + // Resample from 48kHz to the session's input sample rate (16kHz for + // WebRTC, set in runRealtimeSession). This single resample feeds both + // the audio buffer and VAD without a lossy intermediate step. + if session.InputSampleRate != opusSampleRate { + samples = sound.ResampleInt16(samples, opusSampleRate, session.InputSampleRate) + } + + pcmBytes := sound.Int16toBytesLE(samples) + + session.AudioBufferLock.Lock() + newSize := len(session.InputAudioBuffer) + len(pcmBytes) + if newSize <= maxAudioBufferSize { + session.InputAudioBuffer = append(session.InputAudioBuffer, pcmBytes...) + } else { + xlog.Warn("audio buffer full, dropping incoming audio") + } + session.AudioBufferLock.Unlock() + } +} diff --git a/core/http/endpoints/openai/types/types.go b/core/http/endpoints/openai/types/types.go index 751e79b6fbd5..2f75486adcc3 100644 --- a/core/http/endpoints/openai/types/types.go +++ b/core/http/endpoints/openai/types/types.go @@ -712,17 +712,39 @@ type SessionAudioInput struct { // Configuration for input audio noise reduction. This can be set to null to turn off. Noise reduction filters audio added to the input audio buffer before it is sent to VAD and the model. Filtering the audio can improve VAD and turn detection accuracy (reducing false positives) and model performance by improving perception of the input audio. NoiseReduction *AudioNoiseReduction `json:"noise_reduction,omitempty"` - // Configuration for input audio transcription, defaults to off and can be set to null to turn off once on. Input audio transcription is not native to the model, since the model consumes audio directly. Transcription runs asynchronously through the /audio/transcriptions endpoint and should be treated as guidance of input audio content rather than precisely what the model heard. The client can optionally set the language and prompt for transcription, these offer additional guidance to the transcription service. + // Configuration for turn detection: Server VAD or Semantic VAD. Set to null + // to turn off, in which case the client must manually trigger model response. TurnDetection *TurnDetectionUnion `json:"turn_detection,omitempty"` - // Configuration for turn detection, ether Server VAD or Semantic VAD. This can be set to null to turn off, in which case the client must manually trigger model response. - // - // Server VAD means that the model will detect the start and end of speech based on audio volume and respond at the end of user speech. - // - // Semantic VAD is more advanced and uses a turn detection model (in conjunction with VAD) to semantically estimate whether the user has finished speaking, then dynamically sets a timeout based on this probability. For example, if user audio trails off with "uhhm", the model will score a low probability of turn end and wait longer for the user to continue speaking. This can be useful for more natural conversations, but may have a higher latency. + // True when the JSON payload explicitly included "turn_detection" (even as null). + // Standard Go JSON can't distinguish absent from null for pointer fields. + TurnDetectionSet bool `json:"-"` + + // Configuration for input audio transcription, defaults to off and can be + // set to null to turn off once on. Transcription *AudioTranscription `json:"transcription,omitempty"` } +func (s *SessionAudioInput) UnmarshalJSON(data []byte) error { + // Check whether turn_detection key exists in the raw JSON. + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + type alias SessionAudioInput + var a alias + if err := json.Unmarshal(data, &a); err != nil { + return err + } + *s = SessionAudioInput(a) + + if _, ok := raw["turn_detection"]; ok { + s.TurnDetectionSet = true + } + return nil +} + type SessionAudioOutput struct { Format *AudioFormatUnion `json:"format,omitempty"` Speed float32 `json:"speed,omitempty"` @@ -1012,10 +1034,13 @@ func (r *SessionUnion) UnmarshalJSON(data []byte) error { return err } switch SessionType(t.Type) { - case SessionTypeRealtime: - return json.Unmarshal(data, &r.Realtime) + case SessionTypeRealtime, "": + // Default to realtime when no type field is present (e.g. session.update events). + r.Realtime = &RealtimeSession{} + return json.Unmarshal(data, r.Realtime) case SessionTypeTranscription: - return json.Unmarshal(data, &r.Transcription) + r.Transcription = &TranscriptionSession{} + return json.Unmarshal(data, r.Transcription) default: return fmt.Errorf("unknown session type: %s", t.Type) } diff --git a/core/http/middleware/trace.go b/core/http/middleware/trace.go index 15bc970a965e..2d7bcf16719d 100644 --- a/core/http/middleware/trace.go +++ b/core/http/middleware/trace.go @@ -158,7 +158,7 @@ func GetTraces() []APIExchange { mu.Unlock() sort.Slice(traces, func(i, j int) bool { - return traces[i].Timestamp.Before(traces[j].Timestamp) + return traces[i].Timestamp.After(traces[j].Timestamp) }) return traces diff --git a/core/http/react-ui/src/pages/Settings.jsx b/core/http/react-ui/src/pages/Settings.jsx index b112c91215ad..9ac3c00a9cf0 100644 --- a/core/http/react-ui/src/pages/Settings.jsx +++ b/core/http/react-ui/src/pages/Settings.jsx @@ -55,6 +55,7 @@ const SECTIONS = [ { id: 'memory', icon: 'fa-memory', color: 'var(--color-accent)', label: 'Memory' }, { id: 'backends', icon: 'fa-cogs', color: 'var(--color-accent)', label: 'Backends' }, { id: 'performance', icon: 'fa-gauge-high', color: 'var(--color-success)', label: 'Performance' }, + { id: 'tracing', icon: 'fa-bug', color: 'var(--color-warning)', label: 'Tracing' }, { id: 'api', icon: 'fa-globe', color: 'var(--color-warning)', label: 'API & CORS' }, { id: 'p2p', icon: 'fa-network-wired', color: 'var(--color-accent)', label: 'P2P' }, { id: 'galleries', icon: 'fa-images', color: 'var(--color-accent)', label: 'Galleries' }, @@ -327,10 +328,19 @@ export default function Settings() { update('debug', v)} /> - + + + + {/* Tracing */} +
sectionRefs.current.tracing = el} style={{ marginBottom: 'var(--spacing-xl)' }}> +

+ Tracing +

+
+ update('enable_tracing', v)} /> - + update('tracing_max_items', parseInt(e.target.value) || 0)} placeholder="100" disabled={!settings.enable_tracing} />
diff --git a/core/http/react-ui/src/pages/Talk.jsx b/core/http/react-ui/src/pages/Talk.jsx index 590b89bda32d..fa9a784fad09 100644 --- a/core/http/react-ui/src/pages/Talk.jsx +++ b/core/http/react-ui/src/pages/Talk.jsx @@ -1,196 +1,688 @@ -import { useState, useRef, useCallback } from 'react' +import { useState, useRef, useEffect, useCallback } from 'react' import { useOutletContext } from 'react-router-dom' -import ModelSelector from '../components/ModelSelector' -import LoadingSpinner from '../components/LoadingSpinner' -import { chatApi, ttsApi, audioApi } from '../utils/api' +import { realtimeApi } from '../utils/api' + +const STATUS_STYLES = { + disconnected: { icon: 'fa-solid fa-circle', color: 'var(--color-text-secondary)', bg: 'transparent' }, + connecting: { icon: 'fa-solid fa-spinner fa-spin', color: 'var(--color-primary)', bg: 'var(--color-primary-light)' }, + connected: { icon: 'fa-solid fa-circle', color: 'var(--color-success)', bg: 'rgba(34,197,94,0.1)' }, + listening: { icon: 'fa-solid fa-microphone', color: 'var(--color-success)', bg: 'rgba(34,197,94,0.1)' }, + thinking: { icon: 'fa-solid fa-brain fa-beat', color: 'var(--color-primary)', bg: 'var(--color-primary-light)' }, + speaking: { icon: 'fa-solid fa-volume-high fa-beat-fade', color: 'var(--color-accent)', bg: 'rgba(168,85,247,0.1)' }, + error: { icon: 'fa-solid fa-circle', color: 'var(--color-error)', bg: 'var(--color-error-light)' }, +} export default function Talk() { const { addToast } = useOutletContext() - const [llmModel, setLlmModel] = useState('') - const [whisperModel, setWhisperModel] = useState('') - const [ttsModel, setTtsModel] = useState('') - const [isRecording, setIsRecording] = useState(false) - const [loading, setLoading] = useState(false) - const [status, setStatus] = useState('Press the record button to start talking.') - const [audioUrl, setAudioUrl] = useState(null) - const [conversationHistory, setConversationHistory] = useState([]) - const mediaRecorderRef = useRef(null) - const chunksRef = useRef([]) + + // Pipeline models + const [pipelineModels, setPipelineModels] = useState([]) + const [selectedModel, setSelectedModel] = useState('') + const [modelsLoading, setModelsLoading] = useState(true) + + // Connection state + const [status, setStatus] = useState('disconnected') + const [statusText, setStatusText] = useState('Disconnected') + const [isConnected, setIsConnected] = useState(false) + + // Transcript + const [transcript, setTranscript] = useState([]) + const streamingRef = useRef(null) // tracks the index of the in-progress assistant message + + // Session settings + const [instructions, setInstructions] = useState( + 'You are a helpful voice assistant. Your responses will be spoken aloud using text-to-speech, so keep them concise and conversational. Do not use markdown formatting, bullet points, numbered lists, code blocks, or special characters. Speak naturally as you would in a phone conversation.' + ) + const [voice, setVoice] = useState('') + const [voiceEdited, setVoiceEdited] = useState(false) + const [language, setLanguage] = useState('') + + // Diagnostics + const [diagVisible, setDiagVisible] = useState(false) + + // Refs for WebRTC / audio + const pcRef = useRef(null) + const dcRef = useRef(null) + const localStreamRef = useRef(null) const audioRef = useRef(null) + const hasErrorRef = useRef(false) + + // Diagnostics refs + const audioCtxRef = useRef(null) + const analyserRef = useRef(null) + const diagFrameRef = useRef(null) + const statsIntervalRef = useRef(null) + const waveCanvasRef = useRef(null) + const specCanvasRef = useRef(null) + const transcriptEndRef = useRef(null) + + // Diagnostics stats (not worth re-rendering for every frame) + const [diagStats, setDiagStats] = useState({ + peakFreq: '--', thd: '--', rms: '--', sampleRate: '--', + packetsRecv: '--', packetsLost: '--', jitter: '--', concealed: '--', raw: '', + }) + + // Fetch pipeline models on mount + useEffect(() => { + realtimeApi.pipelineModels() + .then(models => { + setPipelineModels(models || []) + if (models?.length > 0) { + setSelectedModel(models[0].name) + if (!voiceEdited) setVoice(models[0].voice || '') + } + }) + .catch(err => addToast(`Failed to load pipeline models: ${err.message}`, 'error')) + .finally(() => setModelsLoading(false)) + }, []) + + // Auto-scroll transcript + useEffect(() => { + transcriptEndRef.current?.scrollIntoView({ behavior: 'smooth' }) + }, [transcript]) + + const selectedModelInfo = pipelineModels.find(m => m.name === selectedModel) + + // ── Status helper ── + const updateStatus = useCallback((state, text) => { + setStatus(state) + setStatusText(text || state) + }, []) + + // ── Session update ── + const sendSessionUpdate = useCallback(() => { + const dc = dcRef.current + if (!dc || dc.readyState !== 'open') return + if (!instructions.trim() && !voice.trim() && !language.trim()) return + + const session = {} + if (instructions.trim()) session.instructions = instructions.trim() + if (voice.trim() || language.trim()) { + session.audio = {} + if (voice.trim()) session.audio.output = { voice: voice.trim() } + if (language.trim()) session.audio.input = { transcription: { language: language.trim() } } + } + + dc.send(JSON.stringify({ type: 'session.update', session })) + }, [instructions, voice, language]) + + // ── Server event handler ── + const handleServerEvent = useCallback((event) => { + switch (event.type) { + case 'session.created': + sendSessionUpdate() + updateStatus('listening', 'Listening...') + break + case 'session.updated': + break + case 'input_audio_buffer.speech_started': + updateStatus('listening', 'Hearing you speak...') + break + case 'input_audio_buffer.speech_stopped': + updateStatus('thinking', 'Processing...') + break + case 'conversation.item.input_audio_transcription.completed': + if (event.transcript) { + streamingRef.current = null + setTranscript(prev => [...prev, { role: 'user', text: event.transcript }]) + } + updateStatus('thinking', 'Generating response...') + break + case 'response.output_audio_transcript.delta': + if (event.delta) { + setTranscript(prev => { + if (streamingRef.current !== null) { + const updated = [...prev] + updated[streamingRef.current] = { + ...updated[streamingRef.current], + text: updated[streamingRef.current].text + event.delta, + } + return updated + } + streamingRef.current = prev.length + return [...prev, { role: 'assistant', text: event.delta }] + }) + } + break + case 'response.output_audio_transcript.done': + if (event.transcript) { + setTranscript(prev => { + if (streamingRef.current !== null) { + const updated = [...prev] + updated[streamingRef.current] = { ...updated[streamingRef.current], text: event.transcript } + return updated + } + return [...prev, { role: 'assistant', text: event.transcript }] + }) + } + streamingRef.current = null + break + case 'response.output_audio.delta': + updateStatus('speaking', 'Speaking...') + break + case 'response.done': + updateStatus('listening', 'Listening...') + break + case 'error': + hasErrorRef.current = true + updateStatus('error', 'Error: ' + (event.error?.message || 'Unknown error')) + break + } + }, [sendSessionUpdate, updateStatus]) - const startRecording = async () => { - if (!navigator.mediaDevices) { - addToast('MediaDevices API not supported', 'error') + // ── Connect ── + const connect = useCallback(async () => { + if (!selectedModel) { + addToast('Please select a pipeline model first.', 'warning') return } + if (!navigator.mediaDevices?.getUserMedia) { + updateStatus('error', 'Microphone access requires HTTPS or localhost.') + return + } + + updateStatus('connecting', 'Connecting...') + setIsConnected(true) + try { - const stream = await navigator.mediaDevices.getUserMedia({ audio: true }) - const recorder = new MediaRecorder(stream) - chunksRef.current = [] - recorder.ondataavailable = (e) => chunksRef.current.push(e.data) - recorder.start() - mediaRecorderRef.current = recorder - setIsRecording(true) - setStatus('Recording... Click to stop.') + const localStream = await navigator.mediaDevices.getUserMedia({ audio: true }) + localStreamRef.current = localStream + + const pc = new RTCPeerConnection({}) + pcRef.current = pc + + for (const track of localStream.getAudioTracks()) { + pc.addTrack(track, localStream) + } + + pc.ontrack = (event) => { + if (audioRef.current) audioRef.current.srcObject = event.streams[0] + if (diagVisible) startDiagnostics() + } + + const dc = pc.createDataChannel('oai-events') + dcRef.current = dc + dc.onmessage = (msg) => { + try { + const text = typeof msg.data === 'string' ? msg.data : new TextDecoder().decode(msg.data) + handleServerEvent(JSON.parse(text)) + } catch (e) { + console.error('Failed to parse server event:', e) + } + } + dc.onclose = () => console.log('Data channel closed') + + pc.onconnectionstatechange = () => { + if (pc.connectionState === 'connected') { + updateStatus('connected', 'Connected, waiting for session...') + } else if (pc.connectionState === 'failed' || pc.connectionState === 'closed') { + disconnect() + } + } + + const offer = await pc.createOffer() + await pc.setLocalDescription(offer) + + await new Promise((resolve) => { + if (pc.iceGatheringState === 'complete') return resolve() + pc.onicegatheringstatechange = () => { + if (pc.iceGatheringState === 'complete') resolve() + } + setTimeout(resolve, 5000) + }) + + const data = await realtimeApi.call({ + sdp: pc.localDescription.sdp, + model: selectedModel, + }) + + await pc.setRemoteDescription({ type: 'answer', sdp: data.sdp }) } catch (err) { - addToast(`Microphone error: ${err.message}`, 'error') + hasErrorRef.current = true + updateStatus('error', 'Connection failed: ' + err.message) + disconnect() + } + }, [selectedModel, diagVisible, handleServerEvent, updateStatus, addToast]) + + // ── Disconnect ── + const disconnect = useCallback(() => { + stopDiagnostics() + if (dcRef.current) { dcRef.current.close(); dcRef.current = null } + if (pcRef.current) { pcRef.current.close(); pcRef.current = null } + if (localStreamRef.current) { + localStreamRef.current.getTracks().forEach(t => t.stop()) + localStreamRef.current = null } + if (audioRef.current) audioRef.current.srcObject = null + + if (!hasErrorRef.current) updateStatus('disconnected', 'Disconnected') + hasErrorRef.current = false + setIsConnected(false) + }, [updateStatus]) + + // Cleanup on unmount + useEffect(() => { + return () => { + stopDiagnostics() + if (dcRef.current) dcRef.current.close() + if (pcRef.current) pcRef.current.close() + if (localStreamRef.current) localStreamRef.current.getTracks().forEach(t => t.stop()) + } + }, []) + + // ── Test tone ── + const sendTestTone = useCallback(() => { + const dc = dcRef.current + if (!dc || dc.readyState !== 'open') return + dc.send(JSON.stringify({ type: 'test_tone' })) + setTranscript(prev => [...prev, { role: 'assistant', text: '(Test tone requested)' }]) + }, []) + + // ── Diagnostics ── + function startDiagnostics() { + const audioEl = audioRef.current + if (!audioEl?.srcObject) return + + if (!audioCtxRef.current) { + const ctx = new AudioContext() + const source = ctx.createMediaStreamSource(audioEl.srcObject) + const analyser = ctx.createAnalyser() + analyser.fftSize = 8192 + analyser.smoothingTimeConstant = 0.3 + source.connect(analyser) + audioCtxRef.current = ctx + analyserRef.current = analyser + setDiagStats(prev => ({ ...prev, sampleRate: ctx.sampleRate + ' Hz' })) + } + + if (!diagFrameRef.current) drawDiagnostics() + if (!statsIntervalRef.current) { + pollWebRTCStats() + statsIntervalRef.current = setInterval(pollWebRTCStats, 1000) + } + } + + function stopDiagnostics() { + if (diagFrameRef.current) { cancelAnimationFrame(diagFrameRef.current); diagFrameRef.current = null } + if (statsIntervalRef.current) { clearInterval(statsIntervalRef.current); statsIntervalRef.current = null } + if (audioCtxRef.current) { audioCtxRef.current.close(); audioCtxRef.current = null; analyserRef.current = null } } - const stopRecording = useCallback(() => { - if (!mediaRecorderRef.current) return - - mediaRecorderRef.current.onstop = async () => { - setIsRecording(false) - setLoading(true) - - const audioBlob = new Blob(chunksRef.current, { type: 'audio/webm' }) - - try { - // 1. Transcribe - setStatus('Transcribing audio...') - const formData = new FormData() - formData.append('file', audioBlob) - formData.append('model', whisperModel) - const transcription = await audioApi.transcribe(formData) - const userText = transcription.text - - setStatus(`You said: "${userText}". Generating response...`) - - // 2. Chat completion - const newHistory = [...conversationHistory, { role: 'user', content: userText }] - const chatResponse = await chatApi.complete({ - model: llmModel, - messages: newHistory, - }) - const assistantText = chatResponse?.choices?.[0]?.message?.content || '' - const updatedHistory = [...newHistory, { role: 'assistant', content: assistantText }] - setConversationHistory(updatedHistory) - - setStatus(`Response: "${assistantText}". Generating speech...`) - - // 3. TTS - const ttsBlob = await ttsApi.generateV1({ input: assistantText, model: ttsModel }) - const url = URL.createObjectURL(ttsBlob) - setAudioUrl(url) - setStatus('Press the record button to continue.') - - // Auto-play - setTimeout(() => audioRef.current?.play(), 100) - } catch (err) { - addToast(`Error: ${err.message}`, 'error') - setStatus('Error occurred. Try again.') - } finally { - setLoading(false) + function drawDiagnostics() { + const analyser = analyserRef.current + if (!analyser) { diagFrameRef.current = null; return } + + diagFrameRef.current = requestAnimationFrame(drawDiagnostics) + + // Waveform + const waveCanvas = waveCanvasRef.current + if (waveCanvas) { + const wCtx = waveCanvas.getContext('2d') + const timeData = new Float32Array(analyser.fftSize) + analyser.getFloatTimeDomainData(timeData) + const w = waveCanvas.width, h = waveCanvas.height + wCtx.fillStyle = '#000'; wCtx.fillRect(0, 0, w, h) + wCtx.strokeStyle = '#0f0'; wCtx.lineWidth = 1; wCtx.beginPath() + const sliceWidth = w / timeData.length + let x = 0 + for (let i = 0; i < timeData.length; i++) { + const y = (1 - timeData[i]) * h / 2 + i === 0 ? wCtx.moveTo(x, y) : wCtx.lineTo(x, y) + x += sliceWidth } + wCtx.stroke() + + let sumSq = 0 + for (let i = 0; i < timeData.length; i++) sumSq += timeData[i] * timeData[i] + const rms = Math.sqrt(sumSq / timeData.length) + const rmsDb = rms > 0 ? (20 * Math.log10(rms)).toFixed(1) : '-Inf' + setDiagStats(prev => ({ ...prev, rms: rmsDb + ' dBFS' })) } - mediaRecorderRef.current.stop() - mediaRecorderRef.current.stream?.getTracks().forEach(t => t.stop()) - }, [whisperModel, llmModel, ttsModel, conversationHistory]) + // Spectrum + const specCanvas = specCanvasRef.current + if (specCanvas && audioCtxRef.current) { + const sCtx = specCanvas.getContext('2d') + const freqData = new Float32Array(analyser.frequencyBinCount) + analyser.getFloatFrequencyData(freqData) + const sw = specCanvas.width, sh = specCanvas.height + sCtx.fillStyle = '#000'; sCtx.fillRect(0, 0, sw, sh) + + const sampleRate = audioCtxRef.current.sampleRate + const binHz = sampleRate / analyser.fftSize + const maxFreqDisplay = 4000 + const maxBin = Math.min(Math.ceil(maxFreqDisplay / binHz), freqData.length) + const barWidth = sw / maxBin + + sCtx.fillStyle = '#0cf' + let peakBin = 0, peakVal = -Infinity + for (let i = 0; i < maxBin; i++) { + const db = freqData[i] + if (db > peakVal) { peakVal = db; peakBin = i } + const barH = Math.max(0, ((db + 100) / 100) * sh) + sCtx.fillRect(i * barWidth, sh - barH, Math.max(1, barWidth - 0.5), barH) + } + + // Frequency labels + sCtx.fillStyle = '#888'; sCtx.font = '10px monospace' + for (let f = 500; f <= maxFreqDisplay; f += 500) { + sCtx.fillText(f + '', (f / binHz) * barWidth - 10, sh - 2) + } + + // 440 Hz marker + const bin440 = Math.round(440 / binHz) + const x440 = bin440 * barWidth + sCtx.strokeStyle = '#f00'; sCtx.lineWidth = 1 + sCtx.beginPath(); sCtx.moveTo(x440, 0); sCtx.lineTo(x440, sh); sCtx.stroke() + sCtx.fillStyle = '#f00'; sCtx.fillText('440', x440 + 2, 10) - const resetConversation = () => { - setConversationHistory([]) - setAudioUrl(null) - setStatus('Conversation reset. Press record to start.') - addToast('Conversation reset', 'info') + const peakFreq = peakBin * binHz + const fundamentalBin = Math.round(440 / binHz) + const fundamentalPower = Math.pow(10, freqData[fundamentalBin] / 10) + let harmonicPower = 0 + for (let h = 2; h <= 10; h++) { + const hBin = Math.round(440 * h / binHz) + if (hBin < freqData.length) harmonicPower += Math.pow(10, freqData[hBin] / 10) + } + const thd = fundamentalPower > 0 + ? (Math.sqrt(harmonicPower / fundamentalPower) * 100).toFixed(1) + '%' + : '--%' + + setDiagStats(prev => ({ + ...prev, + peakFreq: peakFreq.toFixed(0) + ' Hz (' + peakVal.toFixed(1) + ' dB)', + thd, + })) + } } - const allModelsSet = llmModel && whisperModel && ttsModel + async function pollWebRTCStats() { + const pc = pcRef.current + if (!pc) return + try { + const stats = await pc.getStats() + const raw = [] + stats.forEach((report) => { + if (report.type === 'inbound-rtp' && report.kind === 'audio') { + setDiagStats(prev => ({ + ...prev, + packetsRecv: report.packetsReceived ?? '--', + packetsLost: report.packetsLost ?? '--', + jitter: report.jitter !== undefined ? (report.jitter * 1000).toFixed(1) + ' ms' : '--', + concealed: report.concealedSamples ?? '--', + })) + raw.push('-- inbound-rtp (audio) --') + raw.push(' packetsReceived: ' + report.packetsReceived) + raw.push(' packetsLost: ' + report.packetsLost) + raw.push(' jitter: ' + (report.jitter !== undefined ? (report.jitter * 1000).toFixed(2) + ' ms' : 'N/A')) + raw.push(' bytesReceived: ' + report.bytesReceived) + raw.push(' concealedSamples: ' + report.concealedSamples) + raw.push(' totalSamplesReceived: ' + report.totalSamplesReceived) + } + }) + setDiagStats(prev => ({ ...prev, raw: raw.join('\n') })) + } catch (_e) { /* stats polling error */ } + } + + const toggleDiagnostics = useCallback(() => { + setDiagVisible(prev => { + const next = !prev + if (next) { + setTimeout(startDiagnostics, 0) + } else { + stopDiagnostics() + } + return next + }) + }, []) + + const statusStyle = STATUS_STYLES[status] || STATUS_STYLES.disconnected + // ── Render ── return (
-
+

Talk

-

Voice conversation with AI

+

Real-time voice conversation via WebRTC

- {/* Main interaction area */} -
- {/* Big record button */} - - - {/* Status */} -

- {loading ? : null} - {' '}{status} -

- - {/* Recording indicator */} - {isRecording && ( +
+ {/* Connection status */} +
+ + {statusText} +
+ + {/* Info note */} +
+ +

+ Note: Select a pipeline model and click Connect. + Your microphone streams continuously; the server detects speech and responds automatically. +

+
+ + {/* Pipeline model selector */} +
+ + +
+ + {/* Pipeline details */} + {selectedModelInfo && (
- - Recording... + {[ + { label: 'VAD', value: selectedModelInfo.vad }, + { label: 'Transcription', value: selectedModelInfo.transcription }, + { label: 'LLM', value: selectedModelInfo.llm }, + { label: 'TTS', value: selectedModelInfo.tts }, + ].map(item => ( +
+
{item.label}
+
{item.value}
+
+ ))}
)} - {/* Audio playback */} - {audioUrl && ( -
-
- - - + + + + + +
- + {{template "views/partials/footer" .}}
diff --git a/core/http/views/traces.html b/core/http/views/traces.html index 3e66c82b41f8..6287cc47782f 100644 --- a/core/http/views/traces.html +++ b/core/http/views/traces.html @@ -254,12 +254,54 @@

Response

+ + +