From c9876826ea4696fcede051663ccc738455c9a886 Mon Sep 17 00:00:00 2001 From: Chris Busillo Date: Sun, 3 May 2026 12:39:02 -0400 Subject: [PATCH] feat(core): reuse Responses WebSocket sessions Prewarm Responses WebSocket sessions with generate=false, reuse the connection across cloned model clients, and chain compatible follow-up requests with previous_response_id plus incremental input. --- code-rs/core/src/client.rs | 709 +++++++++++++++++++++++++++---------- 1 file changed, 514 insertions(+), 195 deletions(-) diff --git a/code-rs/core/src/client.rs b/code-rs/core/src/client.rs index 54894da534b8..61a024a0909b 100644 --- a/code-rs/core/src/client.rs +++ b/code-rs/core/src/client.rs @@ -24,6 +24,7 @@ use serde::Deserialize; use serde::Serialize; use serde_json::Value; use tokio::sync::mpsc; +use tokio::sync::Mutex as TokioMutex; use tokio::time::timeout; use tokio_util::io::ReaderStream; use tokio_stream::wrappers::ReceiverStream; @@ -32,6 +33,8 @@ use tracing::trace; use tracing::warn; use uuid::Uuid; use chrono::{DateTime, Duration as ChronoDuration, Utc}; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Error as WsError; use tokio_tungstenite::tungstenite::client::IntoClientRequest; @@ -114,6 +117,109 @@ const MODEL_CAP_RESET_AFTER_HEADER: &str = "x-codex-model-cap-reset-after-second const CODE_OPENAI_SUBAGENT_ENV: &str = "CODE_OPENAI_SUBAGENT"; +type ResponsesWebSocketStream = WebSocketStream>; + +#[derive(Clone, Debug, PartialEq)] +struct ResponsesRequestSnapshot { + comparable_payload: Value, + input: Vec, +} + +#[derive(Default)] +struct ResponsesWebsocketSession { + connection: Option, + turn_state: Arc>, + last_request: Option, + last_response_id: Option, +} + +impl std::fmt::Debug for ResponsesWebsocketSession { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResponsesWebsocketSession") + .field("connected", &self.connection.is_some()) + .field("has_turn_state", &self.turn_state.get().is_some()) + .field("has_last_request", &self.last_request.is_some()) + .field("has_last_response_id", &self.last_response_id.is_some()) + .finish() + } +} + +fn responses_request_snapshot( + payload_json: &Value, + input: &[ResponseItem], +) -> ResponsesRequestSnapshot { + let mut comparable_payload = payload_json.clone(); + if let Some(obj) = comparable_payload.as_object_mut() { + obj.remove("input"); + } + ResponsesRequestSnapshot { + comparable_payload, + input: input.to_vec(), + } +} + +fn incremental_input_for_websocket_request( + previous: &ResponsesRequestSnapshot, + current: &ResponsesRequestSnapshot, +) -> Option> { + if previous.comparable_payload != current.comparable_payload { + return None; + } + if !current.input.starts_with(&previous.input) { + return None; + } + Some(current.input[previous.input.len()..].to_vec()) +} + +fn build_responses_websocket_payload( + payload_json: &Value, + input_override: Option>, + previous_response_id: Option, + generate: Option, +) -> Result { + let mut ws_payload = serde_json::Map::new(); + ws_payload.insert( + "type".to_string(), + serde_json::Value::String("response.create".to_string()), + ); + if let Some(obj) = payload_json.as_object() { + for (k, v) in obj { + ws_payload.insert(k.clone(), v.clone()); + } + } + if let Some(input) = input_override { + ws_payload.insert("input".to_string(), serde_json::to_value(input)?); + } + if let Some(previous_response_id) = previous_response_id { + ws_payload.insert( + "previous_response_id".to_string(), + Value::String(previous_response_id), + ); + } + if let Some(generate) = generate { + ws_payload.insert("generate".to_string(), Value::Bool(generate)); + } + Ok(serde_json::to_string(&Value::Object(ws_payload))?) +} + +fn terminal_response_id_from_websocket_event(text: &str) -> Option> { + let event: SseEvent = serde_json::from_str(text).ok()?; + match event.kind.as_str() { + "response.completed" => { + let response = event.response?; + let completed: ResponseCompleted = serde_json::from_value(response).ok()?; + Some(Some(completed.id)) + } + "response.done" => { + let response = event.response?; + let done: ResponseDone = serde_json::from_value(response).ok()?; + Some(done.id) + } + "response.failed" | "response.incomplete" => Some(None), + _ => None, + } +} + #[derive(Default, Debug)] struct StreamCheckpoint { /// Highest sequence_number observed across attempts. Used to drop replayed deltas. @@ -265,6 +371,7 @@ pub struct ModelClient { summary: ReasoningSummaryConfig, reasoning_summary_disabled: AtomicBool, websockets_disabled: AtomicBool, + websocket_session: Arc>, verbosity: TextVerbosityConfig, debug_logger: Arc>, } @@ -286,6 +393,7 @@ impl Clone for ModelClient { websockets_disabled: AtomicBool::new( self.websockets_disabled.load(Ordering::Relaxed), ), + websocket_session: Arc::clone(&self.websocket_session), verbosity: self.verbosity, debug_logger: Arc::clone(&self.debug_logger), } @@ -319,6 +427,7 @@ impl ModelClient { summary, reasoning_summary_disabled: AtomicBool::new(false), websockets_disabled: AtomicBool::new(false), + websocket_session: Arc::new(TokioMutex::new(ResponsesWebsocketSession::default())), verbosity: effective_verbosity, debug_logger, } @@ -585,13 +694,20 @@ impl ModelClient { match self.provider.wire_api { WireApi::Responses => { if let Some(ws_version) = self.active_ws_version_for_prompt(prompt) { - match self - .stream_responses_websocket(prompt, log_tag, ws_version) + let ws_result = match self + .prewarm_responses_websocket_if_needed(prompt, log_tag, ws_version) .await { + Ok(()) => self + .stream_responses_websocket(prompt, log_tag, ws_version, false) + .await, + Err(err) => Err(err), + }; + match ws_result { Ok(stream) => Ok(stream), Err(err) => { self.websockets_disabled.store(true, Ordering::Relaxed); + self.reset_responses_websocket_session().await; warn!( "preferred websocket transport failed; falling back to responses HTTP stream: {err}" ); @@ -612,13 +728,20 @@ impl ModelClient { let ws_version = self .active_ws_version_for_prompt(prompt) .unwrap_or(preferred_ws_version_from_env()); - match self - .stream_responses_websocket(prompt, log_tag, ws_version) + let ws_result = match self + .prewarm_responses_websocket_if_needed(prompt, log_tag, ws_version) .await { + Ok(()) => self + .stream_responses_websocket(prompt, log_tag, ws_version, false) + .await, + Err(err) => Err(err), + }; + match ws_result { Ok(stream) => Ok(stream), Err(err) => { self.websockets_disabled.store(true, Ordering::Relaxed); + self.reset_responses_websocket_session().await; warn!( "responses_websocket transport failed; falling back to responses HTTP stream: {err}" ); @@ -678,11 +801,46 @@ impl ModelClient { } } + async fn reset_responses_websocket_session(&self) { + let mut session = self.websocket_session.lock().await; + *session = ResponsesWebsocketSession::default(); + } + + async fn prewarm_responses_websocket_if_needed( + &self, + prompt: &Prompt, + log_tag: Option<&str>, + ws_version: ResponsesWebsocketVersion, + ) -> Result<()> { + { + let session = self.websocket_session.lock().await; + if session.last_request.is_some() { + return Ok(()); + } + } + + let mut stream = self + .stream_responses_websocket(prompt, log_tag, ws_version, true) + .await?; + while let Some(event) = stream.next().await { + match event? { + ResponseEvent::Completed { .. } => return Ok(()), + _ => {} + } + } + Err(CodexErr::Stream( + "websocket prewarm ended before response.completed".to_string(), + None, + None, + )) + } + async fn stream_responses_websocket( &self, prompt: &Prompt, log_tag: Option<&str>, ws_version: ResponsesWebsocketVersion, + warmup: bool, ) -> Result { let auth_manager = self.auth_manager.clone(); let auth_mode = auth_manager @@ -751,7 +909,6 @@ impl ModelClient { let model_slug = request_model; let session_id = prompt.session_id_override.unwrap_or(self.session_id); let session_id_str = session_id.to_string(); - let turn_state: Arc> = Arc::new(OnceLock::new()); let mut attempt = 0; let max_retries = self.provider.request_max_retries(); let mut request_id = String::new(); @@ -851,6 +1008,10 @@ impl ModelClient { req_builder = attach_openai_subagent_header(req_builder); req_builder = attach_codex_beta_features_header(req_builder, &self.config); + let turn_state = { + let session = self.websocket_session.lock().await; + Arc::clone(&session.turn_state) + }; if let Some(state) = turn_state.get() { req_builder = req_builder.header(X_CODEX_TURN_STATE_HEADER, state); } @@ -907,220 +1068,294 @@ impl ModelClient { }), ); - // Wrap the normal /responses request payload in the WebSocket envelope. - let mut ws_payload = serde_json::Map::new(); - ws_payload.insert( - "type".to_string(), - serde_json::Value::String("response.create".to_string()), - ); - if let Some(obj) = payload_json.as_object() { - for (k, v) in obj { - ws_payload.insert(k.clone(), v.clone()); - } - } - let ws_payload_text = serde_json::to_string(&serde_json::Value::Object(ws_payload))?; - - let connect = timeout( - self.provider.websocket_connect_timeout(), - tokio_tungstenite::connect_async(ws_request), - ) - .await; - match connect { - Ok(Ok((mut ws_stream, response))) => { - let (tx_event, rx_event) = mpsc::channel::>(1600); - - let response_headers = header_map_to_json(response.headers()); - if tx_event - .send(Ok(ResponseEvent::ResponseHeaders(response_headers))) - .await - .is_err() - { - debug!("receiver dropped response headers event"); - } - - if let Some(value) = response - .headers() - .get(X_CODEX_TURN_STATE_HEADER) - .and_then(|value| value.to_str().ok()) - { - if let Some(existing) = turn_state.get() - && existing != value - { - warn!( - existing, - new = value, - "received unexpected x-codex-turn-state during websocket connect" - ); - } else { - let _ = turn_state.set(value.to_string()); + let current_snapshot = + responses_request_snapshot(&payload_json, &input_with_instructions); + let (input_override, previous_response_id) = if warmup { + (None, None) + } else { + let session = self.websocket_session.lock().await; + match (&session.last_request, &session.last_response_id) { + (Some(previous), Some(response_id)) => { + match incremental_input_for_websocket_request( + previous, + ¤t_snapshot, + ) { + Some(input) => (Some(input), Some(response_id.clone())), + None => (None, None), } } - - if let Some(snapshot) = parse_rate_limit_snapshot(response.headers()) { - debug!( - "rate limit headers:\n{}", - format_rate_limit_headers(response.headers()) - ); + _ => (None, None), + } + }; + let ws_payload_text = build_responses_websocket_payload( + &payload_json, + input_override, + previous_response_id, + warmup.then_some(false), + )?; + + let (tx_event, rx_event) = mpsc::channel::>(1600); + let mut session = self.websocket_session.lock().await; + if session.connection.is_none() { + let connect = timeout( + self.provider.websocket_connect_timeout(), + tokio_tungstenite::connect_async(ws_request), + ) + .await; + match connect { + Ok(Ok((ws_stream, response))) => { + let response_headers = header_map_to_json(response.headers()); if tx_event - .send(Ok(ResponseEvent::RateLimits(snapshot))) + .send(Ok(ResponseEvent::ResponseHeaders(response_headers))) .await .is_err() { - debug!("receiver dropped rate limit snapshot event"); + debug!("receiver dropped response headers event"); } - } - let models_etag = response - .headers() - .get("X-Models-Etag") - .and_then(|value| value.to_str().ok()) - .map(ToString::to_string); - if let Some(etag) = models_etag { - if tx_event - .send(Ok(ResponseEvent::ModelsEtag(etag))) - .await - .is_err() + if let Some(value) = response + .headers() + .get(X_CODEX_TURN_STATE_HEADER) + .and_then(|value| value.to_str().ok()) { - debug!("receiver dropped models etag event"); + if let Some(existing) = session.turn_state.get() + && existing != value + { + warn!( + existing, + new = value, + "received unexpected x-codex-turn-state during websocket connect" + ); + } else { + let _ = session.turn_state.set(value.to_string()); + } } - } - if response.headers().contains_key("x-reasoning-included") { - if tx_event - .send(Ok(ResponseEvent::ServerReasoningIncluded(true))) - .await - .is_err() - { - debug!("receiver dropped server reasoning included event"); + if let Some(snapshot) = parse_rate_limit_snapshot(response.headers()) { + debug!( + "rate limit headers:\n{}", + format_rate_limit_headers(response.headers()) + ); + if tx_event + .send(Ok(ResponseEvent::RateLimits(snapshot))) + .await + .is_err() + { + debug!("receiver dropped rate limit snapshot event"); + } } - } - ws_stream - .send(Message::Text(ws_payload_text)) - .await - .map_err(|err| { - CodexErr::Stream( - format!("[ws] failed to send websocket request: {err}"), - None, - Some(request_id.clone()), - ) - })?; - - // Keep websocket ingress bounded so a slow downstream consumer - // cannot cause unbounded buffering and memory growth. - let (tx_bytes, rx_bytes) = - mpsc::channel::>(RESPONSES_WEBSOCKET_INGRESS_BUFFER); - let request_id_for_ws = request_id.clone(); - let ws_reader_handle = tokio::spawn(async move { - loop { - let Some(next) = ws_stream.next().await else { - break; - }; - match next { - Ok(Message::Text(text)) => { - if let Some(error) = parse_wrapped_websocket_error_event(&text) - .and_then(map_wrapped_websocket_error_event) - { - let _ = tx_bytes.send(Err(error)).await; - break; - } + let models_etag = response + .headers() + .get("X-Models-Etag") + .and_then(|value| value.to_str().ok()) + .map(ToString::to_string); + if let Some(etag) = models_etag { + if tx_event + .send(Ok(ResponseEvent::ModelsEtag(etag))) + .await + .is_err() + { + debug!("receiver dropped models etag event"); + } + } - let chunk = format!("data: {text}\n\n"); - if tx_bytes.send(Ok(Bytes::from(chunk))).await.is_err() { - break; - } - } - Ok(Message::Ping(payload)) => { - if ws_stream.send(Message::Pong(payload)).await.is_err() { - break; - } - } - Ok(Message::Pong(_)) => {} - Ok(Message::Close(_)) => break, - Ok(Message::Binary(_)) => { - let _ = tx_bytes - .send(Err(CodexErr::Stream( - "[ws] unexpected binary websocket event".to_string(), - None, - Some(request_id_for_ws.clone()), - ))) - .await; - break; - } - Ok(_) => {} - Err(err) => { - let _ = tx_bytes - .send(Err(CodexErr::Stream( - format!("[ws] websocket error: {err}"), - None, - Some(request_id_for_ws.clone()), - ))) - .await; - break; - } + if response.headers().contains_key("x-reasoning-included") { + if tx_event + .send(Ok(ResponseEvent::ServerReasoningIncluded(true))) + .await + .is_err() + { + debug!("receiver dropped server reasoning included event"); } } - }); - let stream = ReceiverStream::new(rx_bytes); - let debug_logger = Arc::clone(&self.debug_logger); - let request_id_clone = request_id.clone(); - let otel_event_manager = self.otel_event_manager.clone(); - let stream_idle_timeout = self.provider.stream_idle_timeout(); - tokio::spawn(async move { - process_sse( - stream, - tx_event, - stream_idle_timeout, - debug_logger, - request_id_clone, - otel_event_manager, - Arc::new(RwLock::new(StreamCheckpoint::default())), - ) - .await; - // process_sse may finish before the server closes the websocket. - // Abort the websocket reader task to avoid lingering open sockets. - ws_reader_handle.abort(); - }); + session.connection = Some(ws_stream); + } + Ok(Err(err)) => { + drop(session); + if websocket_connect_is_upgrade_required(&err) { + self.websockets_disabled.store(true, Ordering::Relaxed); + warn!("responses websocket upgrade required; falling back to HTTP responses transport"); + return self.stream_responses(prompt, log_tag).await; + } - return Ok(ResponseStream { rx_event }); - } - Ok(Err(err)) => { - if websocket_connect_is_upgrade_required(&err) { + let err = CodexErr::Stream( + format!("[ws] failed to connect: {err}"), + None, + Some(request_id.clone()), + ); + if (attempt as u64) < max_retries { + tokio::time::sleep(backoff(attempt as u64)).await; + continue; + } self.websockets_disabled.store(true, Ordering::Relaxed); - warn!("responses websocket upgrade required; falling back to HTTP responses transport"); - return self.stream_responses(prompt, log_tag).await; + return Err(err); } - - let err = CodexErr::Stream( - format!("[ws] failed to connect: {err}"), - None, - Some(request_id.clone()), - ); - if (attempt as u64) < max_retries { - tokio::time::sleep(backoff(attempt as u64)).await; - continue; + Err(_) => { + drop(session); + let err = CodexErr::Stream( + format!( + "[ws] timed out connecting after {} ms", + self.provider.websocket_connect_timeout().as_millis() + ), + None, + Some(request_id.clone()), + ); + if (attempt as u64) < max_retries { + tokio::time::sleep(backoff(attempt as u64)).await; + continue; + } + self.websockets_disabled.store(true, Ordering::Relaxed); + return Err(err); } - self.websockets_disabled.store(true, Ordering::Relaxed); - return Err(err); } - Err(_) => { - let err = CodexErr::Stream( - format!( - "[ws] timed out connecting after {} ms", - self.provider.websocket_connect_timeout().as_millis() - ), - None, - Some(request_id.clone()), - ); - if (attempt as u64) < max_retries { - tokio::time::sleep(backoff(attempt as u64)).await; - continue; - } - self.websockets_disabled.store(true, Ordering::Relaxed); - return Err(err); + } + + let Some(ws_stream) = session.connection.as_mut() else { + return Err(CodexErr::Stream( + "[ws] websocket connection is closed".to_string(), + None, + Some(request_id.clone()), + )); + }; + if let Err(err) = ws_stream.send(Message::Text(ws_payload_text)).await { + session.connection = None; + session.last_request = None; + session.last_response_id = None; + let err = CodexErr::Stream( + format!("[ws] failed to send websocket request: {err}"), + None, + Some(request_id.clone()), + ); + if (attempt as u64) < max_retries { + drop(session); + tokio::time::sleep(backoff(attempt as u64)).await; + continue; } + self.websockets_disabled.store(true, Ordering::Relaxed); + return Err(err); } + drop(session); + + // Keep websocket ingress bounded so a slow downstream consumer + // cannot cause unbounded buffering and memory growth. The reader + // exits on the terminal response event and leaves the connection in + // the session for the next chained request. + let (tx_bytes, rx_bytes) = + mpsc::channel::>(RESPONSES_WEBSOCKET_INGRESS_BUFFER); + let request_id_for_ws = request_id.clone(); + let websocket_session = Arc::clone(&self.websocket_session); + tokio::spawn(async move { + let mut session = websocket_session.lock().await; + let Some(ws_stream) = session.connection.as_mut() else { + let _ = tx_bytes + .send(Err(CodexErr::Stream( + "[ws] websocket connection is closed".to_string(), + None, + Some(request_id_for_ws.clone()), + ))) + .await; + return; + }; + + loop { + let Some(next) = ws_stream.next().await else { + session.connection = None; + break; + }; + match next { + Ok(Message::Text(text)) => { + if let Some(error) = parse_wrapped_websocket_error_event(&text) + .and_then(map_wrapped_websocket_error_event) + { + session.connection = None; + session.last_request = None; + session.last_response_id = None; + let _ = tx_bytes.send(Err(error)).await; + break; + } + + let terminal_response_id = terminal_response_id_from_websocket_event(&text); + let chunk = format!("data: {text}\n\n"); + if tx_bytes.send(Ok(Bytes::from(chunk))).await.is_err() { + break; + } + if let Some(response_id) = terminal_response_id { + match response_id { + Some(response_id) if !response_id.is_empty() => { + session.last_request = Some(current_snapshot); + session.last_response_id = Some(response_id); + } + _ => { + session.last_request = None; + session.last_response_id = None; + } + } + break; + } + } + Ok(Message::Ping(payload)) => { + if ws_stream.send(Message::Pong(payload)).await.is_err() { + session.connection = None; + break; + } + } + Ok(Message::Pong(_)) => {} + Ok(Message::Close(_)) => { + session.connection = None; + break; + } + Ok(Message::Binary(_)) => { + session.connection = None; + session.last_request = None; + session.last_response_id = None; + let _ = tx_bytes + .send(Err(CodexErr::Stream( + "[ws] unexpected binary websocket event".to_string(), + None, + Some(request_id_for_ws.clone()), + ))) + .await; + break; + } + Ok(_) => {} + Err(err) => { + session.connection = None; + session.last_request = None; + session.last_response_id = None; + let _ = tx_bytes + .send(Err(CodexErr::Stream( + format!("[ws] websocket error: {err}"), + None, + Some(request_id_for_ws.clone()), + ))) + .await; + break; + } + } + } + }); + + let stream = ReceiverStream::new(rx_bytes); + let debug_logger = Arc::clone(&self.debug_logger); + let request_id_clone = request_id.clone(); + let otel_event_manager = self.otel_event_manager.clone(); + let stream_idle_timeout = self.provider.stream_idle_timeout(); + tokio::spawn(async move { + process_sse( + stream, + tx_event, + stream_idle_timeout, + debug_logger, + request_id_clone, + otel_event_manager, + Arc::new(RwLock::new(StreamCheckpoint::default())), + ) + .await; + }); + + return Ok(ResponseStream { rx_event }); } } @@ -3039,6 +3274,90 @@ mod tests { // Helpers // ──────────────────────────── + fn response_text_item(text: &str) -> ResponseItem { + ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![code_protocol::models::ContentItem::InputText { + text: text.to_string(), + }], + end_turn: None, + phase: None, + } + } + + #[test] + fn websocket_incremental_input_requires_matching_request_prefix() { + let first = response_text_item("one"); + let second = response_text_item("two"); + let payload = json!({ + "model": "gpt-5.5", + "instructions": "be useful", + "input": [first.clone()], + "stream": true + }); + let previous = responses_request_snapshot(&payload, std::slice::from_ref(&first)); + let current_payload = json!({ + "model": "gpt-5.5", + "instructions": "be useful", + "input": [first.clone(), second.clone()], + "stream": true + }); + let current = + responses_request_snapshot(¤t_payload, &[first.clone(), second.clone()]); + + assert_eq!( + incremental_input_for_websocket_request(&previous, ¤t), + Some(vec![second.clone()]) + ); + + let changed_payload = json!({ + "model": "gpt-5.5", + "instructions": "be terse", + "input": [first, second], + "stream": true + }); + let changed = responses_request_snapshot(&changed_payload, ¤t.input); + assert_eq!( + incremental_input_for_websocket_request(&previous, &changed), + None + ); + } + + #[test] + fn websocket_payload_adds_generate_and_previous_response_fields() { + let delta = response_text_item("follow up"); + let payload = json!({ + "model": "gpt-5.5", + "instructions": "be useful", + "input": [], + "stream": true, + "prompt_cache_key": "session-1" + }); + + let warmup = build_responses_websocket_payload(&payload, None, None, Some(false)) + .expect("warmup payload"); + let warmup_json: Value = serde_json::from_str(&warmup).expect("warmup json"); + assert_eq!(warmup_json["type"], "response.create"); + assert_eq!(warmup_json["generate"], false); + assert!(warmup_json.get("previous_response_id").is_none()); + + let chained = build_responses_websocket_payload( + &payload, + Some(vec![delta.clone()]), + Some("resp_previous".to_string()), + None, + ) + .expect("chained payload"); + let chained_json: Value = serde_json::from_str(&chained).expect("chained json"); + assert_eq!(chained_json["previous_response_id"], "resp_previous"); + assert_eq!( + chained_json["input"], + serde_json::to_value(vec![delta]).expect("delta input json") + ); + assert!(chained_json.get("generate").is_none()); + } + #[test] fn unauthorized_outcome_returns_permanent_error_for_permanent_refresh_failure() { let err = RefreshTokenError::permanent("token revoked");