diff --git a/Cargo.lock b/Cargo.lock index de40447..5c691c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5960,6 +5960,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.9" @@ -7185,6 +7191,7 @@ dependencies = [ "getrandom 0.3.4", "js-sys", "rand 0.9.2", + "sha1_smol", "wasm-bindgen", ] diff --git a/Cargo.toml b/Cargo.toml index 2c191ef..02d1016 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,9 +33,10 @@ lip_sync = { version = "0.1", git = "https://github.com/L-jasmine/lip_sync.git" rand = "0.9.0" uuid = { version = "1.14", features = [ "v4", # Lets you generate random UUIDs + "v5", # Lets you generate namespace-based UUIDs "fast-rng", ] } -bytes = "1.10.0" +bytes = "1.11.0" aho-corasick = "1.1.3" lazy-regex = "3.4.2" diff --git a/src/config.rs b/src/config.rs index dc4b69b..2de686a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -353,6 +353,12 @@ pub struct RecordConfig { pub callback_url: Option, } +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct EchokitCC { + pub url: String, + // pub output_optimization: TTSTextOptimizationConfig, +} + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[serde(untagged)] pub enum AIConfig { @@ -361,6 +367,11 @@ pub enum AIConfig { tts: TTSConfig, asr: ASRConfig, }, + Claude { + claude: EchokitCC, + asr: ASRConfig, + tts: TTSConfig, + }, GeminiAndTTS { gemini: GeminiConfig, tts: TTSConfig, diff --git a/src/main.rs b/src/main.rs index eea9bca..9ccf0f7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use axum::{ Router, @@ -163,6 +163,27 @@ async fn routes( } }); } + config::AIConfig::Claude { claude, asr, tts } => { + let session = Arc::new(RwLock::new(Default::default())); + let session_ = session.clone(); + + tokio::spawn(async move { + if let Err(e) = crate::services::ws::stable::claude::run_session_manager( + &tts, &asr, &claude, rx, session, + ) + .await + { + log::error!("Claude session manager exited with error: {}", e); + } + }); + + router = router + .route( + "/proxy/state/{id}", + get(services::ws::stable::claude::has_notification), + ) + .layer(axum::Extension(session_)); + } } router = router diff --git a/src/protocol.rs b/src/protocol.rs index f78e1bf..13edc5f 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -9,8 +9,10 @@ pub enum ServerEvent { ASR { text: String }, Action { action: String }, + Choices { message: String, items: Vec }, StartAudio { text: String }, AudioChunk { data: Vec }, + DisplayText { text: String }, AudioChunkWithVowel { data: Vec, vowel: u8 }, EndAudio, StartVideo, @@ -47,6 +49,7 @@ pub enum ClientCommand { StartChat, Submit, Text { input: String }, + Select { index: usize }, } #[test] diff --git a/src/services/ws.rs b/src/services/ws.rs index 2c9cf0a..622c6b9 100644 --- a/src/services/ws.rs +++ b/src/services/ws.rs @@ -25,6 +25,9 @@ pub enum WsCommand { Video(Vec>), EndResponse, EndVad, + Choices(String, Vec), + DisplayText(String), + Close, } type WsTx = tokio::sync::mpsc::UnboundedSender; type WsRx = tokio::sync::mpsc::UnboundedReceiver; @@ -113,6 +116,7 @@ pub enum ClientMsg { AudioChunk(Bytes), Submit, Text(String), + Select(usize), } pub struct ConnectConfig { @@ -151,6 +155,11 @@ async fn process_socket_io( match r { Some(WsEvent::Command(cmd)) => { + if matches!(cmd, WsCommand::Close) { + log::info!("Received Close command, closing websocket"); + return Ok(()); + } + if config.enable_opus { process_command_with_opus( socket, @@ -184,6 +193,12 @@ async fn process_socket_io( .send(ClientMsg::Text(input)) .await .map_err(|_| anyhow::anyhow!("audio_tx closed"))?, + ProcessMessageResult::Select(index) => { + audio_tx + .send(ClientMsg::Select(index)) + .await + .map_err(|_| anyhow::anyhow!("audio_tx closed"))?; + } ProcessMessageResult::Skip => {} ProcessMessageResult::StartChat => { audio_tx @@ -285,6 +300,18 @@ async fn process_command(ws: &mut WebSocket, cmd: WsCommand) -> anyhow::Result<( ws.send(Message::binary(audio_chunk)).await?; } } + WsCommand::Choices(message, items) => { + let choices = + rmp_serde::to_vec(&crate::protocol::ServerEvent::Choices { message, items }) + .expect("Failed to serialize Choices ServerEvent"); + ws.send(Message::binary(choices)).await?; + } + WsCommand::DisplayText(text) => { + let display_text = + rmp_serde::to_vec(&crate::protocol::ServerEvent::DisplayText { text }) + .expect("Failed to serialize DisplayText ServerEvent"); + ws.send(Message::binary(display_text)).await?; + } WsCommand::EndAudio => { log::trace!("EndAudio"); let end_audio = rmp_serde::to_vec(&crate::protocol::ServerEvent::EndAudio) @@ -306,6 +333,7 @@ async fn process_command(ws: &mut WebSocket, cmd: WsCommand) -> anyhow::Result<( .expect("Failed to serialize EndVad ServerEvent"); ws.send(Message::binary(end_vad)).await?; } + WsCommand::Close => {} } Ok(()) } @@ -341,12 +369,23 @@ async fn process_command_with_opus( .expect("Failed to serialize ASR ServerEvent"); ws.send(Message::binary(asr)).await?; } - WsCommand::Action { action } => { let action = rmp_serde::to_vec(&crate::protocol::ServerEvent::Action { action }) .expect("Failed to serialize Action ServerEvent"); ws.send(Message::binary(action)).await?; } + WsCommand::Choices(message, items) => { + let choices = + rmp_serde::to_vec(&crate::protocol::ServerEvent::Choices { message, items }) + .expect("Failed to serialize Choices ServerEvent"); + ws.send(Message::binary(choices)).await?; + } + WsCommand::DisplayText(text) => { + let display_text = + rmp_serde::to_vec(&crate::protocol::ServerEvent::DisplayText { text }) + .expect("Failed to serialize DisplayText ServerEvent"); + ws.send(Message::binary(display_text)).await?; + } WsCommand::StartAudio(text) => { log::trace!("StartAudio: {text:?}"); opus_encode @@ -453,6 +492,7 @@ async fn process_command_with_opus( .expect("Failed to serialize EndVad ServerEvent"); ws.send(Message::binary(end_vad)).await?; } + WsCommand::Close => {} } Ok(()) } @@ -461,6 +501,7 @@ enum ProcessMessageResult { Audio(Bytes), Submit, Text(String), + Select(usize), StartChat, Close, Skip, @@ -478,13 +519,16 @@ fn process_message(msg: Message) -> ProcessMessageResult { crate::protocol::ClientCommand::Text { input } => { ProcessMessageResult::Text(input) } + crate::protocol::ClientCommand::Select { index } => { + ProcessMessageResult::Select(index) + } } } else { ProcessMessageResult::Skip } } Message::Binary(d) => { - log::debug!("Received binary message of size: {}", d.len()); + log::trace!("Received binary message of size: {}", d.len()); ProcessMessageResult::Audio(d) } Message::Close(c) => { diff --git a/src/services/ws/stable/asr.rs b/src/services/ws/stable/asr.rs index 22bb22f..b76dd36 100644 --- a/src/services/ws/stable/asr.rs +++ b/src/services/ws/stable/asr.rs @@ -98,6 +98,7 @@ impl WhisperASRSession { vad_started |= self.vad_session.detect(&audio_chunk)?; } } + ClientMsg::Select(..) => {} } } } @@ -205,6 +206,7 @@ impl WhisperASRSession { log::warn!("`{id}` received a Unexpected Submit during Stream ASR"); return Err(anyhow::anyhow!("Unexpected Submit during Stream ASR")); } + ClientMsg::Select(..) => {} } } @@ -384,6 +386,7 @@ impl ParaformerASRSession { continue; } + ClientMsg::Select(..) => {} } } @@ -518,6 +521,7 @@ impl ParaformerASRSession { } start_submit = true; } + ClientMsg::Select(..) => {} } } else { log::warn!("`{}` client rx channel closed unexpectedly", session.id); diff --git a/src/services/ws/stable/claude.rs b/src/services/ws/stable/claude.rs new file mode 100644 index 0000000..2084bc5 --- /dev/null +++ b/src/services/ws/stable/claude.rs @@ -0,0 +1,1030 @@ +use std::{ + collections::{HashMap, LinkedList}, + sync::{Arc, RwLock, atomic::AtomicBool}, +}; + +use axum::{Extension, extract::Path, response::IntoResponse}; + +use crate::{ + config::{ASRConfig, EchokitCC, TTSConfig}, + services::ws::{ClientMsg, stable::tts::TTSRequestTx}, +}; + +use super::Session; + +#[derive(Default, Clone)] +pub struct ClaudeNotification { + notification: Arc, +} + +impl ClaudeNotification { + pub fn mark(&self) { + self.notification + .store(true, std::sync::atomic::Ordering::Relaxed); + } + + pub fn value(&self) -> bool { + self.notification.load(std::sync::atomic::Ordering::Relaxed) + } + + pub fn clear(&self) { + self.notification + .store(false, std::sync::atomic::Ordering::Relaxed); + } +} + +#[derive(Default)] +pub struct ClaudeNotifications { + pub sessions: HashMap, +} + +async fn get_input( + session: &mut Session, + asr_session: &mut super::asr::AsrSession, +) -> anyhow::Result { + loop { + log::info!( + "{}:{:x} waiting for asr input", + session.id, + session.request_id + ); + let text = if session.stream_asr { + match asr_session.stream_get_input(session).await { + Ok(t) => t, + Err(e) => { + log::error!( + "{}:{:x} error getting asr input: {}", + session.id, + session.request_id, + e + ); + session.send_end_vad().map_err(|_| { + anyhow::anyhow!( + "{}:{:x} error sending end vad ws command after asr error", + session.id, + session.request_id + ) + })?; + return Err(e); + } + } + } else { + asr_session + .get_input(&session.id, &mut session.client_rx) + .await? + }; + if text.is_empty() { + log::info!( + "{}:{:x} empty asr result, ending session", + session.id, + session.request_id + ); + + session.send_end_response().map_err(|_| { + anyhow::anyhow!( + "{}:{:x} error sending end response ws command for empty asr result", + session.id, + session.request_id + ) + })?; + + continue; + } else { + log::info!( + "{}:{:x} asr result: {}", + session.id, + session.request_id, + text + ); + session.send_asr_result(vec![text.clone()]).map_err(|_| { + anyhow::anyhow!( + "{}:{:x} error sending asr result ws command for message `{}`", + session.id, + session.request_id, + text + ) + })?; + return Ok(text); + } + } +} + +async fn get_choice(session: &mut Session) -> anyhow::Result { + while let Some(evt) = session.client_rx.recv().await { + if let ClientMsg::Select(choice) = evt { + return Ok(choice); + } + log::debug!( + "{}:{:x} ignoring non-select client message during select prompt", + session.id, + session.request_id + ); + } + Err(anyhow::anyhow!( + "client disconnected before making a choice" + )) +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct AskUserQuestionToolArgs { + questions: Vec, +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct AskUserQuestion { + header: String, + question: String, + options: Vec, +} + +#[derive(serde::Serialize, serde::Deserialize)] +struct AskUserQuestionItem { + label: String, + description: String, +} + +struct RunSessionState { + cc_session: cc_session::ClaudeSession, + session: Session, + rx: tokio::sync::mpsc::UnboundedReceiver, + notify: ClaudeNotification, +} + +enum RunSessionSelectResult { + Session(Option), + ClientMsg(Option), + ClaudeMsg(Option), +} + +enum SendStateError { + ClaudeError, + ClientError, +} + +impl RunSessionState { + async fn recv(&mut self) -> anyhow::Result { + async fn recv_client_msg(session: &mut Session) -> Option { + struct PendingClientMsg; + impl Future for PendingClientMsg { + type Output = Option; + + fn poll( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + std::task::Poll::Pending + } + } + + if session.client_rx.is_closed() { + PendingClientMsg.await + } else { + session.client_rx.recv().await + } + } + + let r = tokio::select! { + new_session = self.rx.recv() => { + RunSessionSelectResult::Session(new_session) + } + client_msg = recv_client_msg(&mut self.session) => { + RunSessionSelectResult::ClientMsg(client_msg) + } + claude_msg = self.cc_session.receive_message() => { + RunSessionSelectResult::ClaudeMsg(claude_msg?) + } + }; + + Ok(r) + } + + async fn send_input(&mut self, input: &str) -> anyhow::Result<()> { + self.cc_session + .send_message(&cc_session::WsInputMessage::Input { + input: input.to_string(), + }) + .await + } + + async fn send_display(&mut self, output: &str) -> anyhow::Result<()> { + self.session.send_display_text(output.to_string())?; + Ok(()) + } + + async fn send_output_with_tts( + &mut self, + output: &str, + tts_req_tx: &TTSRequestTx, + ) -> anyhow::Result<()> { + let mut text_splitter = crate::ai::TextSplitter::new(); + text_splitter.push_chunk(&output); + + let finished_output = text_splitter.finish(); + + let mut rx_list = LinkedList::new(); + for chunk in finished_output { + let (tts_response_tx, tts_response_rx) = tokio::sync::mpsc::unbounded_channel(); + if let Err(e) = tts_req_tx.send((chunk.to_string(), tts_response_tx)).await { + log::error!( + "{}:{:x} error sending tts request: {}", + self.session.id, + self.session.request_id, + e + ); + } else { + rx_list.push_back((chunk, tts_response_rx)); + } + } + + for (text_chunk, mut tts_response_rx) in rx_list { + self.session.send_start_audio(text_chunk)?; + while let Some(chunk) = tts_response_rx.recv().await { + self.session.send_audio_chunk(chunk)?; + } + self.session.send_end_audio()?; + } + + self.session.send_end_response()?; + Ok(()) + } + + async fn send_confirm(&mut self) -> anyhow::Result<()> { + self.cc_session + .send_message(&cc_session::WsInputMessage::Confirm {}) + .await + } + + async fn send_select(&mut self, index: usize) -> anyhow::Result<()> { + self.cc_session + .send_message(&cc_session::WsInputMessage::Select { index }) + .await + } + + async fn send_cancel(&mut self) -> anyhow::Result<()> { + self.cc_session + .send_message(&cc_session::WsInputMessage::Cancel {}) + .await + } + + async fn sync_cc_state(&mut self) -> anyhow::Result<()> { + self.cc_session + .send_message(&cc_session::WsInputMessage::CurrentState {}) + .await + } + + fn set_session(&mut self, session: Session) { + self.session = session; + } + + async fn send_self_state( + &mut self, + tts_req_tx: &mut TTSRequestTx, + asr_session: &mut super::asr::AsrSession, + ) -> Result<(), SendStateError> { + let state = self.cc_session.state.clone(); + log::debug!( + "{}:{:x} sending self state: {:?}", + self.session.id, + self.session.request_id, + state + ); + match state { + cc_session::ClaudeCodeState::Output { + output, + is_thinking, + } => { + if is_thinking { + log::debug!( + "{}:{:x} sending thinking output", + self.session.id, + self.session.request_id + ); + let _ = self.send_display(&output).await; + } else { + // self.cc_session.state = cc_session::ClaudeCodeState::Idle; + + if let Err(e) = self.send_output_with_tts(&output, tts_req_tx).await { + log::warn!( + "{}:{:x} error sending tts output: {}", + self.session.id, + self.session.request_id, + e + ); + } + + self.cc_session.last_output = output; + } + } + cc_session::ClaudeCodeState::Idle => { + if !self.cc_session.last_output.is_empty() { + if let Err(e) = self + .session + .send_display_text(self.cc_session.last_output.clone()) + { + log::warn!( + "{}:{:x} error sending display output: {}", + self.session.id, + self.session.request_id, + e + ); + return Err(SendStateError::ClientError); + } + } + + log::info!( + "{}:{:x} waiting for user input", + self.session.id, + self.session.request_id + ); + + match self.wait_input(asr_session).await { + Ok(text) => { + let _ = self.send_input(&text).await; + tokio::time::sleep(std::time::Duration::from_millis(300)).await; + let _ = self.send_confirm().await; + self.cc_session.state = cc_session::ClaudeCodeState::Idle; + } + Err(e) => { + log::warn!( + "{}:{:x} error getting input: {}", + self.session.id, + self.session.request_id, + e + ); + return Err(SendStateError::ClientError); + } + } + } + cc_session::ClaudeCodeState::PreUseTool { + request, + is_pending, + } => { + if is_pending { + for (i, tool) in request.into_iter().enumerate() { + if tool.done { + continue; + } + match self.wait_tool_use_choice(tool.name, tool.input).await { + Ok(-1) => { + self.send_cancel().await.map_err(|e| { + log::error!( + "{}:{:x} error sending tool use cancel: {}", + self.session.id, + self.session.request_id, + e + ); + SendStateError::ClaudeError + })?; + } + Ok(n) => { + self.send_select(n as usize).await.map_err(|e| { + log::error!( + "{}:{:x} error sending tool use confirm: {}", + self.session.id, + self.session.request_id, + e + ); + SendStateError::ClaudeError + })?; + if let cc_session::ClaudeCodeState::PreUseTool { request, .. } = + &mut self.cc_session.state + { + request[i].submited = true; + } + + break; + } + Err(e) => { + log::warn!( + "{}:{:x} error sending tool use choice prompt: {}", + self.session.id, + self.session.request_id, + e + ); + return Err(SendStateError::ClientError); + } + } + } + } + } + cc_session::ClaudeCodeState::StopUseTool { is_error } => { + if is_error { + let _ = self.send_display(&"Tool use stopped with error.").await; + self.session.send_end_response().map_err(|e| { + log::error!( + "{}:{:x} error sending end response after tool use stop: {}", + self.session.id, + self.session.request_id, + e + ); + SendStateError::ClientError + })?; + } else { + let _ = self.send_display(&"Tool use completed successfully.").await; + } + } + }; + + Ok(()) + } + + async fn wait_input( + &mut self, + asr_session: &mut super::asr::AsrSession, + ) -> anyhow::Result { + loop { + log::debug!( + "{}:{:x} waiting for user input confirmation", + self.session.id, + self.session.request_id + ); + let text = get_input(&mut self.session, asr_session).await?; + self.session.send_choice_prompt( + text.clone(), + vec!["Confirm".to_string(), "Cancel".to_string()], + )?; + log::debug!( + "{}:{:x} waiting for user input choice", + self.session.id, + self.session.request_id + ); + let choice = get_choice(&mut self.session).await?; + match choice { + 0 => { + return Ok(text); + } + _ => { + self.session + .send_notify("Input cancelled, please provide input again".to_string())?; + self.session.send_end_response()?; + continue; + } + } + } + } + + async fn wait_tool_use_choice( + &mut self, + tool_name: String, + tool_args: serde_json::Value, + ) -> anyhow::Result { + if tool_name == "AskUserQuestion" { + let tool_args = serde_json::from_value::(tool_args); + if tool_args.is_err() { + self.session.send_notify(format!( + "Claude requested to use tool `AskUserQuestion` with invalid args: {}", + tool_args.err().unwrap() + ))?; + return Ok(-1); + } else { + let tool_args = tool_args.unwrap(); + for question in tool_args.questions { + let options: Vec = question + .options + .iter() + .map(|item| format!("{}: {}", item.label, item.description)) + .collect(); + self.session.send_choice_prompt( + format!( + "{}\n{}\nPlease select one of the following options:", + question.header, question.question + ), + options, + )?; + let choice = get_choice(&mut self.session).await?; + return Ok(choice as i32); + } + Ok(-1) + } + } else { + let tool_args_string = if let serde_json::Value::String(ref s) = tool_args { + s.clone() + } else if let serde_json::Value::Object(ref map) = tool_args { + map.iter() + .map(|(k, v)| format!("{}: {}", k, v)) + .collect::>() + .join("\n") + } else { + tool_args.to_string() + }; + self.session.send_choice_prompt( + format!( + "Claude is requesting to use tool `{}` \n with args:\n{}", + tool_name, tool_args_string + ), + vec!["Confirm".to_string(), "Cancel".to_string()], + )?; + match get_choice(&mut self.session).await? { + 0 => Ok(0), + _ => Ok(-1), + } + } + } +} + +fn update_state( + cc_state: &mut cc_session::ClaudeCodeState, + new_state: cc_session::ClaudeCodeState, +) -> bool { + match (cc_state, new_state) { + ( + cc_session::ClaudeCodeState::PreUseTool { + request, + is_pending, + }, + cc_session::ClaudeCodeState::PreUseTool { + request: new_request, + is_pending: new_is_pending, + }, + ) => { + if *is_pending != new_is_pending { + *request = new_request; + *is_pending = new_is_pending; + return true; + } + + if request.len() != new_request.len() { + *request = new_request; + return true; + } + + for (r, mut nr) in request.iter_mut().zip(new_request.into_iter()) { + if r.submited && !nr.done { + log::debug!( + "Received PreUseTool state without done tool after submited\n {r:?}\n vs\n {nr:?}" + ); + return false; + } + nr.submited = r.submited; + *r = nr; + } + return true; + } + (cc_state, new_state) => { + if cc_state != &new_state { + *cc_state = new_state; + return true; + } + } + } + false +} + +async fn run_session( + id: uuid::Uuid, + url: &str, + tts_req_tx: &mut TTSRequestTx, + asr_session: &mut super::asr::AsrSession, + notify: ClaudeNotification, + mut rx: tokio::sync::mpsc::UnboundedReceiver, +) -> anyhow::Result<()> { + use cc_session::*; + + let mut cc_session = ClaudeSession::new(id.to_string(), url) + .await + .map_err(|e| anyhow::anyhow!("error creating claude session for id `{}`: {}", id, e))?; + + cc_session + .send_message(&WsInputMessage::CreateSession {}) + .await + .map_err(|e| { + anyhow::anyhow!( + "error sending create session message for id `{}`: {}", + id, + e + ) + })?; + + let session = rx.recv().await.ok_or_else(|| { + anyhow::anyhow!( + "session channel closed before receiving session for id `{}`", + id + ) + })?; + + let mut run_session_state = RunSessionState { + cc_session, + session, + rx, + notify, + }; + + loop { + log::debug!("Claude session {} waiting for events", id); + let r = run_session_state.recv().await?; + + match r { + RunSessionSelectResult::ClaudeMsg(Some(log)) => { + log::debug!("Claude session {} received message: {:?}", id, log); + match log { + WsOutputMessage::SessionPtyOutput { .. } => { + continue; + } + WsOutputMessage::SessionEnded { session_id } => { + log::warn!("Claude session {} ended by server", session_id); + return Ok(()); + } + WsOutputMessage::SessionIdle { session_id } => { + log::info!("Claude session {} is idle", session_id); + if run_session_state.cc_session.state != cc_session::ClaudeCodeState::Idle { + run_session_state.cc_session.state = cc_session::ClaudeCodeState::Idle; + } else { + continue; + } + } + WsOutputMessage::SessionState { + session_id: _, + current_state, + } => { + log::debug!( + "Claude session {} received state update: {:?}", + id, + current_state + ); + if !update_state(&mut run_session_state.cc_session.state, current_state) { + log::debug!("Claude session {} state unchanged after update", id,); + continue; + } + } + WsOutputMessage::SessionError { session_id, code } => { + log::error!( + "Claude session {} received error from server: {:?}", + session_id, + code + ); + return Err(anyhow::anyhow!( + "claude session error for id `{}`: {:?}", + id, + code + )); + } + } + + match run_session_state + .send_self_state(tts_req_tx, asr_session) + .await + { + Ok(_) => {} + Err(SendStateError::ClientError) => { + if !matches!( + run_session_state.cc_session.state, + cc_session::ClaudeCodeState::Idle + ) { + run_session_state.notify.mark(); + } else { + log::info!( + "Claude session {} client disconnected during idle state, ending session", + id + ); + } + } + Err(SendStateError::ClaudeError) => { + return Err(anyhow::anyhow!( + "claude session error for id `{}` during state send", + id + )); + } + } + } + RunSessionSelectResult::ClaudeMsg(None) => { + log::warn!("Claude session {} closed by server", id); + cc_session = ClaudeSession::new(id.to_string(), url).await.map_err(|e| { + anyhow::anyhow!("error recreating claude session for id `{}`: {}", id, e) + })?; + + cc_session + .send_message(&WsInputMessage::CreateSession {}) + .await + .map_err(|e| { + anyhow::anyhow!( + "error sending create session message for id `{}`: {}", + id, + e + ) + })?; + } + RunSessionSelectResult::ClientMsg(Some(_)) => { + let _ = run_session_state.session.send_end_vad(); + let _ = run_session_state.session.send_end_response(); + } + RunSessionSelectResult::ClientMsg(None) => { + log::warn!("Claude session {} client disconnected", id); + } + RunSessionSelectResult::Session(Some(new_session)) => { + log::info!("Claude session {} switching to new session", id); + run_session_state.set_session(new_session); + match run_session_state + .send_self_state(tts_req_tx, asr_session) + .await + { + Ok(_) => { + run_session_state.notify.clear(); + } + Err(SendStateError::ClientError) => { + run_session_state.notify.mark(); + } + Err(SendStateError::ClaudeError) => { + return Err(anyhow::anyhow!( + "claude session error for id `{}` during state send", + id + )); + } + } + } + RunSessionSelectResult::Session(None) => { + log::error!("Claude session {} session channel closed", id); + return Err(anyhow::anyhow!("session channel closed for id `{}`", id)); + } + } + } +} + +const NAMESPACE: uuid::Uuid = uuid::uuid!("8e1f6eb8-d389-4e62-9cfd-f1964e499c25"); // Namespace UUID for generating session UUIDs + +pub async fn run_session_manager( + tts: &TTSConfig, + asr: &ASRConfig, + claude: &EchokitCC, + mut session_rx: tokio::sync::mpsc::UnboundedReceiver, + notifications: Arc>, +) -> anyhow::Result<()> { + let mut tts_session_pool = super::tts::TTSSessionPool::new(tts.clone(), 4); + let (tts_req_tx, tts_req_rx) = tokio::sync::mpsc::channel(128); + + let mut sessions: HashMap< + String, + ( + tokio::sync::mpsc::UnboundedSender, + crate::services::ws::WsTx, + uuid::Uuid, + ), + > = HashMap::new(); + + tokio::spawn(async move { + if let Err(e) = tts_session_pool.run_loop(tts_req_rx).await { + log::error!("tts session pool exit by error: {}", e); + } + }); + + while let Some(session) = session_rx.recv().await { + let (session, cmd_tx, session_id) = + if let Some((tx, cmd_tx, id)) = sessions.get_mut(&session.id) { + let _ = cmd_tx.send(crate::services::ws::WsCommand::Close); + + let new_tx = session.cmd_tx.clone(); + + if let Err(e) = tx.send(session) { + (e.0, new_tx, id.clone()) + } else { + let _ = cmd_tx.send(crate::services::ws::WsCommand::Close); + *cmd_tx = new_tx; + + continue; + } + } else { + let cmd_tx = session.cmd_tx.clone(); + let id = session.id.clone(); + ( + session, + cmd_tx, + uuid::Uuid::new_v5(&NAMESPACE, id.as_bytes()), + ) + }; + + // start new session + + let notify = { + let notify = ClaudeNotification::default(); + let mut notifications_lock = notifications.write().unwrap(); + notifications_lock + .sessions + .insert(session.id.clone(), notify.clone()); + notify + }; + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + + let id = session.id.clone(); + log::info!("Starting new session for id: {}", id); + let _ = tx.send(session); + + sessions.insert(id.clone(), (tx, cmd_tx, session_id)); + + // run session + + let asr = asr.clone(); + + let mut tts_req_tx = tts_req_tx.clone(); + + let url = claude.url.clone(); + + tokio::spawn(async move { + let mut asr_session = super::asr::AsrSession::new_from_config(&asr) + .await + .map_err(|e| { + log::error!("error creating asr session for id `{}`: {}", id, e); + anyhow::anyhow!("error creating asr session for id `{}`: {}", id, e) + })?; + + if let Err(e) = run_session( + session_id, + &url, + &mut tts_req_tx, + &mut asr_session, + notify, + rx, + ) + .await + { + log::error!("session `{}` exited with error: {}", id, e); + } + + anyhow::Result::<()>::Ok(()) + }); + } + log::warn!("session manager exiting"); + Ok(()) +} + +pub async fn has_notification( + Extension(sessions): Extension>>, + Path(id): Path, +) -> impl IntoResponse { + let state = { + let sessions_lock = sessions.read().unwrap(); + sessions_lock.sessions.get(&id).map_or(false, |n| n.value()) + }; + + axum::Json(serde_json::json!({ "has_notification": state })) +} + +mod cc_session { + use futures_util::{SinkExt, StreamExt}; + use reqwest_websocket::{RequestBuilderExt, WebSocket}; + + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + #[serde(tag = "type")] + pub enum WsInputMessage { + #[serde(alias = "create_session")] + CreateSession {}, + #[serde(alias = "get_current_state")] + CurrentState {}, + #[serde(alias = "input")] + Input { input: String }, + #[serde(alias = "cancel")] + Cancel {}, + #[serde(alias = "confirm")] + Confirm {}, + #[serde(alias = "select")] + Select { index: usize }, + } + + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + #[serde(tag = "error_code")] + pub enum WsOutputError { + #[serde(rename = "session_not_found")] + SessionNotFound, + #[serde(rename = "invalid_input")] + InvalidInput { + error_message: String, + }, + #[serde(rename = "invalid_input_for_state")] + InvalidInputForState { + error_state: String, + error_input: String, + }, + InternalError { + error_message: String, + }, + } + + #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] + pub struct UseTool { + pub id: String, + pub name: String, + pub input: serde_json::Value, + pub done: bool, + #[serde(default)] + pub submited: bool, + } + + #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] + #[serde(tag = "state")] + pub enum ClaudeCodeState { + PreUseTool { + request: Vec, + is_pending: bool, + }, + Output { + output: String, + is_thinking: bool, + }, + StopUseTool { + is_error: bool, + }, + Idle, + } + + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] + #[serde(tag = "type")] + pub enum WsOutputMessage { + #[serde(rename = "session_pty_output")] + SessionPtyOutput { output: String }, + #[serde(rename = "session_ended")] + SessionEnded { session_id: String }, + #[serde(rename = "session_idle")] + SessionIdle { session_id: String }, + #[serde(rename = "session_state")] + SessionState { + session_id: String, + current_state: ClaudeCodeState, + }, + #[serde(rename = "session_error")] + SessionError { + session_id: String, + #[serde(flatten)] + code: WsOutputError, + }, + } + + pub struct ClaudeSession { + pub id: String, + pub socket: WebSocket, + pub state: ClaudeCodeState, + pub last_output: String, + } + + impl ClaudeSession { + pub async fn new(id: String, url: &str) -> anyhow::Result { + let url = format!("{}/{}", url.trim_end_matches('/'), id); + log::info!("Connecting to Claude WebSocket at {}", url); + + let client = reqwest::Client::new(); + let response = client.get(url).upgrade().send().await?; + + let websocket = response.into_websocket().await?; + + Ok(Self { + id: id.to_string(), + socket: websocket, + state: ClaudeCodeState::Output { + output: String::new(), + is_thinking: false, + }, + last_output: String::new(), + }) + } + + pub async fn send_message(&mut self, message: &WsInputMessage) -> anyhow::Result<()> { + let msg_text = serde_json::to_string(message)?; + self.socket + .send(reqwest_websocket::Message::Text(msg_text)) + .await?; + Ok(()) + } + + pub async fn receive_message(&mut self) -> anyhow::Result> { + loop { + let msg = self.socket.next().await; + + if msg.is_none() { + return Ok(None); + } + + let msg = msg.unwrap()?; + + match msg { + reqwest_websocket::Message::Text(s) => { + let output_msg: WsOutputMessage = serde_json::from_str(&s)?; + if matches!(output_msg, WsOutputMessage::SessionPtyOutput { .. }) { + log::trace!( + "Received pty output message for session {}: {:?}", + self.id, + output_msg + ); + continue; + } + return Ok(Some(output_msg)); + } + reqwest_websocket::Message::Binary(bytes) => { + log::warn!( + "Received unexpected binary message for session {}: {} bytes", + self.id, + bytes.len() + ); + } + reqwest_websocket::Message::Close { code, reason } => { + log::info!( + "WebSocket closed for session {}: code={:?}, reason={}", + self.id, + code, + reason + ); + return Ok(None); + } + _ => {} + } + } + } + } +} diff --git a/src/services/ws/stable/gemini.rs b/src/services/ws/stable/gemini.rs index 7b89b77..1e30857 100644 --- a/src/services/ws/stable/gemini.rs +++ b/src/services/ws/stable/gemini.rs @@ -265,6 +265,7 @@ async fn run_session( .send_realtime_input(gemini::types::RealtimeInput::Text(input)) .await?; } + ClientMsg::Select(_) => {} }, GeminiEvent::ServerEvent(server_content) => match server_content { gemini::types::ServerContent::ModelTurn(turn) => { @@ -437,6 +438,7 @@ async fn run_session_with_tts( .send_realtime_input(gemini::types::RealtimeInput::Text(input)) .await?; } + ClientMsg::Select(_) => {} }, GeminiEvent::ServerEvent(server_content) => match server_content { gemini::types::ServerContent::ModelTurn(turn) => { diff --git a/src/services/ws/stable/llm.rs b/src/services/ws/stable/llm.rs index 66614de..be70f3a 100644 --- a/src/services/ws/stable/llm.rs +++ b/src/services/ws/stable/llm.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use lazy_regex::regex; use crate::ai::{ - llm::Content, ChatSession, LLMResponsesChunk, ResponsesSession, StableLLMResponseChunk, + ChatSession, LLMResponsesChunk, ResponsesSession, StableLLMResponseChunk, llm::Content, }; pub type ChunksTx = tokio::sync::mpsc::UnboundedSender<(String, super::tts::TTSResponseRx)>; diff --git a/src/services/ws/stable/mod.rs b/src/services/ws/stable/mod.rs index 9e6d547..0b01237 100644 --- a/src/services/ws/stable/mod.rs +++ b/src/services/ws/stable/mod.rs @@ -19,6 +19,7 @@ use crate::{ }; mod asr; +pub mod claude; pub mod gemini; mod llm; mod tts; @@ -135,6 +136,52 @@ impl Session { }) } + pub fn send_choice_prompt(&self, message: String, choices: Vec) -> anyhow::Result<()> { + self.cmd_tx + .send(super::WsCommand::Choices(message, choices)) + .map_err(|_| { + anyhow::anyhow!( + "{}:{:x} error sending choice prompt ws command", + self.id, + self.request_id + ) + }) + } + + pub fn send_notify(&self, message: String) -> anyhow::Result<()> { + self.cmd_tx + .send(super::WsCommand::Action { action: message }) + .map_err(|_| { + anyhow::anyhow!( + "{}:{:x} error sending notify ws command", + self.id, + self.request_id + ) + }) + } + + pub fn send_display_text(&self, text: String) -> anyhow::Result<()> { + self.cmd_tx + .send(super::WsCommand::DisplayText(text)) + .map_err(|_| { + anyhow::anyhow!( + "{}:{:x} error sending display text ws command", + self.id, + self.request_id + ) + }) + } + + pub fn send_end_audio(&self) -> anyhow::Result<()> { + self.cmd_tx.send(super::WsCommand::EndAudio).map_err(|_| { + anyhow::anyhow!( + "{}:{:x} error sending end audio ws command", + self.id, + self.request_id + ) + }) + } + pub fn send_end_response(&self) -> anyhow::Result<()> { self.cmd_tx .send(super::WsCommand::EndResponse) @@ -273,17 +320,14 @@ async fn handle_tts_requests(mut chunks_rx: ChunksRx, session: &mut Session) -> chunk ); - session - .cmd_tx - .send(super::WsCommand::StartAudio(chunk.clone())) - .map_err(|_| { - anyhow::anyhow!( - "{}:{:x} error sending start audio ws command for chunk `{}`", - session.id, - session.request_id, - chunk - ) - })?; + session.send_start_audio(chunk.clone()).map_err(|_| { + anyhow::anyhow!( + "{}:{:x} error sending start audio ws command for chunk `{}`", + session.id, + session.request_id, + chunk + ) + })?; while let Some(tts_chunk) = tts_resp_rx.recv().await { log::trace!( @@ -297,28 +341,22 @@ async fn handle_tts_requests(mut chunks_rx: ChunksRx, session: &mut Session) -> continue; } - session - .cmd_tx - .send(super::WsCommand::Audio(tts_chunk)) - .map_err(|_| { - anyhow::anyhow!( - "{}:{:x} error sending audio chunk ws command for tts chunk", - session.id, - session.request_id - ) - })?; - } - - session - .cmd_tx - .send(super::WsCommand::EndAudio) - .map_err(|_| { + session.send_audio_chunk(tts_chunk.clone()).map_err(|_| { anyhow::anyhow!( - "{}:{:x} error sending end audio ws command after tts chunk", + "{}:{:x} error sending audio chunk ws command for tts chunk", session.id, session.request_id ) })?; + } + + session.send_end_audio().map_err(|_| { + anyhow::anyhow!( + "{}:{:x} error sending end audio ws command after tts chunk", + session.id, + session.request_id + ) + })?; log::info!( "{}:{:x} finished tts for chunk: {}",