From fb8d2540f48aafed32a3cd4cf3f2dbdace8587de Mon Sep 17 00:00:00 2001 From: Eilon Moalem Date: Sun, 31 May 2026 13:15:06 +0300 Subject: [PATCH] realtime warmup start --- packages/sdk/src/index.ts | 2 + packages/sdk/src/realtime/client.ts | 370 ++++++++++++++------ packages/sdk/src/realtime/media-channel.ts | 12 +- packages/sdk/src/realtime/stream-session.ts | 20 +- packages/sdk/tests/realtime.unit.test.ts | 121 ++++++- 5 files changed, 405 insertions(+), 120 deletions(-) diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 81b7053..25e45e2 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -27,6 +27,7 @@ export type { RealTimeClient, RealTimeClientConnectOptions, RealTimeClientInitialState, + RealTimeWarmupClient, } from "./realtime/client"; export type { SetInput } from "./realtime/methods"; export type { @@ -238,6 +239,7 @@ export const createDecartClient = (options: DecartClientOptions = {}) => { return { realtime: { connect: realtimePublish.connect, + warmup: realtimePublish.warmup, subscribe: realtimeSubscribe.subscribe, }, diff --git a/packages/sdk/src/realtime/client.ts b/packages/sdk/src/realtime/client.ts index b35987d..9aac99e 100644 --- a/packages/sdk/src/realtime/client.ts +++ b/packages/sdk/src/realtime/client.ts @@ -70,6 +70,7 @@ export type Events = { diagnostic: DiagnosticEvent; stats: WebRTCStats; }; +type EventBuffer = ReturnType>; export type RealTimeClient = { set: (input: SetInput) => Promise; @@ -91,28 +92,35 @@ export type RealTimeClient = { setImage: (image: Blob | File | string | null, options?: ImageSetOptions) => Promise; }; +export type RealTimeWarmupClient = { + start: (stream: MediaStream) => Promise; + isConnected: () => boolean; + getConnectionState: () => ConnectionState; + disconnect: () => void; + on: (event: K, listener: (data: Events[K]) => void) => void; + off: (event: K, listener: (data: Events[K]) => void) => void; + sessionId: string | null; + subscribeToken: string | null; + getSubscribeToken: () => string | null; +}; + export const createRealTimeClient = (opts: RealTimeClientOptions) => { const { baseUrl, apiKey, integration } = opts; const logger = opts.logger ?? createConsoleLogger("info"); - const connect = async ( + const prepareInputStream = ( stream: MediaStream | null, - options: RealTimeClientConnectOptions, - ): Promise => { - const parsedOptions = realTimeClientConnectOptionsSchema.safeParse(options); - if (!parsedOptions.success) throw parsedOptions.error; - - const { onRemoteStream, onConnectionChange, onQueuePosition, initialState, resolution, preferredVideoCodec } = - parsedOptions.data; - const mirror = parsedOptions.data.mirror ?? false; + mirror: "auto" | boolean, + fps: number, + ): { inputStream: MediaStream; dispose: () => void } => { let inputStream: MediaStream = stream ?? new MediaStream(); - let mirroredStream: MirroredStream | undefined; + if (mirror !== false) { try { const firstVideoTrack = inputStream.getVideoTracks?.()[0]; if (firstVideoTrack && (mirror === true || shouldMirrorTrack(firstVideoTrack))) { - mirroredStream = createMirroredStream(inputStream, { fps: resolveFpsNumber(options.model.fps) }); + mirroredStream = createMirroredStream(inputStream, { fps }); inputStream = mirroredStream.stream; } else if (mirror === true && !firstVideoTrack) { logger.warn("mirror: true requested but no video track was found on the input stream"); @@ -124,125 +132,261 @@ export const createRealTimeClient = (opts: RealTimeClientOptions) => { } } - let session: StreamSession | undefined; - let observability: RealtimeObservability | undefined; + return { + inputStream, + dispose: () => mirroredStream?.dispose(), + }; + }; - try { - const initialImageRef = isFileRefId(initialState?.image) ? initialState.image : undefined; - const initialImage = - initialImageRef === undefined && initialState?.image ? await imageToBase64(initialState.image) : undefined; - const initialPrompt = initialState?.prompt - ? { text: initialState.prompt.text, enhance: initialState.prompt.enhance } - : undefined; - - const url = `${baseUrl}${options.model.urlPath}`; - const { emitter: eventEmitter, emitOrBuffer, flush, stop } = createEventBuffer(); - - observability = new RealtimeObservability({ - telemetryEnabled: opts.telemetryEnabled, - apiKey, - model: options.model.name, - integration, - logger, - onDiagnostic: (event) => emitOrBuffer("diagnostic", event), - onStats: (stats) => emitOrBuffer("stats", stats), - }); + type ParsedConnectOptions = z.infer; - const safariCodec = isDesktopSafari() ? "vp8" : undefined; - const publishCodec: VideoCodec | undefined = safariCodec ?? preferredVideoCodec; + const createClientHandle = ({ + activeSession, + eventEmitter, + stop, + observability, + getSessionId, + getSubscribeToken, + disposeInput, + }: { + activeSession: StreamSession; + eventEmitter: EventBuffer["emitter"]; + stop: () => void; + observability: RealtimeObservability; + getSessionId: () => string | null; + getSubscribeToken: () => string | null; + disposeInput: () => void; + }): RealTimeClient => { + const methods = realtimeMethods(activeSession, imageToBase64); - const queryParams = new URLSearchParams({ - ...(safariCodec ? { livekit_server_codec: safariCodec } : {}), - ...(options.queryParams ?? {}), - api_key: apiKey, - model: options.model.name, - ...(resolution ? { resolution } : {}), - }); + return { + ...methods, + isConnected: () => activeSession.isConnected(), + getConnectionState: () => activeSession.getConnectionState(), + disconnect: () => { + observability.stop(); + stop(); + activeSession.disconnect(); + disposeInput(); + }, + on: eventEmitter.on, + off: eventEmitter.off, + get sessionId() { + return getSessionId(); + }, + get subscribeToken() { + return getSubscribeToken(); + }, + getSubscribeToken, + setImage: async (image: Blob | File | string | null, imgOptions?: ImageSetOptions) => { + if (isFileRefId(image)) { + return activeSession.setImage({ kind: "ref", ref: image }, imgOptions); + } + if (image === null) return activeSession.setImage({ kind: "data", data: null }, imgOptions); + const base64 = await imageToBase64(image); + return activeSession.setImage({ kind: "data", data: base64 }, imgOptions); + }, + }; + }; - session = new StreamSession({ - url: `${url}?${queryParams.toString()}`, - integration, - observability, - localStream: inputStream, - initialImage, - initialImageRef, - initialPrompt, - logger, - videoCodec: publishCodec, - }); + const openSession = async ({ + localStream, + options, + parsedOptions, + livekitWarmup, + }: { + localStream: MediaStream | null; + options: RealTimeClientConnectOptions; + parsedOptions: ParsedConnectOptions; + livekitWarmup: boolean; + }) => { + const { onRemoteStream, onConnectionChange, onQueuePosition, initialState, resolution, preferredVideoCodec } = + parsedOptions; - let sessionId: string | null = null; - let subscribeToken: string | null = null; + const initialImageRef = isFileRefId(initialState?.image) ? initialState.image : undefined; + const initialImage = + initialImageRef === undefined && initialState?.image ? await imageToBase64(initialState.image) : undefined; + const initialPrompt = initialState?.prompt + ? { text: initialState.prompt.text, enhance: initialState.prompt.enhance } + : undefined; - session.on("remoteStream", onRemoteStream); + const url = `${baseUrl}${options.model.urlPath}`; + const { emitter: eventEmitter, emitOrBuffer, flush, stop } = createEventBuffer(); - session.on("connectionChange", (state) => { - emitOrBuffer("connectionChange", state); - onConnectionChange?.(state); - }); + const observability = new RealtimeObservability({ + telemetryEnabled: opts.telemetryEnabled, + apiKey, + model: options.model.name, + integration, + logger, + onDiagnostic: (event) => emitOrBuffer("diagnostic", event), + onStats: (stats) => emitOrBuffer("stats", stats), + }); - session.on("queuePosition", (qp) => { - emitOrBuffer("queuePosition", qp); - onQueuePosition?.(qp); - }); + const safariCodec = isDesktopSafari() ? "vp8" : undefined; + const publishCodec: VideoCodec | undefined = safariCodec ?? preferredVideoCodec; - session.on("sessionStarted", ({ sessionId: id, subscribeToken: token }) => { - sessionId = id; - subscribeToken = token; - observability?.sessionStarted(id); - }); + const queryParams = new URLSearchParams({ + ...(safariCodec ? { livekit_server_codec: safariCodec } : {}), + ...(options.queryParams ?? {}), + ...(livekitWarmup ? { livekit_warmup: "1" } : {}), + api_key: apiKey, + model: options.model.name, + ...(resolution ? { resolution } : {}), + }); + + const session = new StreamSession({ + url: `${url}?${queryParams.toString()}`, + integration, + observability, + localStream, + initialImage, + initialImageRef, + initialPrompt, + logger, + videoCodec: publishCodec, + waitForInitialStateAck: !livekitWarmup, + }); + + let sessionId: string | null = null; + let subscribeToken: string | null = null; + + session.on("remoteStream", onRemoteStream); + + session.on("connectionChange", (state) => { + emitOrBuffer("connectionChange", state); + onConnectionChange?.(state); + }); - session.on("generationTick", (e) => emitOrBuffer("generationTick", e)); - session.on("generationEnded", (e) => emitOrBuffer("generationEnded", e)); + session.on("queuePosition", (qp) => { + emitOrBuffer("queuePosition", qp); + onQueuePosition?.(qp); + }); - session.on("error", (error) => { - logger.error("Realtime error", { error: error.message }); - emitOrBuffer("error", classifyWebrtcError(error)); + session.on("sessionStarted", ({ sessionId: id, subscribeToken: token }) => { + sessionId = id; + subscribeToken = token; + observability.sessionStarted(id); + }); + + session.on("generationTick", (e) => emitOrBuffer("generationTick", e)); + session.on("generationEnded", (e) => emitOrBuffer("generationEnded", e)); + + session.on("error", (error) => { + logger.error("Realtime error", { error: error.message }); + emitOrBuffer("error", classifyWebrtcError(error)); + }); + + try { + await session.connect(); + } catch (error) { + observability.stop(); + session.disconnect(); + stop(); + throw error; + } + + return { + activeSession: session, + eventEmitter, + flush, + stop, + observability, + getSessionId: () => sessionId, + getSubscribeToken: () => subscribeToken, + }; + }; + + const connect = async ( + stream: MediaStream | null, + options: RealTimeClientConnectOptions, + ): Promise => { + const parsedOptions = realTimeClientConnectOptionsSchema.safeParse(options); + if (!parsedOptions.success) throw parsedOptions.error; + + const mirror = parsedOptions.data.mirror ?? false; + const prepared = prepareInputStream(stream, mirror, resolveFpsNumber(parsedOptions.data.model.fps)); + + try { + const sessionContext = await openSession({ + localStream: prepared.inputStream, + options, + parsedOptions: parsedOptions.data, + livekitWarmup: false, }); - const activeSession = session; - await activeSession.connect(); - - const methods = realtimeMethods(activeSession, imageToBase64); - - const client: RealTimeClient = { - ...methods, - isConnected: () => activeSession.isConnected(), - getConnectionState: () => activeSession.getConnectionState(), - disconnect: () => { - observability?.stop(); - stop(); - activeSession.disconnect(); - mirroredStream?.dispose(); - }, - on: eventEmitter.on, - off: eventEmitter.off, - get sessionId() { - return sessionId; - }, - get subscribeToken() { - return subscribeToken; - }, - getSubscribeToken: () => subscribeToken, - setImage: async (image: Blob | File | string | null, imgOptions?: ImageSetOptions) => { - if (isFileRefId(image)) { - return activeSession.setImage({ kind: "ref", ref: image }, imgOptions); - } - if (image === null) return activeSession.setImage({ kind: "data", data: null }, imgOptions); - const base64 = await imageToBase64(image); - return activeSession.setImage({ kind: "data", data: base64 }, imgOptions); - }, - }; - - flush(); + const client = createClientHandle({ + ...sessionContext, + disposeInput: prepared.dispose, + }); + sessionContext.flush(); return client; } catch (error) { - observability?.stop(); - session?.disconnect(); - mirroredStream?.dispose(); + prepared.dispose(); throw error; } }; - return { connect }; + const warmup = async (options: RealTimeClientConnectOptions): Promise => { + const parsedOptions = realTimeClientConnectOptionsSchema.safeParse(options); + if (!parsedOptions.success) throw parsedOptions.error; + + const sessionContext = await openSession({ + localStream: null, + options, + parsedOptions: parsedOptions.data, + livekitWarmup: true, + }); + + let started = false; + let disposeStartedInput: (() => void) | undefined; + const mirror = parsedOptions.data.mirror ?? false; + + const disconnect = () => { + sessionContext.observability.stop(); + sessionContext.stop(); + sessionContext.activeSession.disconnect(); + disposeStartedInput?.(); + }; + + const warmupClient: RealTimeWarmupClient = { + start: async (stream: MediaStream) => { + if (started) { + throw new Error("Realtime warmup has already been started"); + } + started = true; + const prepared = prepareInputStream(stream, mirror, resolveFpsNumber(parsedOptions.data.model.fps)); + disposeStartedInput = prepared.dispose; + try { + await sessionContext.activeSession.publishLocalStream(prepared.inputStream); + } catch (error) { + prepared.dispose(); + throw error; + } + return createClientHandle({ + ...sessionContext, + disposeInput: () => { + prepared.dispose(); + }, + }); + }, + isConnected: () => sessionContext.activeSession.isConnected(), + getConnectionState: () => sessionContext.activeSession.getConnectionState(), + disconnect, + on: sessionContext.eventEmitter.on, + off: sessionContext.eventEmitter.off, + get sessionId() { + return sessionContext.getSessionId(); + }, + get subscribeToken() { + return sessionContext.getSubscribeToken(); + }, + getSubscribeToken: sessionContext.getSubscribeToken, + }; + + sessionContext.flush(); + return warmupClient; + }; + + return { connect, warmup }; }; diff --git a/packages/sdk/src/realtime/media-channel.ts b/packages/sdk/src/realtime/media-channel.ts index 725943d..5630988 100644 --- a/packages/sdk/src/realtime/media-channel.ts +++ b/packages/sdk/src/realtime/media-channel.ts @@ -57,13 +57,19 @@ export class MediaChannel { private remoteStream: MediaStream | null = null; private events: Emitter = mitt(); private readonly logger: Logger; + private localStreamValue: MediaStream | null; constructor(private readonly config: MediaChannelConfig) { this.logger = config.logger ?? createConsoleLogger("warn"); + this.localStreamValue = config.localStream; } get localStream(): MediaStream | null { - return this.config.localStream; + return this.localStreamValue; + } + + setLocalStream(stream: MediaStream | null): void { + this.localStreamValue = stream; } on(event: E, handler: (data: MediaChannelEvents[E]) => void): void { @@ -108,9 +114,9 @@ export class MediaChannel { } async publishLocalTracks(): Promise { - if (!this.config.localStream) return; + if (!this.localStreamValue) return; this.config.observability?.startPhase("publish-local-track"); - await this.publishTracks(this.config.localStream); + await this.publishTracks(this.localStreamValue); this.config.observability?.endPhase("publish-local-track", { success: true }); } diff --git a/packages/sdk/src/realtime/stream-session.ts b/packages/sdk/src/realtime/stream-session.ts index 014baee..6661ff8 100644 --- a/packages/sdk/src/realtime/stream-session.ts +++ b/packages/sdk/src/realtime/stream-session.ts @@ -61,6 +61,7 @@ interface StreamSessionConfig { initialPrompt?: InitialPrompt; logger?: Logger; videoCodec?: VideoCodec; + waitForInitialStateAck?: boolean; } export class StreamSession { @@ -73,12 +74,14 @@ export class StreamSession { private disposed = false; private currentAttempt = 0; + private localStream: MediaStream | null; private readonly initialStateGate = new InitialStateGate(); private readonly logger: Logger; constructor(private readonly config: StreamSessionConfig) { this.logger = config.logger ?? createConsoleLogger("warn"); + this.localStream = config.localStream; this.createTransport(); } @@ -130,6 +133,16 @@ export class StreamSession { return this.signaling.setImage(payload, opts); } + async publishLocalStream(stream: MediaStream): Promise { + this.assertConnected(); + if (this.localStream) { + throw new Error("Local stream has already been published"); + } + this.localStream = stream; + this.media.setLocalStream(stream); + await this.media.publishLocalTracks(); + } + disconnect(): void { this.disposed = true; this.tearDown(); @@ -188,7 +201,8 @@ export class StreamSession { url: roomInfo.livekitUrl, token: roomInfo.token, }); - const isCurrentAttempt = await gateAttempt.waitForReadiness(initialStateAck); + const isCurrentAttempt = + this.config.waitForInitialStateAck === false ? true : await gateAttempt.waitForReadiness(initialStateAck); if (!isCurrentAttempt) { throw new AbortError("Stale connect attempt"); } @@ -243,7 +257,7 @@ export class StreamSession { }; } - if (this.config.localStream) { + if (this.localStream) { return { image: null, prompt: null }; } @@ -314,7 +328,7 @@ export class StreamSession { }); this.media = new MediaChannel({ observability: this.config.observability, - localStream: this.config.localStream, + localStream: this.localStream, logger: this.logger, videoCodec: this.config.videoCodec, }); diff --git a/packages/sdk/tests/realtime.unit.test.ts b/packages/sdk/tests/realtime.unit.test.ts index 270d1f4..61b2f60 100644 --- a/packages/sdk/tests/realtime.unit.test.ts +++ b/packages/sdk/tests/realtime.unit.test.ts @@ -1,5 +1,5 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; -import { models } from "../src/index.js"; +import { createDecartClient, models } from "../src/index.js"; import { REALTIME_CONFIG } from "../src/realtime/config-realtime.js"; import type { ServerError } from "../src/realtime/types.js"; @@ -94,6 +94,17 @@ type FakeWebSocketCloseEvent = { reason: string; }; +describe("top-level realtime client", () => { + it("exposes warmup through createDecartClient", () => { + const client = createDecartClient({ + apiKey: "test", + realtimeBaseUrl: "wss://realtime.example.com", + }); + + expect(typeof client.realtime.warmup).toBe("function"); + }); +}); + describe("Lucy 2.1 realtime", () => { describe("Model Definition", () => { it("has correct model name", () => { @@ -334,6 +345,7 @@ describe("realtime.connect options", () => { onopen: (() => void) | null = null; onmessage: ((event: FakeWebSocketMessageEvent) => void) | null = null; onclose: ((event: FakeWebSocketCloseEvent) => void) | null = null; + sentMessages: unknown[] = []; constructor(readonly url: string) { FakeWebSocket.instances.push(this); @@ -342,6 +354,7 @@ describe("realtime.connect options", () => { send(data: string): void { const message = JSON.parse(data); + this.sentMessages.push(message); if (message.type === "livekit_join") { setTimeout(() => { this.onmessage?.({ @@ -363,6 +376,7 @@ describe("realtime.connect options", () => { beforeEach(() => { FakeWebSocket.instances = []; + liveKitMock.roomInstances.length = 0; vi.stubGlobal("WebSocket", FakeWebSocket); vi.stubGlobal("MediaStream", FakeMediaStream); }); @@ -409,6 +423,111 @@ describe("realtime.connect options", () => { ).rejects.toThrow(); expect(FakeWebSocket.instances).toHaveLength(0); }); + + const createLocalStream = () => + new MediaStream([ + { id: "local-video", kind: "video" }, + { id: "local-audio", kind: "audio" }, + ] as unknown[]) as MediaStream; + + it("warmup adds livekit_warmup and connects LiveKit without publishing tracks", async () => { + const { createRealTimeClient } = await import("../src/realtime/client.js"); + const client = createRealTimeClient({ + baseUrl: "wss://api3.decart.ai", + apiKey: "test-key", + logger: { debug() {}, info() {}, warn() {}, error() {} }, + telemetryEnabled: false, + }); + + const warmupClient = await client.warmup({ + model: models.realtime("lucy-2.1"), + onRemoteStream: vi.fn(), + }); + + const url = new URL(FakeWebSocket.instances[0].url); + expect(url.searchParams.get("livekit_warmup")).toBe("1"); + const room = liveKitMock.roomInstances[0] as InstanceType; + expect(room.connect).toHaveBeenCalledWith("wss://livekit.example.test", "token"); + expect(room.localParticipant.publishTrack).not.toHaveBeenCalled(); + warmupClient.disconnect(); + }); + + it("warmup does not wait for initial prompt ack before becoming ready", async () => { + const { createRealTimeClient } = await import("../src/realtime/client.js"); + const client = createRealTimeClient({ + baseUrl: "wss://api3.decart.ai", + apiKey: "test-key", + logger: { debug() {}, info() {}, warn() {}, error() {} }, + telemetryEnabled: false, + }); + + const warmupClient = await Promise.race([ + client.warmup({ + model: models.realtime("lucy-2.1"), + onRemoteStream: vi.fn(), + initialState: { + prompt: { text: "test prompt" }, + }, + }), + new Promise((_, reject) => setTimeout(() => reject(new Error("warmup timed out")), 100)), + ]); + + expect(FakeWebSocket.instances[0].sentMessages).toContainEqual({ type: "livekit_join" }); + expect(FakeWebSocket.instances[0].sentMessages).toContainEqual({ + type: "prompt", + prompt: "test prompt", + enhance_prompt: true, + }); + warmupClient.disconnect(); + }); + + it("warmup start publishes local video and audio tracks once", async () => { + const { createRealTimeClient } = await import("../src/realtime/client.js"); + const client = createRealTimeClient({ + baseUrl: "wss://api3.decart.ai", + apiKey: "test-key", + logger: { debug() {}, info() {}, warn() {}, error() {} }, + telemetryEnabled: false, + }); + + const warmupClient = await client.warmup({ + model: models.realtime("lucy-2.1"), + preferredVideoCodec: "h264", + onRemoteStream: vi.fn(), + }); + const localStream = createLocalStream(); + const realtimeClient = await warmupClient.start(localStream); + + const room = liveKitMock.roomInstances[0] as InstanceType; + expect(room.localParticipant.publishTrack).toHaveBeenCalledTimes(2); + expect(room.localParticipant.publishTrack).toHaveBeenNthCalledWith( + 1, + localStream.getTracks()[0], + expect.objectContaining({ videoCodec: "h264", source: liveKitMock.Track.Source.Camera }), + ); + expect(room.localParticipant.publishTrack).toHaveBeenNthCalledWith(2, localStream.getTracks()[1]); + await expect(warmupClient.start(localStream)).rejects.toThrow("already been started"); + realtimeClient.disconnect(); + }); + + it("warmup disconnect before start closes the room and websocket", async () => { + const { createRealTimeClient } = await import("../src/realtime/client.js"); + const client = createRealTimeClient({ + baseUrl: "wss://api3.decart.ai", + apiKey: "test-key", + logger: { debug() {}, info() {}, warn() {}, error() {} }, + telemetryEnabled: false, + }); + + const warmupClient = await client.warmup({ + model: models.realtime("lucy-2.1"), + onRemoteStream: vi.fn(), + }); + warmupClient.disconnect(); + + const room = liveKitMock.roomInstances[0] as InstanceType; + expect(room.disconnect).toHaveBeenCalled(); + }); }); describe("SignalingChannel initial handshake", () => {