diff --git a/Cargo.lock b/Cargo.lock index 3fcc3ee56f0d..6ce7a570c9ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4706,6 +4706,7 @@ dependencies = [ "futures", "goose", "goose-mcp", + "goose-providers", "indicatif", "open", "rand 0.8.6", @@ -4769,15 +4770,21 @@ name = "goose-providers" version = "1.37.0" dependencies = [ "anyhow", + "async-stream", "base64 0.22.1", "chrono", + "futures", "once_cell", "regex", + "reqwest 0.13.4", "rmcp", "serde", "serde_json", + "tempfile", "test-case", "thiserror 1.0.69", + "tokio", + "tokio-stream", "tracing", "unicode-normalization", "utoipa 4.2.3", @@ -4825,6 +4832,7 @@ dependencies = [ "futures", "goose", "goose-mcp", + "goose-providers", "hex", "http 1.4.1", "openssl", diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 15e2c5304809..2f33e21b1148 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -23,6 +23,7 @@ path = "src/bin/generate_manpages.rs" clap_mangen = { version = "0.3", default-features = false } goose = { path = "../goose", default-features = false } goose-mcp = { path = "../goose-mcp", default-features = false } +goose-providers = { path = "../goose-providers", default-features = false } rmcp = { workspace = true } clap = { workspace = true } cliclack = { version = "0.5", default-features = false } diff --git a/crates/goose-cli/src/commands/info.rs b/crates/goose-cli/src/commands/info.rs index d702b3e43eb9..821e7acbaf18 100644 --- a/crates/goose-cli/src/commands/info.rs +++ b/crates/goose-cli/src/commands/info.rs @@ -3,8 +3,8 @@ use console::style; use goose::config::paths::Paths; use goose::config::Config; use goose::conversation::message::Message; -use goose::providers::errors::ProviderError; use goose::session::session_manager::{DB_NAME, SESSIONS_FOLDER}; +use goose_providers::errors::ProviderError; use serde_yaml; use std::time::Duration; diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index cd7e63b74098..67c06704d4d4 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -2091,11 +2091,11 @@ fn handle_agent_error(e: &anyhow::Error, is_stream_json_mode: bool) { }); } - if e.downcast_ref::() + if e.downcast_ref::() .map(|provider_error| { matches!( provider_error, - goose::providers::errors::ProviderError::ContextLengthExceeded(_) + goose_providers::errors::ProviderError::ContextLengthExceeded(_) ) }) .unwrap_or(false) @@ -2366,7 +2366,7 @@ mod tests { assert_eq!(current.model_name, "gpt-5.4"); assert_eq!( current.thinking_effort(), - Some(goose::model::ThinkingEffort::High) + Some(goose_providers::thinking::ThinkingEffort::High) ); let switched = build_switched_model_config("openai", "gpt-5.4", ¤t).unwrap(); diff --git a/crates/goose-providers/Cargo.toml b/crates/goose-providers/Cargo.toml index e51dea54d0bd..267d2b91c8b2 100644 --- a/crates/goose-providers/Cargo.toml +++ b/crates/goose-providers/Cargo.toml @@ -13,11 +13,14 @@ workspace = true [dependencies] anyhow = { workspace = true } +async-stream = { workspace = true } base64 = { workspace = true } chrono = { workspace = true } +futures = { workspace = true } once_cell = { workspace = true } -regex = { workspace = true } -rmcp = { workspace = true, features = ["server"] } +regex = { workspace = true, features = ["unicode"] } +reqwest = { workspace = true } +rmcp = { workspace = true, features = ["server", "macros"] } serde = { workspace = true } serde_json = { workspace = true } thiserror = { workspace = true } @@ -28,3 +31,6 @@ uuid = { workspace = true, features = ["v4", "std"] } [dev-dependencies] test-case = { workspace = true } +tempfile = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } diff --git a/crates/goose-providers/src/base.rs b/crates/goose-providers/src/base.rs new file mode 100644 index 000000000000..e063ddebc9e4 --- /dev/null +++ b/crates/goose-providers/src/base.rs @@ -0,0 +1,17 @@ +use std::future::Future; + +pub struct Error; + +pub struct Model { + pub name: String, +} + +pub struct StreamingRequest { + pub model: Model, +} + +pub struct StreamingResponse; + +pub trait Provider { + fn stream(req: StreamingRequest) -> impl Future>; +} diff --git a/crates/goose-providers/src/canonical/mod.rs b/crates/goose-providers/src/canonical.rs similarity index 100% rename from crates/goose-providers/src/canonical/mod.rs rename to crates/goose-providers/src/canonical.rs diff --git a/crates/goose-providers/src/conversation/mod.rs b/crates/goose-providers/src/conversation.rs similarity index 99% rename from crates/goose-providers/src/conversation/mod.rs rename to crates/goose-providers/src/conversation.rs index 31d3f52d61b3..4f8c92197bcd 100644 --- a/crates/goose-providers/src/conversation/mod.rs +++ b/crates/goose-providers/src/conversation.rs @@ -7,6 +7,7 @@ use thiserror::Error; use utoipa::ToSchema; pub mod message; +pub mod token_usage; mod tool_result_serde; #[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq)] diff --git a/crates/goose-providers/src/conversation/token_usage.rs b/crates/goose-providers/src/conversation/token_usage.rs new file mode 100644 index 000000000000..b0419be385ea --- /dev/null +++ b/crates/goose-providers/src/conversation/token_usage.rs @@ -0,0 +1,148 @@ +use std::ops::{Add, AddAssign}; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderUsage { + pub model: String, + pub usage: Usage, +} + +impl ProviderUsage { + pub fn new(model: String, usage: Usage) -> Self { + Self { model, usage } + } + + /// Combine this ProviderUsage with another, adding their token counts + /// Uses the model from this ProviderUsage + pub fn combine_with(&self, other: &ProviderUsage) -> ProviderUsage { + ProviderUsage { + model: self.model.clone(), + usage: self.usage + other.usage, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, Copy)] +pub struct Usage { + pub input_tokens: Option, + pub output_tokens: Option, + pub total_tokens: Option, + pub cache_read_input_tokens: Option, + pub cache_write_input_tokens: Option, +} + +fn sum_optionals(a: Option, b: Option) -> Option +where + T: Add + Default, +{ + match (a, b) { + (Some(x), Some(y)) => Some(x + y), + (Some(x), None) => Some(x + T::default()), + (None, Some(y)) => Some(T::default() + y), + (None, None) => None, + } +} + +impl Add for Usage { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self::new( + sum_optionals(self.input_tokens, other.input_tokens), + sum_optionals(self.output_tokens, other.output_tokens), + sum_optionals(self.total_tokens, other.total_tokens), + ) + .with_cache_tokens( + sum_optionals(self.cache_read_input_tokens, other.cache_read_input_tokens), + sum_optionals( + self.cache_write_input_tokens, + other.cache_write_input_tokens, + ), + ) + } +} + +impl AddAssign for Usage { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl Usage { + pub fn new( + input_tokens: Option, + output_tokens: Option, + total_tokens: Option, + ) -> Self { + let calculated_total = if total_tokens.is_none() { + match (input_tokens, output_tokens) { + (Some(input), Some(output)) => Some(input + output), + (Some(input), None) => Some(input), + (None, Some(output)) => Some(output), + (None, None) => None, + } + } else { + total_tokens + }; + + Self { + input_tokens, + output_tokens, + total_tokens: calculated_total, + cache_read_input_tokens: None, + cache_write_input_tokens: None, + } + } + + pub fn with_cache_tokens( + mut self, + cache_read_input_tokens: Option, + cache_write_input_tokens: Option, + ) -> Self { + self.cache_read_input_tokens = cache_read_input_tokens; + self.cache_write_input_tokens = cache_write_input_tokens; + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + use serde_json::json; + + #[test] + fn test_usage_serialization() -> Result<()> { + let usage = Usage::new(Some(10), Some(20), Some(30)); + let serialized = serde_json::to_string(&usage)?; + let deserialized: Usage = serde_json::from_str(&serialized)?; + + assert_eq!(usage.input_tokens, deserialized.input_tokens); + assert_eq!(usage.output_tokens, deserialized.output_tokens); + assert_eq!(usage.total_tokens, deserialized.total_tokens); + + // Test JSON structure + let json_value: serde_json::Value = serde_json::from_str(&serialized)?; + assert_eq!(json_value["input_tokens"], json!(10)); + assert_eq!(json_value["output_tokens"], json!(20)); + assert_eq!(json_value["total_tokens"], json!(30)); + + Ok(()) + } + + #[test] + fn test_usage_addition_includes_cached_tokens() { + let usage_a = + Usage::new(Some(100), Some(20), Some(120)).with_cache_tokens(Some(10), Some(5)); + let usage_b = Usage::new(Some(50), Some(8), Some(58)).with_cache_tokens(Some(4), Some(1)); + + let combined = usage_a + usage_b; + + assert_eq!(combined.input_tokens, Some(150)); + assert_eq!(combined.output_tokens, Some(28)); + assert_eq!(combined.total_tokens, Some(178)); + assert_eq!(combined.cache_read_input_tokens, Some(14)); + assert_eq!(combined.cache_write_input_tokens, Some(6)); + } +} diff --git a/crates/goose/src/providers/errors.rs b/crates/goose-providers/src/errors.rs similarity index 100% rename from crates/goose/src/providers/errors.rs rename to crates/goose-providers/src/errors.rs diff --git a/crates/goose-providers/src/formats.rs b/crates/goose-providers/src/formats.rs new file mode 100644 index 000000000000..d8c308735bfe --- /dev/null +++ b/crates/goose-providers/src/formats.rs @@ -0,0 +1 @@ +pub mod openai; diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose-providers/src/formats/openai.rs similarity index 93% rename from crates/goose/src/providers/formats/openai.rs rename to crates/goose-providers/src/formats/openai.rs index b9d7ba285d58..1ae1acb34159 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose-providers/src/formats/openai.rs @@ -1,17 +1,17 @@ use crate::conversation::message::{Message, MessageContent, ProviderMetadata}; +use crate::conversation::token_usage::{ProviderUsage, Usage}; +use crate::errors::ProviderError; +use crate::images::{convert_image, detect_image_path, load_image_file, ImageFormat}; +use crate::json::safely_parse_json; use crate::mcp_utils::extract_text_from_resource; -use crate::model::ModelConfig; -use crate::providers::base::{split_think_blocks, ProviderUsage, ThinkFilter, Usage}; -use crate::providers::errors::ProviderError; -use crate::providers::utils::{ - convert_image, detect_image_path, extract_reasoning_effort, is_openai_responses_model, - is_valid_function_name, load_image_file, openai_reasoning_effort_for_thinking, - safely_parse_json, sanitize_function_name, ImageFormat, +use crate::thinking::{ + split_think_blocks, ThinkFilter, ThinkingEffort, GEMINI_THOUGHT_SIGNATURE_KEY, }; use anyhow::{anyhow, Error}; use async_stream::try_stream; use chrono; use futures::Stream; +use regex::Regex; use rmcp::model::{ object, AnnotateAble, CallToolRequestParams, Content, ErrorCode, ErrorData, RawContent, Role, Tool, @@ -21,6 +21,7 @@ use serde_json::{json, Value}; use std::borrow::Cow; use std::collections::HashMap; use std::ops::Deref; +use std::sync::OnceLock; type ToolCallData = HashMap< i32, @@ -1085,7 +1086,7 @@ where let metadata = if let Some(sig) = &last_signature { let mut combined = extra_fields.clone().unwrap_or_default(); combined.insert( - crate::providers::formats::google::THOUGHT_SIGNATURE_KEY.to_string(), + GEMINI_THOUGHT_SIGNATURE_KEY.to_string(), json!(sig) ); Some(combined) @@ -1218,7 +1219,7 @@ where } pub fn create_request( - model_config: &ModelConfig, + model_config: ModelConfigParams, system: &str, messages: &[Message], tools: &[Tool], @@ -1238,8 +1239,16 @@ pub fn create_request( ) } +pub struct ModelConfigParams<'a> { + pub model_name: &'a str, + pub thinking_effort: Option, + pub temperature: Option, + pub max_tokens: Option, + pub request_params: Option<&'a HashMap>, +} + pub fn create_request_with_options( - model_config: &ModelConfig, + model_config: ModelConfigParams, system: &str, messages: &[Message], tools: &[Tool], @@ -1253,11 +1262,11 @@ pub fn create_request_with_options( )); } - let (model_name, legacy_reasoning_effort) = extract_reasoning_effort(&model_config.model_name); + let (model_name, legacy_reasoning_effort) = extract_reasoning_effort(model_config.model_name); let is_reasoning_model = is_openai_responses_model(&model_name); let reasoning_effort = if is_reasoning_model { model_config - .thinking_effort() + .thinking_effort .map_or(legacy_reasoning_effort, |effort| { openai_reasoning_effort_for_thinking(&model_name, effort) }) @@ -1319,7 +1328,7 @@ pub fn create_request_with_options( payload["stream_options"] = json!({"include_usage": true}); } - if let Some(params) = &model_config.request_params { + if let Some(params) = model_config.request_params { if let Some(obj) = payload.as_object_mut() { for (key, value) in params { if key != "thinking_effort" && !is_reserved_request_param_key(key) { @@ -1332,6 +1341,105 @@ pub fn create_request_with_options( Ok(payload) } +/// Extract an explicit reasoning-effort suffix from a model name. +/// +/// Returns `(base_model_name, Some(effort))` when the user appended a +/// recognised suffix like `-high` or `-xhigh`, e.g. `gpt-5.4-high` → +/// `("gpt-5.4", Some("high"))`. +/// +/// When no suffix is present the effort is `None` — callers should omit +/// the `reasoning` field entirely so the API applies its own per-model +/// default. This avoids hard-coding a default that may be invalid for +/// certain models (e.g. `gpt-5-pro` only accepts `high`; older o-series +/// models reject `none` and `xhigh`). +pub fn extract_reasoning_effort(model_name: &str) -> (String, Option) { + if !is_openai_responses_model(model_name) { + return (model_name.to_string(), None); + } + + static RE: OnceLock = OnceLock::new(); + let re = RE.get_or_init(|| { + Regex::new(r"(?i)^(?P.+)-(?Pnone|low|medium|high|xhigh)$").unwrap() + }); + + if let Some(captures) = re.captures(model_name) { + let base = captures["base"].to_string(); + let effort = captures["effort"].to_ascii_lowercase(); + return (base, Some(effort)); + } + + (model_name.to_string(), None) +} + +/// True when the model should use the OpenAI Responses API. +/// +/// The Responses API is backwards-compatible with all OpenAI reasoning +/// models, so every `o`-series (`o1`, `o3`, `o4`, …) and `gpt-5` variant +/// routes here. The matcher intentionally scans the full model identifier so +/// hosted aliases like `databricks-gpt-5.4`, `goose-o3-mini`, or +/// `headless-goose-o3-mini` work without provider-specific normalization. +pub fn is_openai_responses_model(model_name: &str) -> bool { + static RE: OnceLock = OnceLock::new(); + let re = + RE.get_or_init(|| Regex::new(r"(?i)(?:^|[-/])(?:o\d+(?:$|-)|gpt-5(?:$|[-.]))").unwrap()); + re.is_match(model_name) +} + +pub fn openai_reasoning_effort_for_thinking( + model_name: &str, + effort: ThinkingEffort, +) -> Option { + if effort == ThinkingEffort::Off { + return Some("none".to_string()); + } + + let supported = openai_reasoning_efforts_for_model(model_name); + let preferred: &[&str] = match effort { + ThinkingEffort::Off => unreachable!(), + ThinkingEffort::Low => &["low", "medium", "high", "xhigh"], + ThinkingEffort::Medium => &["medium", "high", "low", "xhigh"], + ThinkingEffort::High => &["high", "medium", "xhigh", "low"], + ThinkingEffort::Max => &["xhigh", "high", "medium", "low"], + }; + + preferred + .iter() + .find(|level| supported.contains(level)) + .map(|level| (*level).to_string()) +} + +fn openai_reasoning_efforts_for_model(model_name: &str) -> &'static [&'static str] { + let normalized = model_name.to_ascii_lowercase(); + + if normalized.contains("gpt-5") { + if normalized.contains("-pro") || normalized.contains("/pro") { + &["high"] + } else if normalized.contains("gpt-5.4") + || normalized.contains("gpt-5-4") + || normalized.contains("gpt-5.5") + || normalized.contains("gpt-5-5") + { + &["low", "medium", "high", "xhigh"] + } else { + &["low", "medium", "high"] + } + } else { + &["low", "medium", "high"] + } +} + +pub fn sanitize_function_name(name: &str) -> String { + static RE: OnceLock = OnceLock::new(); + let re = RE.get_or_init(|| Regex::new(r"[^a-zA-Z0-9_-]").unwrap()); + re.replace_all(name, "_").to_string() +} + +pub fn is_valid_function_name(name: &str) -> bool { + static RE: OnceLock = OnceLock::new(); + let re = RE.get_or_init(|| Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap()); + re.is_match(name) +} + #[cfg(test)] mod tests { use super::*; @@ -1471,31 +1579,6 @@ mod tests { let timeout_schema = &tools[0]["function"]["parameters"]["properties"]["timeout_secs"]; assert_eq!(timeout_schema["type"], "integer"); assert!(!timeout_schema["type"].is_array()); - - // Test case 5: Verify the actual ShellParams schema is compatible (no anyOf for timeout_secs) - use crate::agents::platform_extensions::developer::shell::ShellParams; - use schemars::schema_for; - let schema_value = serde_json::to_value(schema_for!(ShellParams)).unwrap(); - let schema_obj = schema_value.as_object().unwrap().clone(); - let tool = rmcp::model::Tool::new("shell", "run shell", schema_obj); - let mut tools = vec![json!({ - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.input_schema, - } - })]; - validate_tool_schemas(&mut tools); - let timeout = &tools[0]["function"]["parameters"]["properties"]["timeout_secs"]; - assert!( - timeout.get("anyOf").is_none(), - "timeout_secs should not have anyOf after validation, got: {timeout}" - ); - assert_eq!( - timeout["type"], "integer", - "timeout_secs should have type=integer" - ); } const OPENAI_TOOL_USE_RESPONSE: &str = r#"{ @@ -2051,19 +2134,15 @@ mod tests { #[test] fn test_create_request_gpt_4o() -> anyhow::Result<()> { // Test default medium reasoning effort for O3 model - let model_config = ModelConfig { - model_name: "gpt-4o".to_string(), - context_limit: Some(4096), + let model_config = ModelConfigParams { + model_name: "gpt-4o", + thinking_effort: None, temperature: None, max_tokens: Some(1024), - toolshim: false, - toolshim_model: None, - fast_model_config: None, request_params: None, - reasoning: None, }; let request = create_request( - &model_config, + model_config, "system", &[], &[], @@ -2094,19 +2173,15 @@ mod tests { // Unknown models on OpenAI-compatible local providers (llama_swap, // lmstudio) have no canonical record and no GOOSE_MAX_TOKENS, so the // request must not pin the legacy 4096 default. See issue #9007. - let model_config = ModelConfig { - model_name: "some-unknown-local-model".to_string(), - context_limit: None, + let model_config = ModelConfigParams { + model_name: "some-unknown-local-model", + thinking_effort: None, temperature: None, max_tokens: None, - toolshim: false, - toolshim_model: None, - fast_model_config: None, request_params: None, - reasoning: None, }; let request = create_request( - &model_config, + model_config, "system", &[], &[], @@ -2127,44 +2202,34 @@ mod tests { #[test] fn test_request_params_preserve_reserved_fields() -> anyhow::Result<()> { - let model_config = ModelConfig { - model_name: "glm-4.7".to_string(), - context_limit: Some(204800), + let params = std::collections::HashMap::from([ + ( + "thinking".to_string(), + json!({ + "type": "enabled", + "clear_thinking": false + }), + ), + ("stream".to_string(), json!(false)), + ( + "stream_options".to_string(), + json!({"include_usage": false}), + ), + ("model".to_string(), json!("wrong-model")), + ("messages".to_string(), json!([])), + ("max_tokens".to_string(), json!(1)), + ("temperature".to_string(), json!(2.0)), + ("provider_custom".to_string(), json!("allowed")), + ]); + let model_config = ModelConfigParams { + model_name: "glm-4.7", + thinking_effort: None, temperature: None, max_tokens: Some(4096), - toolshim: false, - toolshim_model: None, - fast_model_config: None, - request_params: Some(std::collections::HashMap::from([ - ( - "thinking".to_string(), - json!({ - "type": "enabled", - "clear_thinking": false - }), - ), - ("stream".to_string(), json!(false)), - ( - "stream_options".to_string(), - json!({"include_usage": false}), - ), - ("model".to_string(), json!("wrong-model")), - ("messages".to_string(), json!([])), - ("max_tokens".to_string(), json!(1)), - ("temperature".to_string(), json!(2.0)), - ("provider_custom".to_string(), json!("allowed")), - ])), - reasoning: None, + request_params: Some(¶ms), }; - let request = create_request( - &model_config, - "system", - &[], - &[], - &ImageFormat::OpenAi, - true, - )?; + let request = create_request(model_config, "system", &[], &[], &ImageFormat::OpenAi, true)?; assert_eq!( request["thinking"], @@ -2186,19 +2251,15 @@ mod tests { #[test] fn test_create_request_o1_default() -> anyhow::Result<()> { - let model_config = ModelConfig { - model_name: "o1".to_string(), - context_limit: Some(4096), + let model_config = ModelConfigParams { + model_name: "o1", + thinking_effort: None, temperature: None, max_tokens: Some(1024), - toolshim: false, - toolshim_model: None, - fast_model_config: None, request_params: None, - reasoning: None, }; let request = create_request( - &model_config, + model_config, "system", &[], &[], @@ -2232,19 +2293,15 @@ mod tests { fn test_create_request_o1_medium_effort() -> anyhow::Result<()> { let mut params = std::collections::HashMap::new(); params.insert("thinking_effort".to_string(), json!("medium")); - let model_config = ModelConfig { - model_name: "o1".to_string(), - context_limit: Some(4096), + let model_config = ModelConfigParams { + model_name: "o1", + thinking_effort: Some(ThinkingEffort::Medium), temperature: None, max_tokens: Some(1024), - toolshim: false, - toolshim_model: None, - fast_model_config: None, - request_params: Some(params), - reasoning: None, + request_params: Some(¶ms), }; let request = create_request( - &model_config, + model_config, "system", &[], &[], @@ -2263,19 +2320,15 @@ mod tests { fn test_create_request_o3_off_effort_preserves_none() -> anyhow::Result<()> { let mut params = std::collections::HashMap::new(); params.insert("thinking_effort".to_string(), json!("off")); - let model_config = ModelConfig { - model_name: "o3".to_string(), - context_limit: Some(4096), + let model_config = ModelConfigParams { + model_name: "o3", + thinking_effort: Some(ThinkingEffort::Off), temperature: None, max_tokens: Some(1024), - toolshim: false, - toolshim_model: None, - fast_model_config: None, - request_params: Some(params), - reasoning: None, + request_params: Some(¶ms), }; let request = create_request( - &model_config, + model_config, "system", &[], &[], @@ -2294,19 +2347,15 @@ mod tests { fn test_create_request_gpt5_pro_max_effort_uses_supported_level() -> anyhow::Result<()> { let mut params = std::collections::HashMap::new(); params.insert("thinking_effort".to_string(), json!("max")); - let model_config = ModelConfig { - model_name: "gpt-5.2-pro-2025-12-11".to_string(), - context_limit: Some(4096), + let model_config = ModelConfigParams { + model_name: "gpt-5.2-pro-2025-12-11", + thinking_effort: Some(ThinkingEffort::Max), temperature: None, max_tokens: Some(1024), - toolshim: false, - toolshim_model: None, - fast_model_config: None, - request_params: Some(params), - reasoning: None, + request_params: Some(¶ms), }; let request = create_request( - &model_config, + model_config, "system", &[], &[], @@ -2325,19 +2374,15 @@ mod tests { fn test_create_request_o3_custom_reasoning_effort() -> anyhow::Result<()> { let mut params = std::collections::HashMap::new(); params.insert("thinking_effort".to_string(), json!("high")); - let model_config = ModelConfig { - model_name: "o3-mini".to_string(), - context_limit: Some(4096), + let model_config = ModelConfigParams { + model_name: "o3-mini", + thinking_effort: Some(ThinkingEffort::High), temperature: None, max_tokens: Some(1024), - toolshim: false, - toolshim_model: None, - fast_model_config: None, - request_params: Some(params), - reasoning: None, + request_params: Some(¶ms), }; let request = create_request( - &model_config, + model_config, "system", &[], &[], @@ -2968,16 +3013,12 @@ data: [DONE]"#; #[test] fn test_create_request_preserves_reasoning_content_for_legacy_compat() -> anyhow::Result<()> { - let model_config = ModelConfig { - model_name: "deepseek-reasoner".to_string(), - context_limit: Some(128000), + let model_config = ModelConfigParams { + model_name: "deepseek-reasoner", + thinking_effort: None, temperature: None, max_tokens: Some(1024), - toolshim: false, - toolshim_model: None, - fast_model_config: None, request_params: None, - reasoning: None, }; let message = Message::assistant() .with_content(MessageContent::thinking("preserve this", "")) @@ -2988,7 +3029,7 @@ data: [DONE]"#; ); let request = create_request( - &model_config, + model_config, "system", &[message], &[], @@ -3454,4 +3495,80 @@ data: [DONE]"#; let parsed: DeltaToolCallFunction = serde_json::from_str(raw).unwrap(); assert_eq!(parsed.arguments, "{\"k\":1}"); } + + #[test] + fn test_is_openai_responses_model_matches_o_and_gpt5_families() { + for model in [ + "o3", + "o3-mini", + "o4-mini", + "gpt-5", + "gpt-5-pro", + "gpt-5.4", + "gpt-5.4-mini", + "gpt-5-4", + "gpt-5-2-pro", + "databricks-gpt-5.4", + "goose-gpt-5.4-high", + "headless-goose-o3-mini", + ] { + assert!(is_openai_responses_model(model), "{model} should match"); + } + } + + #[test] + fn test_is_openai_responses_model_rejects_other_families() { + for model in [ + "gpt-4o", + "claude-sonnet-4", + "databricks-claude-sonnet-4", + "llama-3-70b", + ] { + assert!( + !is_openai_responses_model(model), + "{model} should not match" + ); + } + } + + #[test] + fn test_extract_reasoning_effort_for_responses_models() { + for (model, expected_name, expected_effort) in [ + ("o3-none", "o3", Some("none")), + ("o3-xhigh", "o3", Some("xhigh")), + ("gpt-5-low", "gpt-5", Some("low")), + ("gpt-5.4", "gpt-5.4", None), + ( + "databricks-gpt-5.4-high", + "databricks-gpt-5.4", + Some("high"), + ), + ("databricks-o3-low", "databricks-o3", Some("low")), + ("goose-gpt-5-high", "goose-gpt-5", Some("high")), + ("gpt-4o", "gpt-4o", None), + ] { + let (name, effort) = extract_reasoning_effort(model); + assert_eq!(name, expected_name, "unexpected base model for {model}"); + assert_eq!( + effort.as_deref(), + expected_effort, + "unexpected effort for {model}" + ); + } + } + + #[test] + fn test_sanitize_function_name() { + assert_eq!(sanitize_function_name("hello-world"), "hello-world"); + assert_eq!(sanitize_function_name("hello world"), "hello_world"); + assert_eq!(sanitize_function_name("hello@world"), "hello_world"); + } + + #[test] + fn test_is_valid_function_name() { + assert!(is_valid_function_name("hello-world")); + assert!(is_valid_function_name("hello_world")); + assert!(!is_valid_function_name("hello world")); + assert!(!is_valid_function_name("hello@world")); + } } diff --git a/crates/goose-providers/src/images.rs b/crates/goose-providers/src/images.rs new file mode 100644 index 000000000000..ebbfbd210ee9 --- /dev/null +++ b/crates/goose-providers/src/images.rs @@ -0,0 +1,218 @@ +use std::{io::Read as _, path::Path}; + +use base64::Engine as _; +use rmcp::model::{AnnotateAble as _, ImageContent, RawImageContent}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +use crate::errors::ProviderError; + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +pub enum ImageFormat { + OpenAi, + Anthropic, +} + +/// Convert an image content into an image json based on format +pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value { + match image_format { + ImageFormat::OpenAi => json!({ + "type": "image_url", + "image_url": { + "url": format!("data:{};base64,{}", image.mime_type, image.data) + } + }), + ImageFormat::Anthropic => json!({ + "type": "image", + "source": { + "type": "base64", + "media_type": image.mime_type, + "data": image.data, + } + }), + } +} + +/// Detect if a string contains a path to an image file +pub fn detect_image_path(text: &str) -> Option<&str> { + // Basic image file extension check + let extensions = [".png", ".jpg", ".jpeg"]; + + // Find any word that ends with an image extension + for word in text.split_whitespace() { + if extensions + .iter() + .any(|ext| word.to_lowercase().ends_with(ext)) + { + let path = Path::new(word); + // Check if it's an absolute path and file exists + if path.is_absolute() && path.is_file() { + // Verify it's actually an image file + if is_image_file(path) { + return Some(word); + } + } + } + } + None +} + +/// Check if a file is actually an image by examining its magic bytes +fn is_image_file(path: &Path) -> bool { + if let Ok(mut file) = std::fs::File::open(path) { + let mut buffer = [0u8; 8]; // Large enough for most image magic numbers + if file.read(&mut buffer).is_ok() { + // Check magic numbers for common image formats + return match &buffer[0..4] { + // PNG: 89 50 4E 47 + [0x89, 0x50, 0x4E, 0x47] => true, + // JPEG: FF D8 FF + [0xFF, 0xD8, 0xFF, _] => true, + // GIF: 47 49 46 38 + [0x47, 0x49, 0x46, 0x38] => true, + _ => false, + }; + } + } + false +} + +/// Convert a local image file to base64 encoded ImageContent +pub fn load_image_file(path: &str) -> Result { + let path = Path::new(path); + + // Verify it's an image before proceeding + if !is_image_file(path) { + return Err(ProviderError::RequestFailed( + "File is not a valid image".to_string(), + )); + } + + // Read the file + let bytes = std::fs::read(path) + .map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?; + + // Detect mime type from extension + let mime_type = match path.extension().and_then(|e| e.to_str()) { + Some(ext) => match ext.to_lowercase().as_str() { + "png" => "image/png", + "jpg" | "jpeg" => "image/jpeg", + _ => { + return Err(ProviderError::RequestFailed( + "Unsupported image format".to_string(), + )) + } + }, + None => { + return Err(ProviderError::RequestFailed( + "Unknown image format".to_string(), + )) + } + }; + + // Convert to base64 + let data = base64::prelude::BASE64_STANDARD.encode(&bytes); + + Ok(RawImageContent { + mime_type: mime_type.to_string(), + data, + meta: None, + } + .no_annotation()) +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile; + + #[test] + fn test_detect_image_path() { + // Create a temporary PNG file with valid PNG magic numbers + let temp_dir = tempfile::tempdir().unwrap(); + let png_path = temp_dir.path().join("test.png"); + let png_data = [ + 0x89, 0x50, 0x4E, 0x47, // PNG magic number + 0x0D, 0x0A, 0x1A, 0x0A, // PNG header + 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data + ]; + std::fs::write(&png_path, png_data).unwrap(); + let png_path_str = png_path.to_str().unwrap(); + + // Create a fake PNG (wrong magic numbers) + let fake_png_path = temp_dir.path().join("fake.png"); + std::fs::write(&fake_png_path, b"not a real png").unwrap(); + + // Test with valid PNG file using absolute path + let text = format!("Here is an image {}", png_path_str); + assert_eq!(detect_image_path(&text), Some(png_path_str)); + + // Test with non-image file that has .png extension + let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap()); + assert_eq!(detect_image_path(&text), None); + + // Test with nonexistent file + let text = "Here is a fake.png that doesn't exist"; + assert_eq!(detect_image_path(text), None); + + // Test with non-image file + let text = "Here is a file.txt"; + assert_eq!(detect_image_path(text), None); + + // Test with relative path (should not match) + let text = "Here is a relative/path/image.png"; + assert_eq!(detect_image_path(text), None); + } + + #[test] + fn test_load_image_file() { + // Create a temporary PNG file with valid PNG magic numbers + let temp_dir = tempfile::tempdir().unwrap(); + let png_path = temp_dir.path().join("test.png"); + let png_data = [ + 0x89, 0x50, 0x4E, 0x47, // PNG magic number + 0x0D, 0x0A, 0x1A, 0x0A, // PNG header + 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data + ]; + std::fs::write(&png_path, png_data).unwrap(); + let png_path_str = png_path.to_str().unwrap(); + + // Create a fake PNG (wrong magic numbers) + let fake_png_path = temp_dir.path().join("fake.png"); + std::fs::write(&fake_png_path, b"not a real png").unwrap(); + let fake_png_path_str = fake_png_path.to_str().unwrap(); + + // Test loading valid PNG file + let result = load_image_file(png_path_str); + assert!(result.is_ok()); + let image = result.unwrap(); + assert_eq!(image.mime_type, "image/png"); + + // Test loading fake PNG file + let result = load_image_file(fake_png_path_str); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("not a valid image")); + + // Test nonexistent file + let result = load_image_file("nonexistent.png"); + assert!(result.is_err()); + + // Create a GIF file with valid header bytes + let gif_path = temp_dir.path().join("test.gif"); + // Minimal GIF89a header + let gif_data = [0x47, 0x49, 0x46, 0x38, 0x39, 0x61]; + std::fs::write(&gif_path, gif_data).unwrap(); + let gif_path_str = gif_path.to_str().unwrap(); + + // Test loading unsupported GIF format + let result = load_image_file(gif_path_str); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Unsupported image format")); + } +} diff --git a/crates/goose-providers/src/json.rs b/crates/goose-providers/src/json.rs new file mode 100644 index 000000000000..c8a23418b242 --- /dev/null +++ b/crates/goose-providers/src/json.rs @@ -0,0 +1,221 @@ +/// Safely parse a JSON string that may contain doubly-encoded or malformed JSON. +/// This function first attempts to parse the input string as-is. If that fails, +/// it applies control character escaping and truncated JSON repair and tries again. +/// +/// This approach preserves valid JSON like `{"key1": "value1",\n"key2": "value"}` +/// (which contains a literal \n but is perfectly valid JSON) while still fixing +/// broken JSON like `{"key1": "value1\n","key2": "value"}` (which contains an +/// unescaped newline character). +pub fn safely_parse_json(s: &str) -> Result { + // First, try parsing the string as-is + match serde_json::from_str(s) { + Ok(value) => Ok(value), + Err(_) => { + for candidate in [ + repair_truncated_json(s), + json_escape_control_chars_in_string(s), + ] { + if let Ok(value) = serde_json::from_str(&candidate) { + return Ok(value); + } + } + + let repaired = repair_truncated_json(&json_escape_control_chars_in_string(s)); + serde_json::from_str(&repaired) + } + } +} + +fn repair_truncated_json(s: &str) -> String { + let mut repaired = String::with_capacity(s.len() + 8); + let mut in_string = false; + let mut escape_next = false; + let mut closers = Vec::new(); + + for c in s.chars() { + repaired.push(c); + + if in_string { + if escape_next { + escape_next = false; + continue; + } + + match c { + '\\' => escape_next = true, + '"' => in_string = false, + _ => {} + } + continue; + } + + match c { + '"' => in_string = true, + '{' => closers.push('}'), + '[' => closers.push(']'), + '}' | ']' => { + if closers.last() == Some(&c) { + closers.pop(); + } + } + _ => {} + } + } + + if in_string { + if escape_next { + repaired.push('\\'); + } + repaired.push('"'); + } + + while let Some(closer) = closers.pop() { + repaired.push(closer); + } + + repaired +} + +/// Helper to escape control characters in a string that is supposed to be a JSON document. +/// This function iterates through the input string `s` and replaces any literal +/// control characters (U+0000 to U+001F) with their JSON-escaped equivalents +/// (e.g., '\n' becomes "\\n", '\u0001' becomes "\\u0001"). +/// +/// It does NOT escape quotes (") or backslashes (\) because it assumes `s` is a +/// full JSON document, and these characters might be structural (e.g., object delimiters, +/// existing valid escape sequences). The goal is to fix common LLM errors where +/// control characters are emitted raw into what should be JSON string values, +/// making the overall JSON structure unparsable. +/// +/// If the input string `s` has other JSON syntax errors (e.g., an unescaped quote +/// *within* a string value like `{"key": "string with " quote"}`), this function +/// will not fix them. It specifically targets unescaped control characters. +pub fn json_escape_control_chars_in_string(s: &str) -> String { + let mut r = String::with_capacity(s.len()); // Pre-allocate for efficiency + for c in s.chars() { + match c { + // ASCII Control characters (U+0000 to U+001F) + '\u{0000}'..='\u{001F}' => { + match c { + '\u{0008}' => r.push_str("\\b"), // Backspace + '\u{000C}' => r.push_str("\\f"), // Form feed + '\n' => r.push_str("\\n"), // Line feed + '\r' => r.push_str("\\r"), // Carriage return + '\t' => r.push_str("\\t"), // Tab + // Other control characters (e.g., NUL, SOH, VT, etc.) + // that don't have a specific short escape sequence. + _ => { + r.push_str(&format!("\\u{:04x}", c as u32)); + } + } + } + // Other characters are passed through. + // This includes quotes (") and backslashes (\). If these are part of the + // JSON structure (e.g. {"key": "value"}) or part of an already correctly + // escaped sequence within a string value (e.g. "string with \\\" quote"), + // they are preserved as is. This function does not attempt to fix + // malformed quote or backslash usage *within* string values if the LLM + // generates them incorrectly (e.g. {"key": "unescaped " quote in string"}). + _ => r.push(c), + } + } + r +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_safely_parse_json() { + // Test valid JSON that should parse without escaping (contains proper escape sequence) + let valid_json = r#"{"key1": "value1","key2": "value2"}"#; + let result = safely_parse_json(valid_json).unwrap(); + assert_eq!(result["key1"], "value1"); + assert_eq!(result["key2"], "value2"); + + // Test JSON with actual unescaped newlines that needs escaping + let invalid_json = "{\"key1\": \"value1\n\",\"key2\": \"value2\"}"; + let result = safely_parse_json(invalid_json).unwrap(); + assert_eq!(result["key1"], "value1\n"); + assert_eq!(result["key2"], "value2"); + + // Test already valid JSON - should parse on first try + let good_json = r#"{"test": "value"}"#; + let result = safely_parse_json(good_json).unwrap(); + assert_eq!(result["test"], "value"); + + // Test truncated JSON with unclosed string, object, and array + let truncated_json = r#"{"key": "unclosed_string","nested": {"items": [1, 2, 3"#; + let result = safely_parse_json(truncated_json).unwrap(); + assert_eq!(result["key"], "unclosed_string"); + assert_eq!(result["nested"]["items"], json!([1, 2, 3])); + + // Test dangling backslash at end of a truncated string + let dangling_escape_json = String::from(r#"{"path":"abc\"#); + let result = safely_parse_json(&dangling_escape_json).unwrap(); + assert_eq!(result["path"], "abc\\"); + + // Test empty object + let empty_json = "{}"; + let result = safely_parse_json(empty_json).unwrap(); + assert!(result.as_object().unwrap().is_empty()); + + // Test JSON with escaped newlines (valid JSON) - should parse on first try + let escaped_json = r#"{"key": "value with\nnewline"}"#; + let result = safely_parse_json(escaped_json).unwrap(); + assert_eq!(result["key"], "value with\nnewline"); + } + + #[test] + fn test_json_escape_control_chars_in_string() { + // Test basic control character escaping + assert_eq!( + json_escape_control_chars_in_string("Hello\nWorld"), + "Hello\\nWorld" + ); + assert_eq!( + json_escape_control_chars_in_string("Hello\tWorld"), + "Hello\\tWorld" + ); + assert_eq!( + json_escape_control_chars_in_string("Hello\rWorld"), + "Hello\\rWorld" + ); + + // Test multiple control characters + assert_eq!( + json_escape_control_chars_in_string("Hello\n\tWorld\r"), + "Hello\\n\\tWorld\\r" + ); + + // Test that quotes and backslashes are preserved (not escaped) + assert_eq!( + json_escape_control_chars_in_string("Hello \"World\""), + "Hello \"World\"" + ); + assert_eq!( + json_escape_control_chars_in_string("Hello\\World"), + "Hello\\World" + ); + + // Test JSON-like string with control characters + assert_eq!( + json_escape_control_chars_in_string("{\"message\": \"Hello\nWorld\"}"), + "{\"message\": \"Hello\\nWorld\"}" + ); + + // Test no changes for normal strings + assert_eq!( + json_escape_control_chars_in_string("Hello World"), + "Hello World" + ); + + // Test other control characters get unicode escapes + assert_eq!( + json_escape_control_chars_in_string("Hello\u{0001}World"), + "Hello\\u0001World" + ); + } +} diff --git a/crates/goose-providers/src/lib.rs b/crates/goose-providers/src/lib.rs index 7b77b3a95464..17a6d4e23a5d 100644 --- a/crates/goose-providers/src/lib.rs +++ b/crates/goose-providers/src/lib.rs @@ -1,4 +1,10 @@ +pub mod base; pub mod canonical; pub mod conversation; -mod mcp_utils; -mod utils; +pub mod errors; +pub mod formats; +pub mod images; +pub mod json; +pub(crate) mod mcp_utils; +pub mod thinking; +pub mod utils; diff --git a/crates/goose-providers/src/thinking.rs b/crates/goose-providers/src/thinking.rs new file mode 100644 index 000000000000..abfcc904f107 --- /dev/null +++ b/crates/goose-providers/src/thinking.rs @@ -0,0 +1,587 @@ +use std::{fmt, str::FromStr, sync::LazyLock}; + +use regex::Regex; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +pub const GEMINI_THOUGHT_SIGNATURE_KEY: &str = "thoughtSignature"; + +pub fn split_think_blocks(text: &str) -> (String, String) { + let mut filter = ThinkFilter::new(); + let mut out = filter.push(text); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + (out.content, out.thinking) +} + +#[derive(Debug, Default, PartialEq, Eq)] +pub struct FilterOut { + pub content: String, + pub thinking: String, +} + +pub struct ThinkFilter { + buffer: String, + inside_think: bool, + think_depth: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ThinkTag { + Open, + Close, + SelfClosing, +} + +enum BufferEvent { + Tag { + pos: usize, + end: usize, + kind: ThinkTag, + }, + Partial(usize), +} + +impl ThinkFilter { + pub fn new() -> Self { + Self { + buffer: String::new(), + inside_think: false, + think_depth: 0, + } + } + + pub fn push(&mut self, chunk: &str) -> FilterOut { + self.buffer.push_str(chunk); + self.process_buffer() + } + + pub fn finish(mut self) -> FilterOut { + let mut out = self.process_buffer(); + if !self.buffer.is_empty() { + if self.inside_think { + out.thinking.push_str(&self.buffer); + } else { + out.content.push_str(&self.buffer); + } + self.buffer.clear(); + } + out + } + + fn process_buffer(&mut self) -> FilterOut { + let mut out = FilterOut::default(); + + loop { + match next_buffer_event(&self.buffer, self.inside_think) { + Some(BufferEvent::Tag { pos, end, kind }) => { + if pos > 0 { + let prefix = self.buffer.get(..pos).unwrap_or_default().to_string(); + if self.inside_think { + out.thinking.push_str(&prefix); + } else { + out.content.push_str(&prefix); + } + } + + self.buffer.drain(..end); + + match kind { + ThinkTag::Open => { + self.think_depth += 1; + self.inside_think = true; + } + ThinkTag::Close => { + self.think_depth = self.think_depth.saturating_sub(1); + self.inside_think = self.think_depth > 0; + } + ThinkTag::SelfClosing => {} + } + } + Some(BufferEvent::Partial(pos)) => { + if pos > 0 { + let prefix = self.buffer.get(..pos).unwrap_or_default().to_string(); + if self.inside_think { + out.thinking.push_str(&prefix); + } else { + out.content.push_str(&prefix); + } + self.buffer.drain(..pos); + } + break; + } + None => { + if !self.buffer.is_empty() { + if self.inside_think { + out.thinking.push_str(&self.buffer); + } else { + out.content.push_str(&self.buffer); + } + self.buffer.clear(); + } + break; + } + } + } + + out + } +} + +impl Default for ThinkFilter { + fn default() -> Self { + Self::new() + } +} + +fn next_buffer_event(buffer: &str, inside_think: bool) -> Option { + let mut search_from = 0; + + while let Some(rel_pos) = buffer.get(search_from..).and_then(|rest| rest.find('<')) { + let pos = search_from + rel_pos; + let suffix = buffer.get(pos..).unwrap_or_default(); + + if let Some((kind, end)) = parse_think_tag(buffer, pos) { + if inside_think || matches!(kind, ThinkTag::Open | ThinkTag::SelfClosing) { + return Some(BufferEvent::Tag { pos, end, kind }); + } + } else if !contains_unquoted_gt(suffix) && is_possible_partial_think_tag(suffix) { + return Some(BufferEvent::Partial(pos)); + } + + search_from = pos + 1; + } + + None +} + +fn parse_think_tag(buffer: &str, start: usize) -> Option<(ThinkTag, usize)> { + let bytes = buffer.as_bytes(); + if bytes.get(start) != Some(&b'<') { + return None; + } + + let mut idx = start + 1; + let is_close = if bytes.get(idx) == Some(&b'/') { + idx += 1; + true + } else { + false + }; + + let name_start = idx; + while bytes.get(idx).is_some_and(u8::is_ascii_alphabetic) { + idx += 1; + } + + if idx == name_start { + return None; + } + + let name = buffer.get(name_start..idx).unwrap_or_default(); + let is_think = name.eq_ignore_ascii_case("think") || name.eq_ignore_ascii_case("thinking"); + if !is_think { + return None; + } + + if is_close { + while bytes.get(idx).is_some_and(u8::is_ascii_whitespace) { + idx += 1; + } + if bytes.get(idx) == Some(&b'>') { + return Some((ThinkTag::Close, idx + 1)); + } + return None; + } + + // Require a real tag boundary immediately after the name (>, /, or whitespace). + // Without this, `` or `` would be classified as a + // think tag and stripped from normal content. + let valid_open_boundary = match bytes.get(idx) { + Some(&b) => b == b'>' || b == b'/' || b.is_ascii_whitespace(), + None => false, + }; + if !valid_open_boundary { + return None; + } + + let mut quote: Option = None; + let mut last_non_ws: Option = None; + while let Some(&byte) = bytes.get(idx) { + match quote { + Some(quote_byte) => { + if byte == quote_byte { + quote = None; + } + } + None if matches!(byte, b'"' | b'\'') => { + quote = Some(byte); + last_non_ws = Some(byte); + } + None if byte == b'>' => { + let kind = if last_non_ws == Some(b'/') { + ThinkTag::SelfClosing + } else { + ThinkTag::Open + }; + return Some((kind, idx + 1)); + } + None if !byte.is_ascii_whitespace() => { + last_non_ws = Some(byte); + } + None => {} + } + idx += 1; + } + + None +} + +fn is_possible_partial_think_tag(suffix: &str) -> bool { + if contains_unquoted_gt(suffix) { + return false; + } + + // Allow a trailing `/` so a chunk boundary that lands between `` in a self-closing `` (or ``) is still recognised + // as a partial tag and buffered until the `>` arrives in the next chunk. + static OPEN_RE: LazyLock = LazyLock::new(|| { + Regex::new(r"(?is)^<(?:t(?:h(?:i(?:n(?:k(?:i(?:n(?:g)?)?)?)?)?)?)?)(?:\s.*|/)?$").unwrap() + }); + static CLOSE_RE: LazyLock = LazyLock::new(|| { + Regex::new(r"(?is)^ bool { + let mut quote: Option = None; + for &byte in text.as_bytes() { + match quote { + Some(quote_byte) => { + if byte == quote_byte { + quote = None; + } + } + None if matches!(byte, b'"' | b'\'') => quote = Some(byte), + None if byte == b'>' => return true, + None => {} + } + } + false +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum ThinkingEffort { + Off, + Low, + Medium, + High, + Max, +} + +impl FromStr for ThinkingEffort { + type Err = String; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "off" | "disabled" | "none" => Ok(Self::Off), + "low" => Ok(Self::Low), + "medium" | "med" => Ok(Self::Medium), + "high" => Ok(Self::High), + "max" | "xhigh" => Ok(Self::Max), + other => Err(format!("unknown thinking effort: '{other}'")), + } + } +} + +impl fmt::Display for ThinkingEffort { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Off => write!(f, "off"), + Self::Low => write!(f, "low"), + Self::Medium => write!(f, "medium"), + Self::High => write!(f, "high"), + Self::Max => write!(f, "max"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_think_blocks_extracts_inline_reasoning() { + assert_eq!( + split_think_blocks("xy"), + ("y".to_string(), "x".to_string()) + ); + } + + #[test] + fn test_split_think_blocks_is_case_insensitive() { + assert_eq!( + split_think_blocks("xy"), + ("y".to_string(), "x".to_string()) + ); + } + + #[test] + fn test_split_think_blocks_handles_multiple_blocks() { + assert_eq!( + split_think_blocks("abcd"), + ("bd".to_string(), "ac".to_string()) + ); + } + + #[test] + fn test_split_think_blocks_without_tags() { + assert_eq!( + split_think_blocks("plain content"), + ("plain content".to_string(), String::new()) + ); + } + + #[test] + fn test_split_think_blocks_handles_attributes() { + assert_eq!( + split_think_blocks(r#"ab"#), + ("b".to_string(), "a".to_string()) + ); + } + + #[test] + fn test_split_think_blocks_handles_quoted_gt_in_self_closing_attributes() { + for input in [ + r#"Visible"#, + "Visible", + ] { + assert_eq!( + split_think_blocks(input), + ("Visible".to_string(), String::new()), + "mismatch for {input:?}" + ); + } + } + + #[test] + fn test_split_think_blocks_handles_quoted_gt_in_open_attributes() { + assert_eq!( + split_think_blocks(r#"HiddenVisible"#), + ("Visible".to_string(), "Hidden".to_string()) + ); + } + + #[test] + fn test_split_think_blocks_handles_thinking_variant() { + assert_eq!( + split_think_blocks("ab"), + ("b".to_string(), "a".to_string()) + ); + } + + #[test] + fn test_think_filter_streaming_across_partial_tags() { + let mut filter = ThinkFilter::new(); + let mut out = FilterOut::default(); + + for chunk in ["xy"] { + let partial = filter.push(chunk); + out.content.push_str(&partial.content); + out.thinking.push_str(&partial.thinking); + } + + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert_eq!(out.content, "y"); + assert_eq!(out.thinking, "x"); + } + + #[test] + fn test_think_filter_preserves_non_think_tags() { + let mut filter = ThinkFilter::new(); + let mut out = filter.push(""); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert_eq!(out.content, "
"); + assert!(out.thinking.is_empty()); + } + + #[test] + fn test_think_filter_finish_treats_unterminated_think_as_thinking() { + let mut filter = ThinkFilter::new(); + let mut out = filter.push("unfinished"); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert!(out.content.is_empty()); + assert_eq!(out.thinking, "unfinished"); + } + + #[test] + fn test_think_filter_tracks_generation_prompt_open_block() { + let mut filter = ThinkFilter::new(); + let _ = filter.push("<|assistant|>\n"); + let mut out = filter.push("hidden reasoningvisible answer"); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert_eq!(out.content, "visible answer"); + assert_eq!(out.thinking, "hidden reasoning"); + } + + #[test] + fn test_think_filter_preserves_tags_with_think_prefix() { + for input in [ + "hello", + "payload", + "note", + ] { + let mut filter = ThinkFilter::new(); + let mut out = filter.push(input); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert_eq!(out.content, input, "content mismatch for {input:?}"); + assert!( + out.thinking.is_empty(), + "unexpected thinking for {input:?}: {:?}", + out.thinking + ); + } + } + + #[test] + fn test_think_filter_accepts_think_with_attributes() { + let mut filter = ThinkFilter::new(); + let mut out = filter.push("hiddenvisible"); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert_eq!(out.content, "visible"); + assert_eq!(out.thinking, "hidden"); + } + + #[test] + fn test_think_filter_treats_self_closing_as_noop() { + // `` carries no reasoning payload. It must not flip the filter + // into "inside_think" mode, and the tag itself must not leak into + // visible content. + for input in [ + "before after", + "before after", + "before after", + "before after", + ] { + let mut filter = ThinkFilter::new(); + let mut out = filter.push(input); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert_eq!( + out.content, "before after", + "content mismatch for {input:?}" + ); + assert!( + out.thinking.is_empty(), + "unexpected thinking for {input:?}: {:?}", + out.thinking + ); + } + } + + #[test] + fn test_think_filter_self_closing_does_not_swallow_following_content() { + // Regression: a self-closing `` used to be classified as an + // Open tag, which incremented think_depth and routed everything after + // it into the thinking bucket for the rest of the stream. + let mut filter = ThinkFilter::new(); + let mut out = filter.push("visible chunk 1"); + let final_out = filter.push("visible chunk 2"); + let tail_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + out.content.push_str(&tail_out.content); + out.thinking.push_str(&tail_out.thinking); + + assert_eq!(out.content, "visible chunk 1visible chunk 2"); + assert!(out.thinking.is_empty()); + } + + #[test] + fn test_think_filter_streaming_across_self_closing_boundary() { + // Regression: a chunk boundary between `` in a + // self-closing `` used to fall out of the partial-tag regex + // (which only allowed `...`), so the `` arrived. + for (a, b) in [ + ("before after"), + ("before after"), + ("head tail"), + ] { + let mut filter = ThinkFilter::new(); + let mut out = filter.push(a); + let second = filter.push(b); + let final_out = filter.finish(); + out.content.push_str(&second.content); + out.content.push_str(&final_out.content); + out.thinking.push_str(&second.thinking); + out.thinking.push_str(&final_out.thinking); + + assert!( + !out.content.contains('<'), + "partial tag leaked into content for ({a:?}, {b:?}): {:?}", + out.content + ); + assert!( + out.thinking.is_empty(), + "unexpected thinking for ({a:?}, {b:?}): {:?}", + out.thinking + ); + } + } + + #[test] + fn test_think_filter_streaming_across_quoted_attribute_boundary() { + let mut filter = ThinkFilter::new(); + let mut out = filter.push(r#"Visible"#); + let final_out = filter.finish(); + out.content.push_str(&second.content); + out.content.push_str(&final_out.content); + out.thinking.push_str(&second.thinking); + out.thinking.push_str(&final_out.thinking); + + assert_eq!(out.content, "Visible"); + assert!(out.thinking.is_empty()); + } + + #[test] + fn test_think_filter_self_closing_inside_open_block_closes_nothing() { + // `` inside an open `` block is still a no-op: depth + // should stay at 1 until the real `` arrives. + let mut filter = ThinkFilter::new(); + let mut out = filter.push("before hidden1 hidden2visible"); + let final_out = filter.finish(); + out.content.push_str(&final_out.content); + out.thinking.push_str(&final_out.thinking); + + assert_eq!(out.content, "before visible"); + assert_eq!(out.thinking, "hidden1 hidden2"); + } +} diff --git a/crates/goose-server/Cargo.toml b/crates/goose-server/Cargo.toml index c1f04fc6dddf..68ac9df48cd5 100644 --- a/crates/goose-server/Cargo.toml +++ b/crates/goose-server/Cargo.toml @@ -53,6 +53,7 @@ native-tls = [ [dependencies] goose = { path = "../goose", default-features = false } goose-mcp = { path = "../goose-mcp", default-features = false } +goose-providers = { path = "../goose-providers", default-features = false } rmcp = { workspace = true } axum = { workspace = true, features = ["ws", "macros"] } tokio = { workspace = true } diff --git a/crates/goose-server/src/openapi.rs b/crates/goose-server/src/openapi.rs index 3a0df42a07ff..b4973f9aa937 100644 --- a/crates/goose-server/src/openapi.rs +++ b/crates/goose-server/src/openapi.rs @@ -5,10 +5,11 @@ use goose::config::permission::PermissionLevel; use goose::config::ExtensionEntry; use goose::conversation::Conversation; use goose::download_manager::{DownloadProgress, DownloadStatus}; -use goose::model::{ModelConfig, ThinkingEffort}; +use goose::model::ModelConfig; use goose::permission::permission_confirmation::{Permission, PrincipalType}; use goose::providers::base::{ConfigKey, ModelInfo, ProviderMetadata, ProviderType}; use goose::session::{Session, SessionInsights, SessionType, SystemInfo}; +use goose_providers::thinking::ThinkingEffort; use rmcp::model::{ Annotations, Content, EmbeddedResource, Icon, IconTheme, ImageContent, JsonObject, RawAudioContent, RawContent, RawEmbeddedResource, RawImageContent, RawResource, RawTextContent, @@ -392,6 +393,9 @@ derive_utoipa!(IconTheme as IconThemeSchema); super::routes::config_management::upsert_config, super::routes::config_management::remove_config, super::routes::config_management::read_config, + super::routes::config_management::add_extension, + super::routes::config_management::remove_extension, + super::routes::config_management::get_extensions, super::routes::config_management::read_all_config, super::routes::config_management::list_provider_secrets, super::routes::config_management::delete_provider_secret, @@ -427,6 +431,8 @@ derive_utoipa!(IconTheme as IconThemeSchema); super::routes::agent::export_app, super::routes::agent::import_app, super::routes::agent::update_from_session, + super::routes::agent::agent_add_extension, + super::routes::agent::agent_remove_extension, super::routes::agent::update_agent_provider, super::routes::agent::update_session, super::routes::action_required::confirm_tool_action, @@ -446,6 +452,7 @@ derive_utoipa!(IconTheme as IconThemeSchema); super::routes::session::import_session_nostr, super::routes::session::update_session_user_recipe_values, super::routes::session::fork_session, + super::routes::session::get_session_extensions, super::routes::schedule::create_schedule, super::routes::schedule::list_schedules, super::routes::schedule::delete_schedule, @@ -491,6 +498,8 @@ derive_utoipa!(IconTheme as IconThemeSchema); super::routes::config_management::SlashCommandsResponse, super::routes::config_management::SlashCommand, super::routes::config_management::CommandType, + super::routes::config_management::ExtensionResponse, + super::routes::config_management::ExtensionQuery, super::routes::config_management::ToolPermission, super::routes::config_management::UpsertPermissionsQuery, super::routes::config_management::UpdateCustomProviderRequest, @@ -523,6 +532,7 @@ derive_utoipa!(IconTheme as IconThemeSchema); super::routes::session::UpdateSessionUserRecipeValuesResponse, super::routes::session::ForkRequest, super::routes::session::ForkResponse, + super::routes::session::SessionExtensionsResponse, Message, MessageContent, MessageMetadata, @@ -641,6 +651,8 @@ derive_utoipa!(IconTheme as IconThemeSchema); super::routes::agent::RestartAgentRequest, super::routes::agent::UpdateWorkingDirRequest, super::routes::agent::UpdateFromSessionRequest, + super::routes::agent::AddExtensionRequest, + super::routes::agent::RemoveExtensionRequest, super::routes::agent::ResumeAgentResponse, super::routes::agent::RestartAgentResponse, goose::agents::ExtensionLoadResult, diff --git a/crates/goose-server/src/routes/errors.rs b/crates/goose-server/src/routes/errors.rs index 0866f8129397..aa7ffa553acf 100644 --- a/crates/goose-server/src/routes/errors.rs +++ b/crates/goose-server/src/routes/errors.rs @@ -5,7 +5,7 @@ use axum::{ }; use goose::config::ConfigError; use goose::model::ConfigError as ModelConfigError; -use goose::providers::errors::ProviderError; +use goose_providers::errors::ProviderError; use serde::Serialize; use utoipa::ToSchema; diff --git a/crates/goose-server/src/routes/session.rs b/crates/goose-server/src/routes/session.rs index 408e800591b8..96bbc9590c52 100644 --- a/crates/goose-server/src/routes/session.rs +++ b/crates/goose-server/src/routes/session.rs @@ -9,11 +9,12 @@ use axum::{ routing::{delete, get, put}, Json, Router, }; +use goose::agents::ExtensionConfig; use goose::recipe::Recipe; #[cfg(feature = "nostr")] use goose::session::nostr_share; use goose::session::session_manager::{SessionInsights, SessionType}; -use goose::session::Session; +use goose::session::{EnabledExtensionsState, Session}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; @@ -569,6 +570,47 @@ async fn fork_session( })) } +#[derive(Serialize, ToSchema)] +#[serde(rename_all = "camelCase")] +pub struct SessionExtensionsResponse { + extensions: Vec, +} + +#[utoipa::path( + get, + path = "/sessions/{session_id}/extensions", + params( + ("session_id" = String, Path, description = "Unique identifier for the session") + ), + responses( + (status = 200, description = "Session extensions retrieved successfully", body = SessionExtensionsResponse), + (status = 401, description = "Unauthorized - Invalid or missing API key"), + (status = 404, description = "Session not found"), + (status = 500, description = "Internal server error") + ), + security( + ("api_key" = []) + ), + tag = "Session Management" +)] +async fn get_session_extensions( + State(state): State>, + Path(session_id): Path, +) -> Result, StatusCode> { + let session = state + .session_manager() + .get_session(&session_id, false) + .await + .map_err(|_| StatusCode::NOT_FOUND)?; + + let extensions = EnabledExtensionsState::extensions_or_default( + Some(&session.extension_data), + goose::config::Config::global(), + ); + + Ok(Json(SessionExtensionsResponse { extensions })) +} + pub fn routes(state: Arc) -> Router { Router::new() .route("/sessions", get(list_sessions)) @@ -595,6 +637,10 @@ pub fn routes(state: Arc) -> Router { put(update_session_user_recipe_values), ) .route("/sessions/{session_id}/fork", post(fork_session)) + .route( + "/sessions/{session_id}/extensions", + get(get_session_extensions), + ) .with_state(state) } #[derive(Deserialize, ToSchema)] diff --git a/crates/goose/src/acp/provider.rs b/crates/goose/src/acp/provider.rs index de4c214fd073..16bca02d5b98 100644 --- a/crates/goose/src/acp/provider.rs +++ b/crates/goose/src/acp/provider.rs @@ -14,6 +14,7 @@ use agent_client_protocol_schema::AGENT_METHOD_NAMES; use anyhow::{Context, Result}; use async_stream::try_stream; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; use rmcp::model::{CallToolRequestParams, CallToolResult, Content as RmcpContent, Role, Tool}; use std::collections::{HashMap, HashSet}; use std::future::Future; @@ -36,9 +37,9 @@ use crate::conversation::message::{Message, MessageContent, TOOL_META_EXTERNAL_D use crate::model::ModelConfig; use crate::permission::permission_confirmation::PrincipalType; use crate::permission::{Permission, PermissionConfirmation}; -use crate::providers::base::{MessageStream, PermissionRouting, Provider, ProviderUsage, Usage}; -use crate::providers::errors::ProviderError; +use crate::providers::base::{MessageStream, PermissionRouting, Provider}; use crate::subprocess::configure_subprocess; +use goose_providers::errors::ProviderError; /// Sentinel: resolved to the actual model name during connect(). pub const ACP_CURRENT_MODEL: &str = "current"; diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 6ab6e62501b6..d9d381d49711 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -43,7 +43,6 @@ use crate::permission::permission_inspector::PermissionInspector; use crate::permission::permission_judge::PermissionCheckResult; use crate::permission::PermissionConfirmation; use crate::providers::base::{PermissionRouting, Provider}; -use crate::providers::errors::ProviderError; use crate::recipe::{Author, Recipe, Response, Settings}; use crate::scheduler_trait::SchedulerTrait; use crate::security::adversary_inspector::AdversaryInspector; @@ -54,6 +53,7 @@ use crate::session::{Session, SessionManager, SessionNameUpdate}; use crate::tool_inspection::ToolInspectionManager; use crate::tool_monitor::RepetitionInspector; use crate::utils::is_token_cancelled; +use goose_providers::errors::ProviderError; use regex::Regex; use rmcp::model::{ CallToolRequestParams, CallToolResult, Content, ErrorCode, ErrorData, GetPromptResult, Prompt, @@ -2928,12 +2928,10 @@ mod tests { use super::*; use crate::permission::permission_confirmation::PrincipalType; use crate::plugins::discovery::{DiscoveredPlugin, PluginScope}; - use crate::providers::base::{ - stream_from_single_message, MessageStream, PermissionRouting, ProviderUsage, Usage, - }; - use crate::providers::errors::ProviderError; + use crate::providers::base::{stream_from_single_message, MessageStream, PermissionRouting}; use crate::recipe::Response; use crate::session::session_manager::SessionType; + use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; use rmcp::model::Tool; use std::path::PathBuf; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -2972,8 +2970,7 @@ mod tests { _: &str, _: &[crate::conversation::message::Message], _: &[rmcp::model::Tool], - ) -> Result - { + ) -> Result { unimplemented!() } fn permission_routing(&self) -> PermissionRouting { diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 89e7535a1d82..84c7db62a6fd 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -1,4 +1,5 @@ use anyhow::Result; +use goose_providers::errors::ProviderError; use regex::Regex; use std::sync::Arc; @@ -14,12 +15,12 @@ use crate::conversation::message::{Message, MessageContent, ToolRequest}; use crate::conversation::Conversation; #[cfg(test)] use crate::providers::base::stream_from_single_message; -use crate::providers::base::{MessageStream, Provider, ProviderUsage}; -use crate::providers::errors::ProviderError; +use crate::providers::base::{MessageStream, Provider}; use crate::providers::toolshim::{ augment_message_with_selected_tool_interpreter, convert_tool_messages_to_text, modify_system_prompt_for_tool_json, sanitize_residual_markers, }; +use goose_providers::conversation::token_usage::ProviderUsage; use rmcp::model::Tool; use tracing::warn; @@ -316,10 +317,6 @@ impl Agent { while let Some(result) = stream.next().await { let (msg_opt, usage_opt) = result?; - if let Some(usage) = usage_opt.as_ref() { - crate::providers::base::set_current_model(&usage.model); - } - if let Some(msg) = msg_opt { accumulated_message = Some(match accumulated_message { Some(mut prev) => { @@ -361,10 +358,6 @@ impl Agent { while let Some(result) = stream.next().await { let (message, usage) = result?; - if let Some(usage) = usage.as_ref() { - crate::providers::base::set_current_model(&usage.model); - } - yield (message, usage); } } @@ -630,10 +623,10 @@ mod tests { use crate::config::GooseMode; use crate::conversation::message::Message; use crate::model::ModelConfig; - use crate::providers::base::{Provider, ProviderUsage, Usage}; - use crate::providers::errors::ProviderError; + use crate::providers::base::Provider; use crate::session::session_manager::SessionType; use async_trait::async_trait; + use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; use rmcp::object; #[derive(Clone)] diff --git a/crates/goose/src/context_mgmt/mod.rs b/crates/goose/src/context_mgmt/mod.rs index 4d62daf8aebe..dbae784b4b82 100644 --- a/crates/goose/src/context_mgmt/mod.rs +++ b/crates/goose/src/context_mgmt/mod.rs @@ -2,12 +2,13 @@ use crate::conversation::message::{ActionRequiredData, MessageMetadata}; use crate::conversation::message::{Message, MessageContent}; use crate::conversation::{merge_consecutive_messages, Conversation}; use crate::prompt_template::render_template; +use crate::providers::base::Provider; #[cfg(test)] use crate::providers::base::{stream_from_single_message, MessageStream}; -use crate::providers::base::{Provider, ProviderUsage}; -use crate::providers::errors::ProviderError; use crate::{config::Config, token_counter::create_token_counter}; use anyhow::Result; +use goose_providers::conversation::token_usage::ProviderUsage; +use goose_providers::errors::ProviderError; use indoc::indoc; use rmcp::model::Role; use serde::Serialize; @@ -319,10 +320,15 @@ async fn do_compact( Ok((mut response, mut provider_usage)) => { response.role = Role::User; - provider_usage - .ensure_tokens(&system_prompt, &summarization_request, &response, &[]) - .await - .map_err(|e| anyhow::anyhow!("Failed to ensure usage tokens: {}", e))?; + crate::providers::usage_estimator::ensure_usage_tokens( + &mut provider_usage, + &system_prompt, + &summarization_request, + &response, + &[], + ) + .await + .map_err(|e| anyhow::anyhow!("Failed to ensure usage tokens: {}", e))?; return Ok((response, provider_usage)); } @@ -561,11 +567,10 @@ pub fn maybe_summarize_tool_pairs( #[cfg(test)] mod tests { use super::*; - use crate::{ - model::ModelConfig, - providers::{base::Usage, errors::ProviderError}, - }; + use crate::model::ModelConfig; use async_trait::async_trait; + use goose_providers::conversation::token_usage::Usage; + use goose_providers::errors::ProviderError; use rmcp::model::{AnnotateAble, CallToolRequestParams, RawContent, Tool}; fn create_tool_pair( diff --git a/crates/goose/src/doctor.rs b/crates/goose/src/doctor.rs index 2da8cadf334f..2fadd576dd33 100644 --- a/crates/goose/src/doctor.rs +++ b/crates/goose/src/doctor.rs @@ -5,11 +5,12 @@ use crate::agents::ExtensionConfig; use crate::config::Config; use crate::conversation::message::Message; use crate::model::ModelConfig; +use crate::providers; use crate::providers::base::Provider; -use crate::providers::{self, errors::ProviderError}; use crate::session::{ config_path, latest_llm_log_path, latest_server_log_path, read_capped, read_tail, SystemInfo, }; +use goose_providers::errors::ProviderError; pub async fn run(agent: &crate::agents::Agent, session_id: &str) -> anyhow::Result { if let Some(msg) = ensure_working_provider(agent, session_id).await? { diff --git a/crates/goose/src/execution/manager.rs b/crates/goose/src/execution/manager.rs index 3c4a7414db03..0e4604bc7480 100644 --- a/crates/goose/src/execution/manager.rs +++ b/crates/goose/src/execution/manager.rs @@ -610,8 +610,9 @@ mod tests { use crate::conversation::message::Message; use crate::model::ModelConfig; - use crate::providers::base::{MessageStream, Provider, ProviderUsage, Usage}; - use crate::providers::errors::ProviderError; + use crate::providers::base::{MessageStream, Provider}; + use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; + use goose_providers::errors::ProviderError; struct FailingProvider; diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index 8109705a6efd..739777cad882 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -1,51 +1,15 @@ +use goose_providers::formats::openai::{extract_reasoning_effort, is_openai_responses_model}; +use goose_providers::thinking::ThinkingEffort; use once_cell::sync::Lazy; use serde::de::Deserializer; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; -use std::fmt; -use std::str::FromStr; use thiserror::Error; use utoipa::ToSchema; pub const DEFAULT_CONTEXT_LIMIT: usize = 128_000; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)] -#[serde(rename_all = "lowercase")] -pub enum ThinkingEffort { - Off, - Low, - Medium, - High, - Max, -} - -impl FromStr for ThinkingEffort { - type Err = String; - fn from_str(s: &str) -> Result { - match s.to_lowercase().as_str() { - "off" | "disabled" | "none" => Ok(Self::Off), - "low" => Ok(Self::Low), - "medium" | "med" => Ok(Self::Medium), - "high" => Ok(Self::High), - "max" | "xhigh" => Ok(Self::Max), - other => Err(format!("unknown thinking effort: '{other}'")), - } - } -} - -impl fmt::Display for ThinkingEffort { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Off => write!(f, "off"), - Self::Low => write!(f, "low"), - Self::Medium => write!(f, "medium"), - Self::High => write!(f, "high"), - Self::Max => write!(f, "max"), - } - } -} - #[derive(Debug, Clone, Deserialize)] struct PredefinedModel { name: String, @@ -224,8 +188,7 @@ impl ModelConfig { let canonical = crate::providers::canonical::maybe_get_canonical_model(provider_name, &self.model_name) .or_else(|| { - let (base, _effort) = - crate::providers::utils::extract_reasoning_effort(&self.model_name); + let (base, _effort) = extract_reasoning_effort(&self.model_name); if base != self.model_name { crate::providers::canonical::maybe_get_canonical_model(provider_name, &base) } else { @@ -409,7 +372,7 @@ impl ModelConfig { } pub fn is_openai_reasoning_model(&self) -> bool { - crate::providers::utils::is_openai_responses_model(&self.model_name) + is_openai_responses_model(&self.model_name) } pub fn is_reasoning_model(&self) -> bool { @@ -878,7 +841,7 @@ mod tests { let config = ModelConfig::new_or_fail("gpt-4o").with_canonical_limits("openai"); assert_eq!(config.context_limit, Some(128_000)); - assert_eq!(config.max_tokens, Some(4_096)); + assert_eq!(config.max_tokens, Some(16_384)); assert_eq!(config.reasoning, Some(false)); } diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index c08161aae2bc..37f1290578b7 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -2,6 +2,7 @@ use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; use futures::TryStreamExt; +use goose_providers::errors::ProviderError; use reqwest::StatusCode; use serde_json::Value; use std::io; @@ -10,7 +11,6 @@ use tokio_util::io::StreamReader; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, MessageStream, ModelInfo, Provider, ProviderDef, ProviderMetadata}; -use super::errors::ProviderError; use super::formats::anthropic::{ create_request_with_options, response_to_streaming_message, thinking_type, AnthropicFormatOptions, ThinkingType, diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index f79a3a3f3fb2..43f8a5939b86 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -2,6 +2,9 @@ use anyhow::Result; use async_trait::async_trait; use futures::future::BoxFuture; use futures::Stream; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; +use regex::Regex; use serde::{Deserialize, Serialize}; /// Default HTTP timeout for all provider API calls. @@ -10,7 +13,6 @@ use serde::{Deserialize, Serialize}; pub const DEFAULT_PROVIDER_TIMEOUT_SECS: u64 = 600; use super::canonical::{map_to_canonical_model, CanonicalModelRegistry}; -use super::errors::ProviderError; use super::inventory::{default_inventory_identity, InventoryIdentityInput}; use super::retry::RetryConfig; use crate::config::base::ConfigValue; @@ -24,282 +26,12 @@ use rmcp::model::Tool; use utoipa::ToSchema; use once_cell::sync::Lazy; -use regex::Regex; -use std::ops::{Add, AddAssign}; use std::path::PathBuf; use std::pin::Pin; -use std::sync::LazyLock; -use std::sync::Mutex; - -#[derive(Debug, Default, PartialEq, Eq)] -pub struct FilterOut { - pub content: String, - pub thinking: String, -} - -pub struct ThinkFilter { - buffer: String, - inside_think: bool, - think_depth: usize, -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum ThinkTag { - Open, - Close, - // `` is XML-legal but carries no reasoning payload. Treat it as a - // no-op so we don't flip `inside_think` forever and swallow the rest of - // the stream into the thinking bucket. - SelfClosing, -} - -enum BufferEvent { - Tag { - pos: usize, - end: usize, - kind: ThinkTag, - }, - Partial(usize), -} - -impl ThinkFilter { - pub fn new() -> Self { - Self { - buffer: String::new(), - inside_think: false, - think_depth: 0, - } - } - - pub fn push(&mut self, chunk: &str) -> FilterOut { - self.buffer.push_str(chunk); - self.process_buffer() - } - - pub fn finish(mut self) -> FilterOut { - let mut out = self.process_buffer(); - if !self.buffer.is_empty() { - if self.inside_think { - out.thinking.push_str(&self.buffer); - } else { - out.content.push_str(&self.buffer); - } - self.buffer.clear(); - } - out - } - - fn process_buffer(&mut self) -> FilterOut { - let mut out = FilterOut::default(); - - loop { - match next_buffer_event(&self.buffer, self.inside_think) { - Some(BufferEvent::Tag { pos, end, kind }) => { - if pos > 0 { - let prefix = self.buffer.get(..pos).unwrap_or_default().to_string(); - if self.inside_think { - out.thinking.push_str(&prefix); - } else { - out.content.push_str(&prefix); - } - } - - self.buffer.drain(..end); - - match kind { - ThinkTag::Open => { - self.think_depth += 1; - self.inside_think = true; - } - ThinkTag::Close => { - self.think_depth = self.think_depth.saturating_sub(1); - self.inside_think = self.think_depth > 0; - } - ThinkTag::SelfClosing => {} - } - } - Some(BufferEvent::Partial(pos)) => { - if pos > 0 { - let prefix = self.buffer.get(..pos).unwrap_or_default().to_string(); - if self.inside_think { - out.thinking.push_str(&prefix); - } else { - out.content.push_str(&prefix); - } - self.buffer.drain(..pos); - } - break; - } - None => { - if !self.buffer.is_empty() { - if self.inside_think { - out.thinking.push_str(&self.buffer); - } else { - out.content.push_str(&self.buffer); - } - self.buffer.clear(); - } - break; - } - } - } - - out - } -} - -impl Default for ThinkFilter { - fn default() -> Self { - Self::new() - } -} - -pub fn split_think_blocks(text: &str) -> (String, String) { - let mut filter = ThinkFilter::new(); - let mut out = filter.push(text); - let final_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - (out.content, out.thinking) -} - -fn next_buffer_event(buffer: &str, inside_think: bool) -> Option { - let mut search_from = 0; - - while let Some(rel_pos) = buffer.get(search_from..).and_then(|rest| rest.find('<')) { - let pos = search_from + rel_pos; - let suffix = buffer.get(pos..).unwrap_or_default(); - - if let Some((kind, end)) = parse_think_tag(buffer, pos) { - if inside_think || matches!(kind, ThinkTag::Open | ThinkTag::SelfClosing) { - return Some(BufferEvent::Tag { pos, end, kind }); - } - } else if !contains_unquoted_gt(suffix) && is_possible_partial_think_tag(suffix) { - return Some(BufferEvent::Partial(pos)); - } - - search_from = pos + 1; - } - - None -} - -fn parse_think_tag(buffer: &str, start: usize) -> Option<(ThinkTag, usize)> { - let bytes = buffer.as_bytes(); - if bytes.get(start) != Some(&b'<') { - return None; - } - - let mut idx = start + 1; - let is_close = if bytes.get(idx) == Some(&b'/') { - idx += 1; - true - } else { - false - }; - - let name_start = idx; - while bytes.get(idx).is_some_and(u8::is_ascii_alphabetic) { - idx += 1; - } - - if idx == name_start { - return None; - } - - let name = buffer.get(name_start..idx).unwrap_or_default(); - let is_think = name.eq_ignore_ascii_case("think") || name.eq_ignore_ascii_case("thinking"); - if !is_think { - return None; - } - - if is_close { - while bytes.get(idx).is_some_and(u8::is_ascii_whitespace) { - idx += 1; - } - if bytes.get(idx) == Some(&b'>') { - return Some((ThinkTag::Close, idx + 1)); - } - return None; - } - - // Require a real tag boundary immediately after the name (>, /, or whitespace). - // Without this, `` or `` would be classified as a - // think tag and stripped from normal content. - let valid_open_boundary = match bytes.get(idx) { - Some(&b) => b == b'>' || b == b'/' || b.is_ascii_whitespace(), - None => false, - }; - if !valid_open_boundary { - return None; - } +use std::sync::{LazyLock, Mutex}; - let mut quote: Option = None; - let mut last_non_ws: Option = None; - while let Some(&byte) = bytes.get(idx) { - match quote { - Some(quote_byte) => { - if byte == quote_byte { - quote = None; - } - } - None if matches!(byte, b'"' | b'\'') => { - quote = Some(byte); - last_non_ws = Some(byte); - } - None if byte == b'>' => { - let kind = if last_non_ws == Some(b'/') { - ThinkTag::SelfClosing - } else { - ThinkTag::Open - }; - return Some((kind, idx + 1)); - } - None if !byte.is_ascii_whitespace() => { - last_non_ws = Some(byte); - } - None => {} - } - idx += 1; - } - - None -} - -fn is_possible_partial_think_tag(suffix: &str) -> bool { - if contains_unquoted_gt(suffix) { - return false; - } - - // Allow a trailing `/` so a chunk boundary that lands between `` in a self-closing `` (or ``) is still recognised - // as a partial tag and buffered until the `>` arrives in the next chunk. - static OPEN_RE: LazyLock = LazyLock::new(|| { - Regex::new(r"(?is)^<(?:t(?:h(?:i(?:n(?:k(?:i(?:n(?:g)?)?)?)?)?)?)?)(?:\s.*|/)?$").unwrap() - }); - static CLOSE_RE: LazyLock = LazyLock::new(|| { - Regex::new(r"(?is)^ bool { - let mut quote: Option = None; - for &byte in text.as_bytes() { - match quote { - Some(quote_byte) => { - if byte == quote_byte { - quote = None; - } - } - None if matches!(byte, b'"' | b'\'') => quote = Some(byte), - None if byte == b'>' => return true, - None => {} - } - } - false -} +/// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias +pub static CURRENT_MODEL: Lazy>> = Lazy::new(|| Mutex::new(None)); fn strip_xml_tags(text: &str) -> String { static BLOCK_RE: LazyLock = LazyLock::new(|| { @@ -363,21 +95,6 @@ fn extract_short_title(text: &str) -> String { text.to_string() } -/// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias -pub static CURRENT_MODEL: Lazy>> = Lazy::new(|| Mutex::new(None)); - -/// Set the current model in the global store -pub fn set_current_model(model: &str) { - if let Ok(mut current_model) = CURRENT_MODEL.lock() { - *current_model = Some(model.to_string()); - } -} - -/// Get the current model from the global store, the real model, not an alias -pub fn get_current_model() -> Option { - CURRENT_MODEL.lock().ok().and_then(|model| model.clone()) -} - pub static MSG_COUNT_FOR_SESSION_NAME_GENERATION: usize = 3; /// Information about a model's capabilities @@ -672,129 +389,6 @@ impl ConfigKey { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ProviderUsage { - pub model: String, - pub usage: Usage, -} - -impl ProviderUsage { - pub fn new(model: String, usage: Usage) -> Self { - Self { model, usage } - } - - /// Ensures this ProviderUsage has token counts, estimating them if necessary - pub async fn ensure_tokens( - &mut self, - system_prompt: &str, - request_messages: &[Message], - response: &Message, - tools: &[Tool], - ) -> Result<(), ProviderError> { - crate::providers::usage_estimator::ensure_usage_tokens( - self, - system_prompt, - request_messages, - response, - tools, - ) - .await - .map_err(|e| ProviderError::ExecutionError(format!("Failed to ensure usage tokens: {}", e))) - } - - /// Combine this ProviderUsage with another, adding their token counts - /// Uses the model from this ProviderUsage - pub fn combine_with(&self, other: &ProviderUsage) -> ProviderUsage { - ProviderUsage { - model: self.model.clone(), - usage: self.usage + other.usage, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, Copy)] -pub struct Usage { - pub input_tokens: Option, - pub output_tokens: Option, - pub total_tokens: Option, - pub cache_read_input_tokens: Option, - pub cache_write_input_tokens: Option, -} - -fn sum_optionals(a: Option, b: Option) -> Option -where - T: Add + Default, -{ - match (a, b) { - (Some(x), Some(y)) => Some(x + y), - (Some(x), None) => Some(x + T::default()), - (None, Some(y)) => Some(T::default() + y), - (None, None) => None, - } -} - -impl Add for Usage { - type Output = Self; - - fn add(self, other: Self) -> Self { - Self::new( - sum_optionals(self.input_tokens, other.input_tokens), - sum_optionals(self.output_tokens, other.output_tokens), - sum_optionals(self.total_tokens, other.total_tokens), - ) - .with_cache_tokens( - sum_optionals(self.cache_read_input_tokens, other.cache_read_input_tokens), - sum_optionals( - self.cache_write_input_tokens, - other.cache_write_input_tokens, - ), - ) - } -} - -impl AddAssign for Usage { - fn add_assign(&mut self, rhs: Self) { - *self = *self + rhs; - } -} - -impl Usage { - pub fn new( - input_tokens: Option, - output_tokens: Option, - total_tokens: Option, - ) -> Self { - let calculated_total = if total_tokens.is_none() { - match (input_tokens, output_tokens) { - (Some(input), Some(output)) => Some(input + output), - (Some(input), None) => Some(input), - (None, Some(output)) => Some(output), - (None, None) => None, - } - } else { - total_tokens - }; - - Self { - input_tokens, - output_tokens, - total_tokens: calculated_total, - cache_read_input_tokens: None, - cache_write_input_tokens: None, - } - } - - pub fn with_cache_tokens( - mut self, - cache_read_input_tokens: Option, - cache_write_input_tokens: Option, - ) -> Self { - self.cache_read_input_tokens = cache_read_input_tokens; - self.cache_write_input_tokens = cache_write_input_tokens; - self - } -} - pub(crate) fn current_working_dir() -> PathBuf { std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")) } @@ -1277,8 +871,6 @@ mod tests { use std::collections::HashMap; use test_case::test_case; - use serde_json::json; - #[test] fn test_strip_xml_tags() { assert_eq!(strip_xml_tags("reasoninganswer"), "answer"); @@ -1306,278 +898,6 @@ mod tests { ); } - #[test] - fn test_split_think_blocks_extracts_inline_reasoning() { - assert_eq!( - split_think_blocks("xy"), - ("y".to_string(), "x".to_string()) - ); - } - - #[test] - fn test_split_think_blocks_is_case_insensitive() { - assert_eq!( - split_think_blocks("xy"), - ("y".to_string(), "x".to_string()) - ); - } - - #[test] - fn test_split_think_blocks_handles_multiple_blocks() { - assert_eq!( - split_think_blocks("abcd"), - ("bd".to_string(), "ac".to_string()) - ); - } - - #[test] - fn test_split_think_blocks_without_tags() { - assert_eq!( - split_think_blocks("plain content"), - ("plain content".to_string(), String::new()) - ); - } - - #[test] - fn test_split_think_blocks_handles_attributes() { - assert_eq!( - split_think_blocks(r#"ab"#), - ("b".to_string(), "a".to_string()) - ); - } - - #[test] - fn test_split_think_blocks_handles_quoted_gt_in_self_closing_attributes() { - for input in [ - r#"Visible"#, - "Visible", - ] { - assert_eq!( - split_think_blocks(input), - ("Visible".to_string(), String::new()), - "mismatch for {input:?}" - ); - } - } - - #[test] - fn test_split_think_blocks_handles_quoted_gt_in_open_attributes() { - assert_eq!( - split_think_blocks(r#"HiddenVisible"#), - ("Visible".to_string(), "Hidden".to_string()) - ); - } - - #[test] - fn test_split_think_blocks_handles_thinking_variant() { - assert_eq!( - split_think_blocks("ab"), - ("b".to_string(), "a".to_string()) - ); - } - - #[test] - fn test_think_filter_streaming_across_partial_tags() { - let mut filter = ThinkFilter::new(); - let mut out = FilterOut::default(); - - for chunk in ["xy"] { - let partial = filter.push(chunk); - out.content.push_str(&partial.content); - out.thinking.push_str(&partial.thinking); - } - - let final_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - - assert_eq!(out.content, "y"); - assert_eq!(out.thinking, "x"); - } - - #[test] - fn test_think_filter_preserves_non_think_tags() { - let mut filter = ThinkFilter::new(); - let mut out = filter.push("
"); - let final_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - - assert_eq!(out.content, "
"); - assert!(out.thinking.is_empty()); - } - - #[test] - fn test_think_filter_finish_treats_unterminated_think_as_thinking() { - let mut filter = ThinkFilter::new(); - let mut out = filter.push("unfinished"); - let final_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - - assert!(out.content.is_empty()); - assert_eq!(out.thinking, "unfinished"); - } - - #[test] - fn test_think_filter_tracks_generation_prompt_open_block() { - let mut filter = ThinkFilter::new(); - let _ = filter.push("<|assistant|>\n"); - let mut out = filter.push("hidden reasoningvisible answer"); - let final_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - - assert_eq!(out.content, "visible answer"); - assert_eq!(out.thinking, "hidden reasoning"); - } - - #[test] - fn test_think_filter_preserves_tags_with_think_prefix() { - for input in [ - "hello", - "payload", - "note", - ] { - let mut filter = ThinkFilter::new(); - let mut out = filter.push(input); - let final_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - - assert_eq!(out.content, input, "content mismatch for {input:?}"); - assert!( - out.thinking.is_empty(), - "unexpected thinking for {input:?}: {:?}", - out.thinking - ); - } - } - - #[test] - fn test_think_filter_accepts_think_with_attributes() { - let mut filter = ThinkFilter::new(); - let mut out = filter.push("hiddenvisible"); - let final_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - - assert_eq!(out.content, "visible"); - assert_eq!(out.thinking, "hidden"); - } - - #[test] - fn test_think_filter_treats_self_closing_as_noop() { - // `` carries no reasoning payload. It must not flip the filter - // into "inside_think" mode, and the tag itself must not leak into - // visible content. - for input in [ - "before after", - "before after", - "before after", - "before after", - ] { - let mut filter = ThinkFilter::new(); - let mut out = filter.push(input); - let final_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - - assert_eq!( - out.content, "before after", - "content mismatch for {input:?}" - ); - assert!( - out.thinking.is_empty(), - "unexpected thinking for {input:?}: {:?}", - out.thinking - ); - } - } - - #[test] - fn test_think_filter_self_closing_does_not_swallow_following_content() { - // Regression: a self-closing `` used to be classified as an - // Open tag, which incremented think_depth and routed everything after - // it into the thinking bucket for the rest of the stream. - let mut filter = ThinkFilter::new(); - let mut out = filter.push("visible chunk 1"); - let final_out = filter.push("visible chunk 2"); - let tail_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - out.content.push_str(&tail_out.content); - out.thinking.push_str(&tail_out.thinking); - - assert_eq!(out.content, "visible chunk 1visible chunk 2"); - assert!(out.thinking.is_empty()); - } - - #[test] - fn test_think_filter_streaming_across_self_closing_boundary() { - // Regression: a chunk boundary between `` in a - // self-closing `` used to fall out of the partial-tag regex - // (which only allowed `...`), so the `` arrived. - for (a, b) in [ - ("before after"), - ("before after"), - ("head tail"), - ] { - let mut filter = ThinkFilter::new(); - let mut out = filter.push(a); - let second = filter.push(b); - let final_out = filter.finish(); - out.content.push_str(&second.content); - out.content.push_str(&final_out.content); - out.thinking.push_str(&second.thinking); - out.thinking.push_str(&final_out.thinking); - - assert!( - !out.content.contains('<'), - "partial tag leaked into content for ({a:?}, {b:?}): {:?}", - out.content - ); - assert!( - out.thinking.is_empty(), - "unexpected thinking for ({a:?}, {b:?}): {:?}", - out.thinking - ); - } - } - - #[test] - fn test_think_filter_streaming_across_quoted_attribute_boundary() { - let mut filter = ThinkFilter::new(); - let mut out = filter.push(r#"Visible"#); - let final_out = filter.finish(); - out.content.push_str(&second.content); - out.content.push_str(&final_out.content); - out.thinking.push_str(&second.thinking); - out.thinking.push_str(&final_out.thinking); - - assert_eq!(out.content, "Visible"); - assert!(out.thinking.is_empty()); - } - - #[test] - fn test_think_filter_self_closing_inside_open_block_closes_nothing() { - // `` inside an open `` block is still a no-op: depth - // should stay at 1 until the real `` arrives. - let mut filter = ThinkFilter::new(); - let mut out = filter.push("before hidden1 hidden2visible"); - let final_out = filter.finish(); - out.content.push_str(&final_out.content); - out.thinking.push_str(&final_out.thinking); - - assert_eq!(out.content, "before visible"); - assert_eq!(out.thinking, "hidden1 hidden2"); - } - #[test] fn test_extract_short_title() { assert_eq!(extract_short_title("List files"), "List files"); @@ -1723,42 +1043,6 @@ mod tests { assert_eq!(usage.model, "unknown"); } - #[test] - fn test_usage_serialization() -> Result<()> { - let usage = Usage::new(Some(10), Some(20), Some(30)); - let serialized = serde_json::to_string(&usage)?; - let deserialized: Usage = serde_json::from_str(&serialized)?; - - assert_eq!(usage.input_tokens, deserialized.input_tokens); - assert_eq!(usage.output_tokens, deserialized.output_tokens); - assert_eq!(usage.total_tokens, deserialized.total_tokens); - - // Test JSON structure - let json_value: serde_json::Value = serde_json::from_str(&serialized)?; - assert_eq!(json_value["input_tokens"], json!(10)); - assert_eq!(json_value["output_tokens"], json!(20)); - assert_eq!(json_value["total_tokens"], json!(30)); - - Ok(()) - } - - #[test] - fn test_set_and_get_current_model() { - // Set the model - set_current_model("gpt-4o"); - - // Get the model and verify - let model = get_current_model(); - assert_eq!(model, Some("gpt-4o".to_string())); - - // Change the model - set_current_model("claude-sonnet-4-20250514"); - - // Get the updated model and verify - let model = get_current_model(); - assert_eq!(model, Some("claude-sonnet-4-20250514".to_string())); - } - #[test] fn test_provider_metadata_context_limits() { // Test that ProviderMetadata::new correctly sets context limits @@ -1843,19 +1127,4 @@ mod tests { assert_eq!(info.output_token_cost, Some(0.00001)); assert_eq!(info.currency, Some("$".to_string())); } - - #[test] - fn test_usage_addition_includes_cached_tokens() { - let usage_a = - Usage::new(Some(100), Some(20), Some(120)).with_cache_tokens(Some(10), Some(5)); - let usage_b = Usage::new(Some(50), Some(8), Some(58)).with_cache_tokens(Some(4), Some(1)); - - let combined = usage_a + usage_b; - - assert_eq!(combined.input_tokens, Some(150)); - assert_eq!(combined.output_tokens, Some(28)); - assert_eq!(combined.total_tokens, Some(178)); - assert_eq!(combined.cache_read_input_tokens, Some(14)); - assert_eq!(combined.cache_write_input_tokens, Some(6)); - } } diff --git a/crates/goose/src/providers/bedrock.rs b/crates/goose/src/providers/bedrock.rs index 13be0e79554d..256d21dc03f7 100644 --- a/crates/goose/src/providers/bedrock.rs +++ b/crates/goose/src/providers/bedrock.rs @@ -1,9 +1,6 @@ use std::collections::HashMap; -use super::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, -}; -use super::errors::ProviderError; +use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use super::retry::{ProviderRetry, RetryConfig}; use crate::conversation::message::Message; use crate::model::ModelConfig; @@ -14,6 +11,8 @@ use aws_sdk_bedrockruntime::config::ProvideCredentials; use aws_sdk_bedrockruntime::operation::converse::ConverseError; use aws_sdk_bedrockruntime::{types as bedrock, Client}; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::ProviderUsage; +use goose_providers::errors::ProviderError; use reqwest::header::HeaderValue; use rmcp::model::Tool; use serde_json::Value; diff --git a/crates/goose/src/providers/chatgpt_codex.rs b/crates/goose/src/providers/chatgpt_codex.rs index 17cd31bc07fe..65ddca08c4b7 100644 --- a/crates/goose/src/providers/chatgpt_codex.rs +++ b/crates/goose/src/providers/chatgpt_codex.rs @@ -3,7 +3,6 @@ use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::api_client::AuthProvider; use crate::providers::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; -use crate::providers::errors::ProviderError; use crate::providers::formats::openai_responses::responses_api_to_streaming_message; use crate::providers::openai_compatible::handle_status; use crate::providers::retry::ProviderRetry; @@ -16,6 +15,7 @@ use base64::Engine; use chrono::{DateTime, Utc}; use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; +use goose_providers::errors::ProviderError; use jsonwebtoken::jwk::JwkSet; use jsonwebtoken::{decode, decode_header, DecodingKey, Validation}; use reqwest::header::{HeaderName, HeaderValue}; @@ -230,7 +230,7 @@ fn get_reasoning_effort(model_name: &str) -> String { } fn reasoning_effort_for_config(model_config: &ModelConfig) -> Option { - use crate::model::ThinkingEffort; + use goose_providers::thinking::ThinkingEffort; model_config .thinking_effort() diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index e8957c965912..3899a642e146 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -2,6 +2,8 @@ use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; use rmcp::model::{Role, Tool}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -17,9 +19,8 @@ use tokio::sync::oneshot; use super::base::{ stream_from_single_message, ConfigKey, MessageStream, PermissionRouting, Provider, ProviderDef, - ProviderMetadata, ProviderUsage, Usage, + ProviderMetadata, }; -use super::errors::ProviderError; use super::utils::filter_extensions_from_system_prompt; use crate::config::base::ClaudeCodeCommand; use crate::config::paths::Paths; diff --git a/crates/goose/src/providers/cli_common.rs b/crates/goose/src/providers/cli_common.rs index f2531ec9f5a3..fdfc89ee2201 100644 --- a/crates/goose/src/providers/cli_common.rs +++ b/crates/goose/src/providers/cli_common.rs @@ -1,9 +1,9 @@ use serde_json::Value; -use super::base::{ProviderUsage, Usage}; -use super::errors::ProviderError; use crate::conversation::message::{Message, MessageContent}; use crate::utils::safe_truncate; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; use rmcp::model::Role; pub(crate) fn extract_usage_tokens(usage_info: &Value) -> Usage { diff --git a/crates/goose/src/providers/codex.rs b/crates/goose/src/providers/codex.rs index 8b420d4146e3..094c3e10b058 100644 --- a/crates/goose/src/providers/codex.rs +++ b/crates/goose/src/providers/codex.rs @@ -2,6 +2,8 @@ use anyhow::Result; use async_trait::async_trait; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::thinking::ThinkingEffort; use serde_json::json; use std::collections::HashMap; use std::io::Write; @@ -11,10 +13,7 @@ use tempfile::NamedTempFile; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use super::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, -}; -use super::errors::ProviderError; +use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::{CodexCommand, CodexSkipGitCheck}; use crate::config::paths::Paths; @@ -23,6 +22,7 @@ use crate::config::{Config, ExtensionConfig, GooseMode}; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::subprocess::configure_subprocess; +use goose_providers::errors::ProviderError; use rmcp::model::Role; use rmcp::model::Tool; @@ -60,25 +60,22 @@ pub struct CodexProvider { } impl CodexProvider { - fn legacy_reasoning_effort() -> Option { + fn legacy_reasoning_effort() -> Option { Config::global() .get_param::("CODEX_REASONING_EFFORT") .ok() .and_then(|effort| match effort.to_lowercase().as_str() { - "none" => Some(crate::model::ThinkingEffort::Off), - "low" => Some(crate::model::ThinkingEffort::Low), - "medium" => Some(crate::model::ThinkingEffort::Medium), - "high" => Some(crate::model::ThinkingEffort::High), - "xhigh" => Some(crate::model::ThinkingEffort::Max), + "none" => Some(ThinkingEffort::Off), + "low" => Some(ThinkingEffort::Low), + "medium" => Some(ThinkingEffort::Medium), + "high" => Some(ThinkingEffort::High), + "xhigh" => Some(ThinkingEffort::Max), _ => None, }) } - fn map_thinking_effort( - _model_name: &str, - effort: Option, - ) -> Option { - use crate::model::ThinkingEffort; + fn map_thinking_effort(_model_name: &str, effort: Option) -> Option { + use ThinkingEffort; match effort .or_else(Self::legacy_reasoning_effort) .unwrap_or(ThinkingEffort::High) @@ -1238,7 +1235,7 @@ mod tests { #[test] fn test_map_thinking_effort() { - use crate::model::ThinkingEffort; + use ThinkingEffort; let _guard = env_lock::lock_env([ ("CODEX_REASONING_EFFORT", None::<&str>), diff --git a/crates/goose/src/providers/cursor_agent.rs b/crates/goose/src/providers/cursor_agent.rs index b9ef06e14e97..12477f1c722d 100644 --- a/crates/goose/src/providers/cursor_agent.rs +++ b/crates/goose/src/providers/cursor_agent.rs @@ -9,9 +9,7 @@ use tokio::process::Command; use super::base::{ stream_from_single_message, ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, - ProviderUsage, Usage, }; -use super::errors::ProviderError; use super::utils::{filter_extensions_from_system_prompt, RequestLog}; use crate::config::base::CursorAgentCommand; use crate::config::search_path::SearchPaths; @@ -19,6 +17,8 @@ use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::subprocess::configure_subprocess; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; use rmcp::model::Tool; const CURSOR_AGENT_PROVIDER_NAME: &str = "cursor-agent"; diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 76b9bc45942a..dbde81aff32b 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -1,6 +1,10 @@ use anyhow::Result; use async_trait::async_trait; use futures::future::BoxFuture; +use goose_providers::formats::openai::{ + extract_reasoning_effort, is_openai_responses_model, openai_reasoning_effort_for_thinking, +}; +use goose_providers::images::ImageFormat; use serde_json::Value; use std::collections::HashSet; use std::sync::LazyLock; @@ -14,7 +18,6 @@ use super::base::{ }; use super::databricks_auth::{DatabricksAuth, DatabricksAuthProvider}; use super::embedding::EmbeddingCapable; -use super::errors::ProviderError; use super::formats::databricks::create_request; use super::formats::openai_responses::create_responses_request; use super::openai_compatible::{ @@ -22,7 +25,7 @@ use super::openai_compatible::{ stream_openai_compat, stream_responses_compat, }; use super::retry::ProviderRetry; -use super::utils::{is_openai_responses_model, ImageFormat, RequestLog}; +use super::utils::RequestLog; use crate::config::ConfigError; use crate::conversation::message::Message; use crate::instance_id::get_instance_id; @@ -31,6 +34,7 @@ use crate::providers::retry::{ RetryConfig, DEFAULT_BACKOFF_MULTIPLIER, DEFAULT_INITIAL_RETRY_INTERVAL_MS, DEFAULT_MAX_RETRIES, DEFAULT_MAX_RETRY_INTERVAL_MS, }; +use goose_providers::errors::ProviderError; use rmcp::model::Tool; use serde_json::json; @@ -485,7 +489,7 @@ impl DatabricksProvider { if is_embedding { "serving-endpoints/text-embedding-3-small/invocations".to_string() } else { - let (clean_name, _) = super::utils::extract_reasoning_effort(model_name); + let (clean_name, _) = extract_reasoning_effort(model_name); if Self::is_responses_model(&clean_name) { "serving-endpoints/responses".to_string() } else { @@ -586,7 +590,7 @@ impl Provider for DatabricksProvider { messages: &[Message], tools: &[Tool], ) -> Result { - let (endpoint_name, _) = super::utils::extract_reasoning_effort(&model_config.model_name); + let (endpoint_name, _) = extract_reasoning_effort(&model_config.model_name); let endpoint_info = self.resolve_endpoint_info_cached(&endpoint_name).await.ok(); let effective_model_name = endpoint_info .as_ref() @@ -618,7 +622,7 @@ impl Provider for DatabricksProvider { payload["model"] = Value::String(endpoint_name.clone()); if payload.get("reasoning").is_none() { if let Some(effort) = model_config.thinking_effort().and_then(|effort| { - super::utils::openai_reasoning_effort_for_thinking(effective_model_name, effort) + openai_reasoning_effort_for_thinking(effective_model_name, effort) }) { payload.as_object_mut().unwrap().insert( "reasoning".to_string(), @@ -821,7 +825,7 @@ impl Provider for DatabricksProvider { } async fn fetch_model_info(&self, model_name: &str) -> Result { - let (endpoint_name, _) = super::utils::extract_reasoning_effort(model_name); + let (endpoint_name, _) = extract_reasoning_effort(model_name); let endpoint_info = self.resolve_endpoint_info_cached(&endpoint_name).await?; Ok(Self::model_info_from_endpoint(endpoint_info)) } diff --git a/crates/goose/src/providers/databricks_v2.rs b/crates/goose/src/providers/databricks_v2.rs index 2385ae9bde8c..15b715ef4dc4 100644 --- a/crates/goose/src/providers/databricks_v2.rs +++ b/crates/goose/src/providers/databricks_v2.rs @@ -3,6 +3,10 @@ use async_stream::try_stream; use async_trait::async_trait; use futures::future::BoxFuture; use futures::TryStreamExt; +use goose_providers::formats::openai::{ + self, extract_reasoning_effort, is_openai_responses_model, ModelConfigParams, +}; +use goose_providers::images::ImageFormat; use serde::Serialize; use serde_json::Value; use std::io; @@ -17,11 +21,10 @@ use super::base::{ DEFAULT_PROVIDER_TIMEOUT_SECS, }; use super::databricks_auth::{DatabricksAuth, DatabricksAuthProvider}; -use super::errors::ProviderError; -use super::formats::{anthropic, openai, openai_responses}; +use super::formats::{anthropic, openai_responses}; use super::openai_compatible::{handle_status, stream_openai_compat, stream_responses_compat}; use super::retry::ProviderRetry; -use super::utils::{extract_reasoning_effort, is_openai_responses_model, ImageFormat, RequestLog}; +use super::utils::RequestLog; use crate::config::ConfigError; use crate::conversation::message::Message; use crate::model::ModelConfig; @@ -29,6 +32,7 @@ use crate::providers::retry::{ RetryConfig, DEFAULT_BACKOFF_MULTIPLIER, DEFAULT_INITIAL_RETRY_INTERVAL_MS, DEFAULT_MAX_RETRIES, DEFAULT_MAX_RETRY_INTERVAL_MS, }; +use goose_providers::errors::ProviderError; use rmcp::model::Tool; const DATABRICKS_V2_PROVIDER_NAME: &str = "databricks_v2"; @@ -247,7 +251,13 @@ impl DatabricksV2Provider { tools: &[Tool], ) -> Result { let mut payload = openai::create_request( - model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, diff --git a/crates/goose/src/providers/formats/anthropic.rs b/crates/goose/src/providers/formats/anthropic.rs index d39b0bacc465..ac7b50675d7a 100644 --- a/crates/goose/src/providers/formats/anthropic.rs +++ b/crates/goose/src/providers/formats/anthropic.rs @@ -1,10 +1,11 @@ use crate::conversation::message::{Message, MessageContent}; use crate::mcp_utils::extract_text_from_resource; -use crate::model::{ModelConfig, ThinkingEffort}; -use crate::providers::base::Usage; -use crate::providers::errors::ProviderError; -use crate::providers::utils::{convert_image, ImageFormat}; +use crate::model::ModelConfig; use anyhow::{anyhow, Result}; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; +use goose_providers::images::{convert_image, ImageFormat}; +use goose_providers::thinking::ThinkingEffort; use rmcp::model::{object, CallToolRequestParams, ErrorCode, ErrorData, JsonObject, Role, Tool}; use rmcp::object as json_object; use serde_json::{json, Value}; @@ -484,7 +485,7 @@ pub fn get_usage(data: &Value) -> Result { let total_tokens_i32 = (total_input_i32 as i64 + output_tokens_i32 as i64).min(i32::MAX as i64) as i32; - tracing::debug!("🔍 Anthropic ACTUAL token counts from direct object: input={}, output={}, total={}", + tracing::debug!("🔍 Anthropic ACTUAL token counts from direct object: input={}, output={}, total={}", total_input_i32, output_tokens_i32, total_tokens_i32); Ok(Usage::new( @@ -663,12 +664,7 @@ pub fn create_request_with_options( /// Process streaming response from Anthropic's API pub fn response_to_streaming_message( mut stream: S, -) -> impl futures::Stream< - Item = anyhow::Result<( - Option, - Option, - )>, -> + 'static +) -> impl futures::Stream, Option)>> + 'static where S: futures::Stream> + Unpin + Send + 'static, { @@ -702,7 +698,7 @@ where try_stream! { let mut accumulated_tool_calls: std::collections::HashMap = std::collections::HashMap::new(); let mut current_tool_id: Option = None; - let mut final_usage: Option = None; + let mut final_usage: Option = None; let mut message_id: Option = None; let mut thinking: Option = None; @@ -744,7 +740,7 @@ where .and_then(|v| v.as_str()) .unwrap_or("unknown") .to_string(); - final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage)); + final_usage = Some(ProviderUsage::new(model, usage)); } } continue; @@ -884,14 +880,14 @@ where (None, None) => None, }; - let merged_usage = crate::providers::base::Usage::new(merged_input, merged_output, merged_total); - final_usage = Some(crate::providers::base::ProviderUsage::new(existing_usage.model.clone(), merged_usage)); + let merged_usage = Usage::new(merged_input, merged_output, merged_total); + final_usage = Some(ProviderUsage::new(existing_usage.model.clone(), merged_usage)); } else { let model = event.data.get("model") .and_then(|v| v.as_str()) .unwrap_or("unknown") .to_string(); - final_usage = Some(crate::providers::base::ProviderUsage::new(model, delta_usage)); + final_usage = Some(ProviderUsage::new(model, delta_usage)); } } continue; @@ -903,7 +899,7 @@ where .and_then(|v| v.as_str()) .unwrap_or("unknown") .to_string(); - final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage)); + final_usage = Some(ProviderUsage::new(model, usage)); } break; } diff --git a/crates/goose/src/providers/formats/bedrock.rs b/crates/goose/src/providers/formats/bedrock.rs index 23ed0c0698c4..7c531761db78 100644 --- a/crates/goose/src/providers/formats/bedrock.rs +++ b/crates/goose/src/providers/formats/bedrock.rs @@ -14,8 +14,8 @@ use rmcp::model::{ }; use serde_json::Value; -use super::super::base::Usage; use crate::conversation::message::{Message, MessageContent}; +use goose_providers::conversation::token_usage::Usage; pub fn to_bedrock_message_with_caching( message: &Message, diff --git a/crates/goose/src/providers/formats/databricks.rs b/crates/goose/src/providers/formats/databricks.rs index 38d84111b942..aaa352ae6c12 100644 --- a/crates/goose/src/providers/formats/databricks.rs +++ b/crates/goose/src/providers/formats/databricks.rs @@ -3,12 +3,14 @@ use crate::model::ModelConfig; use crate::providers::formats::anthropic::{ thinking_budget_tokens, thinking_effort, thinking_type, ThinkingType, }; -use crate::providers::utils::{ - convert_image, detect_image_path, extract_reasoning_effort, is_openai_responses_model, - is_valid_function_name, load_image_file, openai_reasoning_effort_for_thinking, - safely_parse_json, sanitize_function_name, ImageFormat, -}; + use anyhow::{anyhow, Error}; +use goose_providers::formats::openai::{ + extract_reasoning_effort, is_openai_responses_model, is_valid_function_name, + openai_reasoning_effort_for_thinking, sanitize_function_name, +}; +use goose_providers::images::{convert_image, detect_image_path, load_image_file, ImageFormat}; +use goose_providers::json::safely_parse_json; use rmcp::model::{ object, AnnotateAble, CallToolRequestParams, Content, ErrorCode, ErrorData, RawContent, ResourceContents, Role, Tool, diff --git a/crates/goose/src/providers/formats/gcpvertexai.rs b/crates/goose/src/providers/formats/gcpvertexai.rs index 31d70bf334b7..6e064eee0ec1 100644 --- a/crates/goose/src/providers/formats/gcpvertexai.rs +++ b/crates/goose/src/providers/formats/gcpvertexai.rs @@ -1,8 +1,8 @@ use super::{anthropic, google}; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::base::Usage; use anyhow::{Context, Result}; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; use rmcp::model::Tool; use serde_json::Value; @@ -10,12 +10,8 @@ use std::fmt; pub type StreamingMessageStream = std::pin::Pin< Box< - dyn futures::Stream< - Item = anyhow::Result<( - Option, - Option, - )>, - > + Send + dyn futures::Stream, Option)>> + + Send + 'static, >, >; diff --git a/crates/goose/src/providers/formats/google.rs b/crates/goose/src/providers/formats/google.rs index b35c2db504a0..ad232b6b0b0d 100644 --- a/crates/goose/src/providers/formats/google.rs +++ b/crates/goose/src/providers/formats/google.rs @@ -1,8 +1,9 @@ use crate::model::ModelConfig; -use crate::providers::base::Usage; -use crate::providers::errors::ProviderError; -use crate::providers::utils::{is_valid_function_name, sanitize_function_name}; use anyhow::Result; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; +use goose_providers::formats::openai::{is_valid_function_name, sanitize_function_name}; +use goose_providers::thinking::ThinkingEffort; use rmcp::model::{ object, AnnotateAble, CallToolRequestParams, ErrorCode, ErrorData, RawContent, Role, Tool, }; @@ -359,12 +360,7 @@ pub fn get_usage(data: &Value) -> Result { pub fn response_to_streaming_message( mut stream: S, -) -> impl futures::Stream< - Item = anyhow::Result<( - Option, - Option, - )>, -> + 'static +) -> impl futures::Stream, Option)>> + 'static where S: futures::Stream> + Unpin + Send + 'static, { @@ -372,7 +368,7 @@ where use futures::StreamExt; try_stream! { - let mut final_usage: Option = None; + let mut final_usage: Option = None; let mut last_signature: Option = None; let stream_id = Uuid::new_v4().to_string(); let mut incomplete_data: Option = None; @@ -446,7 +442,7 @@ where .and_then(|v| v.as_str()) .unwrap_or("unknown") .to_string(); - final_usage = Some(crate::providers::base::ProviderUsage::new(model, usage)); + final_usage = Some(ProviderUsage::new(model, usage)); } } @@ -542,7 +538,6 @@ fn get_thinking_config(model_config: &ModelConfig) -> Option { } if is_gemini_3 { - use crate::model::ThinkingEffort; let effort = model_config .thinking_effort() .unwrap_or(ThinkingEffort::Off); diff --git a/crates/goose/src/providers/formats/mod.rs b/crates/goose/src/providers/formats/mod.rs index 7e7218f4d6da..abfbf25a60e2 100644 --- a/crates/goose/src/providers/formats/mod.rs +++ b/crates/goose/src/providers/formats/mod.rs @@ -5,7 +5,6 @@ pub mod databricks; pub mod gcpvertexai; pub mod google; pub mod ollama; -pub mod openai; pub mod openai_responses; pub mod openrouter; pub mod snowflake; diff --git a/crates/goose/src/providers/formats/ollama.rs b/crates/goose/src/providers/formats/ollama.rs index eb2c8ced32f4..77e463449365 100644 --- a/crates/goose/src/providers/formats/ollama.rs +++ b/crates/goose/src/providers/formats/ollama.rs @@ -10,18 +10,20 @@ //! - qwen3-coder-32b use crate::conversation::message::{Message, MessageContent}; -use crate::providers::base::ProviderUsage; -use crate::providers::utils::is_valid_function_name; use async_stream::try_stream; use chrono; use futures::Stream; +use goose_providers::{ + conversation::token_usage::ProviderUsage, + formats::openai::{self, is_valid_function_name}, +}; use regex::Regex; use rmcp::model::{object, CallToolRequestParams, ErrorCode, ErrorData, Role}; use serde_json::Value; use std::borrow::Cow; use uuid::Uuid; -pub use super::openai::{ +pub use goose_providers::formats::openai::{ create_request, format_messages, format_tools, get_usage, validate_tool_schemas, }; @@ -82,7 +84,7 @@ pub fn parse_xml_tool_calls(content: &str) -> (Option, Vec anyhow::Result { - let message = super::openai::response_to_message(response)?; + let message = openai::response_to_message(response)?; let has_tool_requests = message .content @@ -163,7 +165,7 @@ where try_stream! { use futures::StreamExt; - let base_stream = super::openai::response_to_streaming_message(stream); + let base_stream = openai::response_to_streaming_message(stream); let mut base_stream = std::pin::pin!(base_stream); let mut accumulated_text = String::new(); diff --git a/crates/goose/src/providers/formats/openai_responses.rs b/crates/goose/src/providers/formats/openai_responses.rs index 167a1c229288..bcb64d544f8d 100644 --- a/crates/goose/src/providers/formats/openai_responses.rs +++ b/crates/goose/src/providers/formats/openai_responses.rs @@ -1,14 +1,14 @@ use crate::conversation::message::{Message, MessageContent}; use crate::mcp_utils::extract_text_from_resource; use crate::model::ModelConfig; -use crate::providers::base::{ProviderUsage, Usage}; -use crate::providers::utils::{ - extract_reasoning_effort, is_openai_responses_model, openai_reasoning_effort_for_thinking, -}; use anyhow::{anyhow, Error}; use async_stream::try_stream; use chrono; use futures::Stream; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::formats::openai::{ + extract_reasoning_effort, is_openai_responses_model, openai_reasoning_effort_for_thinking, +}; use rmcp::model::{object, CallToolRequestParams, RawContent, Role, Tool}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; diff --git a/crates/goose/src/providers/formats/openrouter.rs b/crates/goose/src/providers/formats/openrouter.rs index 22ac7465b252..a6d68e0c46d9 100644 --- a/crates/goose/src/providers/formats/openrouter.rs +++ b/crates/goose/src/providers/formats/openrouter.rs @@ -1,6 +1,7 @@ use crate::conversation::message::{Message, MessageContent, ProviderMetadata}; -use crate::model::{ModelConfig, ThinkingEffort}; -use crate::providers::formats::openai; +use crate::model::ModelConfig; +use goose_providers::formats::openai; +use goose_providers::thinking::ThinkingEffort; use rmcp::model::Role; use serde_json::{json, Value}; diff --git a/crates/goose/src/providers/formats/snowflake.rs b/crates/goose/src/providers/formats/snowflake.rs index 34f63af29da2..011c99b15197 100644 --- a/crates/goose/src/providers/formats/snowflake.rs +++ b/crates/goose/src/providers/formats/snowflake.rs @@ -1,9 +1,9 @@ use crate::conversation::message::{Message, MessageContent}; use crate::mcp_utils::extract_text_from_resource; use crate::model::ModelConfig; -use crate::providers::base::Usage; -use crate::providers::errors::ProviderError; use anyhow::{anyhow, Result}; +use goose_providers::conversation::token_usage::Usage; +use goose_providers::errors::ProviderError; use rmcp::model::{object, CallToolRequestParams, Role, Tool}; use rmcp::object; use serde_json::{json, Value}; diff --git a/crates/goose/src/providers/gcpvertexai.rs b/crates/goose/src/providers/gcpvertexai.rs index 08a7589c2ce1..8f215b925aea 100644 --- a/crates/goose/src/providers/gcpvertexai.rs +++ b/crates/goose/src/providers/gcpvertexai.rs @@ -21,7 +21,6 @@ use crate::providers::base::{ DEFAULT_PROVIDER_TIMEOUT_SECS, }; -use crate::providers::errors::ProviderError; use crate::providers::formats::gcpvertexai::{ create_request, response_to_streaming_message, GcpLocation, ModelProvider, RequestContext, DEFAULT_MODEL, KNOWN_MODELS, @@ -31,6 +30,7 @@ use crate::providers::openai_compatible::{map_http_error_to_provider_error, sani use crate::providers::retry::RetryConfig; use crate::providers::utils::RequestLog; use crate::session_context::SESSION_ID_HEADER; +use goose_providers::errors::ProviderError; use rmcp::model::Tool; const GCP_VERTEX_AI_PROVIDER_NAME: &str = "gcp_vertex_ai"; diff --git a/crates/goose/src/providers/gemini_cli.rs b/crates/goose/src/providers/gemini_cli.rs index 07c3974cd4a9..0d1c29aaf6cd 100644 --- a/crates/goose/src/providers/gemini_cli.rs +++ b/crates/goose/src/providers/gemini_cli.rs @@ -9,10 +9,8 @@ use tokio::process::Command; use super::base::{ stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, - ProviderUsage, Usage, }; use super::cli_common::{error_from_event, extract_usage_tokens}; -use super::errors::ProviderError; use super::utils::filter_extensions_from_system_prompt; use crate::config::base::GeminiCliCommand; use crate::config::search_path::SearchPaths; @@ -23,6 +21,8 @@ use crate::providers::base::ConfigKey; use crate::subprocess::configure_subprocess; use async_stream::try_stream; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; use rmcp::model::Role; use rmcp::model::Tool; diff --git a/crates/goose/src/providers/gemini_oauth.rs b/crates/goose/src/providers/gemini_oauth.rs index b3ac6a493480..65ad9525c19c 100644 --- a/crates/goose/src/providers/gemini_oauth.rs +++ b/crates/goose/src/providers/gemini_oauth.rs @@ -5,9 +5,9 @@ use crate::providers::base::{ ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, DEFAULT_PROVIDER_TIMEOUT_SECS, }; -use crate::providers::errors::ProviderError; use crate::providers::formats::google::{create_request, response_to_streaming_message}; use crate::providers::google::GOOGLE_DOC_URL; +use goose_providers::errors::ProviderError; const GEMINI_OAUTH_DEFAULT_MODEL: &str = "gemini-3-flash-preview"; const GEMINI_OAUTH_DEFAULT_FAST_MODEL: &str = "gemini-2.5-flash-lite"; diff --git a/crates/goose/src/providers/githubcopilot.rs b/crates/goose/src/providers/githubcopilot.rs index 8005db171529..2cfbcf289c3f 100644 --- a/crates/goose/src/providers/githubcopilot.rs +++ b/crates/goose/src/providers/githubcopilot.rs @@ -8,6 +8,9 @@ use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use axum::http; use chrono::{DateTime, Utc}; +use goose_providers::errors::ProviderError; +use goose_providers::formats::openai::{is_openai_responses_model, ModelConfigParams}; +use goose_providers::images::ImageFormat; use reqwest::{Client, Response}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -22,15 +25,13 @@ tokio::task_local! { } use super::base::{ - collect_stream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, - DEFAULT_PROVIDER_TIMEOUT_SECS, + collect_stream, Provider, ProviderDef, ProviderMetadata, DEFAULT_PROVIDER_TIMEOUT_SECS, }; -use super::errors::ProviderError; -use super::formats::openai::{create_request, get_usage, response_to_message}; use super::formats::openai_responses::create_responses_request; use super::openai_compatible::handle_response_openai_compat; use super::retry::ProviderRetry; -use super::utils::{get_model, is_openai_responses_model, ImageFormat, RequestLog}; +use super::utils::{get_model, RequestLog}; +use goose_providers::formats::openai::{create_request, get_usage, response_to_message}; use crate::config::{Config, ConfigError}; use crate::conversation::message::{Message, MessageContent}; @@ -38,6 +39,7 @@ use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, MessageStream}; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; use rmcp::model::{RawContent, Tool}; use std::ops::Deref; @@ -442,7 +444,13 @@ impl GithubCopilotProvider { if supports_streaming { let payload = create_request( - model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, @@ -478,7 +486,13 @@ impl GithubCopilotProvider { Some(session_id) }; let payload = create_request( - model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, @@ -726,8 +740,7 @@ fn promote_tool_choice(response: Value) -> Value { #[cfg(test)] mod tests { - use super::{normalize_host, promote_tool_choice, GithubCopilotProvider, GithubCopilotUrls}; - use crate::providers::utils::is_openai_responses_model; + use super::*; use serde_json::json; #[test] diff --git a/crates/goose/src/providers/google.rs b/crates/goose/src/providers/google.rs index ba272dd9fa18..a12f98551715 100644 --- a/crates/goose/src/providers/google.rs +++ b/crates/goose/src/providers/google.rs @@ -1,10 +1,10 @@ use super::api_client::{ApiClient, AuthMethod}; use super::base::MessageStream; -use super::errors::ProviderError; use super::openai_compatible::{handle_status, map_http_error_to_provider_error, sanitize_url}; use super::retry::ProviderRetry; use super::utils::RequestLog; use crate::conversation::message::Message; +use goose_providers::errors::ProviderError; use crate::model::ModelConfig; use crate::providers::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata}; diff --git a/crates/goose/src/providers/http_status.rs b/crates/goose/src/providers/http_status.rs index 279b5bc0e9be..58e317e1bc30 100644 --- a/crates/goose/src/providers/http_status.rs +++ b/crates/goose/src/providers/http_status.rs @@ -7,12 +7,11 @@ use std::time::{Duration, SystemTime}; use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; +use goose_providers::errors::ProviderError; use reqwest::header::{HeaderMap, RETRY_AFTER}; use reqwest::{Response, StatusCode}; use serde_json::Value; -use super::errors::ProviderError; - /// Strip credentials and sensitive query parameters from a URL for safe /// inclusion in error messages and logs. Drops userinfo (`user:pass@`) and /// all query parameters (which may contain API keys like `?key=...`). diff --git a/crates/goose/src/providers/huggingface.rs b/crates/goose/src/providers/huggingface.rs index b36e1bea2bda..5416b00e5f8b 100644 --- a/crates/goose/src/providers/huggingface.rs +++ b/crates/goose/src/providers/huggingface.rs @@ -3,7 +3,6 @@ use super::base::{ ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, DEFAULT_PROVIDER_TIMEOUT_SECS, }; -use super::errors::ProviderError; use super::huggingface_auth; use super::inventory::{default_inventory_identity, InventoryIdentityInput}; use super::openai_compatible::OpenAiCompatibleProvider; @@ -13,6 +12,7 @@ use crate::conversation::message::Message; use crate::model::ModelConfig; use anyhow::{anyhow, Result}; use futures::future::BoxFuture; +use goose_providers::errors::ProviderError; use rmcp::model::Tool; pub const HUGGINGFACE_API_HOST: &str = "https://router.huggingface.co/v1"; diff --git a/crates/goose/src/providers/kimicode.rs b/crates/goose/src/providers/kimicode.rs index f1fad7b3c9fb..5cb6e4275419 100644 --- a/crates/goose/src/providers/kimicode.rs +++ b/crates/goose/src/providers/kimicode.rs @@ -20,7 +20,6 @@ use super::base::{ ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, DEFAULT_PROVIDER_TIMEOUT_SECS, }; -use super::errors::ProviderError; use super::formats::anthropic::{create_request, response_to_streaming_message}; use super::oauth_device_flow::{ refresh_device_flow_token, run_device_flow, DeviceFlowConfig, DeviceFlowTokens, RequestEncoding, @@ -31,6 +30,7 @@ use super::utils::RequestLog; use crate::conversation::message::Message; use crate::model::ModelConfig; use futures::future::BoxFuture; +use goose_providers::errors::ProviderError; use rmcp::model::Tool; const KIMI_CODE_PROVIDER_NAME: &str = "kimi_code"; diff --git a/crates/goose/src/providers/litellm.rs b/crates/goose/src/providers/litellm.rs index d6a3a711de8e..7a9c29317bef 100644 --- a/crates/goose/src/providers/litellm.rs +++ b/crates/goose/src/providers/litellm.rs @@ -1,21 +1,24 @@ use anyhow::Result; use async_trait::async_trait; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::ProviderUsage; +use goose_providers::errors::ProviderError; +use goose_providers::images::ImageFormat; use serde_json::{json, Value}; use std::collections::HashMap; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ - ConfigKey, MessageStream, ModelInfo, Provider, ProviderDef, ProviderMetadata, ProviderUsage, + ConfigKey, MessageStream, ModelInfo, Provider, ProviderDef, ProviderMetadata, DEFAULT_PROVIDER_TIMEOUT_SECS, }; use super::embedding::EmbeddingCapable; -use super::errors::ProviderError; use super::openai_compatible::handle_response_openai_compat; use super::retry::ProviderRetry; -use super::utils::{get_model, ImageFormat, RequestLog}; +use super::utils::{get_model, RequestLog}; use crate::conversation::message::Message; use crate::model::ModelConfig; +use goose_providers::formats::openai::ModelConfigParams; use rmcp::model::Tool; const LITELLM_PROVIDER_NAME: &str = "litellm"; @@ -225,8 +228,14 @@ impl Provider for LiteLLMProvider { } else { Some(session_id) }; - let mut payload = super::formats::openai::create_request( - model_config, + let mut payload = goose_providers::formats::openai::create_request( + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, @@ -245,8 +254,8 @@ impl Provider for LiteLLMProvider { }) .await?; - let message = super::formats::openai::response_to_message(&response)?; - let usage = super::formats::openai::get_usage(&response); + let message = goose_providers::formats::openai::response_to_message(&response)?; + let usage = goose_providers::formats::openai::get_usage(&response); let response_model = get_model(&response); let mut log = RequestLog::start(model_config, &payload)?; log.write(&response, Some(&usage))?; diff --git a/crates/goose/src/providers/local_inference.rs b/crates/goose/src/providers/local_inference.rs index e7229f5fdd8f..e0c4ef3316d4 100644 --- a/crates/goose/src/providers/local_inference.rs +++ b/crates/goose/src/providers/local_inference.rs @@ -8,16 +8,16 @@ mod tool_parsing; use crate::config::ExtensionConfig; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; -use crate::providers::base::{ - MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, -}; -use crate::providers::errors::ProviderError; +use crate::providers::base::{MessageStream, Provider, ProviderDef, ProviderMetadata}; use crate::providers::utils::RequestLog; use anyhow::Result; use async_stream::try_stream; use async_trait::async_trait; use backend::{BackendLoadedModel, LocalInferenceBackend}; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; +use goose_providers::images::ImageFormat; use llamacpp::{LlamaCppBackend, LLAMACPP_BACKEND_ID}; use local_model_registry::ChatTemplate; use rmcp::model::Tool; @@ -189,8 +189,7 @@ pub fn recommend_local_model(runtime: &InferenceRuntime) -> String { } fn build_openai_messages_json(system: &str, messages: &[Message]) -> String { - use crate::providers::formats::openai::format_messages; - use crate::providers::utils::ImageFormat; + use goose_providers::formats::openai::format_messages; let mut arr: Vec = vec![json!({"role": "system", "content": system})]; arr.extend(format_messages(messages, &ImageFormat::OpenAi)); diff --git a/crates/goose/src/providers/local_inference/backend.rs b/crates/goose/src/providers/local_inference/backend.rs index 0968aafcf13e..d3eac5de1581 100644 --- a/crates/goose/src/providers/local_inference/backend.rs +++ b/crates/goose/src/providers/local_inference/backend.rs @@ -2,9 +2,9 @@ use rmcp::model::Tool; use std::any::Any; use crate::conversation::message::Message; -use crate::providers::errors::ProviderError; use crate::providers::local_inference::local_model_registry::ModelSettings; use crate::providers::utils::RequestLog; +use goose_providers::errors::ProviderError; use super::{ResolvedModelPaths, StreamSender}; diff --git a/crates/goose/src/providers/local_inference/llamacpp/inference_emulated_tools.rs b/crates/goose/src/providers/local_inference/llamacpp/inference_emulated_tools.rs index bcde00435989..703a51a19ba4 100644 --- a/crates/goose/src/providers/local_inference/llamacpp/inference_emulated_tools.rs +++ b/crates/goose/src/providers/local_inference/llamacpp/inference_emulated_tools.rs @@ -21,7 +21,7 @@ //! support should use the `inference_native_tools` path instead. use crate::conversation::message::{Message, MessageContent}; -use crate::providers::errors::ProviderError; +use goose_providers::errors::ProviderError; use rmcp::model::{CallToolRequestParams, Tool}; use serde_json::json; use std::borrow::Cow; diff --git a/crates/goose/src/providers/local_inference/llamacpp/inference_engine.rs b/crates/goose/src/providers/local_inference/llamacpp/inference_engine.rs index 2a19f5fe7595..5725bff18e67 100644 --- a/crates/goose/src/providers/local_inference/llamacpp/inference_engine.rs +++ b/crates/goose/src/providers/local_inference/llamacpp/inference_engine.rs @@ -1,9 +1,9 @@ -use crate::providers::base::{FilterOut, ThinkFilter}; -use crate::providers::errors::ProviderError; use crate::providers::local_inference::backend::LocalInferenceBackend; use crate::providers::local_inference::local_model_registry::ModelSettings; use crate::providers::local_inference::multimodal::ExtractedImage; use crate::providers::utils::RequestLog; +use goose_providers::errors::ProviderError; +use goose_providers::thinking::{FilterOut, ThinkFilter}; use llama_cpp_2::context::params::LlamaContextParams; use llama_cpp_2::llama_batch::LlamaBatch; use llama_cpp_2::model::{AddBos, ChatTemplateResult, LlamaChatTemplate, LlamaModel}; diff --git a/crates/goose/src/providers/local_inference/llamacpp/inference_native_tools.rs b/crates/goose/src/providers/local_inference/llamacpp/inference_native_tools.rs index 501ad6b8497c..63d1e721ae23 100644 --- a/crates/goose/src/providers/local_inference/llamacpp/inference_native_tools.rs +++ b/crates/goose/src/providers/local_inference/llamacpp/inference_native_tools.rs @@ -1,5 +1,5 @@ use crate::conversation::message::{Message, MessageContent}; -use crate::providers::errors::ProviderError; +use goose_providers::errors::ProviderError; use rmcp::model::CallToolRequestParams; use serde_json::Value; use std::borrow::Cow; diff --git a/crates/goose/src/providers/local_inference/llamacpp/mod.rs b/crates/goose/src/providers/local_inference/llamacpp/mod.rs index 4e076f459325..218745f477bb 100644 --- a/crates/goose/src/providers/local_inference/llamacpp/mod.rs +++ b/crates/goose/src/providers/local_inference/llamacpp/mod.rs @@ -18,8 +18,6 @@ use self::inference_emulated_tools::{ }; use self::inference_engine::{GenerationContext, LoadedChatTemplates, LoadedModel}; use self::inference_native_tools::generate_with_native_tools; -use crate::providers::errors::ProviderError; -use crate::providers::formats::openai::format_tools; use crate::providers::local_inference::backend::{ BackendLoadedModel, LocalGenerationRequest, LocalInferenceBackend, }; @@ -31,6 +29,8 @@ use crate::providers::local_inference::tool_parsing::compact_tools_json; use crate::providers::local_inference::{ build_openai_messages_json, build_openai_text_messages_json, ResolvedModelPaths, }; +use goose_providers::errors::ProviderError; +use goose_providers::formats::openai::format_tools; pub(super) const LLAMACPP_BACKEND_ID: &str = "llamacpp"; diff --git a/crates/goose/src/providers/local_inference/multimodal.rs b/crates/goose/src/providers/local_inference/multimodal.rs index 157e2faec05e..1a464cd02ef4 100644 --- a/crates/goose/src/providers/local_inference/multimodal.rs +++ b/crates/goose/src/providers/local_inference/multimodal.rs @@ -2,7 +2,7 @@ use base64::prelude::*; use serde_json::Value; use crate::conversation::message::{Message, MessageContent}; -use crate::providers::errors::ProviderError; +use goose_providers::errors::ProviderError; #[derive(Debug)] pub struct ExtractedImage { diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 359d0465901c..d64215081f2f 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -27,7 +27,6 @@ pub mod databricks; pub mod databricks_auth; pub mod databricks_v2; pub mod embedding; -pub mod errors; pub mod formats; mod gcpauth; pub mod gcpvertexai; diff --git a/crates/goose/src/providers/nanogpt.rs b/crates/goose/src/providers/nanogpt.rs index 6c28d3b2b6f2..7475c9df6a8d 100644 --- a/crates/goose/src/providers/nanogpt.rs +++ b/crates/goose/src/providers/nanogpt.rs @@ -1,15 +1,16 @@ use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; -use super::errors::ProviderError; use super::openai_compatible::{handle_status, stream_openai_compat}; use super::retry::ProviderRetry; -use super::utils::{ImageFormat, RequestLog}; +use super::utils::RequestLog; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::openai::create_request; use anyhow::Result; use async_trait::async_trait; use futures::future::BoxFuture; +use goose_providers::errors::ProviderError; +use goose_providers::formats::openai::{create_request, ModelConfigParams}; +use goose_providers::images::ImageFormat; use rmcp::model::Tool; pub const NANOGPT_PROVIDER_NAME: &str = "nano-gpt"; @@ -175,7 +176,13 @@ impl Provider for NanoGptProvider { tools: &[Tool], ) -> Result { let payload = create_request( - model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index f03afa209d6f..6aa6fbe99698 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -3,11 +3,10 @@ use super::base::{ ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, DEFAULT_PROVIDER_TIMEOUT_SECS, }; -use super::errors::ProviderError; use super::inventory::InventoryIdentityInput; use super::openai_compatible::handle_status; use super::retry::{ProviderRetry, RetryConfig}; -use super::utils::{ImageFormat, RequestLog}; +use super::utils::RequestLog; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use crate::model::ModelConfig; @@ -17,6 +16,9 @@ use async_stream::try_stream; use async_trait::async_trait; use futures::future::BoxFuture; use futures::TryStreamExt; +use goose_providers::errors::ProviderError; +use goose_providers::formats::openai::ModelConfigParams; +use goose_providers::images::ImageFormat; use reqwest::Response; use rmcp::model::Tool; use serde_json::{json, Value}; @@ -319,7 +321,13 @@ impl Provider for OllamaProvider { tools: &[Tool], ) -> Result { let mut payload = create_request( - model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, @@ -558,7 +566,6 @@ mod tests { #[test] fn test_raw_create_request_contains_unsupported_ollama_fields() { use crate::providers::formats::ollama::create_request; - use crate::providers::utils::ImageFormat; let model_config = ModelConfig::new("llama3.1") .unwrap() @@ -566,7 +573,13 @@ mod tests { let messages = vec![crate::conversation::message::Message::user().with_text("hi")]; let payload = create_request( - &model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, "You are a helpful assistant.", &messages, &[], @@ -588,7 +601,6 @@ mod tests { #[test] fn test_apply_ollama_options_preserves_stream_options_by_default() { use crate::providers::formats::ollama::create_request; - use crate::providers::utils::ImageFormat; let _guard = env_lock::lock_env([ ("GOOSE_INPUT_LIMIT", None::<&str>), @@ -600,7 +612,13 @@ mod tests { let messages = vec![crate::conversation::message::Message::user().with_text("hi")]; let mut payload = create_request( - &model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, "You are a helpful assistant.", &messages, &[], @@ -633,7 +651,6 @@ mod tests { #[test] fn test_apply_ollama_options_strips_stream_options_when_disabled() { use crate::providers::formats::ollama::create_request; - use crate::providers::utils::ImageFormat; let _guard = env_lock::lock_env([ ("GOOSE_INPUT_LIMIT", None::<&str>), @@ -645,7 +662,13 @@ mod tests { let messages = vec![crate::conversation::message::Message::user().with_text("hi")]; let mut payload = create_request( - &model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, "You are a helpful assistant.", &messages, &[], @@ -740,8 +763,8 @@ mod tests { assert!(config.transient_only); - use super::super::errors::ProviderError; use super::super::retry::should_retry; + use goose_providers::errors::ProviderError; assert!(!should_retry( &ProviderError::RequestFailed("Resource not found (404)".into()), diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 3ecc24ff90b8..dfec1948f386 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -3,10 +3,6 @@ use super::base::{ ConfigKey, ModelInfo, Provider, ProviderDef, ProviderMetadata, DEFAULT_PROVIDER_TIMEOUT_SECS, }; use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; -use super::errors::ProviderError; -use super::formats::openai::{ - create_request_with_options, get_usage, response_to_message, OpenAiFormatOptions, -}; use super::formats::openai_responses::{ create_responses_request, get_responses_usage, responses_api_to_message, ResponsesApiResponse, }; @@ -15,12 +11,18 @@ use super::openai_compatible::{ handle_response_openai_compat, handle_status, stream_openai_compat, stream_responses_compat, }; use super::retry::ProviderRetry; -use super::utils::ImageFormat; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use anyhow::Result; use async_trait::async_trait; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::ProviderUsage; +use goose_providers::errors::ProviderError; +use goose_providers::formats::openai::{ + create_request_with_options, get_usage, response_to_message, OpenAiFormatOptions, +}; +use goose_providers::formats::openai::{is_openai_responses_model, ModelConfigParams}; +use goose_providers::images::ImageFormat; use reqwest::StatusCode; use std::collections::HashMap; @@ -476,7 +478,7 @@ impl OpenAiProvider { } fn is_responses_model(model_name: &str) -> bool { - super::utils::is_openai_responses_model(model_name) + is_openai_responses_model(model_name) } fn should_use_responses_api(model_name: &str, base_path: &str) -> bool { @@ -821,8 +823,7 @@ impl Provider for OpenAiProvider { let message = responses_api_to_message(&responses_api_response)?; let usage_data = get_responses_usage(&responses_api_response); - let usage = - super::base::ProviderUsage::new(model_config.model_name.clone(), usage_data); + let usage = ProviderUsage::new(model_config.model_name.clone(), usage_data); log.write( &serde_json::to_value(&message).unwrap_or_default(), @@ -833,7 +834,13 @@ impl Provider for OpenAiProvider { } } else { let payload = create_request_with_options( - model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, @@ -871,8 +878,7 @@ impl Provider for OpenAiProvider { })?; let usage_data = get_usage(json.get("usage").unwrap_or(&serde_json::Value::Null)); - let usage = - super::base::ProviderUsage::new(model_config.model_name.clone(), usage_data); + let usage = ProviderUsage::new(model_config.model_name.clone(), usage_data); log.write( &serde_json::to_value(&message).unwrap_or_default(), diff --git a/crates/goose/src/providers/openai_compatible.rs b/crates/goose/src/providers/openai_compatible.rs index 7ba2a5ff436e..3355cccec474 100644 --- a/crates/goose/src/providers/openai_compatible.rs +++ b/crates/goose/src/providers/openai_compatible.rs @@ -1,6 +1,8 @@ use anyhow::Error; use async_stream::try_stream; use futures::TryStreamExt; +use goose_providers::conversation::token_usage::ProviderUsage; +use goose_providers::images::ImageFormat; use reqwest::Response; #[cfg(test)] use reqwest::StatusCode; @@ -11,16 +13,17 @@ use tokio_util::codec::{FramedRead, LinesCodec}; use tokio_util::io::StreamReader; use super::api_client::ApiClient; -use super::base::{stream_from_single_message, MessageStream, Provider, ProviderUsage}; -use super::errors::ProviderError; +use super::base::{stream_from_single_message, MessageStream, Provider}; use super::retry::ProviderRetry; -use super::utils::{ImageFormat, RequestLog}; +use super::utils::RequestLog; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::openai::{ +use crate::providers::formats::openai_responses::responses_api_to_streaming_message; +use goose_providers::errors::ProviderError; +use goose_providers::formats::openai::{ create_request, get_usage, response_to_message, response_to_streaming_message, + ModelConfigParams, }; -use crate::providers::formats::openai_responses::responses_api_to_streaming_message; use rmcp::model::Tool; pub struct OpenAiCompatibleProvider { @@ -63,7 +66,13 @@ impl OpenAiCompatibleProvider { for_streaming: bool, ) -> Result { create_request( - model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index 59afb94a84ad..7d1ad92b0462 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -1,18 +1,19 @@ use anyhow::Result; use async_trait::async_trait; use futures::future::BoxFuture; +use goose_providers::images::ImageFormat; use serde_json::{json, Value}; use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; -use super::errors::ProviderError; use super::openai_compatible::{handle_status, stream_openai_compat}; use super::retry::ProviderRetry; -use super::utils::{ImageFormat, RequestLog}; +use super::utils::RequestLog; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::openai::create_request; use crate::providers::formats::openrouter as openrouter_format; +use goose_providers::errors::ProviderError; +use goose_providers::formats::openai::{create_request, ModelConfigParams}; use rmcp::model::Tool; pub const OPENROUTER_PROVIDER_NAME: &str = "openrouter"; @@ -256,7 +257,13 @@ impl Provider for OpenRouterProvider { tools: &[Tool], ) -> Result { let mut payload = create_request( - model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, diff --git a/crates/goose/src/providers/retry.rs b/crates/goose/src/providers/retry.rs index f27f322dc179..5cb12faa3a2a 100644 --- a/crates/goose/src/providers/retry.rs +++ b/crates/goose/src/providers/retry.rs @@ -1,6 +1,6 @@ -use super::errors::ProviderError; use crate::providers::base::Provider; use async_trait::async_trait; +use goose_providers::errors::ProviderError; use std::future::Future; use std::time::Duration; use tokio::time::sleep; diff --git a/crates/goose/src/providers/sagemaker_tgi.rs b/crates/goose/src/providers/sagemaker_tgi.rs index d05e15113170..bc4ba5d72c6d 100644 --- a/crates/goose/src/providers/sagemaker_tgi.rs +++ b/crates/goose/src/providers/sagemaker_tgi.rs @@ -10,18 +10,17 @@ use rmcp::model::Tool; use serde_json::{json, Value}; use smithy_transport_reqwest::ReqwestHttpClient; -use super::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, -}; -use super::errors::ProviderError; +use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use super::retry::ProviderRetry; use super::utils::RequestLog; use crate::conversation::message::{Message, MessageContent}; use crate::session_context::SESSION_ID_HEADER; +use goose_providers::errors::ProviderError; use crate::model::ModelConfig; use chrono::Utc; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; use rmcp::model::Role; const SAGEMAKER_TGI_PROVIDER_NAME: &str = "sagemaker_tgi"; diff --git a/crates/goose/src/providers/snowflake.rs b/crates/goose/src/providers/snowflake.rs index de5c262bf15d..7f3639a688a2 100644 --- a/crates/goose/src/providers/snowflake.rs +++ b/crates/goose/src/providers/snowflake.rs @@ -1,19 +1,19 @@ use anyhow::Result; use async_trait::async_trait; +use goose_providers::conversation::token_usage::ProviderUsage; +use goose_providers::images::ImageFormat; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use super::api_client::{ApiClient, AuthMethod}; -use super::base::{ - ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, -}; -use super::errors::ProviderError; +use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use super::formats::snowflake::{create_request, get_usage, response_to_message}; use super::openai_compatible::{map_http_error_to_provider_error, sanitize_url}; use super::retry::ProviderRetry; -use super::utils::{get_model, ImageFormat, RequestLog}; +use super::utils::{get_model, RequestLog}; use crate::config::ConfigError; use crate::conversation::message::Message; +use goose_providers::errors::ProviderError; use crate::model::ModelConfig; use futures::future::BoxFuture; diff --git a/crates/goose/src/providers/testprovider.rs b/crates/goose/src/providers/testprovider.rs index e06d1f27057f..ef8cb5c52e6d 100644 --- a/crates/goose/src/providers/testprovider.rs +++ b/crates/goose/src/providers/testprovider.rs @@ -9,12 +9,13 @@ use std::sync::{Arc, Mutex}; #[cfg(test)] use super::base::stream_from_single_message; -use super::base::{MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage}; -use super::errors::ProviderError; +use super::base::{MessageStream, Provider, ProviderDef, ProviderMetadata}; use crate::conversation::message::{Message, ToolResponse}; use crate::model::ModelConfig; use crate::utils::bytes_to_hex; use futures::future::BoxFuture; +use goose_providers::conversation::token_usage::ProviderUsage; +use goose_providers::errors::ProviderError; use rmcp::model::{CallToolResult, Tool}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -225,8 +226,8 @@ impl Provider for TestProvider { mod tests { use super::*; use crate::conversation::message::{Message, MessageContent}; - use crate::providers::base::{ProviderUsage, Usage}; use chrono::Utc; + use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; use rmcp::model::{RawTextContent, Role, TextContent}; use std::env; diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index c57396d3eab6..00d4e9c77127 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -1,6 +1,5 @@ use super::api_client::{ApiClient, AuthMethod}; use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; -use super::errors::ProviderError; use super::openai_compatible::{ handle_response_openai_compat, handle_status, map_http_error_to_provider_error, stream_openai_compat, @@ -12,9 +11,11 @@ use crate::conversation::message::Message; use anyhow::Result; use async_trait::async_trait; use futures::future::BoxFuture; +use goose_providers::errors::ProviderError; +use goose_providers::images::ImageFormat; use crate::model::ModelConfig; -use crate::providers::formats::openai::create_request; +use goose_providers::formats::openai::{create_request, ModelConfigParams}; use rmcp::model::Tool; use serde_json::Value; @@ -139,11 +140,17 @@ impl Provider for TetrateProvider { tools: &[Tool], ) -> Result { let payload = create_request( - model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system, messages, tools, - &super::utils::ImageFormat::OpenAi, + &ImageFormat::OpenAi, true, )?; diff --git a/crates/goose/src/providers/toolshim.rs b/crates/goose/src/providers/toolshim.rs index 6ea6cebe7e48..58f63297e81d 100644 --- a/crates/goose/src/providers/toolshim.rs +++ b/crates/goose/src/providers/toolshim.rs @@ -30,7 +30,6 @@ //! - `augment_message_with_tool_calls`: A utility function that takes any message, extracts text content, sends it to an interpreter, and adds any detected tool calls back to the message. //! -use super::errors::ProviderError; #[cfg(feature = "local-inference")] use super::local_inference::LOCAL_LLM_MODEL_CONFIG_KEY; use super::ollama::OLLAMA_DEFAULT_PORT; @@ -39,9 +38,12 @@ use crate::conversation::message::{Message, MessageContent}; use crate::conversation::Conversation; use crate::model::ModelConfig; use crate::providers::base::DEFAULT_PROVIDER_TIMEOUT_SECS; -use crate::providers::formats::openai::create_request; use anyhow::Result; use futures::StreamExt; +use goose_providers::errors::ProviderError; +use goose_providers::formats::openai::create_request; +use goose_providers::formats::openai::ModelConfigParams; +use goose_providers::images::ImageFormat; use reqwest::Client; use rmcp::model::{object, CallToolRequestParams, RawContent, Tool}; use serde_json::{json, Value}; @@ -696,11 +698,17 @@ impl OllamaInterpreter { .with_canonical_limits("ollama"); let mut payload = create_request( - &model_config, + ModelConfigParams { + model_name: model_config.model_name.as_str(), + thinking_effort: model_config.thinking_effort(), + temperature: model_config.temperature, + max_tokens: model_config.max_tokens, + request_params: model_config.request_params.as_ref(), + }, system_prompt, &messages, &[], // No tools - &super::utils::ImageFormat::OpenAi, + &ImageFormat::OpenAi, false, )?; diff --git a/crates/goose/src/providers/usage_estimator.rs b/crates/goose/src/providers/usage_estimator.rs index 9385a57fe376..1f0d5f5301f6 100644 --- a/crates/goose/src/providers/usage_estimator.rs +++ b/crates/goose/src/providers/usage_estimator.rs @@ -1,7 +1,7 @@ use crate::conversation::message::Message; -use crate::providers::base::ProviderUsage; use crate::token_counter::create_token_counter; use anyhow::Result; +use goose_providers::conversation::token_usage::ProviderUsage; use rmcp::model::Tool; /// Ensures that ProviderUsage has token counts, estimating them if necessary. @@ -51,7 +51,7 @@ pub async fn ensure_usage_tokens( mod tests { use super::*; use crate::conversation::message::Message; - use crate::providers::base::Usage; + use goose_providers::conversation::token_usage::Usage; #[tokio::test] async fn test_ensure_usage_tokens_already_complete() { diff --git a/crates/goose/src/providers/utils-to-move.md b/crates/goose/src/providers/utils-to-move.md new file mode 100644 index 000000000000..c7087bdda795 --- /dev/null +++ b/crates/goose/src/providers/utils-to-move.md @@ -0,0 +1,8 @@ +utils to move + +use crate::providers::base:: + split_think_blocks + ProviderUsage + -> rm ensure_tokens + ThinkFilter + Usage diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 919db63fc293..5044313a1964 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -1,49 +1,17 @@ -use super::base::Usage; -use super::errors::GoogleErrorCode; use crate::config::paths::Paths; -use crate::model::{ModelConfig, ThinkingEffort}; -use crate::providers::errors::ProviderError; use anyhow::{anyhow, Result}; -use base64::Engine; use fs_err::File; -use regex::Regex; +use goose_providers::conversation::token_usage::Usage; +use goose_providers::errors::{GoogleErrorCode, ProviderError}; use reqwest::{Response, StatusCode}; -use rmcp::model::{AnnotateAble, ImageContent, RawImageContent}; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; +use serde::Serialize; +use serde_json::Value; use std::fmt::Display; -use std::io::{BufWriter, Read, Write}; -use std::path::{Path, PathBuf}; -use std::sync::OnceLock; +use std::io::{BufWriter, Write}; +use std::path::PathBuf; use std::time::Duration; use uuid::Uuid; -#[derive(Debug, Copy, Clone, Serialize, Deserialize)] -pub enum ImageFormat { - OpenAi, - Anthropic, -} - -/// Convert an image content into an image json based on format -pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value { - match image_format { - ImageFormat::OpenAi => json!({ - "type": "image_url", - "image_url": { - "url": format!("data:{};base64,{}", image.mime_type, image.data) - } - }), - ImageFormat::Anthropic => json!({ - "type": "image", - "source": { - "type": "base64", - "media_type": image.mime_type, - "data": image.data, - } - }), - } -} - pub fn filter_extensions_from_system_prompt(system: &str) -> String { let Some(extensions_start) = system.find("# Extensions") else { return system.to_string(); @@ -194,105 +162,6 @@ pub async fn handle_response_google_compat(response: Response) -> Result bool { - static RE: OnceLock = OnceLock::new(); - let re = - RE.get_or_init(|| Regex::new(r"(?i)(?:^|[-/])(?:o\d+(?:$|-)|gpt-5(?:$|[-.]))").unwrap()); - re.is_match(model_name) -} - -/// Extract an explicit reasoning-effort suffix from a model name. -/// -/// Returns `(base_model_name, Some(effort))` when the user appended a -/// recognised suffix like `-high` or `-xhigh`, e.g. `gpt-5.4-high` → -/// `("gpt-5.4", Some("high"))`. -/// -/// When no suffix is present the effort is `None` — callers should omit -/// the `reasoning` field entirely so the API applies its own per-model -/// default. This avoids hard-coding a default that may be invalid for -/// certain models (e.g. `gpt-5-pro` only accepts `high`; older o-series -/// models reject `none` and `xhigh`). -pub fn extract_reasoning_effort(model_name: &str) -> (String, Option) { - if !is_openai_responses_model(model_name) { - return (model_name.to_string(), None); - } - - static RE: OnceLock = OnceLock::new(); - let re = RE.get_or_init(|| { - Regex::new(r"(?i)^(?P.+)-(?Pnone|low|medium|high|xhigh)$").unwrap() - }); - - if let Some(captures) = re.captures(model_name) { - let base = captures["base"].to_string(); - let effort = captures["effort"].to_ascii_lowercase(); - return (base, Some(effort)); - } - - (model_name.to_string(), None) -} - -pub fn openai_reasoning_effort_for_thinking( - model_name: &str, - effort: ThinkingEffort, -) -> Option { - if effort == ThinkingEffort::Off { - return Some("none".to_string()); - } - - let supported = openai_reasoning_efforts_for_model(model_name); - let preferred: &[&str] = match effort { - ThinkingEffort::Off => unreachable!(), - ThinkingEffort::Low => &["low", "medium", "high", "xhigh"], - ThinkingEffort::Medium => &["medium", "high", "low", "xhigh"], - ThinkingEffort::High => &["high", "medium", "xhigh", "low"], - ThinkingEffort::Max => &["xhigh", "high", "medium", "low"], - }; - - preferred - .iter() - .find(|level| supported.contains(level)) - .map(|level| (*level).to_string()) -} - -fn openai_reasoning_efforts_for_model(model_name: &str) -> &'static [&'static str] { - let normalized = model_name.to_ascii_lowercase(); - - if normalized.contains("gpt-5") { - if normalized.contains("-pro") || normalized.contains("/pro") { - &["high"] - } else if normalized.contains("gpt-5.4") - || normalized.contains("gpt-5-4") - || normalized.contains("gpt-5.5") - || normalized.contains("gpt-5-5") - { - &["low", "medium", "high", "xhigh"] - } else { - &["low", "medium", "high"] - } - } else { - &["low", "medium", "high"] - } -} - -pub fn sanitize_function_name(name: &str) -> String { - static RE: OnceLock = OnceLock::new(); - let re = RE.get_or_init(|| Regex::new(r"[^a-zA-Z0-9_-]").unwrap()); - re.replace_all(name, "_").to_string() -} - -pub fn is_valid_function_name(name: &str) -> bool { - static RE: OnceLock = OnceLock::new(); - let re = RE.get_or_init(|| Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap()); - re.is_match(name) -} - /// Extract the model name from a JSON object. Common with most providers to have this top level attribute. pub fn get_model(data: &Value) -> String { if let Some(model) = data.get("model") { @@ -306,94 +175,6 @@ pub fn get_model(data: &Value) -> String { } } -/// Check if a file is actually an image by examining its magic bytes -fn is_image_file(path: &Path) -> bool { - if let Ok(mut file) = std::fs::File::open(path) { - let mut buffer = [0u8; 8]; // Large enough for most image magic numbers - if file.read(&mut buffer).is_ok() { - // Check magic numbers for common image formats - return match &buffer[0..4] { - // PNG: 89 50 4E 47 - [0x89, 0x50, 0x4E, 0x47] => true, - // JPEG: FF D8 FF - [0xFF, 0xD8, 0xFF, _] => true, - // GIF: 47 49 46 38 - [0x47, 0x49, 0x46, 0x38] => true, - _ => false, - }; - } - } - false -} - -/// Detect if a string contains a path to an image file -pub fn detect_image_path(text: &str) -> Option<&str> { - // Basic image file extension check - let extensions = [".png", ".jpg", ".jpeg"]; - - // Find any word that ends with an image extension - for word in text.split_whitespace() { - if extensions - .iter() - .any(|ext| word.to_lowercase().ends_with(ext)) - { - let path = Path::new(word); - // Check if it's an absolute path and file exists - if path.is_absolute() && path.is_file() { - // Verify it's actually an image file - if is_image_file(path) { - return Some(word); - } - } - } - } - None -} - -/// Convert a local image file to base64 encoded ImageContent -pub fn load_image_file(path: &str) -> Result { - let path = Path::new(path); - - // Verify it's an image before proceeding - if !is_image_file(path) { - return Err(ProviderError::RequestFailed( - "File is not a valid image".to_string(), - )); - } - - // Read the file - let bytes = std::fs::read(path) - .map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?; - - // Detect mime type from extension - let mime_type = match path.extension().and_then(|e| e.to_str()) { - Some(ext) => match ext.to_lowercase().as_str() { - "png" => "image/png", - "jpg" | "jpeg" => "image/jpeg", - _ => { - return Err(ProviderError::RequestFailed( - "Unsupported image format".to_string(), - )) - } - }, - None => { - return Err(ProviderError::RequestFailed( - "Unknown image format".to_string(), - )) - } - }; - - // Convert to base64 - let data = base64::prelude::BASE64_STANDARD.encode(&bytes); - - Ok(RawImageContent { - mime_type: mime_type.to_string(), - data, - meta: None, - } - .no_annotation()) -} - pub fn unescape_json_values(value: &Value) -> Value { let mut cloned = value.clone(); unescape_json_values_in_place(&mut cloned); @@ -437,8 +218,9 @@ pub struct RequestLog { pub const LOGS_TO_KEEP: usize = 10; impl RequestLog { - pub fn start(model_config: &ModelConfig, payload: &Payload) -> Result + pub fn start(model_config: ModelConfig, payload: &Payload) -> Result where + ModelConfig: Serialize, Payload: Serialize, { let logs_dir = Paths::in_state_dir("logs"); @@ -521,130 +303,6 @@ impl Drop for RequestLog { } } -/// Safely parse a JSON string that may contain doubly-encoded or malformed JSON. -/// This function first attempts to parse the input string as-is. If that fails, -/// it applies control character escaping and truncated JSON repair and tries again. -/// -/// This approach preserves valid JSON like `{"key1": "value1",\n"key2": "value"}` -/// (which contains a literal \n but is perfectly valid JSON) while still fixing -/// broken JSON like `{"key1": "value1\n","key2": "value"}` (which contains an -/// unescaped newline character). -pub fn safely_parse_json(s: &str) -> Result { - // First, try parsing the string as-is - match serde_json::from_str(s) { - Ok(value) => Ok(value), - Err(_) => { - for candidate in [ - repair_truncated_json(s), - json_escape_control_chars_in_string(s), - ] { - if let Ok(value) = serde_json::from_str(&candidate) { - return Ok(value); - } - } - - let repaired = repair_truncated_json(&json_escape_control_chars_in_string(s)); - serde_json::from_str(&repaired) - } - } -} - -fn repair_truncated_json(s: &str) -> String { - let mut repaired = String::with_capacity(s.len() + 8); - let mut in_string = false; - let mut escape_next = false; - let mut closers = Vec::new(); - - for c in s.chars() { - repaired.push(c); - - if in_string { - if escape_next { - escape_next = false; - continue; - } - - match c { - '\\' => escape_next = true, - '"' => in_string = false, - _ => {} - } - continue; - } - - match c { - '"' => in_string = true, - '{' => closers.push('}'), - '[' => closers.push(']'), - '}' | ']' => { - if closers.last() == Some(&c) { - closers.pop(); - } - } - _ => {} - } - } - - if in_string { - if escape_next { - repaired.push('\\'); - } - repaired.push('"'); - } - - while let Some(closer) = closers.pop() { - repaired.push(closer); - } - - repaired -} - -/// Helper to escape control characters in a string that is supposed to be a JSON document. -/// This function iterates through the input string `s` and replaces any literal -/// control characters (U+0000 to U+001F) with their JSON-escaped equivalents -/// (e.g., '\n' becomes "\\n", '\u0001' becomes "\\u0001"). -/// -/// It does NOT escape quotes (") or backslashes (\) because it assumes `s` is a -/// full JSON document, and these characters might be structural (e.g., object delimiters, -/// existing valid escape sequences). The goal is to fix common LLM errors where -/// control characters are emitted raw into what should be JSON string values, -/// making the overall JSON structure unparsable. -/// -/// If the input string `s` has other JSON syntax errors (e.g., an unescaped quote -/// *within* a string value like `{"key": "string with " quote"}`), this function -/// will not fix them. It specifically targets unescaped control characters. -pub fn json_escape_control_chars_in_string(s: &str) -> String { - let mut r = String::with_capacity(s.len()); // Pre-allocate for efficiency - for c in s.chars() { - match c { - // ASCII Control characters (U+0000 to U+001F) - '\u{0000}'..='\u{001F}' => { - match c { - '\u{0008}' => r.push_str("\\b"), // Backspace - '\u{000C}' => r.push_str("\\f"), // Form feed - '\n' => r.push_str("\\n"), // Line feed - '\r' => r.push_str("\\r"), // Carriage return - '\t' => r.push_str("\\t"), // Tab - // Other control characters (e.g., NUL, SOH, VT, etc.) - // that don't have a specific short escape sequence. - _ => { - r.push_str(&format!("\\u{:04x}", c as u32)); - } - } - } - // Other characters are passed through. - // This includes quotes (") and backslashes (\). If these are part of the - // JSON structure (e.g. {"key": "value"}) or part of an already correctly - // escaped sequence within a string value (e.g. "string with \\\" quote"), - // they are preserved as is. This function does not attempt to fix - // malformed quote or backslash usage *within* string values if the LLM - // generates them incorrectly (e.g. {"key": "unescaped " quote in string"}). - _ => r.push(c), - } - } - r -} - #[cfg(test)] mod tests { use super::*; @@ -659,11 +317,8 @@ mod tests { let logs_dir = Paths::in_state_dir("logs"); assert!(!logs_dir.exists(), "logs dir should not exist yet"); - let log = RequestLog::start( - &ModelConfig::new("test").unwrap(), - &json!({"model": "test"}), - ) - .expect("RequestLog::start should create missing logs dir"); + let log = RequestLog::start(json!({"name": "test"}), &json!({"model": "test"})) + .expect("RequestLog::start should create missing logs dir"); drop(log); assert!(logs_dir.is_dir(), "logs dir should have been created"); @@ -671,111 +326,6 @@ mod tests { std::env::remove_var("GOOSE_PATH_ROOT"); } - #[test] - fn test_detect_image_path() { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir().unwrap(); - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data).unwrap(); - let png_path_str = png_path.to_str().unwrap(); - - // Create a fake PNG (wrong magic numbers) - let fake_png_path = temp_dir.path().join("fake.png"); - std::fs::write(&fake_png_path, b"not a real png").unwrap(); - - // Test with valid PNG file using absolute path - let text = format!("Here is an image {}", png_path_str); - assert_eq!(detect_image_path(&text), Some(png_path_str)); - - // Test with non-image file that has .png extension - let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap()); - assert_eq!(detect_image_path(&text), None); - - // Test with nonexistent file - let text = "Here is a fake.png that doesn't exist"; - assert_eq!(detect_image_path(text), None); - - // Test with non-image file - let text = "Here is a file.txt"; - assert_eq!(detect_image_path(text), None); - - // Test with relative path (should not match) - let text = "Here is a relative/path/image.png"; - assert_eq!(detect_image_path(text), None); - } - - #[test] - fn test_load_image_file() { - // Create a temporary PNG file with valid PNG magic numbers - let temp_dir = tempfile::tempdir().unwrap(); - let png_path = temp_dir.path().join("test.png"); - let png_data = [ - 0x89, 0x50, 0x4E, 0x47, // PNG magic number - 0x0D, 0x0A, 0x1A, 0x0A, // PNG header - 0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data - ]; - std::fs::write(&png_path, png_data).unwrap(); - let png_path_str = png_path.to_str().unwrap(); - - // Create a fake PNG (wrong magic numbers) - let fake_png_path = temp_dir.path().join("fake.png"); - std::fs::write(&fake_png_path, b"not a real png").unwrap(); - let fake_png_path_str = fake_png_path.to_str().unwrap(); - - // Test loading valid PNG file - let result = load_image_file(png_path_str); - assert!(result.is_ok()); - let image = result.unwrap(); - assert_eq!(image.mime_type, "image/png"); - - // Test loading fake PNG file - let result = load_image_file(fake_png_path_str); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("not a valid image")); - - // Test nonexistent file - let result = load_image_file("nonexistent.png"); - assert!(result.is_err()); - - // Create a GIF file with valid header bytes - let gif_path = temp_dir.path().join("test.gif"); - // Minimal GIF89a header - let gif_data = [0x47, 0x49, 0x46, 0x38, 0x39, 0x61]; - std::fs::write(&gif_path, gif_data).unwrap(); - let gif_path_str = gif_path.to_str().unwrap(); - - // Test loading unsupported GIF format - let result = load_image_file(gif_path_str); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Unsupported image format")); - } - - #[test] - fn test_sanitize_function_name() { - assert_eq!(sanitize_function_name("hello-world"), "hello-world"); - assert_eq!(sanitize_function_name("hello world"), "hello_world"); - assert_eq!(sanitize_function_name("hello@world"), "hello_world"); - } - - #[test] - fn test_is_valid_function_name() { - assert!(is_valid_function_name("hello-world")); - assert!(is_valid_function_name("hello_world")); - assert!(!is_valid_function_name("hello world")); - assert!(!is_valid_function_name("hello@world")); - } - #[test] fn unescape_json_values_with_object() { let value = json!({"text": "Hello\\nWorld"}); @@ -892,98 +442,6 @@ mod tests { } } - #[test] - fn test_safely_parse_json() { - // Test valid JSON that should parse without escaping (contains proper escape sequence) - let valid_json = r#"{"key1": "value1","key2": "value2"}"#; - let result = safely_parse_json(valid_json).unwrap(); - assert_eq!(result["key1"], "value1"); - assert_eq!(result["key2"], "value2"); - - // Test JSON with actual unescaped newlines that needs escaping - let invalid_json = "{\"key1\": \"value1\n\",\"key2\": \"value2\"}"; - let result = safely_parse_json(invalid_json).unwrap(); - assert_eq!(result["key1"], "value1\n"); - assert_eq!(result["key2"], "value2"); - - // Test already valid JSON - should parse on first try - let good_json = r#"{"test": "value"}"#; - let result = safely_parse_json(good_json).unwrap(); - assert_eq!(result["test"], "value"); - - // Test truncated JSON with unclosed string, object, and array - let truncated_json = r#"{"key": "unclosed_string","nested": {"items": [1, 2, 3"#; - let result = safely_parse_json(truncated_json).unwrap(); - assert_eq!(result["key"], "unclosed_string"); - assert_eq!(result["nested"]["items"], json!([1, 2, 3])); - - // Test dangling backslash at end of a truncated string - let dangling_escape_json = String::from(r#"{"path":"abc\"#); - let result = safely_parse_json(&dangling_escape_json).unwrap(); - assert_eq!(result["path"], "abc\\"); - - // Test empty object - let empty_json = "{}"; - let result = safely_parse_json(empty_json).unwrap(); - assert!(result.as_object().unwrap().is_empty()); - - // Test JSON with escaped newlines (valid JSON) - should parse on first try - let escaped_json = r#"{"key": "value with\nnewline"}"#; - let result = safely_parse_json(escaped_json).unwrap(); - assert_eq!(result["key"], "value with\nnewline"); - } - - #[test] - fn test_json_escape_control_chars_in_string() { - // Test basic control character escaping - assert_eq!( - json_escape_control_chars_in_string("Hello\nWorld"), - "Hello\\nWorld" - ); - assert_eq!( - json_escape_control_chars_in_string("Hello\tWorld"), - "Hello\\tWorld" - ); - assert_eq!( - json_escape_control_chars_in_string("Hello\rWorld"), - "Hello\\rWorld" - ); - - // Test multiple control characters - assert_eq!( - json_escape_control_chars_in_string("Hello\n\tWorld\r"), - "Hello\\n\\tWorld\\r" - ); - - // Test that quotes and backslashes are preserved (not escaped) - assert_eq!( - json_escape_control_chars_in_string("Hello \"World\""), - "Hello \"World\"" - ); - assert_eq!( - json_escape_control_chars_in_string("Hello\\World"), - "Hello\\World" - ); - - // Test JSON-like string with control characters - assert_eq!( - json_escape_control_chars_in_string("{\"message\": \"Hello\nWorld\"}"), - "{\"message\": \"Hello\\nWorld\"}" - ); - - // Test no changes for normal strings - assert_eq!( - json_escape_control_chars_in_string("Hello World"), - "Hello World" - ); - - // Test other control characters get unicode escapes - assert_eq!( - json_escape_control_chars_in_string("Hello\u{0001}World"), - "Hello\\u0001World" - ); - } - #[test] fn test_parse_google_retry_delay() { let payload = json!({ @@ -1001,65 +459,4 @@ mod tests { Some(Duration::from_secs(42)) ); } - - #[test] - fn test_is_openai_responses_model_matches_o_and_gpt5_families() { - for model in [ - "o3", - "o3-mini", - "o4-mini", - "gpt-5", - "gpt-5-pro", - "gpt-5.4", - "gpt-5.4-mini", - "gpt-5-4", - "gpt-5-2-pro", - "databricks-gpt-5.4", - "goose-gpt-5.4-high", - "headless-goose-o3-mini", - ] { - assert!(is_openai_responses_model(model), "{model} should match"); - } - } - - #[test] - fn test_is_openai_responses_model_rejects_other_families() { - for model in [ - "gpt-4o", - "claude-sonnet-4", - "databricks-claude-sonnet-4", - "llama-3-70b", - ] { - assert!( - !is_openai_responses_model(model), - "{model} should not match" - ); - } - } - - #[test] - fn test_extract_reasoning_effort_for_responses_models() { - for (model, expected_name, expected_effort) in [ - ("o3-none", "o3", Some("none")), - ("o3-xhigh", "o3", Some("xhigh")), - ("gpt-5-low", "gpt-5", Some("low")), - ("gpt-5.4", "gpt-5.4", None), - ( - "databricks-gpt-5.4-high", - "databricks-gpt-5.4", - Some("high"), - ), - ("databricks-o3-low", "databricks-o3", Some("low")), - ("goose-gpt-5-high", "goose-gpt-5", Some("high")), - ("gpt-4o", "gpt-4o", None), - ] { - let (name, effort) = extract_reasoning_effort(model); - assert_eq!(name, expected_name, "unexpected base model for {model}"); - assert_eq!( - effort.as_deref(), - expected_effort, - "unexpected effort for {model}" - ); - } - } } diff --git a/crates/goose/src/providers/xai_oauth.rs b/crates/goose/src/providers/xai_oauth.rs index 4d2f2ad8b011..2e4e1afb0bc6 100644 --- a/crates/goose/src/providers/xai_oauth.rs +++ b/crates/goose/src/providers/xai_oauth.rs @@ -1,6 +1,5 @@ use super::api_client::{ApiClient, AuthMethod, AuthProvider}; use super::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; -use super::errors::ProviderError; use super::openai_compatible::OpenAiCompatibleProvider; use super::xai::{XAI_API_HOST, XAI_DEFAULT_MODEL, XAI_KNOWN_MODELS}; use crate::config::paths::Paths; @@ -12,6 +11,7 @@ use axum::{extract::Query, response::Html, routing::get, Router}; use base64::Engine; use chrono::{DateTime, Utc}; use futures::future::BoxFuture; +use goose_providers::errors::ProviderError; use rmcp::model::Tool; use serde::{Deserialize, Serialize}; use sha2::Digest; diff --git a/crates/goose/src/recipe/manifest.rs b/crates/goose/src/recipe/manifest.rs index f9d11406a902..092e41827f58 100644 --- a/crates/goose/src/recipe/manifest.rs +++ b/crates/goose/src/recipe/manifest.rs @@ -133,8 +133,8 @@ sub_recipes: let sub_recipes = recipe.sub_recipes.unwrap(); assert_eq!( - sub_recipes[0].path, - child_path.to_string_lossy().to_string() + fs::canonicalize(sub_recipes[0].path.clone()).unwrap(), + fs::canonicalize(child_path.to_string_lossy().to_string()).unwrap() ); } } diff --git a/crates/goose/tests/acp_custom_requests_test.rs b/crates/goose/tests/acp_custom_requests_test.rs index 9db97bc817d6..cf3388417e67 100644 --- a/crates/goose/tests/acp_custom_requests_test.rs +++ b/crates/goose/tests/acp_custom_requests_test.rs @@ -10,7 +10,7 @@ use common_tests::fixtures::{ use goose::acp::server::AcpProviderFactory; use goose::model::ModelConfig; use goose::providers::base::{MessageStream, Provider}; -use goose::providers::errors::ProviderError; +use goose_providers::errors::ProviderError; use goose_test_support::{EnforceSessionId, IgnoreSessionId}; use serial_test::serial; use std::path::PathBuf; diff --git a/crates/goose/tests/acp_secret_cache_invalidation_test.rs b/crates/goose/tests/acp_secret_cache_invalidation_test.rs index ef6a34df3905..850cd06a336c 100644 --- a/crates/goose/tests/acp_secret_cache_invalidation_test.rs +++ b/crates/goose/tests/acp_secret_cache_invalidation_test.rs @@ -8,9 +8,9 @@ use goose::config::paths::Paths; use goose::config::{Config, ConfigError}; use goose::model::ModelConfig; use goose::providers::base::{MessageStream, Provider}; -use goose::providers::errors::ProviderError; use goose::providers::inventory::ProviderInventoryService; use goose::session::session_manager::SessionStorage; +use goose_providers::errors::ProviderError; use goose_test_support::EnforceSessionId; use serial_test::serial; use std::sync::Arc; diff --git a/crates/goose/tests/agent.rs b/crates/goose/tests/agent.rs index 0667b03ee6c6..0fa4da3b0e11 100644 --- a/crates/goose/tests/agent.rs +++ b/crates/goose/tests/agent.rs @@ -345,10 +345,10 @@ mod tests { use goose::model::ModelConfig; use goose::providers::base::{ stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, - ProviderUsage, Usage, }; - use goose::providers::errors::ProviderError; use goose::session::session_manager::SessionType; + use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; + use goose_providers::errors::ProviderError; use rmcp::model::{CallToolRequestParams, Tool}; use rmcp::object; use std::path::PathBuf; @@ -506,10 +506,10 @@ mod tests { use goose::model::ModelConfig; use goose::providers::base::{ stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, - ProviderUsage, Usage, }; - use goose::providers::errors::ProviderError; use goose::session::session_manager::SessionType; + use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; + use goose_providers::errors::ProviderError; use rmcp::model::{AnnotateAble, CallToolRequestParams, CallToolResult, RawContent, Tool}; use std::path::PathBuf; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -857,12 +857,11 @@ mod tests { use goose::config::GooseMode; use goose::conversation::message::Message; use goose::model::ModelConfig; - use goose::providers::base::{ - MessageStream, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage, - }; - use goose::providers::errors::ProviderError; + use goose::providers::base::{MessageStream, Provider, ProviderDef, ProviderMetadata}; use goose::session::session_manager::SessionType; use goose::session::SessionManager; + use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; + use goose_providers::errors::ProviderError; use rmcp::model::{CallToolRequestParams, Role, Tool}; use rmcp::object; use std::path::PathBuf; @@ -1131,11 +1130,11 @@ mod tests { use goose::model::ModelConfig; use goose::providers::base::{ stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, - ProviderUsage, Usage, }; - use goose::providers::errors::ProviderError; use goose::session::session_manager::SessionType; use goose::session::SessionManager; + use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; + use goose_providers::errors::ProviderError; use rmcp::model::Tool; use std::path::PathBuf; use std::sync::atomic::{AtomicU32, Ordering}; @@ -1397,12 +1396,11 @@ mod tests { use goose::config::GooseMode; use goose::conversation::message::Message; use goose::model::ModelConfig; - use goose::providers::base::{ - stream_from_single_message, MessageStream, Provider, ProviderUsage, Usage, - }; - use goose::providers::errors::ProviderError; + use goose::providers::base::{stream_from_single_message, MessageStream, Provider}; use goose::session::session_manager::SessionType; use goose::session::SessionManager; + use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; + use goose_providers::errors::ProviderError; use rmcp::model::Tool; use std::path::PathBuf; use std::sync::Arc; diff --git a/crates/goose/tests/compaction.rs b/crates/goose/tests/compaction.rs index 67b16b34a6a5..a2d58aadd294 100644 --- a/crates/goose/tests/compaction.rs +++ b/crates/goose/tests/compaction.rs @@ -8,11 +8,11 @@ use goose::conversation::Conversation; use goose::model::ModelConfig; use goose::providers::base::{ stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, - ProviderUsage, Usage, }; -use goose::providers::errors::ProviderError; use goose::session::session_manager::SessionType; use goose::session::Session; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; use rmcp::model::Tool; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; diff --git a/crates/goose/tests/mcp_integration_test.rs b/crates/goose/tests/mcp_integration_test.rs index 00dca62a83aa..859383d7b461 100644 --- a/crates/goose/tests/mcp_integration_test.rs +++ b/crates/goose/tests/mcp_integration_test.rs @@ -21,9 +21,9 @@ use async_trait::async_trait; use goose::conversation::message::Message; use goose::providers::base::{ stream_from_single_message, MessageStream, Provider, ProviderDef, ProviderMetadata, - ProviderUsage, Usage, }; -use goose::providers::errors::ProviderError; +use goose_providers::conversation::token_usage::{ProviderUsage, Usage}; +use goose_providers::errors::ProviderError; use once_cell::sync::Lazy; use std::process::Command; diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index c25711e6e4e3..24f49d8aab4c 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -16,7 +16,6 @@ use goose::providers::claude_code::CLAUDE_CODE_DEFAULT_MODEL; use goose::providers::codex::CODEX_DEFAULT_MODEL; use goose::providers::create_with_named_model; use goose::providers::databricks::DATABRICKS_DEFAULT_MODEL; -use goose::providers::errors::ProviderError; use goose::providers::google::GOOGLE_DEFAULT_MODEL; use goose::providers::litellm::LITELLM_DEFAULT_MODEL; use goose::providers::openai::OPEN_AI_DEFAULT_MODEL; @@ -25,6 +24,7 @@ use goose::providers::sagemaker_tgi::SAGEMAKER_TGI_DEFAULT_MODEL; use goose::providers::snowflake::SNOWFLAKE_DEFAULT_MODEL; use goose::providers::xai::XAI_DEFAULT_MODEL; use goose::session::{SessionManager, SessionType}; +use goose_providers::errors::ProviderError; use goose_test_support::{ EnforceSessionId, ExpectedSessionId, IgnoreSessionId, McpFixture, FAKE_CODE, }; diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 398813947fd9..fda990138938 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -25,6 +25,11 @@ export type ActionRequiredData = { user_data: unknown; }; +export type AddExtensionRequest = { + config: ExtensionConfig; + session_id: string; +}; + export type Annotations = { audience?: Array; lastModified?: string; @@ -488,6 +493,17 @@ export type ExtensionLoadResult = { success: boolean; }; +export type ExtensionQuery = { + config: ExtensionConfig; + enabled: boolean; + name: string; +}; + +export type ExtensionResponse = { + extensions: Array; + warnings?: Array; +}; + export type FeaturesResponse = { /** * Map of feature name to enabled status @@ -1159,6 +1175,11 @@ export type RedactedThinkingContent = { data: string; }; +export type RemoveExtensionRequest = { + name: string; + session_id: string; +}; + export type RepoVariantsResponse = { available_memory_bytes: number; downloaded_quants: Array; @@ -1350,6 +1371,10 @@ export type SessionDisplayInfo = { workingDir: string; }; +export type SessionExtensionsResponse = { + extensions: Array; +}; + export type SessionInsights = { totalSessions: number; totalTokens: number; @@ -1779,6 +1804,37 @@ export type ConfirmToolActionResponses = { 200: unknown; }; +export type AgentAddExtensionData = { + body: AddExtensionRequest; + path?: never; + query?: never; + url: '/agent/add_extension'; +}; + +export type AgentAddExtensionErrors = { + /** + * Unauthorized - invalid secret key + */ + 401: unknown; + /** + * Agent not initialized + */ + 424: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type AgentAddExtensionResponses = { + /** + * Extension added + */ + 200: string; +}; + +export type AgentAddExtensionResponse = AgentAddExtensionResponses[keyof AgentAddExtensionResponses]; + export type CallToolData = { body: CallToolRequest; path?: never; @@ -1949,6 +2005,37 @@ export type ReadResourceResponses = { export type ReadResourceResponse2 = ReadResourceResponses[keyof ReadResourceResponses]; +export type AgentRemoveExtensionData = { + body: RemoveExtensionRequest; + path?: never; + query?: never; + url: '/agent/remove_extension'; +}; + +export type AgentRemoveExtensionErrors = { + /** + * Unauthorized - invalid secret key + */ + 401: unknown; + /** + * Agent not initialized + */ + 424: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type AgentRemoveExtensionResponses = { + /** + * Extension removed + */ + 200: string; +}; + +export type AgentRemoveExtensionResponse = AgentRemoveExtensionResponses[keyof AgentRemoveExtensionResponses]; + export type RestartAgentData = { body: RestartAgentRequest; path?: never; @@ -2384,6 +2471,89 @@ export type UpdateCustomProviderResponses = { export type UpdateCustomProviderResponse = UpdateCustomProviderResponses[keyof UpdateCustomProviderResponses]; +export type GetExtensionsData = { + body?: never; + path?: never; + query?: never; + url: '/config/extensions'; +}; + +export type GetExtensionsErrors = { + /** + * Internal server error + */ + 500: unknown; +}; + +export type GetExtensionsResponses = { + /** + * All extensions retrieved successfully + */ + 200: ExtensionResponse; +}; + +export type GetExtensionsResponse = GetExtensionsResponses[keyof GetExtensionsResponses]; + +export type AddExtensionData = { + body: ExtensionQuery; + path?: never; + query?: never; + url: '/config/extensions'; +}; + +export type AddExtensionErrors = { + /** + * Invalid request + */ + 400: unknown; + /** + * Could not serialize config.yaml + */ + 422: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type AddExtensionResponses = { + /** + * Extension added or updated successfully + */ + 200: string; +}; + +export type AddExtensionResponse = AddExtensionResponses[keyof AddExtensionResponses]; + +export type RemoveExtensionData = { + body?: never; + path: { + name: string; + }; + query?: never; + url: '/config/extensions/{name}'; +}; + +export type RemoveExtensionErrors = { + /** + * Extension not found + */ + 404: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type RemoveExtensionResponses = { + /** + * Extension removed successfully + */ + 200: string; +}; + +export type RemoveExtensionResponse = RemoveExtensionResponses[keyof RemoveExtensionResponses]; + export type UpsertPermissionsData = { body: UpsertPermissionsQuery; path?: never; @@ -4404,6 +4574,42 @@ export type ExportSessionResponses = { export type ExportSessionResponse = ExportSessionResponses[keyof ExportSessionResponses]; +export type GetSessionExtensionsData = { + body?: never; + path: { + /** + * Unique identifier for the session + */ + session_id: string; + }; + query?: never; + url: '/sessions/{session_id}/extensions'; +}; + +export type GetSessionExtensionsErrors = { + /** + * Unauthorized - Invalid or missing API key + */ + 401: unknown; + /** + * Session not found + */ + 404: unknown; + /** + * Internal server error + */ + 500: unknown; +}; + +export type GetSessionExtensionsResponses = { + /** + * Session extensions retrieved successfully + */ + 200: SessionExtensionsResponse; +}; + +export type GetSessionExtensionsResponse = GetSessionExtensionsResponses[keyof GetSessionExtensionsResponses]; + export type ForkSessionData = { body: ForkRequest; path: { diff --git a/ui/desktop/src/components/ConfigContext.tsx b/ui/desktop/src/components/ConfigContext.tsx index df80183efd3e..018f2f872e5b 100644 --- a/ui/desktop/src/components/ConfigContext.tsx +++ b/ui/desktop/src/components/ConfigContext.tsx @@ -1,16 +1,21 @@ import React, { createContext, useContext, useState, useEffect, useMemo, useCallback } from 'react'; -import { readAllConfig, readConfig, removeConfig, upsertConfig, providers } from '../api'; import { - getConfiguredExtensions, - addConfiguredExtension, - removeConfiguredExtension, -} from '../acp/extensions'; + readAllConfig, + readConfig, + removeConfig, + upsertConfig, + addExtension as apiAddExtension, + removeExtension as apiRemoveExtension, + providers, +} from '../api'; +import { getConfiguredExtensions } from '../acp/extensions'; import { pruneDeprecatedBundledExtensions, syncBundledExtensions } from './settings/extensions'; import type { ConfigResponse, UpsertConfigQuery, ConfigKeyQuery, ProviderDetails, + ExtensionQuery, ExtensionConfig, } from '../api'; @@ -108,7 +113,10 @@ export const ConfigProvider: React.FC = ({ children }) => { const addExtension = useCallback( async (name: string, config: ExtensionConfig, enabled: boolean) => { - await addConfiguredExtension(name, config, enabled); + const query: ExtensionQuery = { name, config, enabled }; + await apiAddExtension({ + body: query, + }); await reloadConfig(); // Refresh extensions list after successful addition await refreshExtensions(); @@ -118,7 +126,7 @@ export const ConfigProvider: React.FC = ({ children }) => { const removeExtension = useCallback( async (name: string) => { - await removeConfiguredExtension(name); + await apiRemoveExtension({ path: { name: name } }); await reloadConfig(); // Refresh extensions list after successful removal await refreshExtensions(); @@ -198,10 +206,11 @@ export const ConfigProvider: React.FC = ({ children }) => { config: ExtensionConfig, enabled: boolean ) => { - await addConfiguredExtension(name, config, enabled); + const query: ExtensionQuery = { name, config, enabled }; + await apiAddExtension({ body: query }); }; const removeExtensionForSync = async (name: string) => { - await removeConfiguredExtension(name); + await apiRemoveExtension({ path: { name } }); }; extensions = await pruneDeprecatedBundledExtensions(extensions, removeExtensionForSync); await syncBundledExtensions(extensions, addExtensionForSync); diff --git a/ui/desktop/src/components/settings/SettingsView.test.tsx b/ui/desktop/src/components/settings/SettingsView.test.tsx index dc6864f307fe..f0b24b488ec2 100644 --- a/ui/desktop/src/components/settings/SettingsView.test.tsx +++ b/ui/desktop/src/components/settings/SettingsView.test.tsx @@ -7,13 +7,21 @@ import SettingsView from './SettingsView'; import { IntlTestWrapper } from '../../i18n/test-utils'; vi.mock('../../api/sdk.gen', () => ({ - getTunnelStatus: vi.fn().mockResolvedValue({ data: { state: 'running' } }), + getTunnelStatus: vi.fn().mockResolvedValue({ data: { state: 'disabled' } }), })); vi.mock('../../utils/analytics', () => ({ trackSettingsTabViewed: vi.fn(), })); +vi.mock('../../contexts/FeaturesContext', () => ({ + useFeatures: () => ({ + localInference: false, + codeMode: true, + isLoading: false, + }), +})); + vi.mock('../Layout/MainPanelLayout', () => ({ MainPanelLayout: ({ children }: { children: React.ReactNode }) =>
{children}
, })); @@ -59,7 +67,7 @@ vi.mock('./config/ConfigSettings', () => ({ })); describe('SettingsView', () => { - it('hides local inference and mesh settings tabs from ApeCloud builds', async () => { + it('hides unavailable local inference and mesh settings tabs', async () => { render( { expect(screen.queryByTestId('settings-local-inference-tab')).not.toBeInTheDocument(); expect(screen.queryByTestId('settings-mesh-tab')).not.toBeInTheDocument(); - expect(screen.queryByTestId('settings-sharing-tab')).not.toBeInTheDocument(); expect(screen.getByText('Models section')).toBeInTheDocument(); }); }); diff --git a/ui/desktop/src/components/settings/extensions/agent-api.ts b/ui/desktop/src/components/settings/extensions/agent-api.ts index 064491f58ddf..bd38284873c4 100644 --- a/ui/desktop/src/components/settings/extensions/agent-api.ts +++ b/ui/desktop/src/components/settings/extensions/agent-api.ts @@ -1,6 +1,5 @@ import { toastService } from '../../../toasts'; -import { ExtensionConfig } from '../../../api'; -import { addSessionExtension, removeSessionExtension } from '../../../acp/extensions'; +import { agentAddExtension, ExtensionConfig, agentRemoveExtension } from '../../../api'; import { errorMessage } from '../../../utils/conversionUtils'; import { createExtensionRecoverHints, @@ -21,7 +20,10 @@ export async function addToAgent( : 0; try { - await addSessionExtension(sessionId, extensionConfig); + await agentAddExtension({ + body: { session_id: sessionId, config: extensionConfig }, + throwOnError: true, + }); if (showToast) { toastService.dismiss(toastId); toastService.success({ @@ -59,7 +61,10 @@ export async function removeFromAgent( : 0; try { - await removeSessionExtension(sessionId, extensionName); + await agentRemoveExtension({ + body: { session_id: sessionId, name: extensionName }, + throwOnError: true, + }); if (showToast) { toastService.dismiss(toastId); toastService.success({ diff --git a/ui/desktop/src/i18n/messages/en.json b/ui/desktop/src/i18n/messages/en.json index 0e418fa8faeb..e1ac03527bdf 100644 --- a/ui/desktop/src/i18n/messages/en.json +++ b/ui/desktop/src/i18n/messages/en.json @@ -4124,12 +4124,18 @@ "settingsView.tabKeyboard": { "defaultMessage": "Keyboard" }, + "settingsView.tabLocalInference": { + "defaultMessage": "Local Inference" + }, "settingsView.tabModels": { "defaultMessage": "Models" }, "settingsView.tabPrompts": { "defaultMessage": "Prompts" }, + "settingsView.tabSession": { + "defaultMessage": "Session" + }, "settingsView.title": { "defaultMessage": "Settings" }, diff --git a/ui/text/src/extensions.tsx b/ui/text/src/extensions.tsx index 82877b54cbd8..3928fb6091ba 100644 --- a/ui/text/src/extensions.tsx +++ b/ui/text/src/extensions.tsx @@ -1,7 +1,12 @@ import React, { useCallback, useEffect, useState } from "react"; import { Box, Text, useInput, useStdout } from "ink"; import { TextInput } from "@inkjs/ui"; -import type { GooseClient } from "@aaif/goose-sdk"; +import type { + GooseClient, + GooseExtension, + GooseExtensionEntry, + McpServerStdio, +} from "@aaif/goose-sdk"; import { CRANBERRY, GOLD, @@ -21,17 +26,63 @@ type ExtEntry = { [key: string]: unknown; }; -function isExtEntry(v: unknown): v is ExtEntry { - return ( - !!v && - typeof v === "object" && - "enabled" in v && - "type" in v && - "name" in v && - typeof (v as ExtEntry).enabled === "boolean" && - typeof (v as ExtEntry).type === "string" && - typeof (v as ExtEntry).name === "string" - ); +function entryToExtEntry(entry: GooseExtensionEntry): ExtEntry | null { + const ext = entry.extension; + if (ext.type !== "mcp") { + return { + enabled: entry.enabled, + type: ext.type, + name: ext.name, + description: ext.description ?? "", + display_name: ext.display_name ?? null, + timeout: "timeout" in ext ? (ext.timeout ?? null) : null, + bundled: ext.bundled ?? null, + }; + } + const server = ext.server; + if ("type" in server && server.type === "sse") return null; + const common = { + enabled: entry.enabled, + description: ext.description ?? "", + env_keys: ext.envKeys ?? [], + timeout: ext.timeout ?? null, + bundled: ext.bundled ?? null, + }; + if ("type" in server && server.type === "http") { + return { + ...common, + type: "streamable_http", + name: server.name, + uri: server.url, + headers: Object.fromEntries( + (server.headers ?? []).map((h) => [h.name, h.value]), + ), + socket: ext.socket ?? null, + }; + } + const stdio = server as McpServerStdio; + return { + ...common, + type: "stdio", + name: stdio.name, + cmd: stdio.command, + args: stdio.args, + }; +} + +function toGooseExtension(e: ExtEntry): GooseExtension { + if (e.type === "streamable_http") { + return { + type: "mcp", + server: { type: "http", name: e.name, url: String(e.uri ?? ""), headers: [] }, + description: e.description || undefined, + }; + } + return { + type: "mcp", + server: { name: e.name, command: String(e.cmd ?? ""), args: (e.args as string[]) ?? [], env: [] }, + description: e.description || undefined, + }; } type AddType = "stdio" | "streamable_http"; @@ -126,9 +177,9 @@ export default function ExtensionsManager({ client.goose.sessionExtensionsList_unstable({ sessionId }), ]); - const allExtensions = (configResp.extensions as unknown[]).filter( - isExtEntry, - ); + const allExtensions = (configResp.extensions as GooseExtensionEntry[]) + .map(entryToExtEntry) + .filter((e): e is ExtEntry => e !== null); const activeNames = new Set( (sessionResp.extensions as Array<{ name?: string }>).map((e) => e.name), ); @@ -188,8 +239,7 @@ export default function ExtensionsManager({ const config = buildConfig(addType, addValue, addName, description); withSaving(async () => { await client.goose.configExtensionsAdd_unstable({ - name: config.name, - extensionConfig: config as any, + extension: toGooseExtension(config), enabled: true, }); await client.goose.sessionExtensionsAdd_unstable({