diff --git a/.gitignore b/.gitignore index 5f373f3..28ff0e5 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ dist .DS_Store packages/sdk/tests/e2e-output .env +package-lock.json diff --git a/examples/react-vite/package.json b/examples/react-vite/package.json index 77f4f52..81c5646 100644 --- a/examples/react-vite/package.json +++ b/examples/react-vite/package.json @@ -9,6 +9,7 @@ "preview": "vite preview" }, "dependencies": { + "@aws/ivs-web-broadcast": "^1.0.0", "@decartai/sdk": "workspace:*", "react": "^19.0.0", "react-dom": "^19.0.0" diff --git a/examples/react-vite/src/App.tsx b/examples/react-vite/src/App.tsx index cc620a3..3083391 100644 --- a/examples/react-vite/src/App.tsx +++ b/examples/react-vite/src/App.tsx @@ -3,6 +3,7 @@ import { VideoStream } from "./components/VideoStream"; function App() { const [prompt, setPrompt] = useState("anime style, vibrant colors"); + const [transport, setTransport] = useState<"webrtc" | "ivs">("webrtc"); return (
@@ -20,7 +21,21 @@ function App() {
- +
+ +
+ + ); } diff --git a/examples/react-vite/src/components/VideoStream.tsx b/examples/react-vite/src/components/VideoStream.tsx index b9dbc75..7654d93 100644 --- a/examples/react-vite/src/components/VideoStream.tsx +++ b/examples/react-vite/src/components/VideoStream.tsx @@ -3,9 +3,10 @@ import { useEffect, useRef, useState } from "react"; interface VideoStreamProps { prompt: string; + transport?: "webrtc" | "ivs"; } -export function VideoStream({ prompt }: VideoStreamProps) { +export function VideoStream({ prompt, transport = "webrtc" }: VideoStreamProps) { const inputRef = useRef(null); const outputRef = useRef(null); const realtimeClientRef = useRef(null); @@ -33,8 +34,6 @@ export function VideoStream({ prompt }: VideoStreamProps) { inputRef.current.srcObject = stream; } - setStatus("connecting..."); - const apiKey = import.meta.env.VITE_DECART_API_KEY; if (!apiKey) { throw new Error("DECART_API_KEY is not set"); @@ -44,8 +43,11 @@ export function VideoStream({ prompt }: VideoStreamProps) { apiKey, }); + setStatus(`connecting via ${transport}...`); + const realtimeClient = await client.realtime.connect(stream, { model, + transport, onRemoteStream: (transformedStream: MediaStream) => { if (outputRef.current) { outputRef.current.srcObject = transformedStream; diff --git a/packages/sdk/index.html b/packages/sdk/index.html index dcc9b62..21d05c1 100644 --- a/packages/sdk/index.html +++ b/packages/sdk/index.html @@ -241,6 +241,13 @@

Configuration

+
+ + +
@@ -377,6 +384,7 @@

Console Logs

const elements = { apiKey: document.getElementById('api-key'), modelSelect: document.getElementById('model-select'), + transportSelect: document.getElementById('transport-select'), realtimeBaseUrl: document.getElementById('realtime-base-url'), initialPrompt: document.getElementById('initial-prompt'), startCamera: document.getElementById('start-camera'), @@ -443,6 +451,27 @@

Console Logs

elements.statusText.textContent = status.charAt(0).toUpperCase() + status.slice(1); } + // Transport selection handler — load IVS SDK from CDN when IVS is selected + let ivsScriptLoaded = false; + elements.transportSelect.addEventListener('change', (e) => { + const transport = e.target.value; + addLog(`Selected transport: ${transport}`, 'info'); + + if (transport === 'ivs' && !ivsScriptLoaded) { + addLog('Loading IVS Web Broadcast SDK from CDN...', 'info'); + const script = document.createElement('script'); + script.src = 'https://web-broadcast.live-video.net/1.14.0/amazon-ivs-web-broadcast.js'; + script.onload = () => { + ivsScriptLoaded = true; + addLog('IVS Web Broadcast SDK loaded', 'success'); + }; + script.onerror = () => { + addLog('Failed to load IVS SDK — IVS transport will not work', 'error'); + }; + document.head.appendChild(script); + } + }); + // Model selection handler elements.modelSelect.addEventListener('change', (e) => { const selectedModel = e.target.value; @@ -527,8 +556,12 @@

Console Logs

initialImage = await initialImageResponse.blob(); } + const selectedTransport = elements.transportSelect.value; + addLog(`Connecting with transport: ${selectedTransport}`, 'info'); + decartRealtime = await decartClient.realtime.connect(localStream, { model, + transport: selectedTransport, onRemoteStream: (stream) => { addLog('Received remote stream from Decart', 'success'); elements.remoteVideo.srcObject = stream; diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 412679a..36c3171 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -59,5 +59,13 @@ "mitt": "^3.0.1", "p-retry": "^6.2.1", "zod": "^4.0.17" + }, + "peerDependencies": { + "@aws/ivs-web-broadcast": ">=1.14.0" + }, + "peerDependenciesMeta": { + "@aws/ivs-web-broadcast": { + "optional": true + } } } diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index caddfa2..f2107c7 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -24,6 +24,8 @@ export type { RealTimeClientConnectOptions, RealTimeClientInitialState, } from "./realtime/client"; +export type { CompositeLatencyEstimate } from "./realtime/composite-latency"; +export type { PixelLatencyMeasurement } from "./realtime/pixel-latency"; export type { ConnectionPhase, DiagnosticEvent, diff --git a/packages/sdk/src/realtime/client.ts b/packages/sdk/src/realtime/client.ts index f1869b0..91c25b7 100644 --- a/packages/sdk/src/realtime/client.ts +++ b/packages/sdk/src/realtime/client.ts @@ -4,8 +4,9 @@ import { modelStateSchema } from "../shared/types"; import { classifyWebrtcError, type DecartSDKError } from "../utils/errors"; import type { Logger } from "../utils/logger"; import { AudioStreamManager } from "./audio-stream-manager"; -import type { DiagnosticEvent } from "./diagnostics"; +import type { DiagnosticEmitter, DiagnosticEvent } from "./diagnostics"; import { createEventBuffer } from "./event-buffer"; +import { IVSManager } from "./ivs-manager"; import { realtimeMethods, type SetInput } from "./methods"; import { decodeSubscribeToken, @@ -15,8 +16,13 @@ import { type SubscribeOptions, } from "./subscribe-client"; import { type ITelemetryReporter, NullTelemetryReporter, TelemetryReporter } from "./telemetry-reporter"; +import type { RealtimeTransportManager } from "./transport-manager"; +import type { CompositeLatencyEstimate } from "./composite-latency"; +import type { PixelLatencyMeasurement } from "./pixel-latency"; +import { LatencyDiagnostics } from "./latency-diagnostics"; import type { ConnectionState, GenerationTickMessage, SessionIdMessage } from "./types"; import { WebRTCManager } from "./webrtc-manager"; +import { IVSStatsCollector } from "./ivs-stats-collector"; import { type WebRTCStats, WebRTCStatsCollector } from "./webrtc-stats"; async function blobToBase64(blob: Blob): Promise { @@ -93,6 +99,14 @@ const realTimeClientConnectOptionsSchema = z.object({ }), initialState: realTimeClientInitialStateSchema.optional(), customizeOffer: createAsyncFunctionSchema(z.function()).optional(), + transport: z.enum(["webrtc", "ivs"]).optional().default("webrtc"), + latencyTracking: z + .object({ + composite: z.boolean().optional(), + pixelMarker: z.boolean().optional(), + videoElement: z.custom().optional(), + }) + .optional(), }); export type RealTimeClientConnectOptions = Omit, "model"> & { model: ModelDefinition | CustomModelDefinition; @@ -104,6 +118,8 @@ export type Events = { generationTick: { seconds: number }; diagnostic: DiagnosticEvent; stats: WebRTCStats; + compositeLatency: CompositeLatencyEstimate; + pixelLatency: PixelLatencyMeasurement; }; export type RealTimeClient = { @@ -151,7 +167,8 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { inputStream = stream ?? new MediaStream(); } - let webrtcManager: WebRTCManager | undefined; + const transport = parsedOptions.data.transport; + let transportManager: RealtimeTransportManager | undefined; let telemetryReporter: ITelemetryReporter = new NullTelemetryReporter(); let handleConnectionStateChange: ((state: ConnectionState) => void) | null = null; @@ -171,32 +188,44 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { const { emitter: eventEmitter, emitOrBuffer, flush, stop } = createEventBuffer(); - webrtcManager = new WebRTCManager({ - webrtcUrl: `${url}?api_key=${encodeURIComponent(apiKey)}&model=${encodeURIComponent(options.model.name)}`, + const sharedCallbacks = { integration, logger, - onDiagnostic: (name, data) => { + onDiagnostic: ((name: DiagnosticEvent["name"], data: DiagnosticEvent["data"]) => { emitOrBuffer("diagnostic", { name, data } as Events["diagnostic"]); addTelemetryDiagnostic(name, data); - }, + }) as DiagnosticEmitter, onRemoteStream, - onConnectionStateChange: (state) => { + onConnectionStateChange: (state: ConnectionState) => { emitOrBuffer("connectionChange", state); handleConnectionStateChange?.(state); }, - onError: (error) => { - logger.error("WebRTC error", { error: error.message }); + onError: (error: Error) => { + logger.error(`${transport} error`, { error: error.message }); emitOrBuffer("error", classifyWebrtcError(error)); }, - customizeOffer: options.customizeOffer as ((offer: RTCSessionDescriptionInit) => Promise) | undefined, - vp8MinBitrate: 300, - vp8StartBitrate: 600, modelName: options.model.name, initialImage, initialPrompt, - }); + }; - const manager = webrtcManager; + if (transport === "ivs") { + const ivsUrlPath = options.model.urlPath.replace(/\/?$/, "-ivs"); + transportManager = new IVSManager({ + ivsUrl: `${baseUrl}${ivsUrlPath}?api_key=${encodeURIComponent(apiKey)}&model=${encodeURIComponent(options.model.name)}`, + ...sharedCallbacks, + }); + } else { + transportManager = new WebRTCManager({ + webrtcUrl: `${url}?api_key=${encodeURIComponent(apiKey)}&model=${encodeURIComponent(options.model.name)}`, + ...sharedCallbacks, + customizeOffer: options.customizeOffer as ((offer: RTCSessionDescriptionInit) => Promise) | undefined, + vp8MinBitrate: 300, + vp8StartBitrate: 600, + }); + } + + const manager = transportManager; let sessionId: string | null = null; let subscribeToken: string | null = null; @@ -225,7 +254,7 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { }; const sessionIdListener = (msg: SessionIdMessage) => { - subscribeToken = encodeSubscribeToken(msg.session_id, msg.server_ip, msg.server_port); + subscribeToken = encodeSubscribeToken(msg.session_id, msg.server_ip, msg.server_port, transport); sessionId = msg.session_id; // Start telemetry reporter now that we have a session ID @@ -239,6 +268,7 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { sessionId: msg.session_id, model: options.model.name, integration, + transport, logger, }); reporter.start(); @@ -262,68 +292,121 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { const methods = realtimeMethods(manager, imageToBase64); - let statsCollector: WebRTCStatsCollector | null = null; - let statsCollectorPeerConnection: RTCPeerConnection | null = null; - // Video stall detection state (Twilio pattern: fps < 0.5 = stalled) const STALL_FPS_THRESHOLD = 0.5; let videoStalled = false; let stallStartMs = 0; - const startStatsCollection = (): (() => void) => { - statsCollector?.stop(); - videoStalled = false; - stallStartMs = 0; - statsCollector = new WebRTCStatsCollector(); - const pc = manager.getPeerConnection(); - statsCollectorPeerConnection = pc; - if (pc) { - statsCollector.start(pc, (stats) => { - emitOrBuffer("stats", stats); - telemetryReporter.addStats(stats); - - // Stall detection: check if video fps dropped below threshold - const fps = stats.video?.framesPerSecond ?? 0; - if (!videoStalled && stats.video && fps < STALL_FPS_THRESHOLD) { - videoStalled = true; - stallStartMs = Date.now(); - emitOrBuffer("diagnostic", { name: "videoStall", data: { stalled: true, durationMs: 0 } }); - addTelemetryDiagnostic("videoStall", { stalled: true, durationMs: 0 }, stallStartMs); - } else if (videoStalled && fps >= STALL_FPS_THRESHOLD) { - const durationMs = Date.now() - stallStartMs; - videoStalled = false; - emitOrBuffer("diagnostic", { name: "videoStall", data: { stalled: false, durationMs } }); - addTelemetryDiagnostic("videoStall", { stalled: false, durationMs }); - } - }); + const handleStats = (stats: WebRTCStats): void => { + emitOrBuffer("stats", stats); + telemetryReporter.addStats(stats); + + // Stall detection: check if video fps dropped below threshold + const fps = stats.video?.framesPerSecond ?? 0; + if (!videoStalled && stats.video && fps < STALL_FPS_THRESHOLD) { + videoStalled = true; + stallStartMs = Date.now(); + emitOrBuffer("diagnostic", { name: "videoStall", data: { stalled: true, durationMs: 0 } }); + addTelemetryDiagnostic("videoStall", { stalled: true, durationMs: 0 }, stallStartMs); + } else if (videoStalled && fps >= STALL_FPS_THRESHOLD) { + const durationMs = Date.now() - stallStartMs; + videoStalled = false; + emitOrBuffer("diagnostic", { name: "videoStall", data: { stalled: false, durationMs } }); + addTelemetryDiagnostic("videoStall", { stalled: false, durationMs }); } - return () => { + }; + + let statsCollector: WebRTCStatsCollector | IVSStatsCollector | null = null; + let statsCollectorPeerConnection: RTCPeerConnection | null = null; + + if (transport === "webrtc" && manager instanceof WebRTCManager) { + const webrtcManager = manager; + + const startStatsCollection = (): (() => void) => { statsCollector?.stop(); - statsCollector = null; - statsCollectorPeerConnection = null; + videoStalled = false; + stallStartMs = 0; + const collector = new WebRTCStatsCollector(); + statsCollector = collector; + const pc = webrtcManager.getPeerConnection(); + statsCollectorPeerConnection = pc; + if (pc) { + collector.start(pc, handleStats); + } + return () => { + collector.stop(); + statsCollector = null; + statsCollectorPeerConnection = null; + }; }; - }; - handleConnectionStateChange = (state) => { - if (!opts.telemetryEnabled) { - return; - } + handleConnectionStateChange = (state) => { + if (!opts.telemetryEnabled) { + return; + } - if (state !== "connected" && state !== "generating") { - return; - } + if (state !== "connected" && state !== "generating") { + return; + } - const peerConnection = manager.getPeerConnection(); - if (!peerConnection || peerConnection === statsCollectorPeerConnection) { - return; + const peerConnection = webrtcManager.getPeerConnection(); + if (!peerConnection || peerConnection === statsCollectorPeerConnection) { + return; + } + + startStatsCollection(); + }; + + // Auto-start stats when telemetry is enabled + if (opts.telemetryEnabled) { + startStatsCollection(); } + } else if (transport === "ivs" && manager instanceof IVSManager) { + const ivsManager = manager; - startStatsCollection(); - }; + const startIVSStatsCollection = (): void => { + statsCollector?.stop(); + videoStalled = false; + stallStartMs = 0; + const collector = new IVSStatsCollector(); + statsCollector = collector; + collector.start(ivsManager, handleStats); + }; + + handleConnectionStateChange = (state) => { + if (!opts.telemetryEnabled) { + return; + } - // Auto-start stats when telemetry is enabled - if (opts.telemetryEnabled) { - startStatsCollection(); + if (state !== "connected" && state !== "generating") { + return; + } + + // Only start once — IVS doesn't have PC reconnection like WebRTC + if (!statsCollector?.isRunning()) { + startIVSStatsCollection(); + } + }; + + // Auto-start stats when telemetry is enabled + if (opts.telemetryEnabled) { + startIVSStatsCollection(); + } + } + + // Latency diagnostics (composite + pixel marker) + let latencyStartTimer: ReturnType | undefined; + let latencyDiag: LatencyDiagnostics | null = null; + if (parsedOptions.data.latencyTracking) { + latencyDiag = new LatencyDiagnostics({ + ...parsedOptions.data.latencyTracking, + sendMessage: (msg) => manager.sendMessage(msg), + onCompositeLatency: (est) => emitOrBuffer("compositeLatency", est), + onPixelLatency: (m) => emitOrBuffer("pixelLatency", m), + }); + manager.getWebsocketMessageEmitter().on("latencyReport", (msg) => latencyDiag!.onServerReport(msg)); + eventEmitter.on("stats", (stats) => latencyDiag!.onStats(stats)); + latencyStartTimer = setTimeout(() => latencyDiag?.start(), 1000); } const client: RealTimeClient = { @@ -332,6 +415,8 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { isConnected: () => manager.isConnected(), getConnectionState: () => manager.getConnectionState(), disconnect: () => { + clearTimeout(latencyStartTimer); + latencyDiag?.stop(); statsCollector?.stop(); telemetryReporter.stop(); stop(); @@ -368,14 +453,18 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { return client; } catch (error) { telemetryReporter.stop(); - webrtcManager?.cleanup(); + transportManager?.cleanup(); audioStreamManager?.cleanup(); throw error; } }; - const subscribe = async (options: SubscribeOptions): Promise => { - const { sid, ip, port } = decodeSubscribeToken(options.token); + const subscribeWebRTC = async ( + options: SubscribeOptions, + sid: string, + ip: string, + port: number, + ): Promise => { const subscribeUrl = `${baseUrl}/subscribe/${encodeURIComponent(sid)}?IP=${encodeURIComponent(ip)}&port=${encodeURIComponent(port)}&api_key=${encodeURIComponent(apiKey)}`; const { emitter: eventEmitter, emitOrBuffer, flush, stop } = createEventBuffer(); @@ -422,6 +511,105 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { } }; + const subscribeIVS = async (options: SubscribeOptions, sid: string): Promise => { + const { getIVSBroadcastClient } = await import("./ivs-connection"); + const ivs = await getIVSBroadcastClient(); + + const { emitter: eventEmitter, emitOrBuffer, flush, stop } = createEventBuffer(); + + // Fetch viewer token from bouncer (convert wss:// → https:// for HTTP call) + const httpBaseUrl = baseUrl.replace(/^wss:\/\//, "https://").replace(/^ws:\/\//, "http://"); + const resp = await fetch(`${httpBaseUrl}/v1/subscribe-ivs/${encodeURIComponent(sid)}`, { + headers: { "x-api-key": apiKey }, + }); + if (!resp.ok) { + throw new Error(`Failed to get IVS viewer token: ${resp.status}`); + } + const { subscribe_token, server_publish_participant_id } = (await resp.json()) as { + subscribe_token: string; + server_publish_participant_id: string; + }; + + let connected = false; + let connectionState: ConnectionState = "connecting"; + emitOrBuffer("connectionChange", connectionState); + + // Create subscribe-only IVS stage — filter to server's output stream only + const subscribeStrategy = { + stageStreamsToPublish: () => [] as never[], + shouldPublishParticipant: () => false, + shouldSubscribeToParticipant: (participant: { id: string }) => { + if (server_publish_participant_id && participant.id !== server_publish_participant_id) { + return ivs.SubscribeType.NONE; + } + return ivs.SubscribeType.AUDIO_VIDEO; + }, + }; + + const stage = new ivs.Stage(subscribe_token, subscribeStrategy); + + await new Promise((resolve, reject) => { + const timer = setTimeout(() => reject(new Error("IVS viewer subscribe timeout")), 30_000); + + stage.on(ivs.StageEvents.STAGE_PARTICIPANT_STREAMS_ADDED, (...args: unknown[]) => { + const participant = args[0] as { isLocal: boolean }; + const streams = args[1] as { mediaStreamTrack: MediaStreamTrack }[]; + if (participant.isLocal) return; + + clearTimeout(timer); + const remoteStream = new MediaStream(); + for (const s of streams) { + remoteStream.addTrack(s.mediaStreamTrack); + } + options.onRemoteStream(remoteStream); + connected = true; + connectionState = "connected"; + emitOrBuffer("connectionChange", connectionState); + resolve(); + }); + + stage.on(ivs.StageEvents.STAGE_CONNECTION_STATE_CHANGED, (...args: unknown[]) => { + const state = args[0] as string; + if (state === ivs.ConnectionState.DISCONNECTED.toString()) { + clearTimeout(timer); + connected = false; + connectionState = "disconnected"; + emitOrBuffer("connectionChange", connectionState); + } + }); + + stage.join().catch((err) => { + clearTimeout(timer); + reject(err); + }); + }); + + const client: RealTimeSubscribeClient = { + isConnected: () => connected, + getConnectionState: () => connectionState, + disconnect: () => { + stop(); + stage.leave(); + connected = false; + connectionState = "disconnected"; + }, + on: eventEmitter.on, + off: eventEmitter.off, + }; + + flush(); + return client; + }; + + const subscribe = async (options: SubscribeOptions): Promise => { + const { sid, ip, port, transport } = decodeSubscribeToken(options.token); + + if (transport === "ivs") { + return subscribeIVS(options, sid); + } + return subscribeWebRTC(options, sid, ip, port); + }; + return { connect, subscribe, diff --git a/packages/sdk/src/realtime/composite-latency.ts b/packages/sdk/src/realtime/composite-latency.ts new file mode 100644 index 0000000..213bda6 --- /dev/null +++ b/packages/sdk/src/realtime/composite-latency.ts @@ -0,0 +1,43 @@ +import type { LatencyReportMessage } from "./types"; + +export type CompositeLatencyEstimate = { + clientProxyRttMs: number; + serverProxyRttMs: number; + pipelineLatencyMs: number; + compositeE2eMs: number; +}; + +export class CompositeLatencyTracker { + private latestServerReport: { + serverProxyRttMs: number; + pipelineLatencyMs: number; + } | null = null; + + onServerReport(msg: LatencyReportMessage): void { + this.latestServerReport = { + serverProxyRttMs: msg.server_proxy_rtt_ms, + pipelineLatencyMs: msg.pipeline_latency_ms, + }; + } + + /** + * Compute composite E2E estimate. + * @param clientRttSeconds - client RTT in seconds from WebRTC stats, or null if unavailable (IVS) + */ + getEstimate(clientRttSeconds: number | null): CompositeLatencyEstimate | null { + if (!this.latestServerReport) return null; + + const { serverProxyRttMs, pipelineLatencyMs } = this.latestServerReport; + // Client RTT may be unavailable for IVS transport (no candidate-pair stats). + // In that case, report lower-bound estimate with clientProxyRttMs = 0. + const clientProxyRttMs = clientRttSeconds != null ? clientRttSeconds * 1000 : 0; + const compositeE2eMs = clientProxyRttMs + serverProxyRttMs + pipelineLatencyMs; + + return { + clientProxyRttMs, + serverProxyRttMs, + pipelineLatencyMs, + compositeE2eMs, + }; + } +} diff --git a/packages/sdk/src/realtime/diagnostics.ts b/packages/sdk/src/realtime/diagnostics.ts index 69059d9..60f9a23 100644 --- a/packages/sdk/src/realtime/diagnostics.ts +++ b/packages/sdk/src/realtime/diagnostics.ts @@ -1,5 +1,5 @@ /** Connection phase names for timing events. */ -export type ConnectionPhase = "websocket" | "avatar-image" | "initial-prompt" | "webrtc-handshake" | "total"; +export type ConnectionPhase = "websocket" | "avatar-image" | "initial-prompt" | "webrtc-handshake" | "ivs-stage-setup" | "total"; export type PhaseTimingEvent = { phase: ConnectionPhase; diff --git a/packages/sdk/src/realtime/ivs-connection.ts b/packages/sdk/src/realtime/ivs-connection.ts new file mode 100644 index 0000000..7cdb7d6 --- /dev/null +++ b/packages/sdk/src/realtime/ivs-connection.ts @@ -0,0 +1,540 @@ +import mitt from "mitt"; + +import type { Logger } from "../utils/logger"; +import { buildUserAgent } from "../utils/user-agent"; +import type { DiagnosticEmitter } from "./diagnostics"; +import type { + ConnectionState, + IncomingIVSMessage, + OutgoingIVSMessage, + PromptAckMessage, + SetImageAckMessage, + WsMessageEvents, +} from "./types"; + +// ── IVS SDK type declarations ───────────────────────────────────────── +// Minimal type surface for @aws/ivs-web-broadcast so the SDK compiles +// even when the package is not installed. + +interface IVSStageStrategy { + stageStreamsToPublish(): IVSLocalStageStream[]; + shouldPublishParticipant(participant: IVSStageParticipant): boolean; + shouldSubscribeToParticipant(participant: IVSStageParticipant): IVSSubscribeType; +} + +interface IVSStage { + join(): Promise; + leave(): void; + on(event: string, handler: (...args: unknown[]) => void): void; +} + +interface IVSStageParticipant { + id: string; + isLocal: boolean; +} + +export interface IVSStageStream { + mediaStreamTrack: MediaStreamTrack; + requestRTCStats?(): Promise; +} + +export interface IVSLocalStageStream { + requestRTCStats?(): Promise; +} + +declare enum IVSSubscribeType { + NONE = "NONE", + AUDIO_VIDEO = "AUDIO_VIDEO", +} + +declare enum IVSStreamType { + VIDEO = "VIDEO", + AUDIO = "AUDIO", +} + +declare enum IVSStageEvents { + STAGE_CONNECTION_STATE_CHANGED = "STAGE_CONNECTION_STATE_CHANGED", + STAGE_PARTICIPANT_STREAMS_ADDED = "STAGE_PARTICIPANT_STREAMS_ADDED", +} + +declare enum IVSConnectionState { + CONNECTED = "CONNECTED", + DISCONNECTED = "DISCONNECTED", +} + +export interface IVSBroadcastModule { + Stage: new (token: string, strategy: IVSStageStrategy) => IVSStage; + LocalStageStream: new (track: MediaStreamTrack) => IVSLocalStageStream; + SubscribeType: typeof IVSSubscribeType; + StreamType: typeof IVSStreamType; + StageEvents: typeof IVSStageEvents; + ConnectionState: typeof IVSConnectionState; +} + +// ── Dynamic loader ──────────────────────────────────────────────────── + +export async function getIVSBroadcastClient(): Promise { + try { + const moduleName = "@aws/ivs-web-broadcast"; + // biome-ignore lint/suspicious/noExplicitAny: dynamic import of optional dependency + const mod = await (Function(`return import("${moduleName}")`)() as Promise); + return mod.default ?? mod; + } catch { + if (typeof globalThis !== "undefined" && "IVSBroadcastClient" in globalThis) { + // biome-ignore lint/suspicious/noExplicitAny: global fallback + return (globalThis as any).IVSBroadcastClient as IVSBroadcastModule; + } + throw new Error("@aws/ivs-web-broadcast not found. Install via npm or load via script tag."); + } +} + +// ── Types ───────────────────────────────────────────────────────────── + +const SETUP_TIMEOUT_MS = 30_000; + +interface IVSConnectionCallbacks { + onRemoteStream?: (stream: MediaStream) => void; + onStateChange?: (state: ConnectionState) => void; + onError?: (error: Error) => void; + modelName?: string; + initialImage?: string; + initialPrompt?: { text: string; enhance?: boolean }; + logger?: Logger; + onDiagnostic?: DiagnosticEmitter; +} + +const noopDiagnostic: DiagnosticEmitter = () => {}; + +// ── Connection ──────────────────────────────────────────────────────── + +export class IVSConnection { + private ws: WebSocket | null = null; + private publishStage: IVSStage | null = null; + private subscribeStage: IVSStage | null = null; + private connectionReject: ((error: Error) => void) | null = null; + private remoteStageStreams: IVSStageStream[] = []; + private localStageStreams: IVSLocalStageStream[] = []; + private logger: Logger; + private emitDiagnostic: DiagnosticEmitter; + state: ConnectionState = "disconnected"; + websocketMessagesEmitter = mitt(); + + constructor(private callbacks: IVSConnectionCallbacks = {}) { + this.logger = callbacks.logger ?? { debug() {}, info() {}, warn() {}, error() {} }; + this.emitDiagnostic = callbacks.onDiagnostic ?? noopDiagnostic; + } + + async connect(url: string, localStream: MediaStream | null, timeout: number, integration?: string): Promise { + // Phase 1: WebSocket + const userAgent = encodeURIComponent(buildUserAgent(integration)); + const separator = url.includes("?") ? "&" : "?"; + const wsUrl = `${url}${separator}user_agent=${userAgent}`; + + let rejectConnect!: (error: Error) => void; + const connectAbort = new Promise((_, reject) => { + rejectConnect = reject; + }); + connectAbort.catch(() => {}); + this.connectionReject = (error) => rejectConnect(error); + + const totalStart = performance.now(); + try { + const wsStart = performance.now(); + await Promise.race([ + new Promise((resolve, reject) => { + const timer = setTimeout(() => reject(new Error("WebSocket timeout")), timeout); + this.ws = new WebSocket(wsUrl); + + this.ws.onopen = () => { + clearTimeout(timer); + this.emitDiagnostic("phaseTiming", { + phase: "websocket", + durationMs: performance.now() - wsStart, + success: true, + }); + resolve(); + }; + this.ws.onmessage = (e) => { + try { + this.handleMessage(JSON.parse(e.data)); + } catch (err) { + this.logger.error("Message parse error", { error: String(err) }); + } + }; + this.ws.onerror = () => { + clearTimeout(timer); + const error = new Error("WebSocket error"); + this.emitDiagnostic("phaseTiming", { + phase: "websocket", + durationMs: performance.now() - wsStart, + success: false, + error: error.message, + }); + reject(error); + rejectConnect(error); + }; + this.ws.onclose = () => { + this.setState("disconnected"); + clearTimeout(timer); + reject(new Error("WebSocket closed before connection was established")); + rejectConnect(new Error("WebSocket closed")); + }; + }), + connectAbort, + ]); + + this.setState("connecting"); + + // Phase 2: IVS Stage setup — must complete before sending any messages. + // The bouncer creates the stage, sends ivs_stage_ready, then waits for + // ivs_joined before starting its message pump. Any set_image/prompt sent + // before ivs_joined would be consumed by the bouncer's join-wait loop + // and rejected as unexpected. + const stageStart = performance.now(); + await Promise.race([this.setupIVSStages(localStream, timeout), connectAbort]); + this.emitDiagnostic("phaseTiming", { + phase: "ivs-stage-setup", + durationMs: performance.now() - stageStart, + success: true, + }); + + // Phase 3: Post-handshake initial state (image/prompt) + // Now the bouncer's message pump is running and can handle these. + if (this.callbacks.initialImage) { + const imageStart = performance.now(); + await Promise.race([ + this.setImageBase64(this.callbacks.initialImage, { + prompt: this.callbacks.initialPrompt?.text, + enhance: this.callbacks.initialPrompt?.enhance, + }), + connectAbort, + ]); + this.emitDiagnostic("phaseTiming", { + phase: "avatar-image", + durationMs: performance.now() - imageStart, + success: true, + }); + } else if (this.callbacks.initialPrompt) { + const promptStart = performance.now(); + await Promise.race([this.sendInitialPrompt(this.callbacks.initialPrompt), connectAbort]); + this.emitDiagnostic("phaseTiming", { + phase: "initial-prompt", + durationMs: performance.now() - promptStart, + success: true, + }); + } else if (localStream) { + const nullStart = performance.now(); + await Promise.race([this.setImageBase64(null, { prompt: null }), connectAbort]); + this.emitDiagnostic("phaseTiming", { + phase: "initial-prompt", + durationMs: performance.now() - nullStart, + success: true, + }); + } + + this.emitDiagnostic("phaseTiming", { + phase: "total", + durationMs: performance.now() - totalStart, + success: true, + }); + } finally { + this.connectionReject = null; + } + } + + private async setupIVSStages(localStream: MediaStream | null, timeout: number): Promise { + const ivs = await getIVSBroadcastClient(); + + // Wait for bouncer to send ivs_stage_ready + const stageReady = await new Promise<{ + client_publish_token: string; + client_subscribe_token: string; + client_publish_participant_id: string; + }>((resolve, reject) => { + const timer = setTimeout(() => reject(new Error("IVS stage ready timeout")), timeout); + + const handler = (e: MessageEvent) => { + try { + const msg = JSON.parse(e.data); + if (msg.type === "ivs_stage_ready") { + clearTimeout(timer); + if (this.ws) { + this.ws.removeEventListener("message", handler); + } + resolve({ + client_publish_token: msg.client_publish_token, + client_subscribe_token: msg.client_subscribe_token, + client_publish_participant_id: msg.client_publish_participant_id ?? "", + }); + } + } catch { + // ignore parse errors, handled by main onmessage + } + }; + + this.ws?.addEventListener("message", handler); + }); + + // Subscribe stage — receive remote video/audio + const remoteStreamPromise = new Promise((resolve, reject) => { + const timer = setTimeout(() => reject(new Error("IVS subscribe stream timeout")), timeout); + + const clientPubId = stageReady.client_publish_participant_id; + const subscribeStrategy: IVSStageStrategy = { + stageStreamsToPublish: () => [], + shouldPublishParticipant: () => false, + shouldSubscribeToParticipant: (participant: IVSStageParticipant) => { + // Skip our own camera feed — only subscribe to server's processed output + if (clientPubId && participant.id === clientPubId) { + return ivs.SubscribeType.NONE; + } + return ivs.SubscribeType.AUDIO_VIDEO; + }, + }; + + this.subscribeStage = new ivs.Stage(stageReady.client_subscribe_token, subscribeStrategy); + + this.subscribeStage.on(ivs.StageEvents.STAGE_PARTICIPANT_STREAMS_ADDED, (...args: unknown[]) => { + const participant = args[0] as IVSStageParticipant; + const streams = args[1] as IVSStageStream[]; + if (participant.isLocal) return; + if (clientPubId && participant.id === clientPubId) return; + + clearTimeout(timer); + this.remoteStageStreams = streams; + const remoteStream = new MediaStream(); + for (const s of streams) { + remoteStream.addTrack(s.mediaStreamTrack); + } + this.callbacks.onRemoteStream?.(remoteStream); + resolve(); + }); + + this.subscribeStage.on(ivs.StageEvents.STAGE_CONNECTION_STATE_CHANGED, (...args: unknown[]) => { + const state = args[0] as string; + if (state === ivs.ConnectionState.DISCONNECTED.toString()) { + clearTimeout(timer); + reject(new Error("IVS subscribe stage disconnected during setup")); + this.setState("disconnected"); + } + }); + + this.subscribeStage.join().catch((err) => { + clearTimeout(timer); + reject(err); + }); + }); + + // Publish stage — send local camera + audio tracks + if (localStream) { + const localStageStreams: IVSLocalStageStream[] = []; + + const videoTrack = localStream.getVideoTracks()[0]; + if (videoTrack) { + localStageStreams.push(new ivs.LocalStageStream(videoTrack)); + } + const audioTrack = localStream.getAudioTracks()[0]; + if (audioTrack) { + localStageStreams.push(new ivs.LocalStageStream(audioTrack)); + } + this.localStageStreams = localStageStreams; + + const publishStrategy: IVSStageStrategy = { + stageStreamsToPublish: () => localStageStreams, + shouldPublishParticipant: () => true, + shouldSubscribeToParticipant: () => ivs.SubscribeType.NONE, + }; + + this.publishStage = new ivs.Stage(stageReady.client_publish_token, publishStrategy); + + this.publishStage.on(ivs.StageEvents.STAGE_CONNECTION_STATE_CHANGED, (...args: unknown[]) => { + const state = args[0] as string; + if (state === ivs.ConnectionState.CONNECTED.toString()) { + // Notify bouncer that we've joined the publish stage + this.send({ type: "ivs_joined" }); + this.setState("connected"); + } else if (state === ivs.ConnectionState.DISCONNECTED.toString()) { + this.setState("disconnected"); + } + }); + + await this.publishStage.join(); + } + + // Wait for remote stream from subscribe stage + await remoteStreamPromise; + } + + private handleMessage(msg: IncomingIVSMessage): void { + try { + if (msg.type === "error") { + const error = new Error(msg.error) as Error & { source?: string }; + error.source = "server"; + this.callbacks.onError?.(error); + if (this.connectionReject) { + this.connectionReject(error); + this.connectionReject = null; + } + return; + } + + if (msg.type === "set_image_ack") { + this.websocketMessagesEmitter.emit("setImageAck", msg); + return; + } + + if (msg.type === "prompt_ack") { + this.websocketMessagesEmitter.emit("promptAck", msg); + return; + } + + if (msg.type === "generation_started") { + this.setState("generating"); + return; + } + + if (msg.type === "generation_tick") { + this.websocketMessagesEmitter.emit("generationTick", msg); + return; + } + + if (msg.type === "generation_ended") { + return; + } + + if (msg.type === "session_id") { + this.websocketMessagesEmitter.emit("sessionId", msg); + return; + } + + if (msg.type === "latency_report") { + this.websocketMessagesEmitter.emit("latencyReport", msg); + return; + } + + // ivs_stage_ready is handled separately in setupIVSStages via addEventListener + } catch (error) { + this.logger.error("Message handler error", { error: String(error) }); + this.callbacks.onError?.(error as Error); + this.connectionReject?.(error as Error); + } + } + + send(message: OutgoingIVSMessage): boolean { + if (this.ws?.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify(message)); + return true; + } + this.logger.warn("Message dropped: WebSocket is not open"); + return false; + } + + async setImageBase64( + imageBase64: string | null, + options?: { prompt?: string | null; enhance?: boolean; timeout?: number }, + ): Promise { + return new Promise((resolve, reject) => { + const timeoutId = setTimeout(() => { + this.websocketMessagesEmitter.off("setImageAck", listener); + reject(new Error("Image send timed out")); + }, options?.timeout ?? SETUP_TIMEOUT_MS); + + const listener = (msg: SetImageAckMessage) => { + clearTimeout(timeoutId); + this.websocketMessagesEmitter.off("setImageAck", listener); + if (msg.success) { + resolve(); + } else { + reject(new Error(msg.error ?? "Failed to send image")); + } + }; + + this.websocketMessagesEmitter.on("setImageAck", listener); + + const message: { + type: "set_image"; + image_data: string | null; + prompt?: string | null; + enhance_prompt?: boolean; + } = { + type: "set_image", + image_data: imageBase64, + }; + + if (options?.prompt !== undefined) { + message.prompt = options.prompt; + } + if (options?.enhance !== undefined) { + message.enhance_prompt = options.enhance; + } + + if (!this.send(message)) { + clearTimeout(timeoutId); + this.websocketMessagesEmitter.off("setImageAck", listener); + reject(new Error("WebSocket is not open")); + } + }); + } + + private async sendInitialPrompt(prompt: { text: string; enhance?: boolean }): Promise { + return new Promise((resolve, reject) => { + const timeoutId = setTimeout(() => { + this.websocketMessagesEmitter.off("promptAck", listener); + reject(new Error("Prompt send timed out")); + }, SETUP_TIMEOUT_MS); + + const listener = (msg: PromptAckMessage) => { + if (msg.prompt === prompt.text) { + clearTimeout(timeoutId); + this.websocketMessagesEmitter.off("promptAck", listener); + if (msg.success) { + resolve(); + } else { + reject(new Error(msg.error ?? "Failed to send prompt")); + } + } + }; + + this.websocketMessagesEmitter.on("promptAck", listener); + + if ( + !this.send({ + type: "prompt", + prompt: prompt.text, + enhance_prompt: prompt.enhance ?? true, + }) + ) { + clearTimeout(timeoutId); + this.websocketMessagesEmitter.off("promptAck", listener); + reject(new Error("WebSocket is not open")); + } + }); + } + + private setState(state: ConnectionState): void { + if (this.state !== state) { + this.state = state; + this.callbacks.onStateChange?.(state); + } + } + + getRemoteStreams(): IVSStageStream[] { + return this.remoteStageStreams; + } + + getLocalStreams(): IVSLocalStageStream[] { + return this.localStageStreams; + } + + cleanup(): void { + this.publishStage?.leave(); + this.publishStage = null; + this.subscribeStage?.leave(); + this.subscribeStage = null; + this.ws?.close(); + this.ws = null; + this.remoteStageStreams = []; + this.localStageStreams = []; + this.setState("disconnected"); + } +} diff --git a/packages/sdk/src/realtime/ivs-manager.ts b/packages/sdk/src/realtime/ivs-manager.ts new file mode 100644 index 0000000..fc325a5 --- /dev/null +++ b/packages/sdk/src/realtime/ivs-manager.ts @@ -0,0 +1,246 @@ +import pRetry, { AbortError } from "p-retry"; + +import type { Logger } from "../utils/logger"; +import type { DiagnosticEmitter } from "./diagnostics"; +import { IVSConnection } from "./ivs-connection"; +import type { RealtimeTransportManager } from "./transport-manager"; +import type { ConnectionState, OutgoingMessage } from "./types"; + +export interface IVSConfig { + ivsUrl: string; + integration?: string; + logger?: Logger; + onDiagnostic?: DiagnosticEmitter; + onRemoteStream: (stream: MediaStream) => void; + onConnectionStateChange?: (state: ConnectionState) => void; + onError?: (error: Error) => void; + modelName?: string; + initialImage?: string; + initialPrompt?: { text: string; enhance?: boolean }; +} + +const PERMANENT_ERRORS = [ + "permission denied", + "not allowed", + "invalid session", + "401", + "invalid api key", + "unauthorized", +]; + +const CONNECTION_TIMEOUT = 60_000 * 5; // 5 minutes + +const RETRY_OPTIONS = { + retries: 5, + factor: 2, + minTimeout: 1000, + maxTimeout: 10000, +} as const; + +export class IVSManager implements RealtimeTransportManager { + private connection: IVSConnection; + private config: IVSConfig; + private logger: Logger; + private localStream: MediaStream | null = null; + private managerState: ConnectionState = "disconnected"; + private hasConnected = false; + private isReconnecting = false; + private intentionalDisconnect = false; + private reconnectGeneration = 0; + + constructor(config: IVSConfig) { + this.config = config; + this.logger = config.logger ?? { debug() {}, info() {}, warn() {}, error() {} }; + this.connection = new IVSConnection({ + onRemoteStream: config.onRemoteStream, + onStateChange: (state) => this.handleConnectionStateChange(state), + onError: config.onError, + modelName: config.modelName, + initialImage: config.initialImage, + initialPrompt: config.initialPrompt, + logger: this.logger, + onDiagnostic: config.onDiagnostic, + }); + } + + private emitState(state: ConnectionState): void { + if (this.managerState !== state) { + this.managerState = state; + if (state === "connected" || state === "generating") this.hasConnected = true; + this.config.onConnectionStateChange?.(state); + } + } + + private handleConnectionStateChange(state: ConnectionState): void { + if (this.intentionalDisconnect) { + this.emitState("disconnected"); + return; + } + + if (this.isReconnecting) { + if (state === "connected" || state === "generating") { + this.isReconnecting = false; + this.emitState(state); + } + return; + } + + if (state === "disconnected" && !this.intentionalDisconnect && this.hasConnected) { + this.reconnect(); + return; + } + + this.emitState(state); + } + + private async reconnect(): Promise { + if (this.isReconnecting || this.intentionalDisconnect) return; + if (!this.localStream) return; + + const reconnectGeneration = ++this.reconnectGeneration; + this.isReconnecting = true; + this.emitState("reconnecting"); + const reconnectStart = performance.now(); + + try { + let attemptCount = 0; + + await pRetry( + async () => { + attemptCount++; + + if (this.intentionalDisconnect || reconnectGeneration !== this.reconnectGeneration) { + throw new AbortError("Reconnect cancelled"); + } + + if (!this.localStream) { + throw new AbortError("Reconnect cancelled: no local stream"); + } + + this.connection.cleanup(); + await this.connection.connect( + this.config.ivsUrl, + this.localStream, + CONNECTION_TIMEOUT, + this.config.integration, + ); + + if (this.intentionalDisconnect || reconnectGeneration !== this.reconnectGeneration) { + this.connection.cleanup(); + throw new AbortError("Reconnect cancelled"); + } + }, + { + ...RETRY_OPTIONS, + onFailedAttempt: (error) => { + if (this.intentionalDisconnect || reconnectGeneration !== this.reconnectGeneration) { + return; + } + this.logger.warn("IVS reconnect attempt failed", { error: error.message, attempt: error.attemptNumber }); + this.config.onDiagnostic?.("reconnect", { + attempt: error.attemptNumber, + maxAttempts: RETRY_OPTIONS.retries + 1, + durationMs: performance.now() - reconnectStart, + success: false, + error: error.message, + }); + this.connection.cleanup(); + }, + shouldRetry: (error) => { + if (this.intentionalDisconnect || reconnectGeneration !== this.reconnectGeneration) { + return false; + } + const msg = error.message.toLowerCase(); + return !PERMANENT_ERRORS.some((err) => msg.includes(err)); + }, + }, + ); + this.config.onDiagnostic?.("reconnect", { + attempt: attemptCount, + maxAttempts: RETRY_OPTIONS.retries + 1, + durationMs: performance.now() - reconnectStart, + success: true, + }); + } catch (error) { + this.isReconnecting = false; + if (this.intentionalDisconnect || reconnectGeneration !== this.reconnectGeneration) { + return; + } + this.emitState("disconnected"); + this.config.onError?.(error instanceof Error ? error : new Error(String(error))); + } + } + + async connect(localStream: MediaStream | null): Promise { + this.localStream = localStream; + this.intentionalDisconnect = false; + this.hasConnected = false; + this.isReconnecting = false; + this.reconnectGeneration += 1; + this.emitState("connecting"); + + return pRetry( + async () => { + if (this.intentionalDisconnect) { + throw new AbortError("Connect cancelled"); + } + await this.connection.connect(this.config.ivsUrl, localStream, CONNECTION_TIMEOUT, this.config.integration); + return true; + }, + { + ...RETRY_OPTIONS, + onFailedAttempt: (error) => { + this.logger.warn("IVS connection attempt failed", { error: error.message, attempt: error.attemptNumber }); + this.connection.cleanup(); + }, + shouldRetry: (error) => { + if (this.intentionalDisconnect) { + return false; + } + const msg = error.message.toLowerCase(); + return !PERMANENT_ERRORS.some((err) => msg.includes(err)); + }, + }, + ); + } + + sendMessage(message: OutgoingMessage): boolean { + return this.connection.send(message); + } + + cleanup(): void { + this.intentionalDisconnect = true; + this.isReconnecting = false; + this.reconnectGeneration += 1; + this.connection.cleanup(); + this.localStream = null; + this.emitState("disconnected"); + } + + isConnected(): boolean { + return this.managerState === "connected" || this.managerState === "generating"; + } + + getConnectionState(): ConnectionState { + return this.managerState; + } + + getWebsocketMessageEmitter() { + return this.connection.websocketMessagesEmitter; + } + + getRemoteStreams() { + return this.connection.getRemoteStreams(); + } + + getLocalStreams() { + return this.connection.getLocalStreams(); + } + + setImage( + imageBase64: string | null, + options?: { prompt?: string; enhance?: boolean; timeout?: number }, + ): Promise { + return this.connection.setImageBase64(imageBase64, options); + } +} diff --git a/packages/sdk/src/realtime/ivs-stats-collector.ts b/packages/sdk/src/realtime/ivs-stats-collector.ts new file mode 100644 index 0000000..2f218e1 --- /dev/null +++ b/packages/sdk/src/realtime/ivs-stats-collector.ts @@ -0,0 +1,93 @@ +import { type WebRTCStats, StatsParser, type StatsOptions } from "./webrtc-stats"; + +const DEFAULT_INTERVAL_MS = 1000; +const MIN_INTERVAL_MS = 500; + +// Minimal interface for IVS streams that support requestRTCStats +interface StatsCapableStream { + requestRTCStats?(): Promise; +} + +export interface IVSStatsSource { + getRemoteStreams(): StatsCapableStream[]; + getLocalStreams(): StatsCapableStream[]; +} + +export class IVSStatsCollector { + private parser = new StatsParser(); + private intervalId: ReturnType | null = null; + private source: IVSStatsSource | null = null; + private onStats: ((stats: WebRTCStats) => void) | null = null; + private intervalMs: number; + + constructor(options: StatsOptions = {}) { + this.intervalMs = Math.max(options.intervalMs ?? DEFAULT_INTERVAL_MS, MIN_INTERVAL_MS); + } + + start(source: IVSStatsSource, onStats: (stats: WebRTCStats) => void): void { + this.stop(); + this.source = source; + this.onStats = onStats; + this.parser.reset(); + this.intervalId = setInterval(() => this.collect(), this.intervalMs); + } + + stop(): void { + if (this.intervalId !== null) { + clearInterval(this.intervalId); + this.intervalId = null; + } + this.source = null; + this.onStats = null; + } + + isRunning(): boolean { + return this.intervalId !== null; + } + + private async collect(): Promise { + if (!this.source || !this.onStats) return; + + try { + // Get RTCStatsReport from remote streams (inbound video/audio) + const remoteStreams = this.source.getRemoteStreams(); + // Get from local streams (outbound video) if available + const localStreams = this.source.getLocalStreams(); + + // Collect all stats reports + const reports: RTCStatsReport[] = []; + + for (const stream of remoteStreams) { + if (stream.requestRTCStats) { + const report = await stream.requestRTCStats(); + if (report) reports.push(report); + } + } + for (const stream of localStreams) { + if (stream.requestRTCStats) { + const report = await stream.requestRTCStats(); + if (report) reports.push(report); + } + } + + if (reports.length === 0) return; + + // Merge all reports into a single Map-like structure that StatsParser can consume + // RTCStatsReport is a Map, so we can merge them + const merged = new Map(); + for (const report of reports) { + report.forEach((value, key) => { + merged.set(key, value); + }); + } + + // StatsParser.parse() expects RTCStatsReport which has a forEach method + // Our merged Map satisfies this interface + const stats = this.parser.parse(merged as unknown as RTCStatsReport); + this.onStats(stats); + } catch { + // Stream might be closed; stop silently + this.stop(); + } + } +} diff --git a/packages/sdk/src/realtime/latency-diagnostics.ts b/packages/sdk/src/realtime/latency-diagnostics.ts new file mode 100644 index 0000000..bdb17f8 --- /dev/null +++ b/packages/sdk/src/realtime/latency-diagnostics.ts @@ -0,0 +1,71 @@ +/** + * Consolidated latency diagnostics for RT sessions. + * + * Bundles CompositeLatencyTracker and PixelLatencyProbe into one + * pluggable object, keeping client.ts clean. + */ + +import type { LatencyReportMessage, OutgoingMessage } from "./types"; +import { CompositeLatencyTracker, type CompositeLatencyEstimate } from "./composite-latency"; +import { PixelLatencyProbe, type PixelLatencyMeasurement } from "./pixel-latency"; +import type { WebRTCStats } from "./webrtc-stats"; + +export type LatencyDiagnosticsOptions = { + composite?: boolean; + pixelMarker?: boolean; + videoElement?: HTMLVideoElement; + sendMessage: (msg: OutgoingMessage) => void; + onCompositeLatency: (estimate: CompositeLatencyEstimate) => void; + onPixelLatency: (measurement: PixelLatencyMeasurement) => void; +}; + +export class LatencyDiagnostics { + private compositeTracker: CompositeLatencyTracker | null = null; + private pixelProbe: PixelLatencyProbe | null = null; + private latestClientRtt: number | null = null; + private readonly videoElement: HTMLVideoElement | undefined; + private readonly onCompositeLatency: (estimate: CompositeLatencyEstimate) => void; + + constructor(options: LatencyDiagnosticsOptions) { + this.onCompositeLatency = options.onCompositeLatency; + this.videoElement = options.videoElement; + + if (options.composite) { + this.compositeTracker = new CompositeLatencyTracker(); + } + + if (options.pixelMarker && options.videoElement) { + this.pixelProbe = new PixelLatencyProbe( + options.sendMessage, + options.onPixelLatency, + ); + } + } + + /** Handle incoming latency_report from server. */ + onServerReport(msg: LatencyReportMessage): void { + if (!this.compositeTracker) return; + this.compositeTracker.onServerReport(msg); + const estimate = this.compositeTracker.getEstimate(this.latestClientRtt); + if (estimate) { + this.onCompositeLatency(estimate); + } + } + + /** Update client RTT from WebRTC stats. */ + onStats(stats: WebRTCStats): void { + this.latestClientRtt = stats.connection?.currentRoundTripTime ?? null; + } + + /** Start pixel probing (call after video is playing). */ + start(): void { + if (this.pixelProbe && this.videoElement) { + this.pixelProbe.start(this.videoElement); + } + } + + /** Tear down everything. */ + stop(): void { + this.pixelProbe?.stop(); + } +} diff --git a/packages/sdk/src/realtime/methods.ts b/packages/sdk/src/realtime/methods.ts index 6755d41..22d0867 100644 --- a/packages/sdk/src/realtime/methods.ts +++ b/packages/sdk/src/realtime/methods.ts @@ -1,6 +1,6 @@ import { z } from "zod"; +import type { RealtimeTransportManager } from "./transport-manager"; import type { PromptAckMessage } from "./types"; -import type { WebRTCManager } from "./webrtc-manager"; const PROMPT_TIMEOUT_MS = 15 * 1000; // 15 seconds const UPDATE_TIMEOUT_MS = 30 * 1000; @@ -23,11 +23,11 @@ const setPromptInputSchema = z.object({ export type SetInput = z.input; export const realtimeMethods = ( - webrtcManager: WebRTCManager, + manager: RealtimeTransportManager, imageToBase64: (image: Blob | File | string) => Promise, ) => { const assertConnected = () => { - const state = webrtcManager.getConnectionState(); + const state = manager.getConnectionState(); if (state !== "connected" && state !== "generating") { throw new Error(`Cannot send message: connection is ${state}`); } @@ -48,7 +48,7 @@ export const realtimeMethods = ( imageBase64 = await imageToBase64(image); } - await webrtcManager.setImage(imageBase64, { prompt, enhance, timeout: UPDATE_TIMEOUT_MS }); + await manager.setImage(imageBase64, { prompt, enhance, timeout: UPDATE_TIMEOUT_MS }); }; const setPrompt = async (prompt: string, { enhance }: { enhance?: boolean } = {}): Promise => { @@ -63,7 +63,7 @@ export const realtimeMethods = ( throw parsedInput.error; } - const emitter = webrtcManager.getWebsocketMessageEmitter(); + const emitter = manager.getWebsocketMessageEmitter(); let promptAckListener: ((msg: PromptAckMessage) => void) | undefined; let timeoutId: ReturnType | undefined; @@ -83,7 +83,7 @@ export const realtimeMethods = ( }); // Send the message first - const sent = webrtcManager.sendMessage({ + const sent = manager.sendMessage({ type: "prompt", prompt: parsedInput.data.prompt, enhance_prompt: parsedInput.data.enhance, diff --git a/packages/sdk/src/realtime/pixel-latency.ts b/packages/sdk/src/realtime/pixel-latency.ts new file mode 100644 index 0000000..2641c5e --- /dev/null +++ b/packages/sdk/src/realtime/pixel-latency.ts @@ -0,0 +1,164 @@ +import type { LatencyProbeMessage } from "./types"; + +export type PixelLatencyMeasurement = { + seq: number; + e2eLatencyMs: number; + timestamp: number; +}; + +export class PixelLatencyProbe { + private static readonly SYNC = [200, 50, 200, 50]; + private static readonly DATA_BITS = 16; + private static readonly CHECKSUM_BITS = 4; + private static readonly TOTAL_PIXELS = 24; + private static readonly PROBE_INTERVAL_MS = 2000; + private static readonly PROBE_TTL_MS = 10000; + + private seq = 0; + private pendingProbes = new Map(); // seq -> clientTime + private canvas: OffscreenCanvas; + private ctx: OffscreenCanvasRenderingContext2D; + private probeIntervalId: ReturnType | null = null; + private running = false; + + constructor( + private sendMessage: (msg: LatencyProbeMessage) => void, + private onMeasurement: (m: PixelLatencyMeasurement) => void, + ) { + this.canvas = new OffscreenCanvas(PixelLatencyProbe.TOTAL_PIXELS, 1); + const ctx = this.canvas.getContext("2d"); + if (!ctx) throw new Error("Failed to create OffscreenCanvas 2d context"); + this.ctx = ctx; + } + + start(videoElement: HTMLVideoElement): void { + if (this.running) return; + this.running = true; + + // Send probes every 2s + this.probeIntervalId = setInterval(() => this.sendProbe(), PixelLatencyProbe.PROBE_INTERVAL_MS); + + // Read frames + this.readFrameLoop(videoElement); + } + + stop(): void { + this.running = false; + if (this.probeIntervalId != null) { + clearInterval(this.probeIntervalId); + this.probeIntervalId = null; + } + this.pendingProbes.clear(); + } + + private sendProbe(): void { + const seq = ++this.seq; + const clientTime = performance.now(); + this.pendingProbes.set(seq, clientTime); + this.sendMessage({ type: "latency_probe", seq, client_time: clientTime }); + + // Clean up old probes + const now = performance.now(); + for (const [s, t] of this.pendingProbes) { + if (now - t > PixelLatencyProbe.PROBE_TTL_MS) { + this.pendingProbes.delete(s); + } + } + } + + private readFrameLoop(video: HTMLVideoElement): void { + if (!this.running) return; + + // Use requestVideoFrameCallback if available (Chrome/Edge), else requestAnimationFrame + if ("requestVideoFrameCallback" in video) { + (video as any).requestVideoFrameCallback((_now: number, _metadata: any) => { + this.readFrame(video); + this.readFrameLoop(video); + }); + } else { + requestAnimationFrame(() => { + this.readFrame(video); + this.readFrameLoop(video); + }); + } + } + + private readFrame(video: HTMLVideoElement): void { + if (video.videoWidth === 0 || video.videoHeight === 0) return; + + try { + // Draw only the bottom-left 24x1 region + this.ctx.drawImage( + video, + 0, + video.videoHeight - 1, // source x, y (bottom-left) + PixelLatencyProbe.TOTAL_PIXELS, + 1, // source width, height + 0, + 0, // dest x, y + PixelLatencyProbe.TOTAL_PIXELS, + 1, // dest width, height + ); + + const imageData = this.ctx.getImageData(0, 0, PixelLatencyProbe.TOTAL_PIXELS, 1); + const pixels = imageData.data; // RGBA, 4 bytes per pixel + + const seq = this.extractSeq(pixels); + if (seq === null) return; + + const clientTime = this.pendingProbes.get(seq); + if (clientTime == null) return; + + this.pendingProbes.delete(seq); + const e2eLatencyMs = performance.now() - clientTime; + + this.onMeasurement({ + seq, + e2eLatencyMs, + timestamp: Date.now(), + }); + } catch { + // Ignore read errors (cross-origin, etc.) + } + } + + private extractSeq(pixels: Uint8ClampedArray): number | null { + // Check sync pattern (R channel of RGBA, since canvas gives us RGB) + // The Y value from yuv420p gets decoded to approximately the same R value + // We use a wide threshold: >= 128 = high, < 128 = low + for (let i = 0; i < PixelLatencyProbe.SYNC.length; i++) { + const r = pixels[i * 4]; // R channel + const expected = PixelLatencyProbe.SYNC[i]; + const isHigh = r >= 128; + const shouldBeHigh = expected >= 128; + if (isHigh !== shouldBeHigh) return null; + } + + // Extract 16-bit seq + let seq = 0; + for (let i = 0; i < PixelLatencyProbe.DATA_BITS; i++) { + const r = pixels[(4 + i) * 4]; + if (r >= 128) { + seq |= 1 << (PixelLatencyProbe.DATA_BITS - 1 - i); + } + } + + // Verify 4-bit XOR checksum + let expectedChecksum = 0; + for (let i = 0; i < PixelLatencyProbe.DATA_BITS; i += 4) { + expectedChecksum ^= (seq >> i) & 0xf; + } + + let actualChecksum = 0; + for (let i = 0; i < PixelLatencyProbe.CHECKSUM_BITS; i++) { + const r = pixels[(20 + i) * 4]; + if (r >= 128) { + actualChecksum |= 1 << (PixelLatencyProbe.CHECKSUM_BITS - 1 - i); + } + } + + if (expectedChecksum !== actualChecksum) return null; + + return seq; + } +} diff --git a/packages/sdk/src/realtime/subscribe-client.ts b/packages/sdk/src/realtime/subscribe-client.ts index 6b1370f..c751018 100644 --- a/packages/sdk/src/realtime/subscribe-client.ts +++ b/packages/sdk/src/realtime/subscribe-client.ts @@ -6,10 +6,16 @@ type TokenPayload = { sid: string; ip: string; port: number; + transport?: "webrtc" | "ivs"; }; -export function encodeSubscribeToken(sessionId: string, serverIp: string, serverPort: number): string { - return btoa(JSON.stringify({ sid: sessionId, ip: serverIp, port: serverPort })); +export function encodeSubscribeToken( + sessionId: string, + serverIp: string, + serverPort: number, + transport?: "webrtc" | "ivs", +): string { + return btoa(JSON.stringify({ sid: sessionId, ip: serverIp, port: serverPort, transport })); } export function decodeSubscribeToken(token: string): TokenPayload { diff --git a/packages/sdk/src/realtime/telemetry-reporter.ts b/packages/sdk/src/realtime/telemetry-reporter.ts index a2a24bd..21c3b0b 100644 --- a/packages/sdk/src/realtime/telemetry-reporter.ts +++ b/packages/sdk/src/realtime/telemetry-reporter.ts @@ -34,6 +34,7 @@ export interface TelemetryReporterOptions { sessionId: string; model?: string; integration?: string; + transport?: "webrtc" | "ivs"; logger: Logger; reportIntervalMs?: number; } @@ -61,6 +62,7 @@ export class TelemetryReporter implements ITelemetryReporter { private sessionId: string; private model?: string; private integration?: string; + private transport?: "webrtc" | "ivs"; private logger: Logger; private reportIntervalMs: number; private intervalId: ReturnType | null = null; @@ -72,6 +74,7 @@ export class TelemetryReporter implements ITelemetryReporter { this.sessionId = options.sessionId; this.model = options.model; this.integration = options.integration; + this.transport = options.transport; this.logger = options.logger; this.reportIntervalMs = options.reportIntervalMs ?? DEFAULT_REPORT_INTERVAL_MS; } @@ -120,6 +123,7 @@ export class TelemetryReporter implements ITelemetryReporter { sdk_version: VERSION, ...(this.model ? { model: this.model } : {}), ...(this.integration ? { integration: this.integration } : {}), + ...(this.transport ? { transport: this.transport } : {}), }; return { diff --git a/packages/sdk/src/realtime/transport-manager.ts b/packages/sdk/src/realtime/transport-manager.ts new file mode 100644 index 0000000..1d6ce24 --- /dev/null +++ b/packages/sdk/src/realtime/transport-manager.ts @@ -0,0 +1,12 @@ +import type { Emitter } from "mitt"; +import type { ConnectionState, OutgoingMessage, WsMessageEvents } from "./types"; + +export interface RealtimeTransportManager { + connect(localStream: MediaStream | null): Promise; + sendMessage(message: OutgoingMessage): boolean; + setImage(imageBase64: string | null, options?: { prompt?: string; enhance?: boolean; timeout?: number }): Promise; + cleanup(): void; + isConnected(): boolean; + getConnectionState(): ConnectionState; + getWebsocketMessageEmitter(): Emitter; +} diff --git a/packages/sdk/src/realtime/types.ts b/packages/sdk/src/realtime/types.ts index e1618e8..3ff67a9 100644 --- a/packages/sdk/src/realtime/types.ts +++ b/packages/sdk/src/realtime/types.ts @@ -71,6 +71,18 @@ export type SessionIdMessage = { server_port: number; }; +export type LatencyReportMessage = { + type: "latency_report"; + server_proxy_rtt_ms: number; + pipeline_latency_ms: number; +}; + +export type LatencyProbeMessage = { + type: "latency_probe"; + seq: number; + client_time: number; +}; + export type ConnectionState = "connecting" | "connected" | "generating" | "disconnected" | "reconnecting"; // Incoming message types (from server) @@ -85,7 +97,8 @@ export type IncomingWebRTCMessage = | GenerationStartedMessage | GenerationTickMessage | GenerationEndedMessage - | SessionIdMessage; + | SessionIdMessage + | LatencyReportMessage; // Outgoing message types (to server) export type OutgoingWebRTCMessage = @@ -93,6 +106,43 @@ export type OutgoingWebRTCMessage = | AnswerMessage | IceCandidateMessage | PromptMessage - | SetAvatarImageMessage; + | SetAvatarImageMessage + | LatencyProbeMessage; + +export type OutgoingMessage = PromptMessage | SetAvatarImageMessage | LatencyProbeMessage; + +// IVS message types +export type IvsStageReadyMessage = { + type: "ivs_stage_ready"; + stage_arn: string; + client_publish_token: string; + client_subscribe_token: string; +}; + +export type IvsJoinedMessage = { + type: "ivs_joined"; +}; -export type OutgoingMessage = PromptMessage | SetAvatarImageMessage; +// IVS incoming messages (from bouncer) +export type IncomingIVSMessage = + | IvsStageReadyMessage + | PromptAckMessage + | ErrorMessage + | SetImageAckMessage + | GenerationStartedMessage + | GenerationTickMessage + | GenerationEndedMessage + | SessionIdMessage + | LatencyReportMessage; + +// IVS outgoing messages (to bouncer) +export type OutgoingIVSMessage = IvsJoinedMessage | PromptMessage | SetAvatarImageMessage | LatencyProbeMessage; + +// Shared WebSocket message events (used by both WebRTC and IVS transports) +export type WsMessageEvents = { + promptAck: PromptAckMessage; + setImageAck: SetImageAckMessage; + sessionId: SessionIdMessage; + generationTick: GenerationTickMessage; + latencyReport: LatencyReportMessage; +}; diff --git a/packages/sdk/src/realtime/webrtc-connection.ts b/packages/sdk/src/realtime/webrtc-connection.ts index dc5802b..facbda1 100644 --- a/packages/sdk/src/realtime/webrtc-connection.ts +++ b/packages/sdk/src/realtime/webrtc-connection.ts @@ -5,12 +5,11 @@ import { buildUserAgent } from "../utils/user-agent"; import type { DiagnosticEmitter, IceCandidateEvent } from "./diagnostics"; import type { ConnectionState, - GenerationTickMessage, IncomingWebRTCMessage, OutgoingWebRTCMessage, PromptAckMessage, - SessionIdMessage, SetImageAckMessage, + WsMessageEvents, } from "./types"; const ICE_SERVERS: RTCIceServer[] = [{ urls: "stun:stun.l.google.com:19302" }]; @@ -30,13 +29,6 @@ interface ConnectionCallbacks { onDiagnostic?: DiagnosticEmitter; } -type WsMessageEvents = { - promptAck: PromptAckMessage; - setImageAck: SetImageAckMessage; - sessionId: SessionIdMessage; - generationTick: GenerationTickMessage; -}; - const noopDiagnostic: DiagnosticEmitter = () => {}; export class WebRTCConnection { @@ -254,6 +246,11 @@ export class WebRTCConnection { return; } + if (msg.type === "latency_report") { + this.websocketMessagesEmitter.emit("latencyReport", msg); + return; + } + // All other messages require peer connection if (!this.pc) return; diff --git a/packages/sdk/src/realtime/webrtc-manager.ts b/packages/sdk/src/realtime/webrtc-manager.ts index 71408fb..c986979 100644 --- a/packages/sdk/src/realtime/webrtc-manager.ts +++ b/packages/sdk/src/realtime/webrtc-manager.ts @@ -2,6 +2,7 @@ import pRetry, { AbortError } from "p-retry"; import type { Logger } from "../utils/logger"; import type { DiagnosticEmitter } from "./diagnostics"; +import type { RealtimeTransportManager } from "./transport-manager"; import type { ConnectionState, OutgoingMessage } from "./types"; import { WebRTCConnection } from "./webrtc-connection"; @@ -39,7 +40,7 @@ const RETRY_OPTIONS = { maxTimeout: 10000, } as const; -export class WebRTCManager { +export class WebRTCManager implements RealtimeTransportManager { private connection: WebRTCConnection; private config: WebRTCConfig; private logger: Logger; diff --git a/packages/sdk/src/realtime/webrtc-stats.ts b/packages/sdk/src/realtime/webrtc-stats.ts index 42319a4..85b27f4 100644 --- a/packages/sdk/src/realtime/webrtc-stats.ts +++ b/packages/sdk/src/realtime/webrtc-stats.ts @@ -63,9 +63,7 @@ export type StatsOptions = { const DEFAULT_INTERVAL_MS = 1000; const MIN_INTERVAL_MS = 500; -export class WebRTCStatsCollector { - private pc: RTCPeerConnection | null = null; - private intervalId: ReturnType | null = null; +export class StatsParser { private prevBytesVideo = 0; private prevBytesAudio = 0; private prevBytesSentVideo = 0; @@ -76,18 +74,9 @@ export class WebRTCStatsCollector { private prevFreezeCount = 0; private prevFreezeDuration = 0; private prevPacketsLostAudio = 0; - private onStats: ((stats: WebRTCStats) => void) | null = null; - private intervalMs: number; - constructor(options: StatsOptions = {}) { - this.intervalMs = Math.max(options.intervalMs ?? DEFAULT_INTERVAL_MS, MIN_INTERVAL_MS); - } - - /** Attach to a peer connection and start polling. */ - start(pc: RTCPeerConnection, onStats: (stats: WebRTCStats) => void): void { - this.stop(); - this.pc = pc; - this.onStats = onStats; + /** Reset all delta-tracking state to zero. */ + reset(): void { this.prevBytesVideo = 0; this.prevBytesAudio = 0; this.prevBytesSentVideo = 0; @@ -97,37 +86,9 @@ export class WebRTCStatsCollector { this.prevFreezeCount = 0; this.prevFreezeDuration = 0; this.prevPacketsLostAudio = 0; - this.intervalId = setInterval(() => this.collect(), this.intervalMs); } - /** Stop polling and release resources. */ - stop(): void { - if (this.intervalId !== null) { - clearInterval(this.intervalId); - this.intervalId = null; - } - this.pc = null; - this.onStats = null; - } - - isRunning(): boolean { - return this.intervalId !== null; - } - - private async collect(): Promise { - if (!this.pc || !this.onStats) return; - - try { - const rawStats = await this.pc.getStats(); - const stats = this.parse(rawStats); - this.onStats(stats); - } catch { - // PC might be closed; stop silently - this.stop(); - } - } - - private parse(rawStats: RTCStatsReport): WebRTCStats { + parse(rawStats: RTCStatsReport): WebRTCStats { const now = performance.now(); const elapsed = this.prevTimestamp > 0 ? (now - this.prevTimestamp) / 1000 : 0; @@ -231,3 +192,51 @@ export class WebRTCStatsCollector { }; } } + +export class WebRTCStatsCollector { + private pc: RTCPeerConnection | null = null; + private intervalId: ReturnType | null = null; + private parser = new StatsParser(); + private onStats: ((stats: WebRTCStats) => void) | null = null; + private intervalMs: number; + + constructor(options: StatsOptions = {}) { + this.intervalMs = Math.max(options.intervalMs ?? DEFAULT_INTERVAL_MS, MIN_INTERVAL_MS); + } + + /** Attach to a peer connection and start polling. */ + start(pc: RTCPeerConnection, onStats: (stats: WebRTCStats) => void): void { + this.stop(); + this.pc = pc; + this.onStats = onStats; + this.parser.reset(); + this.intervalId = setInterval(() => this.collect(), this.intervalMs); + } + + /** Stop polling and release resources. */ + stop(): void { + if (this.intervalId !== null) { + clearInterval(this.intervalId); + this.intervalId = null; + } + this.pc = null; + this.onStats = null; + } + + isRunning(): boolean { + return this.intervalId !== null; + } + + private async collect(): Promise { + if (!this.pc || !this.onStats) return; + + try { + const rawStats = await this.pc.getStats(); + const stats = this.parser.parse(rawStats); + this.onStats(stats); + } catch { + // PC might be closed; stop silently + this.stop(); + } + } +}