From 84a8748b5bfc94795e9b32a64d6fe3ed060f6e5d Mon Sep 17 00:00:00 2001 From: Chris Busillo Date: Sun, 3 May 2026 15:50:13 -0400 Subject: [PATCH] fix(core): guard Responses WebSocket session state Refresh websocket turn state on reconnect, avoid empty chained deltas except after warmup, and wait for the reader task to acquire the session lock before returning the stream. --- code-rs/core/src/client.rs | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/code-rs/core/src/client.rs b/code-rs/core/src/client.rs index 90fd36f3a9c1..e561d3c1e0f5 100644 --- a/code-rs/core/src/client.rs +++ b/code-rs/core/src/client.rs @@ -23,8 +23,9 @@ use reqwest::header::HeaderValue; use serde::Deserialize; use serde::Serialize; use serde_json::Value; -use tokio::sync::mpsc; use tokio::sync::Mutex as TokioMutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::time::timeout; use tokio_util::io::ReaderStream; use tokio_stream::wrappers::ReceiverStream; @@ -131,6 +132,7 @@ struct ResponsesWebsocketSession { turn_state: Arc>, last_request: Option, last_response_id: Option, + last_response_from_warmup: bool, } impl std::fmt::Debug for ResponsesWebsocketSession { @@ -140,6 +142,7 @@ impl std::fmt::Debug for ResponsesWebsocketSession { .field("has_turn_state", &self.turn_state.get().is_some()) .field("has_last_request", &self.last_request.is_some()) .field("has_last_response_id", &self.last_response_id.is_some()) + .field("last_response_from_warmup", &self.last_response_from_warmup) .finish() } } @@ -1081,8 +1084,13 @@ impl ModelClient { previous, ¤t_snapshot, ) { - Some(input) => (Some(input), Some(response_id.clone())), + Some(input) + if !input.is_empty() || session.last_response_from_warmup => + { + (Some(input), Some(response_id.clone())) + } None => (None, None), + _ => (None, None), } } _ => (None, None), @@ -1125,8 +1133,11 @@ impl ModelClient { warn!( existing, new = value, - "received unexpected x-codex-turn-state during websocket connect" + "received new x-codex-turn-state during websocket connect" ); + let refreshed = Arc::new(OnceLock::new()); + let _ = refreshed.set(value.to_string()); + session.turn_state = refreshed; } else { let _ = session.turn_state.set(value.to_string()); } @@ -1224,6 +1235,7 @@ impl ModelClient { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; let err = CodexErr::Stream( format!("[ws] failed to send websocket request: {err}"), None, @@ -1247,8 +1259,10 @@ impl ModelClient { mpsc::channel::>(RESPONSES_WEBSOCKET_INGRESS_BUFFER); let request_id_for_ws = request_id.clone(); let websocket_session = Arc::clone(&self.websocket_session); + let (reader_ready_tx, reader_ready_rx) = oneshot::channel(); tokio::spawn(async move { let mut session = websocket_session.lock().await; + let _ = reader_ready_tx.send(()); let Some(ws_stream) = session.connection.as_mut() else { let _ = tx_bytes .send(Err(CodexErr::Stream( @@ -1265,6 +1279,7 @@ impl ModelClient { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; break; }; match next { @@ -1275,6 +1290,7 @@ impl ModelClient { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; let _ = tx_bytes.send(Err(error)).await; break; } @@ -1289,10 +1305,12 @@ impl ModelClient { Some(response_id) if !response_id.is_empty() => { session.last_request = Some(current_snapshot); session.last_response_id = Some(response_id); + session.last_response_from_warmup = warmup; } _ => { session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; } } break; @@ -1307,12 +1325,16 @@ impl ModelClient { Ok(Message::Pong(_)) => {} Ok(Message::Close(_)) => { session.connection = None; + session.last_request = None; + session.last_response_id = None; + session.last_response_from_warmup = false; break; } Ok(Message::Binary(_)) => { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; let _ = tx_bytes .send(Err(CodexErr::Stream( "[ws] unexpected binary websocket event".to_string(), @@ -1327,6 +1349,7 @@ impl ModelClient { session.connection = None; session.last_request = None; session.last_response_id = None; + session.last_response_from_warmup = false; let _ = tx_bytes .send(Err(CodexErr::Stream( format!("[ws] websocket error: {err}"), @@ -1339,6 +1362,7 @@ impl ModelClient { } } }); + let _ = reader_ready_rx.await; let stream = ReceiverStream::new(rx_bytes); let debug_logger = Arc::clone(&self.debug_logger);