Data Fie
if (typeof value === 'boolean') return value ? 'true' : 'false';
if (typeof value === 'object') return JSON.stringify(value);
return String(value);
+ },
+
+ filterDataFields(data) {
+ const audioKeys = new Set([
+ 'audio_wav_base64', 'audio_duration_s', 'audio_snippet_s',
+ 'audio_sample_rate', 'audio_samples', 'audio_rms_dbfs',
+ 'audio_peak_dbfs', 'audio_dc_offset'
+ ]);
+ return Object.entries(data).filter(([key]) => !audioKeys.has(key));
}
}
}
diff --git a/core/trace/audio_snippet.go b/core/trace/audio_snippet.go
new file mode 100644
index 000000000000..2f2190ca8aa0
--- /dev/null
+++ b/core/trace/audio_snippet.go
@@ -0,0 +1,102 @@
+package trace
+
+import (
+ "bytes"
+ "encoding/base64"
+ "math"
+ "os"
+
+ "github.com/mudler/LocalAI/pkg/audio"
+ "github.com/mudler/LocalAI/pkg/sound"
+ "github.com/mudler/xlog"
+)
+
+// MaxSnippetSeconds is the maximum number of seconds of audio captured per trace.
+const MaxSnippetSeconds = 30
+
+// AudioSnippet captures the first MaxSnippetSeconds of a WAV file and computes
+// quality metrics. The result is a map suitable for merging into a BackendTrace
+// Data field.
+func AudioSnippet(wavPath string) map[string]any {
+ raw, err := os.ReadFile(wavPath)
+ if err != nil {
+ xlog.Warn("audio snippet: read failed", "path", wavPath, "error", err)
+ return nil
+ }
+ // Only process WAV files (RIFF header)
+ if len(raw) <= audio.WAVHeaderSize || string(raw[:4]) != "RIFF" {
+ xlog.Debug("audio snippet: not a WAV file or too small", "path", wavPath, "bytes", len(raw))
+ return nil
+ }
+
+ pcm, sampleRate := audio.ParseWAV(raw)
+ if sampleRate == 0 {
+ sampleRate = 16000
+ }
+
+ return AudioSnippetFromPCM(pcm, sampleRate, len(pcm))
+}
+
+// AudioSnippetFromPCM builds an audio snippet from raw PCM bytes (int16 LE mono).
+// totalPCMBytes is the full audio size before truncation (used to compute total duration).
+func AudioSnippetFromPCM(pcm []byte, sampleRate int, totalPCMBytes int) map[string]any {
+ if len(pcm) == 0 || len(pcm)%2 != 0 {
+ return nil
+ }
+
+ samples := sound.BytesToInt16sLE(pcm)
+ totalSamples := totalPCMBytes / 2
+ durationS := float64(totalSamples) / float64(sampleRate)
+
+ // Truncate to first MaxSnippetSeconds
+ maxSamples := MaxSnippetSeconds * sampleRate
+ if len(samples) > maxSamples {
+ samples = samples[:maxSamples]
+ }
+
+ snippetDuration := float64(len(samples)) / float64(sampleRate)
+
+ rms := sound.CalculateRMS16(samples)
+ rmsDBFS := -math.Inf(1)
+ if rms > 0 {
+ rmsDBFS = 20 * math.Log10(rms/32768.0)
+ }
+
+ var peak int16
+ var dcSum int64
+ for _, s := range samples {
+ if s < 0 && -s > peak {
+ peak = -s
+ } else if s > peak {
+ peak = s
+ }
+ dcSum += int64(s)
+ }
+ peakDBFS := -math.Inf(1)
+ if peak > 0 {
+ peakDBFS = 20 * math.Log10(float64(peak) / 32768.0)
+ }
+ dcOffset := float64(dcSum) / float64(len(samples)) / 32768.0
+
+ // Encode the snippet as WAV
+ snippetPCM := sound.Int16toBytesLE(samples)
+ hdr := audio.NewWAVHeaderWithRate(uint32(len(snippetPCM)), uint32(sampleRate))
+ var buf bytes.Buffer
+ buf.Grow(audio.WAVHeaderSize + len(snippetPCM))
+ if err := hdr.Write(&buf); err != nil {
+ xlog.Warn("audio snippet: write header failed", "error", err)
+ return nil
+ }
+ buf.Write(snippetPCM)
+
+ return map[string]any{
+ "audio_wav_base64": base64.StdEncoding.EncodeToString(buf.Bytes()),
+ "audio_duration_s": math.Round(durationS*100) / 100,
+ "audio_snippet_s": math.Round(snippetDuration*100) / 100,
+ "audio_sample_rate": sampleRate,
+ "audio_samples": totalSamples,
+ "audio_rms_dbfs": math.Round(rmsDBFS*10) / 10,
+ "audio_peak_dbfs": math.Round(peakDBFS*10) / 10,
+ "audio_dc_offset": math.Round(dcOffset*10000) / 10000,
+ }
+}
diff --git a/core/trace/backend_trace.go b/core/trace/backend_trace.go
index 0dfbd8458e9c..4e6237f9f700 100644
--- a/core/trace/backend_trace.go
+++ b/core/trace/backend_trace.go
@@ -85,7 +85,7 @@ func GetBackendTraces() []BackendTrace {
}
sort.Slice(traces, func(i, j int) bool {
- return traces[i].Timestamp.Before(traces[j].Timestamp)
+ return traces[i].Timestamp.After(traces[j].Timestamp)
})
return traces
diff --git a/docs/content/advanced/advanced-usage.md b/docs/content/advanced/advanced-usage.md
index 7742eb29a874..1b7fba2976c9 100644
--- a/docs/content/advanced/advanced-usage.md
+++ b/docs/content/advanced/advanced-usage.md
@@ -188,6 +188,8 @@ there are additional environment variables available that modify the behavior of
| `EXTRA_BACKENDS` | | A space separated list of backends to prepare. For example `EXTRA_BACKENDS="backend/python/diffusers backend/python/transformers"` prepares the python environment on start |
| `DISABLE_AUTODETECT` | `false` | Disable autodetect of CPU flagset on start |
| `LLAMACPP_GRPC_SERVERS` | | A list of llama.cpp workers to distribute the workload. For example `LLAMACPP_GRPC_SERVERS="address1:port,address2:port"` |
+| `OPUS_LIBRARY` | | Path to the libopus shared library (e.g. `/usr/lib/libopus.so.0`). Used by the WebRTC realtime API for Opus audio encoding/decoding. When unset, standard system paths are searched automatically. |
+| `OPUS_SHIM_LIBRARY` | | Path to the libopusshim shared library (e.g. `/usr/local/lib/libopusshim.so`). This thin wrapper is built from `pkg/opus/shim/` during `make build` when libopus-dev is installed. |
Here is how to configure these variables:
diff --git a/go.mod b/go.mod
index 55b2f8228867..0244b2a3209b 100644
--- a/go.mod
+++ b/go.mod
@@ -131,6 +131,7 @@ require (
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/oxffaa/gopher-parse-sitemap v0.0.0-20191021113419-005d2eb1def4 // indirect
github.com/philippgille/chromem-go v0.7.0 // indirect
+ github.com/pion/transport/v4 v4.0.1 // indirect
github.com/pjbgf/sha1cd v0.3.2 // indirect
github.com/rs/zerolog v1.31.0 // indirect
github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect
@@ -208,25 +209,24 @@ require (
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
github.com/nicksnyder/go-i18n/v2 v2.5.1 // indirect
github.com/otiai10/mint v1.6.3 // indirect
- github.com/pion/datachannel v1.5.10 // indirect
+ github.com/pion/datachannel v1.6.0 // indirect
github.com/pion/dtls/v2 v2.2.12 // indirect
- github.com/pion/dtls/v3 v3.0.6 // indirect
- github.com/pion/ice/v4 v4.0.10 // indirect
- github.com/pion/interceptor v0.1.40 // indirect
- github.com/pion/logging v0.2.3 // indirect
- github.com/pion/mdns/v2 v2.0.7 // indirect
+ github.com/pion/dtls/v3 v3.1.2 // indirect
+ github.com/pion/ice/v4 v4.2.1 // indirect
+ github.com/pion/interceptor v0.1.44 // indirect
+ github.com/pion/logging v0.2.4 // indirect
+ github.com/pion/mdns/v2 v2.1.0 // indirect
github.com/pion/randutil v0.1.0 // indirect
- github.com/pion/rtcp v1.2.15 // indirect
- github.com/pion/rtp v1.8.19 // indirect
- github.com/pion/sctp v1.8.39 // indirect
- github.com/pion/sdp/v3 v3.0.13 // indirect
- github.com/pion/srtp/v3 v3.0.6 // indirect
+ github.com/pion/rtcp v1.2.16 // indirect
+ github.com/pion/rtp v1.10.1
+ github.com/pion/sctp v1.9.2 // indirect
+ github.com/pion/sdp/v3 v3.0.18 // indirect
+ github.com/pion/srtp/v3 v3.0.10 // indirect
github.com/pion/stun v0.6.1 // indirect
- github.com/pion/stun/v3 v3.0.0 // indirect
+ github.com/pion/stun/v3 v3.1.1 // indirect
github.com/pion/transport/v2 v2.2.10 // indirect
- github.com/pion/transport/v3 v3.0.7 // indirect
- github.com/pion/turn/v4 v4.0.2 // indirect
- github.com/pion/webrtc/v4 v4.1.2 // indirect
+ github.com/pion/turn/v4 v4.1.4 // indirect
+ github.com/pion/webrtc/v4 v4.2.9
github.com/prometheus/otlptranslator v1.0.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
diff --git a/go.sum b/go.sum
index d31c0b4af09a..341656be1a3b 100644
--- a/go.sum
+++ b/go.sum
@@ -656,8 +656,6 @@ github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM=
github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw=
github.com/mudler/LocalAGI v0.0.0-20260306154948-5a27c471ca78 h1:B3FgipRORpDtDvNlCC/w4N6PPwIyn7M/mzeRiq0EV4o=
github.com/mudler/LocalAGI v0.0.0-20260306154948-5a27c471ca78/go.mod h1:e/00in01SHCpzUD/UyJMopn7P+vJMjsk6qkxZC1qPW0=
-github.com/mudler/cogito v0.9.2 h1:KbzNpuJ782njeBKfg3q7kLIBHTCFi9DgXhPTXnZqu1Y=
-github.com/mudler/cogito v0.9.2/go.mod h1:6sfja3lcu2nWRzEc0wwqGNu/eCG3EWgij+8s7xyUeQ4=
github.com/mudler/cogito v0.9.3-0.20260306202429-e073d115bd04 h1:33Lqv8VBaV/AoaaVtZ5+Bcig4T9fvj0dQmKFCon5Xxo=
github.com/mudler/cogito v0.9.3-0.20260306202429-e073d115bd04/go.mod h1:6sfja3lcu2nWRzEc0wwqGNu/eCG3EWgij+8s7xyUeQ4=
github.com/mudler/edgevpn v0.31.1 h1:7qegiDWd0kAg6ljhNHxqvp8hbo/6BbzSdbb7/2WZfiY=
@@ -750,48 +748,50 @@ github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxc
github.com/phpdave11/gofpdi v1.0.7/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM=
github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
-github.com/pion/datachannel v1.5.10 h1:ly0Q26K1i6ZkGf42W7D4hQYR90pZwzFOjTq5AuCKk4o=
-github.com/pion/datachannel v1.5.10/go.mod h1:p/jJfC9arb29W7WrxyKbepTU20CFgyx5oLo8Rs4Py/M=
+github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0=
+github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk=
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk=
github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
-github.com/pion/dtls/v3 v3.0.6 h1:7Hkd8WhAJNbRgq9RgdNh1aaWlZlGpYTzdqjy9x9sK2E=
-github.com/pion/dtls/v3 v3.0.6/go.mod h1:iJxNQ3Uhn1NZWOMWlLxEEHAN5yX7GyPvvKw04v9bzYU=
-github.com/pion/ice/v4 v4.0.10 h1:P59w1iauC/wPk9PdY8Vjl4fOFL5B+USq1+xbDcN6gT4=
-github.com/pion/ice/v4 v4.0.10/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw=
-github.com/pion/interceptor v0.1.40 h1:e0BjnPcGpr2CFQgKhrQisBU7V3GXK6wrfYrGYaU6Jq4=
-github.com/pion/interceptor v0.1.40/go.mod h1:Z6kqH7M/FYirg3frjGJ21VLSRJGBXB/KqaTIrdqnOic=
+github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc=
+github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo=
+github.com/pion/ice/v4 v4.2.1 h1:XPRYXaLiFq3LFDG7a7bMrmr3mFr27G/gtXN3v/TVfxY=
+github.com/pion/ice/v4 v4.2.1/go.mod h1:2quLV1S5v1tAx3VvAJaH//KGitRXvo4RKlX6D3tnN+c=
+github.com/pion/interceptor v0.1.44 h1:sNlZwM8dWXU9JQAkJh8xrarC0Etn8Oolcniukmuy0/I=
+github.com/pion/interceptor v0.1.44/go.mod h1:4atVlBkcgXuUP+ykQF0qOCGU2j7pQzX2ofvPRFsY5RY=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
-github.com/pion/logging v0.2.3 h1:gHuf0zpoh1GW67Nr6Gj4cv5Z9ZscU7g/EaoC/Ke/igI=
-github.com/pion/logging v0.2.3/go.mod h1:z8YfknkquMe1csOrxK5kc+5/ZPAzMxbKLX5aXpbpC90=
-github.com/pion/mdns/v2 v2.0.7 h1:c9kM8ewCgjslaAmicYMFQIde2H9/lrZpjBkN8VwoVtM=
-github.com/pion/mdns/v2 v2.0.7/go.mod h1:vAdSYNAT0Jy3Ru0zl2YiW3Rm/fJCwIeM0nToenfOJKA=
+github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
+github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
+github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY=
+github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
-github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo=
-github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0=
-github.com/pion/rtp v1.8.19 h1:jhdO/3XhL/aKm/wARFVmvTfq0lC/CvN1xwYKmduly3c=
-github.com/pion/rtp v1.8.19/go.mod h1:bAu2UFKScgzyFqvUKmbvzSdPr+NGbZtv6UB2hesqXBk=
-github.com/pion/sctp v1.8.39 h1:PJma40vRHa3UTO3C4MyeJDQ+KIobVYRZQZ0Nt7SjQnE=
-github.com/pion/sctp v1.8.39/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE=
-github.com/pion/sdp/v3 v3.0.13 h1:uN3SS2b+QDZnWXgdr69SM8KB4EbcnPnPf2Laxhty/l4=
-github.com/pion/sdp/v3 v3.0.13/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E=
-github.com/pion/srtp/v3 v3.0.6 h1:E2gyj1f5X10sB/qILUGIkL4C2CqK269Xq167PbGCc/4=
-github.com/pion/srtp/v3 v3.0.6/go.mod h1:BxvziG3v/armJHAaJ87euvkhHqWe9I7iiOy50K2QkhY=
+github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo=
+github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo=
+github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA=
+github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM=
+github.com/pion/sctp v1.9.2 h1:HxsOzEV9pWoeggv7T5kewVkstFNcGvhMPx0GvUOUQXo=
+github.com/pion/sctp v1.9.2/go.mod h1:OTOlsQ5EDQ6mQ0z4MUGXt2CgQmKyafBEXhUVqLRB6G8=
+github.com/pion/sdp/v3 v3.0.18 h1:l0bAXazKHpepazVdp+tPYnrsy9dfh7ZbT8DxesH5ZnI=
+github.com/pion/sdp/v3 v3.0.18/go.mod h1:ZREGo6A9ZygQ9XkqAj5xYCQtQpif0i6Pa81HOiAdqQ8=
+github.com/pion/srtp/v3 v3.0.10 h1:tFirkpBb3XccP5VEXLi50GqXhv5SKPxqrdlhDCJlZrQ=
+github.com/pion/srtp/v3 v3.0.10/go.mod h1:3mOTIB0cq9qlbn59V4ozvv9ClW/BSEbRp4cY0VtaR7M=
github.com/pion/stun v0.6.1 h1:8lp6YejULeHBF8NmV8e2787BogQhduZugh5PdhDyyN4=
github.com/pion/stun v0.6.1/go.mod h1:/hO7APkX4hZKu/D0f2lHzNyvdkTGtIy3NDmLR7kSz/8=
-github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
-github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
+github.com/pion/stun/v3 v3.1.1 h1:CkQxveJ4xGQjulGSROXbXq94TAWu8gIX2dT+ePhUkqw=
+github.com/pion/stun/v3 v3.1.1/go.mod h1:qC1DfmcCTQjl9PBaMa5wSn3x9IPmKxSdcCsxBcDBndM=
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v2 v2.2.10 h1:ucLBLE8nuxiHfvkFKnkDQRYWYfp8ejf4YBOPfaQpw6Q=
github.com/pion/transport/v2 v2.2.10/go.mod h1:sq1kSLWs+cHW9E+2fJP95QudkzbK7wscs8yYgQToO5E=
-github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
-github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
-github.com/pion/turn/v4 v4.0.2 h1:ZqgQ3+MjP32ug30xAbD6Mn+/K4Sxi3SdNOTFf+7mpps=
-github.com/pion/turn/v4 v4.0.2/go.mod h1:pMMKP/ieNAG/fN5cZiN4SDuyKsXtNTr0ccN7IToA1zs=
-github.com/pion/webrtc/v4 v4.1.2 h1:mpuUo/EJ1zMNKGE79fAdYNFZBX790KE7kQQpLMjjR54=
-github.com/pion/webrtc/v4 v4.1.2/go.mod h1:xsCXiNAmMEjIdFxAYU0MbB3RwRieJsegSB2JZsGN+8U=
+github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
+github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
+github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o=
+github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM=
+github.com/pion/turn/v4 v4.1.4 h1:EU11yMXKIsK43FhcUnjLlrhE4nboHZq+TXBIi3QpcxQ=
+github.com/pion/turn/v4 v4.1.4/go.mod h1:ES1DXVFKnOhuDkqn9hn5VJlSWmZPaRJLyBXoOeO/BmQ=
+github.com/pion/webrtc/v4 v4.2.9 h1:DZIh1HAhPIL3RvwEDFsmL5hfPSLEpxsQk9/Jir2vkJE=
+github.com/pion/webrtc/v4 v4.2.9/go.mod h1:9EmLZve0H76eTzf8v2FmchZ6tcBXtDgpfTEu+drW6SY=
github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4=
github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
diff --git a/pkg/audio/audio.go b/pkg/audio/audio.go
index 946d902f0539..c06d568b1960 100644
--- a/pkg/audio/audio.go
+++ b/pkg/audio/audio.go
@@ -53,3 +53,46 @@ func NewWAVHeader(pcmLen uint32) WAVHeader {
func (h *WAVHeader) Write(writer io.Writer) error {
return binary.Write(writer, binary.LittleEndian, h)
}
+
+// NewWAVHeaderWithRate creates a WAV header for mono 16-bit PCM at the given sample rate.
+func NewWAVHeaderWithRate(pcmLen, sampleRate uint32) WAVHeader {
+ header := WAVHeader{
+ ChunkID: [4]byte{'R', 'I', 'F', 'F'},
+ Format: [4]byte{'W', 'A', 'V', 'E'},
+ Subchunk1ID: [4]byte{'f', 'm', 't', ' '},
+ Subchunk1Size: 16,
+ AudioFormat: 1,
+ NumChannels: 1,
+ SampleRate: sampleRate,
+ ByteRate: sampleRate * 2,
+ BlockAlign: 2,
+ BitsPerSample: 16,
+ Subchunk2ID: [4]byte{'d', 'a', 't', 'a'},
+ Subchunk2Size: pcmLen,
+ }
+ header.ChunkSize = 36 + header.Subchunk2Size
+ return header
+}
+
+// WAVHeaderSize is the size of a standard PCM WAV header in bytes.
+const WAVHeaderSize = 44
+
+// StripWAVHeader removes a WAV header from audio data, returning raw PCM.
+// If the data is too short to contain a header, it is returned unchanged.
+func StripWAVHeader(data []byte) []byte {
+ if len(data) > WAVHeaderSize {
+ return data[WAVHeaderSize:]
+ }
+ return data
+}
+
+// ParseWAV strips the WAV header and returns the raw PCM along with the
+// sample rate read from the header. If the data is too short to contain a
+// valid header the PCM is returned as-is with sampleRate=0.
+func ParseWAV(data []byte) (pcm []byte, sampleRate int) {
+ if len(data) <= WAVHeaderSize {
+ return data, 0
+ }
+ sr := int(binary.LittleEndian.Uint32(data[24:28]))
+ return data[WAVHeaderSize:], sr
+}
diff --git a/pkg/audio/audio_suite_test.go b/pkg/audio/audio_suite_test.go
new file mode 100644
index 000000000000..9c3dd78635a6
--- /dev/null
+++ b/pkg/audio/audio_suite_test.go
@@ -0,0 +1,13 @@
+package audio
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestAudio(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Audio Suite")
+}
diff --git a/pkg/audio/audio_test.go b/pkg/audio/audio_test.go
new file mode 100644
index 000000000000..836aa27aeb48
--- /dev/null
+++ b/pkg/audio/audio_test.go
@@ -0,0 +1,99 @@
+package audio
+
+import (
+ "bytes"
+ "encoding/binary"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("WAV utilities", func() {
+ Describe("NewWAVHeader", func() {
+ It("produces a valid 44-byte header", func() {
+ hdr := NewWAVHeader(3200)
+ var buf bytes.Buffer
+ Expect(hdr.Write(&buf)).To(Succeed())
+ Expect(buf.Len()).To(Equal(WAVHeaderSize))
+
+ b := buf.Bytes()
+ Expect(string(b[0:4])).To(Equal("RIFF"))
+ Expect(string(b[8:12])).To(Equal("WAVE"))
+ Expect(string(b[12:16])).To(Equal("fmt "))
+
+ Expect(binary.LittleEndian.Uint16(b[20:22])).To(Equal(uint16(1))) // PCM
+ Expect(binary.LittleEndian.Uint16(b[22:24])).To(Equal(uint16(1))) // mono
+ Expect(binary.LittleEndian.Uint32(b[24:28])).To(Equal(uint32(16000)))
+ Expect(binary.LittleEndian.Uint32(b[28:32])).To(Equal(uint32(32000)))
+ Expect(string(b[36:40])).To(Equal("data"))
+ Expect(binary.LittleEndian.Uint32(b[40:44])).To(Equal(uint32(3200)))
+ })
+ })
+
+ Describe("NewWAVHeaderWithRate", func() {
+ It("uses the custom sample rate", func() {
+ hdr := NewWAVHeaderWithRate(4800, 24000)
+ var buf bytes.Buffer
+ Expect(hdr.Write(&buf)).To(Succeed())
+ b := buf.Bytes()
+
+ Expect(binary.LittleEndian.Uint32(b[24:28])).To(Equal(uint32(24000)))
+ Expect(binary.LittleEndian.Uint32(b[28:32])).To(Equal(uint32(48000)))
+ })
+ })
+
+ Describe("StripWAVHeader", func() {
+ It("strips the 44-byte header", func() {
+ pcm := []byte{0xDE, 0xAD, 0xBE, 0xEF}
+ hdr := NewWAVHeader(uint32(len(pcm)))
+ var buf bytes.Buffer
+ Expect(hdr.Write(&buf)).To(Succeed())
+ buf.Write(pcm)
+
+ got := StripWAVHeader(buf.Bytes())
+ Expect(got).To(Equal(pcm))
+ })
+
+ It("returns short data unchanged", func() {
+ short := []byte{0x01, 0x02, 0x03}
+ Expect(StripWAVHeader(short)).To(Equal(short))
+
+ exact := make([]byte, WAVHeaderSize)
+ Expect(StripWAVHeader(exact)).To(Equal(exact))
+ })
+ })
+
+ Describe("ParseWAV", func() {
+ It("returns sample rate and PCM data", func() {
+ pcm := make([]byte, 100)
+ for i := range pcm {
+ pcm[i] = byte(i)
+ }
+
+ hdr24 := NewWAVHeaderWithRate(uint32(len(pcm)), 24000)
+ var buf24 bytes.Buffer
+ hdr24.Write(&buf24)
+ buf24.Write(pcm)
+
+ gotPCM, gotRate := ParseWAV(buf24.Bytes())
+ Expect(gotRate).To(Equal(24000))
+ Expect(gotPCM).To(Equal(pcm))
+
+ hdr16 := NewWAVHeader(uint32(len(pcm)))
+ var buf16 bytes.Buffer
+ hdr16.Write(&buf16)
+ buf16.Write(pcm)
+
+ gotPCM, gotRate = ParseWAV(buf16.Bytes())
+ Expect(gotRate).To(Equal(16000))
+ Expect(gotPCM).To(Equal(pcm))
+ })
+
+ It("returns zero rate for short data", func() {
+ short := []byte{0x01, 0x02, 0x03}
+ gotPCM, gotRate := ParseWAV(short)
+ Expect(gotRate).To(Equal(0))
+ Expect(gotPCM).To(Equal(short))
+ })
+ })
+})
diff --git a/pkg/opus/opus.go b/pkg/opus/opus.go
new file mode 100644
index 000000000000..e4a670e84551
--- /dev/null
+++ b/pkg/opus/opus.go
@@ -0,0 +1,261 @@
+package opus
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+ "sync"
+
+ "github.com/ebitengine/purego"
+)
+
+const (
+ ApplicationVoIP = 2048
+ ApplicationAudio = 2049
+ ApplicationRestrictedLowDelay = 2051
+)
+
+var (
+ initOnce sync.Once
+ initErr error
+
+ opusLib uintptr
+ shimLib uintptr
+
+ // libopus functions
+ cEncoderCreate func(fs int32, channels int32, application int32, errPtr *int32) uintptr
+ cEncode func(st uintptr, pcm *int16, frameSize int32, data *byte, maxBytes int32) int32
+ cEncoderDestroy func(st uintptr)
+
+ cDecoderCreate func(fs int32, channels int32, errPtr *int32) uintptr
+ cDecode func(st uintptr, data *byte, dataLen int32, pcm *int16, frameSize int32, decodeFec int32) int32
+ cDecoderDestroy func(st uintptr)
+
+ // shim functions (non-variadic wrappers for opus_encoder_ctl)
+ cSetBitrate func(st uintptr, bitrate int32) int32
+ cSetComplexity func(st uintptr, complexity int32) int32
+)
+
+func loadLib(names []string) (uintptr, error) {
+ var firstErr error
+ for _, name := range names {
+ h, err := purego.Dlopen(name, purego.RTLD_NOW|purego.RTLD_GLOBAL)
+ if err == nil {
+ return h, nil
+ }
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+ return 0, firstErr
+}
+
+func ensureInit() error {
+ initOnce.Do(func() {
+ initErr = doInit()
+ })
+ return initErr
+}
+
+const shimHint = "ensure libopus-dev is installed and rebuild, or set OPUS_LIBRARY / OPUS_SHIM_LIBRARY env vars"
+
+func doInit() error {
+ opusNames := opusSearchPaths()
+ var err error
+ opusLib, err = loadLib(opusNames)
+ if err != nil {
+ return fmt.Errorf("opus: failed to load libopus (%s): %w", shimHint, err)
+ }
+
+ purego.RegisterLibFunc(&cEncoderCreate, opusLib, "opus_encoder_create")
+ purego.RegisterLibFunc(&cEncode, opusLib, "opus_encode")
+ purego.RegisterLibFunc(&cEncoderDestroy, opusLib, "opus_encoder_destroy")
+ purego.RegisterLibFunc(&cDecoderCreate, opusLib, "opus_decoder_create")
+ purego.RegisterLibFunc(&cDecode, opusLib, "opus_decode")
+ purego.RegisterLibFunc(&cDecoderDestroy, opusLib, "opus_decoder_destroy")
+
+ shimNames := shimSearchPaths()
+ shimLib, err = loadLib(shimNames)
+ if err != nil {
+ return fmt.Errorf("opus: failed to load libopusshim (%s): %w", shimHint, err)
+ }
+
+ purego.RegisterLibFunc(&cSetBitrate, shimLib, "opus_shim_encoder_set_bitrate")
+ purego.RegisterLibFunc(&cSetComplexity, shimLib, "opus_shim_encoder_set_complexity")
+
+ return nil
+}
+
+func opusSearchPaths() []string {
+ var paths []string
+
+ if env := os.Getenv("OPUS_LIBRARY"); env != "" {
+ paths = append(paths, env)
+ }
+
+ if exe, err := os.Executable(); err == nil {
+ dir := filepath.Dir(exe)
+ paths = append(paths, filepath.Join(dir, "libopus.so.0"), filepath.Join(dir, "libopus.so"))
+ if runtime.GOOS == "darwin" {
+ paths = append(paths, filepath.Join(dir, "libopus.dylib"))
+ }
+ }
+
+ paths = append(paths, "libopus.so.0", "libopus.so", "libopus.dylib", "opus.dll")
+
+ if runtime.GOOS == "darwin" {
+ paths = append(paths,
+ "/opt/homebrew/lib/libopus.dylib",
+ "/usr/local/lib/libopus.dylib",
+ )
+ }
+
+ return paths
+}
+
+func shimSearchPaths() []string {
+ var paths []string
+
+ if env := os.Getenv("OPUS_SHIM_LIBRARY"); env != "" {
+ paths = append(paths, env)
+ }
+
+ if exe, err := os.Executable(); err == nil {
+ dir := filepath.Dir(exe)
+ paths = append(paths, filepath.Join(dir, "libopusshim.so"))
+ if runtime.GOOS == "darwin" {
+ paths = append(paths, filepath.Join(dir, "libopusshim.dylib"))
+ }
+ }
+
+ paths = append(paths, "./libopusshim.so", "libopusshim.so")
+ if runtime.GOOS == "darwin" {
+ paths = append(paths, "./libopusshim.dylib", "libopusshim.dylib")
+ }
+ return paths
+}
+
+// Encoder wraps a libopus OpusEncoder via purego.
+type Encoder struct {
+ st uintptr
+}
+
+func NewEncoder(sampleRate, channels, application int) (*Encoder, error) {
+ if err := ensureInit(); err != nil {
+ return nil, err
+ }
+
+ var opusErr int32
+ st := cEncoderCreate(int32(sampleRate), int32(channels), int32(application), &opusErr)
+ if opusErr != 0 || st == 0 {
+ return nil, fmt.Errorf("opus_encoder_create failed: error %d", opusErr)
+ }
+ return &Encoder{st: st}, nil
+}
+
+// Encode encodes a frame of PCM int16 samples. It returns the number of bytes
+// written to out, or a negative error code.
+func (e *Encoder) Encode(pcm []int16, frameSize int, out []byte) (int, error) {
+ if len(pcm) == 0 || len(out) == 0 {
+ return 0, errors.New("opus encode: empty input or output buffer")
+ }
+ n := cEncode(e.st, &pcm[0], int32(frameSize), &out[0], int32(len(out)))
+ if n < 0 {
+ return 0, fmt.Errorf("opus_encode failed: error %d", n)
+ }
+ return int(n), nil
+}
+
+func (e *Encoder) SetBitrate(bitrate int) error {
+ if ret := cSetBitrate(e.st, int32(bitrate)); ret != 0 {
+ return fmt.Errorf("opus set bitrate: error %d", ret)
+ }
+ return nil
+}
+
+func (e *Encoder) SetComplexity(complexity int) error {
+ if ret := cSetComplexity(e.st, int32(complexity)); ret != 0 {
+ return fmt.Errorf("opus set complexity: error %d", ret)
+ }
+ return nil
+}
+
+func (e *Encoder) Close() {
+ if e.st != 0 {
+ cEncoderDestroy(e.st)
+ e.st = 0
+ }
+}
+
+// Decoder wraps a libopus OpusDecoder via purego.
+type Decoder struct {
+ st uintptr
+}
+
+func NewDecoder(sampleRate, channels int) (*Decoder, error) {
+ if err := ensureInit(); err != nil {
+ return nil, err
+ }
+
+ var opusErr int32
+ st := cDecoderCreate(int32(sampleRate), int32(channels), &opusErr)
+ if opusErr != 0 || st == 0 {
+ return nil, fmt.Errorf("opus_decoder_create failed: error %d", opusErr)
+ }
+ return &Decoder{st: st}, nil
+}
+
+// Decode decodes an Opus packet into pcm. frameSize is the max number of
+// samples per channel that pcm can hold. Returns the number of decoded samples
+// per channel.
+func (d *Decoder) Decode(data []byte, pcm []int16, frameSize int, fec bool) (int, error) {
+ if len(pcm) == 0 {
+ return 0, errors.New("opus decode: empty output buffer")
+ }
+
+ var dataPtr *byte
+ var dataLen int32
+ if len(data) > 0 {
+ dataPtr = &data[0]
+ dataLen = int32(len(data))
+ }
+
+ decodeFec := int32(0)
+ if fec {
+ decodeFec = 1
+ }
+
+ n := cDecode(d.st, dataPtr, dataLen, &pcm[0], int32(frameSize), decodeFec)
+ if n < 0 {
+ return 0, fmt.Errorf("opus_decode failed: error %d", n)
+ }
+ return int(n), nil
+}
+
+func (d *Decoder) Close() {
+ if d.st != 0 {
+ cDecoderDestroy(d.st)
+ d.st = 0
+ }
+}
+
+// Initialized reports whether the opus libraries were loaded successfully.
+func Initialized() bool {
+ return ensureInit() == nil
+}
+
+// Init eagerly loads the opus libraries, returning any error.
+// Calling this is optional; the libraries are loaded lazily on first use.
+func Init() error {
+ return ensureInit()
+}
+
+// Reset allows re-initialization (for testing).
+func Reset() {
+ initOnce = sync.Once{}
+ initErr = nil
+ opusLib = 0
+ shimLib = 0
+}
diff --git a/pkg/opus/shim/Makefile b/pkg/opus/shim/Makefile
new file mode 100644
index 000000000000..d9467fa39ff7
--- /dev/null
+++ b/pkg/opus/shim/Makefile
@@ -0,0 +1,10 @@
+OPUS_CFLAGS := $(shell pkg-config --cflags opus)
+OPUS_LIBS := $(shell pkg-config --libs opus)
+
+libopusshim.so: opus_shim.c
+ $(CC) -shared -fPIC -o $@ $< $(OPUS_CFLAGS) $(OPUS_LIBS)
+
+clean:
+ rm -f libopusshim.so
+
+.PHONY: clean
diff --git a/pkg/opus/shim/opus_shim.c b/pkg/opus/shim/opus_shim.c
new file mode 100644
index 000000000000..75d3babb4625
--- /dev/null
+++ b/pkg/opus/shim/opus_shim.c
@@ -0,0 +1,9 @@
+#include
+
+int opus_shim_encoder_set_bitrate(OpusEncoder *st, opus_int32 bitrate) {
+ return opus_encoder_ctl(st, OPUS_SET_BITRATE(bitrate));
+}
+
+int opus_shim_encoder_set_complexity(OpusEncoder *st, opus_int32 complexity) {
+ return opus_encoder_ctl(st, OPUS_SET_COMPLEXITY(complexity));
+}
diff --git a/pkg/sound/int16.go b/pkg/sound/int16.go
index f56aa14f9ebe..1b30827d2e9a 100644
--- a/pkg/sound/int16.go
+++ b/pkg/sound/int16.go
@@ -25,11 +25,23 @@ func CalculateRMS16(buffer []int16) float64 {
}
func ResampleInt16(input []int16, inputRate, outputRate int) []int16 {
+ if len(input) == 0 {
+ return nil
+ }
+ if inputRate == outputRate {
+ out := make([]int16, len(input))
+ copy(out, input)
+ return out
+ }
+
// Calculate the resampling ratio
ratio := float64(inputRate) / float64(outputRate)
// Calculate the length of the resampled output
outputLength := int(float64(len(input)) / ratio)
+ if outputLength <= 0 {
+ return []int16{input[0]}
+ }
// Allocate a slice for the resampled output
output := make([]int16, outputLength)
diff --git a/pkg/sound/int16_test.go b/pkg/sound/int16_test.go
new file mode 100644
index 000000000000..f803efda5ce0
--- /dev/null
+++ b/pkg/sound/int16_test.go
@@ -0,0 +1,120 @@
+package sound
+
+import (
+ "math"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+var _ = Describe("Int16 utilities", func() {
+ Describe("BytesToInt16sLE / Int16toBytesLE", func() {
+ It("round-trips correctly", func() {
+ values := []int16{0, 1, -1, 32767, -32768}
+ b := Int16toBytesLE(values)
+ got := BytesToInt16sLE(b)
+
+ Expect(got).To(Equal(values))
+ })
+
+ It("panics on odd-length input", func() {
+ Expect(func() {
+ BytesToInt16sLE([]byte{0x01, 0x02, 0x03})
+ }).To(Panic())
+ })
+
+ It("returns empty slice for empty bytes input", func() {
+ got := BytesToInt16sLE([]byte{})
+ Expect(got).To(BeEmpty())
+ })
+
+ It("returns empty slice for empty int16 input", func() {
+ got := Int16toBytesLE([]int16{})
+ Expect(got).To(BeEmpty())
+ })
+ })
+
+ Describe("ResampleInt16", func() {
+ It("returns identical output for same rate", func() {
+ src := generateSineWave(440, 16000, 320)
+ dst := ResampleInt16(src, 16000, 16000)
+
+ Expect(dst).To(Equal(src))
+ })
+
+ It("downsamples 48k to 16k", func() {
+ src := generateSineWave(440, 48000, 960)
+ dst := ResampleInt16(src, 48000, 16000)
+
+ Expect(dst).To(HaveLen(320))
+
+ freq := estimateFrequency(dst, 16000)
+ Expect(freq).To(BeNumerically("~", 440, 50))
+ })
+
+ It("upsamples 16k to 48k", func() {
+ src := generateSineWave(440, 16000, 320)
+ dst := ResampleInt16(src, 16000, 48000)
+
+ Expect(dst).To(HaveLen(960))
+
+ freq := estimateFrequency(dst, 48000)
+ Expect(freq).To(BeNumerically("~", 440, 50))
+ })
+
+ It("preserves quality through double resampling", func() {
+ src := generateSineWave(440, 48000, 4800) // 100ms
+
+ direct := ResampleInt16(src, 48000, 16000)
+
+ step1 := ResampleInt16(src, 48000, 24000)
+ double := ResampleInt16(step1, 24000, 16000)
+
+ minLen := len(direct)
+ if len(double) < minLen {
+ minLen = len(double)
+ }
+
+ corr := computeCorrelation(direct[:minLen], double[:minLen])
+ Expect(corr).To(BeNumerically(">=", 0.95))
+ })
+
+ It("handles single sample", func() {
+ src := []int16{1000}
+ got := ResampleInt16(src, 48000, 16000)
+ Expect(got).NotTo(BeEmpty())
+ Expect(got[0]).To(Equal(int16(1000)))
+ })
+
+ It("returns nil for empty input", func() {
+ got := ResampleInt16(nil, 48000, 16000)
+ Expect(got).To(BeNil())
+ })
+ })
+
+ Describe("CalculateRMS16", func() {
+ It("computes correct RMS for constant signal", func() {
+ buf := make([]int16, 1000)
+ for i := range buf {
+ buf[i] = 1000
+ }
+ rms := CalculateRMS16(buf)
+ Expect(rms).To(BeNumerically("~", 1000, 0.01))
+ })
+
+ It("returns zero for silence", func() {
+ buf := make([]int16, 1000)
+ rms := CalculateRMS16(buf)
+ Expect(rms).To(BeZero())
+ })
+
+ It("computes correct RMS for known sine wave", func() {
+ amplitude := float64(math.MaxInt16 / 2)
+ buf := generateSineWave(440, 16000, 16000) // 1 second
+ rms := CalculateRMS16(buf)
+ expectedRMS := amplitude / math.Sqrt(2)
+
+ Expect(rms).To(BeNumerically("~", expectedRMS, expectedRMS*0.02))
+ })
+ })
+})
diff --git a/pkg/sound/sound_suite_test.go b/pkg/sound/sound_suite_test.go
new file mode 100644
index 000000000000..5287aa95570d
--- /dev/null
+++ b/pkg/sound/sound_suite_test.go
@@ -0,0 +1,13 @@
+package sound
+
+import (
+ "testing"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+func TestSound(t *testing.T) {
+ RegisterFailHandler(Fail)
+ RunSpecs(t, "Sound Suite")
+}
diff --git a/pkg/sound/testutil_test.go b/pkg/sound/testutil_test.go
new file mode 100644
index 000000000000..0f044df68ec5
--- /dev/null
+++ b/pkg/sound/testutil_test.go
@@ -0,0 +1,72 @@
+package sound
+
+import "math"
+
+// generateSineWave produces a sine wave of the given frequency at the given sample rate.
+func generateSineWave(freq float64, sampleRate, numSamples int) []int16 {
+ out := make([]int16, numSamples)
+ for i := range out {
+ t := float64(i) / float64(sampleRate)
+ out[i] = int16(math.MaxInt16 / 2 * math.Sin(2*math.Pi*freq*t))
+ }
+ return out
+}
+
+// computeCorrelation returns the normalised Pearson correlation between two
+// equal-length int16 slices. Returns 0 when either signal has zero energy.
+func computeCorrelation(a, b []int16) float64 {
+ n := len(a)
+ if n == 0 || n != len(b) {
+ return 0
+ }
+ var sumAB, sumA2, sumB2 float64
+ for i := 0; i < n; i++ {
+ fa, fb := float64(a[i]), float64(b[i])
+ sumAB += fa * fb
+ sumA2 += fa * fa
+ sumB2 += fb * fb
+ }
+ denom := math.Sqrt(sumA2 * sumB2)
+ if denom == 0 {
+ return 0
+ }
+ return sumAB / denom
+}
+
+// estimateFrequency estimates the dominant frequency of a mono int16 signal
+// using zero-crossing count.
+func estimateFrequency(samples []int16, sampleRate int) float64 {
+ if len(samples) < 2 {
+ return 0
+ }
+ crossings := 0
+ for i := 1; i < len(samples); i++ {
+ if (samples[i-1] >= 0 && samples[i] < 0) || (samples[i-1] < 0 && samples[i] >= 0) {
+ crossings++
+ }
+ }
+ duration := float64(len(samples)) / float64(sampleRate)
+ // Each full cycle has 2 zero crossings.
+ return float64(crossings) / (2 * duration)
+}
+
+// computeRMS returns the root-mean-square of an int16 slice.
+func computeRMS(samples []int16) float64 {
+ if len(samples) == 0 {
+ return 0
+ }
+ var sum float64
+ for _, s := range samples {
+ v := float64(s)
+ sum += v * v
+ }
+ return math.Sqrt(sum / float64(len(samples)))
+}
+
+// generatePCMBytes creates a little-endian int16 PCM byte slice containing a
+// sine wave of the given frequency at the given sample rate and duration.
+func generatePCMBytes(freq float64, sampleRate, durationMs int) []byte {
+ numSamples := sampleRate * durationMs / 1000
+ samples := generateSineWave(freq, sampleRate, numSamples)
+ return Int16toBytesLE(samples)
+}
diff --git a/tests/e2e/e2e_suite_test.go b/tests/e2e/e2e_suite_test.go
index 66d9d6cd7ffb..d55e7d7ab469 100644
--- a/tests/e2e/e2e_suite_test.go
+++ b/tests/e2e/e2e_suite_test.go
@@ -87,10 +87,10 @@ var _ = BeforeSuite(func() {
Expect(os.Chmod(mockBackendPath, 0755)).To(Succeed())
// Create model config YAML
- modelConfig := map[string]interface{}{
+ modelConfig := map[string]any{
"name": "mock-model",
"backend": "mock-backend",
- "parameters": map[string]interface{}{
+ "parameters": map[string]any{
"model": "mock-model.bin",
},
}
@@ -99,11 +99,92 @@ var _ = BeforeSuite(func() {
Expect(err).ToNot(HaveOccurred())
Expect(os.WriteFile(configPath, configYAML, 0644)).To(Succeed())
- // Set up system state
- systemState, err := system.GetSystemState(
- system.WithBackendPath(backendPath),
+ // Create pipeline model configs for realtime API tests.
+ // Each component model uses the same mock-backend binary.
+ for _, name := range []string{"mock-vad", "mock-stt", "mock-llm", "mock-tts"} {
+ cfg := map[string]any{
+ "name": name,
+ "backend": "mock-backend",
+ "parameters": map[string]any{
+ "model": name + ".bin",
+ },
+ }
+ data, err := yaml.Marshal(cfg)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(os.WriteFile(filepath.Join(modelsPath, name+".yaml"), data, 0644)).To(Succeed())
+ }
+
+ // Pipeline model that wires the component models together.
+ pipelineCfg := map[string]any{
+ "name": "realtime-pipeline",
+ "pipeline": map[string]any{
+ "vad": "mock-vad",
+ "transcription": "mock-stt",
+ "llm": "mock-llm",
+ "tts": "mock-tts",
+ },
+ }
+ pipelineData, err := yaml.Marshal(pipelineCfg)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(os.WriteFile(filepath.Join(modelsPath, "realtime-pipeline.yaml"), pipelineData, 0644)).To(Succeed())
+
+ // If REALTIME_TEST_MODEL=realtime-test-pipeline, auto-create a pipeline
+ // config from the REALTIME_VAD/STT/LLM/TTS env vars so real-model tests
+ // can run without the user having to write a YAML file manually.
+ if os.Getenv("REALTIME_TEST_MODEL") == "realtime-test-pipeline" {
+ rtVAD := os.Getenv("REALTIME_VAD")
+ rtSTT := os.Getenv("REALTIME_STT")
+ rtLLM := os.Getenv("REALTIME_LLM")
+ rtTTS := os.Getenv("REALTIME_TTS")
+
+ if rtVAD != "" && rtSTT != "" && rtLLM != "" && rtTTS != "" {
+ testPipeline := map[string]any{
+ "name": "realtime-test-pipeline",
+ "pipeline": map[string]any{
+ "vad": rtVAD,
+ "transcription": rtSTT,
+ "llm": rtLLM,
+ "tts": rtTTS,
+ },
+ }
+ data, writeErr := yaml.Marshal(testPipeline)
+ Expect(writeErr).ToNot(HaveOccurred())
+ Expect(os.WriteFile(filepath.Join(modelsPath, "realtime-test-pipeline.yaml"), data, 0644)).To(Succeed())
+ xlog.Info("created realtime-test-pipeline",
+ "vad", rtVAD, "stt", rtSTT, "llm", rtLLM, "tts", rtTTS)
+ }
+ }
+
+ // Import model configs from an external directory (e.g. real model YAMLs
+ // and weights mounted into a container). Symlinks avoid copying large files.
+ if rtModels := os.Getenv("REALTIME_MODELS_PATH"); rtModels != "" {
+ entries, err := os.ReadDir(rtModels)
+ Expect(err).ToNot(HaveOccurred())
+ for _, entry := range entries {
+ src := filepath.Join(rtModels, entry.Name())
+ dst := filepath.Join(modelsPath, entry.Name())
+ if _, err := os.Stat(dst); err == nil {
+ continue // don't overwrite mock configs
+ }
+ if entry.IsDir() {
+ continue
+ }
+ Expect(os.Symlink(src, dst)).To(Succeed())
+ }
+ }
+
+ // Set up system state. When REALTIME_BACKENDS_PATH is set, use it so the
+ // application can discover real backend binaries for real-model tests.
+ systemOpts := []system.SystemStateOptions{
system.WithModelPath(modelsPath),
- )
+ }
+ if realBackends := os.Getenv("REALTIME_BACKENDS_PATH"); realBackends != "" {
+ systemOpts = append(systemOpts, system.WithBackendPath(realBackends))
+ } else {
+ systemOpts = append(systemOpts, system.WithBackendPath(backendPath))
+ }
+
+ systemState, err := system.GetSystemState(systemOpts...)
Expect(err).ToNot(HaveOccurred())
// Create application
@@ -120,7 +201,7 @@ var _ = BeforeSuite(func() {
)
Expect(err).ToNot(HaveOccurred())
- // Register backend with application's model loader
+ // Register mock backend (always available for non-realtime tests).
application.ModelLoader().SetExternalBackend("mock-backend", mockBackendPath)
// Create HTTP app
diff --git a/tests/e2e/mock-backend/main.go b/tests/e2e/mock-backend/main.go
index e94a7bf4266f..d39967f53ff5 100644
--- a/tests/e2e/mock-backend/main.go
+++ b/tests/e2e/mock-backend/main.go
@@ -7,9 +7,11 @@ import (
"flag"
"fmt"
"log"
+ "math"
"net"
"os"
"path/filepath"
+ "strconv"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
@@ -177,12 +179,28 @@ func (m *MockBackend) SoundGeneration(ctx context.Context, in *pb.SoundGeneratio
}, nil
}
-// writeMinimalWAV writes a minimal valid WAV file (short silence) so the HTTP handler can send it.
+// ttsSampleRate returns the sample rate to use for TTS output, configurable
+// via the MOCK_TTS_SAMPLE_RATE environment variable (default 16000).
+func ttsSampleRate() int {
+ if s := os.Getenv("MOCK_TTS_SAMPLE_RATE"); s != "" {
+ if v, err := strconv.Atoi(s); err == nil && v > 0 {
+ return v
+ }
+ }
+ return 16000
+}
+
+// writeMinimalWAV writes a WAV file containing a 440Hz sine wave (0.5s)
+// so that tests can verify audio integrity end-to-end. The sample rate
+// is configurable via MOCK_TTS_SAMPLE_RATE to test rate mismatch bugs.
func writeMinimalWAV(path string) error {
- const sampleRate = 16000
+ sampleRate := ttsSampleRate()
const numChannels = 1
const bitsPerSample = 16
- const numSamples = 1600 // 0.1s
+ const freq = 440.0
+ const durationSec = 0.5
+ numSamples := int(float64(sampleRate) * durationSec)
+
dataSize := numSamples * numChannels * (bitsPerSample / 8)
const headerLen = 44
f, err := os.Create(path)
@@ -203,23 +221,56 @@ func writeMinimalWAV(path string) error {
_ = binary.Write(f, binary.LittleEndian, uint32(sampleRate*numChannels*(bitsPerSample/8)))
_ = binary.Write(f, binary.LittleEndian, uint16(numChannels*(bitsPerSample/8)))
_ = binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
- // data chunk
+ // data chunk — 440Hz sine wave
_, _ = f.Write([]byte("data"))
_ = binary.Write(f, binary.LittleEndian, uint32(dataSize))
- _, _ = f.Write(make([]byte, dataSize))
+ for i := range numSamples {
+ t := float64(i) / float64(sampleRate)
+ sample := int16(math.MaxInt16 / 2 * math.Sin(2*math.Pi*freq*t))
+ _ = binary.Write(f, binary.LittleEndian, sample)
+ }
return nil
}
func (m *MockBackend) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) {
- xlog.Debug("AudioTranscription called")
+ dst := in.GetDst()
+ wavSR := 0
+ dataLen := 0
+ rms := 0.0
+
+ if dst != "" {
+ if data, err := os.ReadFile(dst); err == nil {
+ if len(data) >= 44 {
+ wavSR = int(binary.LittleEndian.Uint32(data[24:28]))
+ dataLen = int(binary.LittleEndian.Uint32(data[40:44]))
+
+ // Compute RMS of the PCM payload (16-bit LE samples)
+ pcm := data[44:]
+ var sumSq float64
+ nSamples := len(pcm) / 2
+ for i := range nSamples {
+ s := int16(pcm[2*i]) | int16(pcm[2*i+1])<<8
+ v := float64(s)
+ sumSq += v * v
+ }
+ if nSamples > 0 {
+ rms = math.Sqrt(sumSq / float64(nSamples))
+ }
+ }
+ }
+ }
+
+ xlog.Debug("AudioTranscription called", "dst", dst, "wav_sample_rate", wavSR, "data_len", dataLen, "rms", rms)
+
+ text := fmt.Sprintf("transcribed: rms=%.1f samples=%d sr=%d", rms, dataLen/2, wavSR)
return &pb.TranscriptResult{
- Text: "This is a mocked transcription.",
+ Text: text,
Segments: []*pb.TranscriptSegment{
{
Id: 0,
Start: 0,
End: 3000,
- Text: "This is a mocked transcription.",
+ Text: text,
Tokens: []int32{1, 2, 3, 4, 5, 6},
},
},
@@ -349,16 +400,30 @@ func (m *MockBackend) GetMetrics(ctx context.Context, in *pb.MetricsRequest) (*p
}
func (m *MockBackend) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, error) {
- xlog.Debug("VAD called", "audio_length", len(in.Audio))
+ // Compute RMS of the received float32 audio to decide whether speech is present.
+ var sumSq float64
+ for _, s := range in.Audio {
+ v := float64(s)
+ sumSq += v * v
+ }
+ rms := 0.0
+ if len(in.Audio) > 0 {
+ rms = math.Sqrt(sumSq / float64(len(in.Audio)))
+ }
+ xlog.Debug("VAD called", "audio_length", len(in.Audio), "rms", rms)
+
+ // If audio is near-silence, return no segments (no speech detected).
+ if rms < 0.001 {
+ return &pb.VADResponse{}, nil
+ }
+
+ // Audio has signal — return a single segment covering the duration.
+ duration := float64(len(in.Audio)) / 16000.0
return &pb.VADResponse{
Segments: []*pb.VADSegment{
{
Start: 0.0,
- End: 1.5,
- },
- {
- Start: 2.0,
- End: 3.5,
+ End: float32(duration),
},
},
}, nil
diff --git a/tests/e2e/realtime_webrtc_test.go b/tests/e2e/realtime_webrtc_test.go
new file mode 100644
index 000000000000..d04c55e5941f
--- /dev/null
+++ b/tests/e2e/realtime_webrtc_test.go
@@ -0,0 +1,459 @@
+package e2e_test
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "math"
+ "net/http"
+ "os"
+ "sync"
+ "time"
+
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+ "github.com/pion/webrtc/v4"
+ "github.com/pion/webrtc/v4/pkg/media"
+)
+
+// --- WebRTC test client ---
+
+type webrtcTestClient struct {
+ pc *webrtc.PeerConnection
+ dc *webrtc.DataChannel
+ sendTrack *webrtc.TrackLocalStaticSample
+
+ events chan map[string]any
+ audioData chan []byte // raw Opus frames received
+
+ dcOpen chan struct{} // closed when data channel opens
+ mu sync.Mutex
+}
+
+func newWebRTCTestClient() *webrtcTestClient {
+ m := &webrtc.MediaEngine{}
+ Expect(m.RegisterDefaultCodecs()).To(Succeed())
+
+ api := webrtc.NewAPI(webrtc.WithMediaEngine(m))
+
+ pc, err := api.NewPeerConnection(webrtc.Configuration{})
+ Expect(err).ToNot(HaveOccurred())
+
+ // Create outbound audio track (Opus)
+ sendTrack, err := webrtc.NewTrackLocalStaticSample(
+ webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeOpus},
+ "audio-client",
+ "test-client",
+ )
+ Expect(err).ToNot(HaveOccurred())
+
+ rtpSender, err := pc.AddTrack(sendTrack)
+ Expect(err).ToNot(HaveOccurred())
+
+ // Drain RTCP
+ go func() {
+ buf := make([]byte, 1500)
+ for {
+ if _, _, err := rtpSender.Read(buf); err != nil {
+ return
+ }
+ }
+ }()
+
+ // Create the "oai-events" data channel (must be created by client)
+ dc, err := pc.CreateDataChannel("oai-events", nil)
+ Expect(err).ToNot(HaveOccurred())
+
+ c := &webrtcTestClient{
+ pc: pc,
+ dc: dc,
+ sendTrack: sendTrack,
+ events: make(chan map[string]any, 256),
+ audioData: make(chan []byte, 4096),
+ dcOpen: make(chan struct{}),
+ }
+
+ dc.OnOpen(func() {
+ close(c.dcOpen)
+ })
+
+ dc.OnMessage(func(msg webrtc.DataChannelMessage) {
+ var evt map[string]any
+ if err := json.Unmarshal(msg.Data, &evt); err == nil {
+ c.events <- evt
+ }
+ })
+
+ // Collect incoming audio tracks
+ pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
+ for {
+ pkt, _, err := track.ReadRTP()
+ if err != nil {
+ return
+ }
+ c.audioData <- pkt.Payload
+ }
+ })
+
+ return c
+}
+
+// connect performs SDP exchange with the server and waits for the data channel to open.
+func (c *webrtcTestClient) connect(model string) {
+ offer, err := c.pc.CreateOffer(nil)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(c.pc.SetLocalDescription(offer)).To(Succeed())
+
+ // Wait for ICE gathering
+ gatherDone := webrtc.GatheringCompletePromise(c.pc)
+ select {
+ case <-gatherDone:
+ case <-time.After(10 * time.Second):
+ Fail("ICE gathering timed out")
+ }
+
+ localDesc := c.pc.LocalDescription()
+ Expect(localDesc).ToNot(BeNil())
+
+ // POST to /v1/realtime/calls
+ reqBody, err := json.Marshal(map[string]string{
+ "sdp": localDesc.SDP,
+ "model": model,
+ })
+ Expect(err).ToNot(HaveOccurred())
+
+ resp, err := http.Post(
+ fmt.Sprintf("http://127.0.0.1:%d/v1/realtime/calls", apiPort),
+ "application/json",
+ bytes.NewReader(reqBody),
+ )
+ Expect(err).ToNot(HaveOccurred())
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ Expect(err).ToNot(HaveOccurred())
+ Expect(resp.StatusCode).To(Equal(http.StatusCreated),
+ "expected 201, got %d: %s", resp.StatusCode, string(body))
+
+ var callResp struct {
+ SDP string `json:"sdp"`
+ SessionID string `json:"session_id"`
+ }
+ Expect(json.Unmarshal(body, &callResp)).To(Succeed())
+ Expect(callResp.SDP).ToNot(BeEmpty())
+
+ // Set the answer
+ Expect(c.pc.SetRemoteDescription(webrtc.SessionDescription{
+ Type: webrtc.SDPTypeAnswer,
+ SDP: callResp.SDP,
+ })).To(Succeed())
+
+ // Wait for data channel to open
+ Eventually(c.dcOpen, 15*time.Second).Should(BeClosed())
+}
+
+// sendEvent sends a JSON event via the data channel.
+func (c *webrtcTestClient) sendEvent(event any) {
+ data, err := json.Marshal(event)
+ ExpectWithOffset(1, err).ToNot(HaveOccurred())
+ ExpectWithOffset(1, c.dc.Send(data)).To(Succeed())
+}
+
+// readEvent reads the next event from the data channel with timeout.
+func (c *webrtcTestClient) readEvent(timeout time.Duration) map[string]any {
+ select {
+ case evt := <-c.events:
+ return evt
+ case <-time.After(timeout):
+ Fail("timed out reading event from data channel")
+ return nil
+ }
+}
+
+// drainUntilEvent reads events until one with the given type appears.
+func (c *webrtcTestClient) drainUntilEvent(eventType string, timeout time.Duration) map[string]any {
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ remaining := time.Until(deadline)
+ if remaining <= 0 {
+ break
+ }
+ evt := c.readEvent(remaining)
+ if evt["type"] == eventType {
+ return evt
+ }
+ }
+ Fail("timed out waiting for event: " + eventType)
+ return nil
+}
+
+// sendSineWave encodes a sine wave to Opus and sends it over the audio track.
+// This is a simplified version that sends raw PCM wrapped as Opus-compatible
+// media samples. In a real client the Opus encoder would be used.
+func (c *webrtcTestClient) sendSilence(durationMs int) {
+ // Send silence as zero-filled PCM samples via track.
+ // We use 20ms Opus frames at 48kHz.
+ framesNeeded := durationMs / 20
+ // Minimal valid Opus silence frame (Opus DTX/silence)
+ silenceFrame := make([]byte, 3)
+ silenceFrame[0] = 0xF8 // Config: CELT-only, no VAD, 20ms frame
+ silenceFrame[1] = 0xFF
+ silenceFrame[2] = 0xFE
+
+ for range framesNeeded {
+ _ = c.sendTrack.WriteSample(media.Sample{
+ Data: silenceFrame,
+ Duration: 20 * time.Millisecond,
+ })
+ time.Sleep(5 * time.Millisecond)
+ }
+}
+
+func (c *webrtcTestClient) close() {
+ if c.pc != nil {
+ c.pc.Close()
+ }
+}
+
+// --- Tests ---
+
+var _ = Describe("Realtime WebRTC API", Label("Realtime"), func() {
+ Context("Signaling", func() {
+ It("should complete SDP exchange and receive session.created", func() {
+ client := newWebRTCTestClient()
+ defer client.close()
+
+ client.connect(pipelineModel())
+
+ evt := client.readEvent(30 * time.Second)
+ Expect(evt["type"]).To(Equal("session.created"))
+
+ session, ok := evt["session"].(map[string]any)
+ Expect(ok).To(BeTrue())
+ Expect(session["id"]).ToNot(BeEmpty())
+ })
+ })
+
+ Context("Event exchange via DataChannel", func() {
+ It("should handle session.update", func() {
+ client := newWebRTCTestClient()
+ defer client.close()
+
+ client.connect(pipelineModel())
+
+ // Read session.created
+ created := client.readEvent(30 * time.Second)
+ Expect(created["type"]).To(Equal("session.created"))
+
+ // Disable VAD
+ client.sendEvent(disableVADEvent())
+
+ updated := client.drainUntilEvent("session.updated", 10*time.Second)
+ Expect(updated).ToNot(BeNil())
+ })
+
+ It("should handle conversation.item.create and response.create", func() {
+ client := newWebRTCTestClient()
+ defer client.close()
+
+ client.connect(pipelineModel())
+
+ created := client.readEvent(30 * time.Second)
+ Expect(created["type"]).To(Equal("session.created"))
+
+ // Disable VAD
+ client.sendEvent(disableVADEvent())
+ client.drainUntilEvent("session.updated", 10*time.Second)
+
+ // Create text item
+ client.sendEvent(map[string]any{
+ "type": "conversation.item.create",
+ "item": map[string]any{
+ "type": "message",
+ "role": "user",
+ "content": []map[string]any{
+ {
+ "type": "input_text",
+ "text": "Hello from WebRTC",
+ },
+ },
+ },
+ })
+
+ added := client.drainUntilEvent("conversation.item.added", 10*time.Second)
+ Expect(added).ToNot(BeNil())
+
+ // Trigger response
+ client.sendEvent(map[string]any{
+ "type": "response.create",
+ })
+
+ done := client.drainUntilEvent("response.done", 60*time.Second)
+ Expect(done).ToNot(BeNil())
+ })
+ })
+
+ Context("Audio track", func() {
+ It("should receive audio on the incoming track after TTS", Label("real-models"), func() {
+ if os.Getenv("REALTIME_TEST_MODEL") == "" {
+ Skip("REALTIME_TEST_MODEL not set")
+ }
+
+ client := newWebRTCTestClient()
+ defer client.close()
+
+ client.connect(pipelineModel())
+
+ created := client.readEvent(30 * time.Second)
+ Expect(created["type"]).To(Equal("session.created"))
+
+ // Disable VAD
+ client.sendEvent(disableVADEvent())
+ client.drainUntilEvent("session.updated", 10*time.Second)
+
+ // Send text and trigger response with TTS
+ client.sendEvent(map[string]any{
+ "type": "conversation.item.create",
+ "item": map[string]any{
+ "type": "message",
+ "role": "user",
+ "content": []map[string]any{
+ {
+ "type": "input_text",
+ "text": "Say hello",
+ },
+ },
+ },
+ })
+ client.drainUntilEvent("conversation.item.added", 10*time.Second)
+
+ client.sendEvent(map[string]any{
+ "type": "response.create",
+ })
+
+ // Collect audio frames while waiting for response.done
+ var audioFrames [][]byte
+ deadline := time.Now().Add(60 * time.Second)
+ loop:
+ for time.Now().Before(deadline) {
+ select {
+ case frame := <-client.audioData:
+ audioFrames = append(audioFrames, frame)
+ case evt := <-client.events:
+ if evt["type"] == "response.done" {
+ break loop
+ }
+ case <-time.After(time.Until(deadline)):
+ break loop
+ }
+ }
+
+ // We should have received some audio frames
+ Expect(len(audioFrames)).To(BeNumerically(">", 0),
+ "expected to receive audio frames on the WebRTC track")
+ })
+ })
+
+ Context("Disconnect cleanup", func() {
+ It("should handle repeated connect/disconnect cycles", func() {
+ for i := range 3 {
+ By(fmt.Sprintf("Cycle %d", i+1))
+ client := newWebRTCTestClient()
+ client.connect(pipelineModel())
+
+ evt := client.readEvent(30 * time.Second)
+ Expect(evt["type"]).To(Equal("session.created"))
+
+ client.close()
+ // Brief pause to let server clean up
+ time.Sleep(500 * time.Millisecond)
+ }
+ })
+ })
+
+ Context("Audio integrity", Label("real-models"), func() {
+ It("should receive recognizable audio from TTS through WebRTC", func() {
+ if os.Getenv("REALTIME_TEST_MODEL") == "" {
+ Skip("REALTIME_TEST_MODEL not set")
+ }
+
+ client := newWebRTCTestClient()
+ defer client.close()
+
+ client.connect(pipelineModel())
+
+ created := client.readEvent(30 * time.Second)
+ Expect(created["type"]).To(Equal("session.created"))
+
+ // Disable VAD
+ client.sendEvent(disableVADEvent())
+ client.drainUntilEvent("session.updated", 10*time.Second)
+
+ // Create text item and trigger response
+ client.sendEvent(map[string]any{
+ "type": "conversation.item.create",
+ "item": map[string]any{
+ "type": "message",
+ "role": "user",
+ "content": []map[string]any{
+ {
+ "type": "input_text",
+ "text": "Say hello",
+ },
+ },
+ },
+ })
+ client.drainUntilEvent("conversation.item.added", 10*time.Second)
+
+ client.sendEvent(map[string]any{
+ "type": "response.create",
+ })
+
+ // Collect Opus frames and decode them
+ var totalBytes int
+ deadline := time.Now().Add(60 * time.Second)
+ loop:
+ for time.Now().Before(deadline) {
+ select {
+ case frame := <-client.audioData:
+ totalBytes += len(frame)
+ case evt := <-client.events:
+ if evt["type"] == "response.done" {
+ // Drain any remaining audio
+ time.Sleep(200 * time.Millisecond)
+ drainAudio:
+ for {
+ select {
+ case frame := <-client.audioData:
+ totalBytes += len(frame)
+ default:
+ break drainAudio
+ }
+ }
+ break loop
+ }
+ case <-time.After(time.Until(deadline)):
+ break loop
+ }
+ }
+
+ // Verify we received meaningful audio data
+ Expect(totalBytes).To(BeNumerically(">", 100),
+ "expected to receive meaningful audio data")
+ })
+ })
+})
+
+// computeRMSInt16 computes RMS of int16 samples (used by audio integrity tests).
+func computeRMSInt16(samples []int16) float64 {
+ if len(samples) == 0 {
+ return 0
+ }
+ var sum float64
+ for _, s := range samples {
+ v := float64(s)
+ sum += v * v
+ }
+ return math.Sqrt(sum / float64(len(samples)))
+}
diff --git a/tests/e2e/realtime_ws_test.go b/tests/e2e/realtime_ws_test.go
new file mode 100644
index 000000000000..c69186f3345e
--- /dev/null
+++ b/tests/e2e/realtime_ws_test.go
@@ -0,0 +1,269 @@
+package e2e_test
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "math"
+ "net/url"
+ "os"
+ "time"
+
+ "github.com/gorilla/websocket"
+ . "github.com/onsi/ginkgo/v2"
+ . "github.com/onsi/gomega"
+)
+
+// --- WebSocket test helpers ---
+
+func connectWS(model string) *websocket.Conn {
+ u := url.URL{
+ Scheme: "ws",
+ Host: fmt.Sprintf("127.0.0.1:%d", apiPort),
+ Path: "/v1/realtime",
+ RawQuery: "model=" + url.QueryEscape(model),
+ }
+ conn, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
+ ExpectWithOffset(1, err).ToNot(HaveOccurred(), "websocket dial failed")
+ if resp != nil && resp.Body != nil {
+ resp.Body.Close()
+ }
+ return conn
+}
+
+func readServerEvent(conn *websocket.Conn, timeout time.Duration) map[string]any {
+ conn.SetReadDeadline(time.Now().Add(timeout))
+ _, msg, err := conn.ReadMessage()
+ ExpectWithOffset(1, err).ToNot(HaveOccurred(), "read server event")
+ var evt map[string]any
+ ExpectWithOffset(1, json.Unmarshal(msg, &evt)).To(Succeed())
+ return evt
+}
+
+func sendClientEvent(conn *websocket.Conn, event any) {
+ data, err := json.Marshal(event)
+ ExpectWithOffset(1, err).ToNot(HaveOccurred())
+ ExpectWithOffset(1, conn.WriteMessage(websocket.TextMessage, data)).To(Succeed())
+}
+
+// drainUntil reads events until it finds one with the given type, or times out.
+func drainUntil(conn *websocket.Conn, eventType string, timeout time.Duration) map[string]any {
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ evt := readServerEvent(conn, time.Until(deadline))
+ if evt["type"] == eventType {
+ return evt
+ }
+ }
+ Fail("timed out waiting for event: " + eventType)
+ return nil
+}
+
+// generatePCMBase64 creates base64-encoded 16-bit LE PCM of a sine wave.
+func generatePCMBase64(freq float64, sampleRate, durationMs int) string {
+ numSamples := sampleRate * durationMs / 1000
+ pcm := make([]byte, numSamples*2)
+ for i := range numSamples {
+ t := float64(i) / float64(sampleRate)
+ sample := int16(math.MaxInt16 / 2 * math.Sin(2*math.Pi*freq*t))
+ pcm[2*i] = byte(sample)
+ pcm[2*i+1] = byte(sample >> 8)
+ }
+ return base64.StdEncoding.EncodeToString(pcm)
+}
+
+// pipelineModel returns the model name to use for realtime tests.
+func pipelineModel() string {
+ if m := os.Getenv("REALTIME_TEST_MODEL"); m != "" {
+ return m
+ }
+ return "realtime-pipeline"
+}
+
+// disableVADEvent returns a session.update event that disables server VAD.
+func disableVADEvent() map[string]any {
+ return map[string]any{
+ "type": "session.update",
+ "session": map[string]any{
+ "audio": map[string]any{
+ "input": map[string]any{
+ "turn_detection": nil,
+ },
+ },
+ },
+ }
+}
+
+// --- Tests ---
+
+var _ = Describe("Realtime WebSocket API", Label("Realtime"), func() {
+ Context("Session management", func() {
+ It("should return session.created on connect", func() {
+ conn := connectWS(pipelineModel())
+ defer conn.Close()
+
+ evt := readServerEvent(conn, 30*time.Second)
+ Expect(evt["type"]).To(Equal("session.created"))
+
+ session, ok := evt["session"].(map[string]any)
+ Expect(ok).To(BeTrue(), "session field should be an object")
+ Expect(session["id"]).ToNot(BeEmpty())
+ })
+
+ It("should return session.updated after session.update", func() {
+ conn := connectWS(pipelineModel())
+ defer conn.Close()
+
+ // Read session.created
+ created := readServerEvent(conn, 30*time.Second)
+ Expect(created["type"]).To(Equal("session.created"))
+
+ // Send session.update to disable VAD
+ sendClientEvent(conn, disableVADEvent())
+
+ evt := drainUntil(conn, "session.updated", 10*time.Second)
+ Expect(evt["type"]).To(Equal("session.updated"))
+ })
+ })
+
+ Context("Manual audio commit", func() {
+ It("should produce a response with audio when audio is committed", func() {
+ conn := connectWS(pipelineModel())
+ defer conn.Close()
+
+ // Read session.created
+ created := readServerEvent(conn, 30*time.Second)
+ Expect(created["type"]).To(Equal("session.created"))
+
+ // Disable server VAD so we can manually commit
+ sendClientEvent(conn, disableVADEvent())
+ drainUntil(conn, "session.updated", 10*time.Second)
+
+ // Append 1 second of 440Hz sine wave at 24kHz (the default remote sample rate)
+ audio := generatePCMBase64(440, 24000, 1000)
+ sendClientEvent(conn, map[string]any{
+ "type": "input_audio_buffer.append",
+ "audio": audio,
+ })
+
+ // Commit the audio buffer
+ sendClientEvent(conn, map[string]any{
+ "type": "input_audio_buffer.commit",
+ })
+
+ // We should receive the response event sequence.
+ // The exact events depend on the pipeline, but we expect at least:
+ // - input_audio_buffer.committed
+ // - conversation.item.input_audio_transcription.completed
+ // - response.output_audio.delta (with base64 audio)
+ // - response.done
+
+ committed := drainUntil(conn, "input_audio_buffer.committed", 30*time.Second)
+ Expect(committed).ToNot(BeNil())
+
+ // Wait for the full response cycle to complete
+ done := drainUntil(conn, "response.done", 60*time.Second)
+ Expect(done).ToNot(BeNil())
+ })
+ })
+
+ Context("Text conversation item", func() {
+ It("should create a text item and trigger a response", func() {
+ conn := connectWS(pipelineModel())
+ defer conn.Close()
+
+ // Read session.created
+ created := readServerEvent(conn, 30*time.Second)
+ Expect(created["type"]).To(Equal("session.created"))
+
+ // Disable VAD
+ sendClientEvent(conn, disableVADEvent())
+ drainUntil(conn, "session.updated", 10*time.Second)
+
+ // Create a text conversation item
+ sendClientEvent(conn, map[string]any{
+ "type": "conversation.item.create",
+ "item": map[string]any{
+ "type": "message",
+ "role": "user",
+ "content": []map[string]any{
+ {
+ "type": "input_text",
+ "text": "Hello, how are you?",
+ },
+ },
+ },
+ })
+
+ // Wait for item to be added
+ added := drainUntil(conn, "conversation.item.added", 10*time.Second)
+ Expect(added).ToNot(BeNil())
+
+ // Trigger a response
+ sendClientEvent(conn, map[string]any{
+ "type": "response.create",
+ })
+
+ // Wait for response to complete
+ done := drainUntil(conn, "response.done", 60*time.Second)
+ Expect(done).ToNot(BeNil())
+ })
+ })
+
+ Context("Audio integrity", func() {
+ It("should return non-empty audio data in response.output_audio.delta", Label("real-models"), func() {
+ if os.Getenv("REALTIME_TEST_MODEL") == "" {
+ Skip("REALTIME_TEST_MODEL not set")
+ }
+
+ conn := connectWS(pipelineModel())
+ defer conn.Close()
+
+ created := readServerEvent(conn, 30*time.Second)
+ Expect(created["type"]).To(Equal("session.created"))
+
+ // Disable VAD
+ sendClientEvent(conn, disableVADEvent())
+ drainUntil(conn, "session.updated", 10*time.Second)
+
+ // Create a text item and trigger response
+ sendClientEvent(conn, map[string]any{
+ "type": "conversation.item.create",
+ "item": map[string]any{
+ "type": "message",
+ "role": "user",
+ "content": []map[string]any{
+ {
+ "type": "input_text",
+ "text": "Say hello",
+ },
+ },
+ },
+ })
+ drainUntil(conn, "conversation.item.added", 10*time.Second)
+
+ sendClientEvent(conn, map[string]any{
+ "type": "response.create",
+ })
+
+ // Collect audio deltas
+ var totalAudioBytes int
+ deadline := time.Now().Add(60 * time.Second)
+ for time.Now().Before(deadline) {
+ evt := readServerEvent(conn, time.Until(deadline))
+ if evt["type"] == "response.output_audio.delta" {
+ if delta, ok := evt["delta"].(string); ok {
+ decoded, err := base64.StdEncoding.DecodeString(delta)
+ Expect(err).ToNot(HaveOccurred())
+ totalAudioBytes += len(decoded)
+ }
+ }
+ if evt["type"] == "response.done" {
+ break
+ }
+ }
+
+ Expect(totalAudioBytes).To(BeNumerically(">", 0), "expected non-empty audio in response")
+ })
+ })
+})