Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions crates/ai/src/acp/bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use super::{
config::AgentRegistry,
process::{stop_child_tree, terminate_process_group},
types::{
AcpAgentCapabilities, AcpAgentStatus, AcpEvent, AcpSessionInfo, AcpSessionList, AgentConfig,
SessionConfigOption,
AcpAgentCapabilities, AcpAgentStatus, AcpErrorKind, AcpEvent, AcpSessionInfo, AcpSessionList,
AgentConfig, SessionConfigOption,
},
};
use crate::runtime::AthasAppHandle as AppHandle;
Expand All @@ -24,6 +24,24 @@ use tokio::{
task::LocalSet,
};

fn classify_acp_error(error: &str) -> Option<AcpErrorKind> {
let lower = error.to_lowercase();

if lower.contains("authentication required") {
return Some(AcpErrorKind::AuthenticationRequired);
}

let requires_provider_setup = lower.contains("no api key found")
|| lower.contains("missing api key")
|| (lower.contains("api key") && lower.contains("required"))
|| lower.contains("environment variable")
|| lower.contains("--setup")
|| lower.contains("not logged in")
|| lower.contains("login required");

requires_provider_setup.then_some(AcpErrorKind::ProviderSetupRequired)
}

/// Worker state running on the LocalSet thread
pub(super) struct AcpWorker {
connection: Option<Arc<acp::ClientSideConnection>>,
Expand Down Expand Up @@ -68,6 +86,7 @@ impl AcpWorker {
AcpEvent::Error {
session_id: session_id.clone(),
error: format!("ACP agent process exited: {}", status),
error_kind: None,
},
);
let _ = app_handle.emit(
Expand Down Expand Up @@ -188,11 +207,13 @@ impl AcpWorker {
.await
{
log::error!("Failed to run ACP prompt: {}", err);
let error = format!("Failed to run prompt: {}", err);
let _ = app_handle.emit(
"acp-event",
AcpEvent::Error {
session_id: Some(session_id.to_string()),
error: format!("Failed to run prompt: {}", err),
error_kind: classify_acp_error(&error),
error,
},
);
}
Expand Down Expand Up @@ -641,3 +662,32 @@ impl AcpAgentBridge {
);
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn classifies_provider_setup_errors() {
let error =
"Failed to run prompt: No API key found. Set Z_AI_API_KEY or run glm-acp-agent --setup";

assert_eq!(
classify_acp_error(error),
Some(AcpErrorKind::ProviderSetupRequired)
);
}

#[test]
fn classifies_authentication_required_errors() {
assert_eq!(
classify_acp_error("Authentication required before sending prompt"),
Some(AcpErrorKind::AuthenticationRequired)
);
}

#[test]
fn leaves_plain_runtime_errors_unclassified() {
assert_eq!(classify_acp_error("ACP agent process exited: 1"), None);
}
}
133 changes: 130 additions & 3 deletions crates/ai/src/acp/bridge_init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ pub(super) async fn initialize_worker(
SessionBootstrapContext {
auth_methods,
supports_session_resume,
default_mode: config.default_mode.clone(),
default_model: config.default_model.clone(),
map_config_options,
child: &mut child,
io_handle: &io_handle,
Expand Down Expand Up @@ -143,6 +145,8 @@ where
{
auth_methods: Vec<acp::AuthMethod>,
supports_session_resume: bool,
default_mode: Option<String>,
default_model: Option<String>,
map_config_options: F,
child: &'a mut Child,
io_handle: &'a tokio::task::JoinHandle<()>,
Expand Down Expand Up @@ -340,6 +344,14 @@ async fn bootstrap_session(

match load_result {
Ok(Ok(load_response)) => {
apply_session_defaults(
connection.clone(),
acp::SessionId::new(existing_session_id.clone()),
ctx.default_mode.as_deref(),
ctx.default_model.as_deref(),
load_response.config_options.as_ref(),
)
.await;
log::info!("ACP session loaded: {}", existing_session_id);
client.set_session_id(existing_session_id.clone()).await;
return Ok(SessionBootstrap {
Expand Down Expand Up @@ -375,6 +387,14 @@ async fn bootstrap_session(

match resume_result {
Ok(Ok(resume_response)) => {
apply_session_defaults(
connection.clone(),
acp::SessionId::new(existing_session_id.clone()),
ctx.default_mode.as_deref(),
ctx.default_model.as_deref(),
resume_response.config_options.as_ref(),
)
.await;
log::info!("ACP session resumed: {}", existing_session_id);
client.set_session_id(existing_session_id.clone()).await;
return Ok(SessionBootstrap {
Expand Down Expand Up @@ -478,6 +498,14 @@ async fn bootstrap_session(
};

log::info!("ACP session created: {}", session.session_id);
apply_session_defaults(
connection.clone(),
session.session_id.clone(),
ctx.default_mode.as_deref(),
ctx.default_model.as_deref(),
session.config_options.as_ref(),
)
.await;
client.set_session_id(session.session_id.to_string()).await;

Ok(SessionBootstrap {
Expand All @@ -491,7 +519,7 @@ async fn create_session(
connection: Arc<acp::ClientSideConnection>,
cwd: PathBuf,
) -> Result<Result<acp::NewSessionResponse, acp::Error>, tokio::time::error::Elapsed> {
let session_request = acp::NewSessionRequest::new(cwd);
let session_request = new_session_request(cwd);
tokio::time::timeout(
std::time::Duration::from_secs(30),
connection.new_session(session_request),
Expand All @@ -504,7 +532,7 @@ async fn load_session(
cwd: PathBuf,
existing_session_id: String,
) -> Result<Result<acp::LoadSessionResponse, acp::Error>, tokio::time::error::Elapsed> {
let request = acp::LoadSessionRequest::new(existing_session_id, cwd);
let request = load_session_request(existing_session_id, cwd);
tokio::time::timeout(
std::time::Duration::from_secs(30),
connection.load_session(request),
Expand All @@ -517,14 +545,82 @@ async fn resume_session(
cwd: PathBuf,
existing_session_id: String,
) -> Result<Result<acp::ResumeSessionResponse, acp::Error>, tokio::time::error::Elapsed> {
let request = acp::ResumeSessionRequest::new(existing_session_id, cwd);
let request = resume_session_request(existing_session_id, cwd);
tokio::time::timeout(
std::time::Duration::from_secs(30),
connection.resume_session(request),
)
.await
}

fn new_session_request(cwd: PathBuf) -> acp::NewSessionRequest {
acp::NewSessionRequest::new(cwd)
}

fn load_session_request(existing_session_id: String, cwd: PathBuf) -> acp::LoadSessionRequest {
acp::LoadSessionRequest::new(existing_session_id, cwd)
}

fn resume_session_request(existing_session_id: String, cwd: PathBuf) -> acp::ResumeSessionRequest {
acp::ResumeSessionRequest::new(existing_session_id, cwd)
}

async fn apply_session_defaults(
connection: Arc<acp::ClientSideConnection>,
session_id: acp::SessionId,
default_mode: Option<&str>,
default_model: Option<&str>,
config_options: Option<&Vec<acp::SessionConfigOption>>,
) {
if let Some(mode_id) = default_mode.filter(|mode| !mode.trim().is_empty())
&& let Err(error) = connection
.set_session_mode(acp::SetSessionModeRequest::new(
session_id.clone(),
mode_id.to_string(),
))
.await
{
log::warn!("Failed to apply ACP default mode '{}': {}", mode_id, error);
}

let Some(model_id) = default_model.filter(|model| !model.trim().is_empty()) else {
return;
};
let Some(config_id) = model_config_option_id(config_options) else {
log::debug!(
"ACP default model '{}' configured, but the agent did not expose a model config option",
model_id
);
return;
};

if let Err(error) = connection
.set_session_config_option(acp::SetSessionConfigOptionRequest::new(
session_id,
config_id,
model_id.to_string(),
))
.await
{
log::warn!(
"Failed to apply ACP default model '{}': {}",
model_id,
error
);
}
}

fn model_config_option_id(
config_options: Option<&Vec<acp::SessionConfigOption>>,
) -> Option<String> {
config_options?
.iter()
.find(|option| {
option.id.to_string() == "model" || option.category.as_deref() == Some("model")
})
.map(|option| option.id.to_string())
}

fn map_mode_state(modes: acp::SessionModeState) -> SessionModeState {
SessionModeState {
current_mode_id: Some(modes.current_mode_id.to_string()),
Expand Down Expand Up @@ -570,3 +666,34 @@ fn emit_initial_session_state(
log::warn!("Failed to emit initial session config options: {}", e);
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn new_session_request_sets_cwd() {
let request = new_session_request(PathBuf::from("/repo"));

assert_eq!(request.cwd, PathBuf::from("/repo"));
assert!(request.mcp_servers.is_empty());
}

#[test]
fn load_session_request_sets_session_and_cwd() {
let request = load_session_request("session-1".to_string(), PathBuf::from("/repo"));

assert_eq!(request.session_id, acp::SessionId::new("session-1"));
assert_eq!(request.cwd, PathBuf::from("/repo"));
assert!(request.mcp_servers.is_empty());
}

#[test]
fn resume_session_request_sets_session_and_cwd() {
let request = resume_session_request("session-1".to_string(), PathBuf::from("/repo"));

assert_eq!(request.session_id, acp::SessionId::new("session-1"));
assert_eq!(request.cwd, PathBuf::from("/repo"));
assert!(request.mcp_servers.is_empty());
}
}
8 changes: 8 additions & 0 deletions crates/ai/src/acp/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ impl AgentRegistry {
continue;
}

if let Some(path) = config.binary_path.as_ref().map(PathBuf::from)
&& path.is_file()
{
config.installed = true;
config.binary_path = Some(path.to_string_lossy().to_string());
continue;
}

if let Some(path) = find_binary(&config.binary_name) {
config.installed = true;
config.binary_path = Some(path.to_string_lossy().to_string());
Expand Down
12 changes: 12 additions & 0 deletions crates/ai/src/acp/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ pub struct AgentConfig {
pub binary_path: Option<String>,
pub args: Vec<String>,
pub env_vars: HashMap<String, String>,
pub default_mode: Option<String>,
pub default_model: Option<String>,
pub icon: Option<String>,
pub description: Option<String>,
pub installed: bool,
Expand All @@ -224,6 +226,8 @@ impl AgentConfig {
binary_path: None,
args: Vec::new(),
env_vars: HashMap::new(),
default_mode: None,
default_model: None,
icon: None,
description: None,
installed: false,
Expand Down Expand Up @@ -370,6 +374,13 @@ pub struct AcpSessionList {
pub next_cursor: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum AcpErrorKind {
AuthenticationRequired,
ProviderSetupRequired,
}

/// Events emitted to the frontend via Tauri
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
Expand Down Expand Up @@ -445,6 +456,7 @@ pub enum AcpEvent {
Error {
session_id: Option<String>,
error: String,
error_kind: Option<AcpErrorKind>,
},
/// Agent status changed
#[serde(rename_all = "camelCase")]
Expand Down
Loading
Loading