diff --git a/Cargo.lock b/Cargo.lock index 5078c79e21ce1a..854a74f25adc13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5904,9 +5904,12 @@ dependencies = [ "async-trait", "client", "collections", + "credentials_provider", "criterion", "ctor", "dap", + "dirs 4.0.0", + "editor", "extension", "fs", "futures 0.3.31", @@ -5915,8 +5918,11 @@ dependencies = [ "http_client", "language", "language_extension", + "language_model", "log", "lsp", + "markdown", + "menu", "moka", "node_runtime", "parking_lot", @@ -5931,12 +5937,14 @@ dependencies = [ "serde_json", "serde_json_lenient", "settings", + "smol", "task", "telemetry", "tempfile", "theme", "theme_extension", "toml 0.8.23", + "ui", "url", "util", "wasmparser 0.221.3", @@ -8911,6 +8919,7 @@ dependencies = [ "credentials_provider", "deepseek", "editor", + "extension", "fs", "futures 0.3.31", "google_ai", diff --git a/crates/acp_thread/src/connection.rs b/crates/acp_thread/src/connection.rs index 8213786a182e1d..a188c0fbe88d5b 100644 --- a/crates/acp_thread/src/connection.rs +++ b/crates/acp_thread/src/connection.rs @@ -204,12 +204,21 @@ pub trait AgentModelSelector: 'static { } } +/// Icon for a model in the model selector. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AgentModelIcon { + /// A built-in icon from Zed's icon set. + Named(IconName), + /// Path to a custom SVG icon file. + Path(SharedString), +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct AgentModelInfo { pub id: acp::ModelId, pub name: SharedString, pub description: Option, - pub icon: Option, + pub icon: Option, } impl From for AgentModelInfo { diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index aec0767c25422d..7ebcd79f13e0ee 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -18,7 +18,7 @@ pub use templates::*; pub use thread::*; pub use tools::*; -use acp_thread::{AcpThread, AgentModelSelector}; +use acp_thread::{AcpThread, AgentModelIcon, AgentModelSelector}; use agent_client_protocol as acp; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; @@ -161,11 +161,16 @@ impl LanguageModels { model: &Arc, provider: &Arc, ) -> acp_thread::AgentModelInfo { + let icon = if let Some(path) = provider.icon_path() { + Some(AgentModelIcon::Path(path)) + } else { + Some(AgentModelIcon::Named(provider.icon())) + }; acp_thread::AgentModelInfo { id: Self::model_id(model), name: model.name().0, description: None, - icon: Some(provider.icon()), + icon, } } @@ -1356,7 +1361,7 @@ mod internal_tests { id: acp::ModelId::new("fake/fake"), name: "Fake".into(), description: None, - icon: Some(ui::IconName::ZedAssistant), + icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)), }] )]) ); diff --git a/crates/agent_ui/src/acp/model_selector.rs b/crates/agent_ui/src/acp/model_selector.rs index f9710ad9b3aac2..6b8e1d87a0934a 100644 --- a/crates/agent_ui/src/acp/model_selector.rs +++ b/crates/agent_ui/src/acp/model_selector.rs @@ -1,6 +1,6 @@ use std::{cmp::Reverse, rc::Rc, sync::Arc}; -use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector}; +use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector}; use agent_servers::AgentServer; use anyhow::Result; use collections::IndexMap; @@ -292,12 +292,18 @@ impl PickerDelegate for AcpModelPickerDelegate { h_flex() .w_full() .gap_1p5() - .when_some(model_info.icon, |this, icon| { - this.child( - Icon::new(icon) + .map(|this| match &model_info.icon { + Some(AgentModelIcon::Path(path)) => this.child( + Icon::from_path(path.clone()) .color(model_icon_color) - .size(IconSize::Small) - ) + .size(IconSize::Small), + ), + Some(AgentModelIcon::Named(icon)) => this.child( + Icon::new(*icon) + .color(model_icon_color) + .size(IconSize::Small), + ), + None => this, }) .child(Label::new(model_info.name.clone()).truncate()), ) diff --git a/crates/agent_ui/src/acp/model_selector_popover.rs b/crates/agent_ui/src/acp/model_selector_popover.rs index e2393c11bd6c23..7fd808bb2059fd 100644 --- a/crates/agent_ui/src/acp/model_selector_popover.rs +++ b/crates/agent_ui/src/acp/model_selector_popover.rs @@ -1,7 +1,7 @@ use std::rc::Rc; use std::sync::Arc; -use acp_thread::{AgentModelInfo, AgentModelSelector}; +use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector}; use agent_servers::AgentServer; use fs::Fs; use gpui::{Entity, FocusHandle}; @@ -64,7 +64,7 @@ impl Render for AcpModelSelectorPopover { .map(|model| model.name.clone()) .unwrap_or_else(|| SharedString::from("Select a Model")); - let model_icon = model.as_ref().and_then(|model| model.icon); + let model_icon = model.as_ref().and_then(|model| model.icon.clone()); let focus_handle = self.focus_handle.clone(); @@ -78,8 +78,13 @@ impl Render for AcpModelSelectorPopover { self.selector.clone(), ButtonLike::new("active-model") .selected_style(ButtonStyle::Tinted(TintColor::Accent)) - .when_some(model_icon, |this, icon| { - this.child(Icon::new(icon).color(color).size(IconSize::XSmall)) + .when_some(model_icon, |this, icon| match icon { + AgentModelIcon::Path(path) => { + this.child(Icon::from_path(path).color(color).size(IconSize::XSmall)) + } + AgentModelIcon::Named(icon_name) => { + this.child(Icon::new(icon_name).color(color).size(IconSize::XSmall)) + } }) .child( Label::new(model_name) diff --git a/crates/agent_ui/src/agent_configuration.rs b/crates/agent_ui/src/agent_configuration.rs index f831329e2cde40..3533c28caa93f8 100644 --- a/crates/agent_ui/src/agent_configuration.rs +++ b/crates/agent_ui/src/agent_configuration.rs @@ -260,11 +260,15 @@ impl AgentConfiguration { h_flex() .w_full() .gap_1p5() - .child( + .child(if let Some(icon_path) = provider.icon_path() { + Icon::from_external_svg(icon_path) + .size(IconSize::Small) + .color(Color::Muted) + } else { Icon::new(provider.icon()) .size(IconSize::Small) - .color(Color::Muted), - ) + .color(Color::Muted) + }) .child( h_flex() .w_full() diff --git a/crates/agent_ui/src/agent_model_selector.rs b/crates/agent_ui/src/agent_model_selector.rs index 43982cdda7bd88..924f37db0440dd 100644 --- a/crates/agent_ui/src/agent_model_selector.rs +++ b/crates/agent_ui/src/agent_model_selector.rs @@ -73,7 +73,8 @@ impl Render for AgentModelSelector { .map(|model| model.model.name().0) .unwrap_or_else(|| SharedString::from("Select a Model")); - let provider_icon = model.as_ref().map(|model| model.provider.icon()); + let provider_icon_path = model.as_ref().and_then(|model| model.provider.icon_path()); + let provider_icon_name = model.as_ref().map(|model| model.provider.icon()); let color = if self.menu_handle.is_deployed() { Color::Accent } else { @@ -85,8 +86,17 @@ impl Render for AgentModelSelector { PickerPopoverMenu::new( self.selector.clone(), ButtonLike::new("active-model") - .when_some(provider_icon, |this, icon| { - this.child(Icon::new(icon).color(color).size(IconSize::XSmall)) + .when_some(provider_icon_path.clone(), |this, icon_path| { + this.child( + Icon::from_external_svg(icon_path) + .color(color) + .size(IconSize::XSmall), + ) + }) + .when(provider_icon_path.is_none(), |this| { + this.when_some(provider_icon_name, |this, icon| { + this.child(Icon::new(icon).color(color).size(IconSize::XSmall)) + }) }) .selected_style(ButtonStyle::Tinted(TintColor::Accent)) .child( diff --git a/crates/agent_ui/src/agent_ui.rs b/crates/agent_ui/src/agent_ui.rs index f7b07b7bd393b8..4a5382c9e4d67e 100644 --- a/crates/agent_ui/src/agent_ui.rs +++ b/crates/agent_ui/src/agent_ui.rs @@ -346,9 +346,13 @@ fn init_language_model_settings(cx: &mut App) { cx.subscribe( &LanguageModelRegistry::global(cx), |_, event: &language_model::Event, cx| match event { - language_model::Event::ProviderStateChanged(_) - | language_model::Event::AddedProvider(_) - | language_model::Event::RemovedProvider(_) => { + language_model::Event::ProviderStateChanged(_) => { + update_active_language_model_from_settings(cx); + } + language_model::Event::AddedProvider(_) => { + update_active_language_model_from_settings(cx); + } + language_model::Event::RemovedProvider(_) => { update_active_language_model_from_settings(cx); } _ => {} @@ -367,26 +371,49 @@ fn update_active_language_model_from_settings(cx: &mut App) { } } - let default = settings.default_model.as_ref().map(to_selected_model); + // Filter out models from providers that are not authenticated + fn is_provider_authenticated( + selection: &LanguageModelSelection, + registry: &LanguageModelRegistry, + cx: &App, + ) -> bool { + let provider_id = LanguageModelProviderId::from(selection.provider.0.clone()); + registry + .provider(&provider_id) + .map_or(false, |provider| provider.is_authenticated(cx)) + } + + let registry = LanguageModelRegistry::global(cx); + let registry_ref = registry.read(cx); + + let default = settings + .default_model + .as_ref() + .filter(|s| is_provider_authenticated(s, registry_ref, cx)) + .map(to_selected_model); let inline_assistant = settings .inline_assistant_model .as_ref() + .filter(|s| is_provider_authenticated(s, registry_ref, cx)) .map(to_selected_model); let commit_message = settings .commit_message_model .as_ref() + .filter(|s| is_provider_authenticated(s, registry_ref, cx)) .map(to_selected_model); let thread_summary = settings .thread_summary_model .as_ref() + .filter(|s| is_provider_authenticated(s, registry_ref, cx)) .map(to_selected_model); let inline_alternatives = settings .inline_alternatives .iter() + .filter(|s| is_provider_authenticated(s, registry_ref, cx)) .map(to_selected_model) .collect::>(); - LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.update(cx, |registry, cx| { registry.select_default_model(default.as_ref(), cx); registry.select_inline_assistant_model(inline_assistant.as_ref(), cx); registry.select_commit_message_model(commit_message.as_ref(), cx); diff --git a/crates/agent_ui/src/language_model_selector.rs b/crates/agent_ui/src/language_model_selector.rs index 5b5a4513c6dca3..9fd717a597e149 100644 --- a/crates/agent_ui/src/language_model_selector.rs +++ b/crates/agent_ui/src/language_model_selector.rs @@ -1,13 +1,12 @@ use std::{cmp::Reverse, sync::Arc}; use collections::IndexMap; +use futures::{StreamExt, channel::mpsc}; use fuzzy::{StringMatch, StringMatchCandidate, match_strings}; -use gpui::{ - Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task, -}; +use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task}; use language_model::{ - AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId, - LanguageModelRegistry, + AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider, + LanguageModelProviderId, LanguageModelRegistry, }; use ordered_float::OrderedFloat; use picker::{Picker, PickerDelegate}; @@ -57,12 +56,12 @@ fn all_models(cx: &App) -> GroupedModels { .into_iter() .map(|model| ModelInfo { model, - icon: provider.icon(), + icon: ProviderIcon::from_provider(provider.as_ref()), }) }) .collect(); - let all = providers + let all: Vec = providers .iter() .flat_map(|provider| { provider @@ -70,7 +69,7 @@ fn all_models(cx: &App) -> GroupedModels { .into_iter() .map(|model| ModelInfo { model, - icon: provider.icon(), + icon: ProviderIcon::from_provider(provider.as_ref()), }) }) .collect(); @@ -78,10 +77,26 @@ fn all_models(cx: &App) -> GroupedModels { GroupedModels::new(all, recommended) } +#[derive(Clone)] +enum ProviderIcon { + Name(IconName), + Path(SharedString), +} + +impl ProviderIcon { + fn from_provider(provider: &dyn LanguageModelProvider) -> Self { + if let Some(path) = provider.icon_path() { + Self::Path(path) + } else { + Self::Name(provider.icon()) + } + } +} + #[derive(Clone)] struct ModelInfo { model: Arc, - icon: IconName, + icon: ProviderIcon, } pub struct LanguageModelPickerDelegate { @@ -91,7 +106,7 @@ pub struct LanguageModelPickerDelegate { filtered_entries: Vec, selected_index: usize, _authenticate_all_providers_task: Task<()>, - _subscriptions: Vec, + _refresh_models_task: Task<()>, popover_styles: bool, focus_handle: FocusHandle, } @@ -116,24 +131,40 @@ impl LanguageModelPickerDelegate { filtered_entries: entries, get_active_model: Arc::new(get_active_model), _authenticate_all_providers_task: Self::authenticate_all_providers(cx), - _subscriptions: vec![cx.subscribe_in( - &LanguageModelRegistry::global(cx), - window, - |picker, _, event, window, cx| { - match event { - language_model::Event::ProviderStateChanged(_) - | language_model::Event::AddedProvider(_) - | language_model::Event::RemovedProvider(_) => { - let query = picker.query(cx); + _refresh_models_task: { + // Create a channel to signal when models need refreshing + let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>(); + + // Subscribe to registry events and send refresh signals through the channel + let registry = LanguageModelRegistry::global(cx); + cx.subscribe(®istry, move |_picker, _, event, _cx| match event { + language_model::Event::ProviderStateChanged(_) => { + refresh_tx.unbounded_send(()).ok(); + } + language_model::Event::AddedProvider(_) => { + refresh_tx.unbounded_send(()).ok(); + } + language_model::Event::RemovedProvider(_) => { + refresh_tx.unbounded_send(()).ok(); + } + _ => {} + }) + .detach(); + + // Spawn a task that listens for refresh signals and updates the picker + cx.spawn_in(window, async move |this, cx| { + while let Some(()) = refresh_rx.next().await { + let result = this.update_in(cx, |picker, window, cx| { picker.delegate.all_models = Arc::new(all_models(cx)); - // Update matches will automatically drop the previous task - // if we get a provider event again - picker.update_matches(query, window, cx) + picker.refresh(window, cx); + }); + if result.is_err() { + // Picker was dropped, exit the loop + break; } - _ => {} } - }, - )], + }) + }, popover_styles, focus_handle, } @@ -504,11 +535,16 @@ impl PickerDelegate for LanguageModelPickerDelegate { h_flex() .w_full() .gap_1p5() - .child( - Icon::new(model_info.icon) + .child(match &model_info.icon { + ProviderIcon::Name(icon_name) => Icon::new(*icon_name) .color(model_icon_color) .size(IconSize::Small), - ) + ProviderIcon::Path(icon_path) => { + Icon::from_external_svg(icon_path.clone()) + .color(model_icon_color) + .size(IconSize::Small) + } + }) .child(Label::new(model_info.model.name().0).truncate()), ) .end_slot(div().pr_3().when(is_selected, |this| { @@ -657,7 +693,7 @@ mod tests { .into_iter() .map(|(provider, name)| ModelInfo { model: Arc::new(TestLanguageModel::new(name, provider)), - icon: IconName::Ai, + icon: ProviderIcon::Name(IconName::Ai), }) .collect() } diff --git a/crates/agent_ui/src/text_thread_editor.rs b/crates/agent_ui/src/text_thread_editor.rs index 161fad95e68c01..30538898b28a1d 100644 --- a/crates/agent_ui/src/text_thread_editor.rs +++ b/crates/agent_ui/src/text_thread_editor.rs @@ -2097,7 +2097,8 @@ impl TextThreadEditor { .default_model() .map(|default| default.provider); - let provider_icon = match active_provider { + let provider_icon_path = active_provider.as_ref().and_then(|p| p.icon_path()); + let provider_icon_name = match &active_provider { Some(provider) => provider.icon(), None => IconName::Ai, }; @@ -2109,6 +2110,16 @@ impl TextThreadEditor { (Color::Muted, IconName::ChevronDown) }; + let provider_icon_element = if let Some(icon_path) = provider_icon_path { + Icon::from_external_svg(icon_path) + .color(color) + .size(IconSize::XSmall) + } else { + Icon::new(provider_icon_name) + .color(color) + .size(IconSize::XSmall) + }; + PickerPopoverMenu::new( self.language_model_selector.clone(), ButtonLike::new("active-model") @@ -2116,7 +2127,7 @@ impl TextThreadEditor { .child( h_flex() .gap_0p5() - .child(Icon::new(provider_icon).color(color).size(IconSize::XSmall)) + .child(provider_icon_element) .child( Label::new(model_name) .color(color) diff --git a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs index fadc4222ae44f3..bdf1ce3640bf50 100644 --- a/crates/ai_onboarding/src/agent_api_keys_onboarding.rs +++ b/crates/ai_onboarding/src/agent_api_keys_onboarding.rs @@ -1,9 +1,25 @@ use gpui::{Action, IntoElement, ParentElement, RenderOnce, point}; -use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; +use language_model::{LanguageModelProvider, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID}; use ui::{Divider, List, ListBulletItem, prelude::*}; +#[derive(Clone)] +enum ProviderIcon { + Name(IconName), + Path(SharedString), +} + +impl ProviderIcon { + fn from_provider(provider: &dyn LanguageModelProvider) -> Self { + if let Some(path) = provider.icon_path() { + Self::Path(path) + } else { + Self::Name(provider.icon()) + } + } +} + pub struct ApiKeysWithProviders { - configured_providers: Vec<(IconName, SharedString)>, + configured_providers: Vec<(ProviderIcon, SharedString)>, } impl ApiKeysWithProviders { @@ -26,14 +42,19 @@ impl ApiKeysWithProviders { } } - fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> { + fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> { LanguageModelRegistry::read_global(cx) .providers() .iter() .filter(|provider| { provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID }) - .map(|provider| (provider.icon(), provider.name().0)) + .map(|provider| { + ( + ProviderIcon::from_provider(provider.as_ref()), + provider.name().0, + ) + }) .collect() } } @@ -47,7 +68,14 @@ impl Render for ApiKeysWithProviders { .map(|(icon, name)| { h_flex() .gap_1p5() - .child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted)) + .child(match icon { + ProviderIcon::Name(icon_name) => Icon::new(icon_name) + .size(IconSize::XSmall) + .color(Color::Muted), + ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path) + .size(IconSize::XSmall) + .color(Color::Muted), + }) .child(Label::new(name)) }); div() diff --git a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs index 3c8ffc1663e066..ae92268ff4db45 100644 --- a/crates/ai_onboarding/src/agent_panel_onboarding_content.rs +++ b/crates/ai_onboarding/src/agent_panel_onboarding_content.rs @@ -11,7 +11,7 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding}; pub struct AgentPanelOnboarding { user_store: Entity, client: Arc, - configured_providers: Vec<(IconName, SharedString)>, + has_configured_providers: bool, continue_with_zed_ai: Arc, } @@ -28,7 +28,7 @@ impl AgentPanelOnboarding { language_model::Event::ProviderStateChanged(_) | language_model::Event::AddedProvider(_) | language_model::Event::RemovedProvider(_) => { - this.configured_providers = Self::compute_available_providers(cx) + this.has_configured_providers = Self::has_configured_providers(cx) } _ => {} }, @@ -38,20 +38,16 @@ impl AgentPanelOnboarding { Self { user_store, client, - configured_providers: Self::compute_available_providers(cx), + has_configured_providers: Self::has_configured_providers(cx), continue_with_zed_ai: Arc::new(continue_with_zed_ai), } } - fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> { + fn has_configured_providers(cx: &App) -> bool { LanguageModelRegistry::read_global(cx) .providers() .iter() - .filter(|provider| { - provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID - }) - .map(|provider| (provider.icon(), provider.name().0)) - .collect() + .any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID) } } @@ -81,7 +77,7 @@ impl Render for AgentPanelOnboarding { }), ) .map(|this| { - if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() { + if enrolled_in_trial || is_pro_user || self.has_configured_providers { this } else { this.child(ApiKeysWithoutProviders::new()) diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 041401418c4272..06e25253ee626b 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -8,7 +8,7 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B use http_client::http::{self, HeaderMap, HeaderValue}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode}; use serde::{Deserialize, Serialize}; -pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode}; +pub use settings::ModelMode; use strum::{EnumIter, EnumString}; use thiserror::Error; diff --git a/crates/extension/src/extension_host_proxy.rs b/crates/extension/src/extension_host_proxy.rs index 6a24e3ba3f496b..70b2da14b47d3a 100644 --- a/crates/extension/src/extension_host_proxy.rs +++ b/crates/extension/src/extension_host_proxy.rs @@ -29,6 +29,7 @@ pub struct ExtensionHostProxy { slash_command_proxy: RwLock>>, context_server_proxy: RwLock>>, debug_adapter_provider_proxy: RwLock>>, + language_model_provider_proxy: RwLock>>, } impl ExtensionHostProxy { @@ -54,6 +55,7 @@ impl ExtensionHostProxy { slash_command_proxy: RwLock::default(), context_server_proxy: RwLock::default(), debug_adapter_provider_proxy: RwLock::default(), + language_model_provider_proxy: RwLock::default(), } } @@ -90,6 +92,15 @@ impl ExtensionHostProxy { .write() .replace(Arc::new(proxy)); } + + pub fn register_language_model_provider_proxy( + &self, + proxy: impl ExtensionLanguageModelProviderProxy, + ) { + self.language_model_provider_proxy + .write() + .replace(Arc::new(proxy)); + } } pub trait ExtensionThemeProxy: Send + Sync + 'static { @@ -375,6 +386,49 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static { fn unregister_context_server(&self, server_id: Arc, cx: &mut App); } +/// A function that registers a language model provider with the registry. +/// This allows extension_host to create the provider (which requires WasmExtension) +/// and pass a registration closure to the language_models crate. +pub type LanguageModelProviderRegistration = Box; + +pub trait ExtensionLanguageModelProviderProxy: Send + Sync + 'static { + /// Register an LLM provider from an extension. + /// The `register_fn` closure will be called with the App context and should + /// register the provider with the LanguageModelRegistry. + fn register_language_model_provider( + &self, + provider_id: Arc, + register_fn: LanguageModelProviderRegistration, + cx: &mut App, + ); + + /// Unregister an LLM provider when an extension is unloaded. + fn unregister_language_model_provider(&self, provider_id: Arc, cx: &mut App); +} + +impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy { + fn register_language_model_provider( + &self, + provider_id: Arc, + register_fn: LanguageModelProviderRegistration, + cx: &mut App, + ) { + let Some(proxy) = self.language_model_provider_proxy.read().clone() else { + return; + }; + + proxy.register_language_model_provider(provider_id, register_fn, cx) + } + + fn unregister_language_model_provider(&self, provider_id: Arc, cx: &mut App) { + let Some(proxy) = self.language_model_provider_proxy.read().clone() else { + return; + }; + + proxy.unregister_language_model_provider(provider_id, cx) + } +} + impl ExtensionContextServerProxy for ExtensionHostProxy { fn register_context_server( &self, diff --git a/crates/extension/src/extension_manifest.rs b/crates/extension/src/extension_manifest.rs index 4ecdd378ca86db..3a09a602d5b461 100644 --- a/crates/extension/src/extension_manifest.rs +++ b/crates/extension/src/extension_manifest.rs @@ -93,6 +93,8 @@ pub struct ExtensionManifest { pub debug_adapters: BTreeMap, DebugAdapterManifestEntry>, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] pub debug_locators: BTreeMap, DebugLocatorManifestEntry>, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + pub language_model_providers: BTreeMap, LanguageModelProviderManifestEntry>, } impl ExtensionManifest { @@ -288,6 +290,68 @@ pub struct DebugAdapterManifestEntry { #[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] pub struct DebugLocatorManifestEntry {} +/// Manifest entry for a language model provider. +#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] +pub struct LanguageModelProviderManifestEntry { + /// Display name for the provider. + pub name: String, + /// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg"). + #[serde(default)] + pub icon: Option, + /// Default models to show even before API connection. + #[serde(default)] + pub models: Vec, + /// Authentication configuration. + #[serde(default)] + pub auth: Option, +} + +/// Manifest entry for a language model. +#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] +pub struct LanguageModelManifestEntry { + /// Unique identifier for the model. + pub id: String, + /// Display name for the model. + pub name: String, + /// Maximum input token count. + #[serde(default)] + pub max_token_count: u64, + /// Maximum output tokens (optional). + #[serde(default)] + pub max_output_tokens: Option, + /// Whether the model supports image inputs. + #[serde(default)] + pub supports_images: bool, + /// Whether the model supports tool/function calling. + #[serde(default)] + pub supports_tools: bool, + /// Whether the model supports extended thinking/reasoning. + #[serde(default)] + pub supports_thinking: bool, +} + +/// Authentication configuration for a language model provider. +#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] +pub struct LanguageModelAuthConfig { + /// Environment variable name for the API key. + #[serde(default)] + pub env_var: Option, + /// Human-readable name for the credential shown in the UI input field (e.g., "API Key", "Access Token"). + #[serde(default)] + pub credential_label: Option, + /// OAuth configuration for web-based authentication flows. + #[serde(default)] + pub oauth: Option, +} + +/// OAuth configuration for web-based authentication. +#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)] +pub struct OAuthConfig { + /// The text to display on the sign-in button (e.g., "Sign in with GitHub"). + #[serde(default)] + pub sign_in_button_label: Option, +} + impl ExtensionManifest { pub async fn load(fs: Arc, extension_dir: &Path) -> Result { let extension_name = extension_dir @@ -358,6 +422,7 @@ fn manifest_from_old_manifest( capabilities: Vec::new(), debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: Default::default(), } } @@ -391,6 +456,7 @@ mod tests { capabilities: vec![], debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), } } diff --git a/crates/extension_api/src/extension_api.rs b/crates/extension_api/src/extension_api.rs index 9418623224289f..555ba6dcc260b6 100644 --- a/crates/extension_api/src/extension_api.rs +++ b/crates/extension_api/src/extension_api.rs @@ -29,6 +29,27 @@ pub use wit::{ GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name, latest_github_release, }, + zed::extension::llm_provider::{ + CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent, + CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType, + ImageData as LlmImageData, MessageContent as LlmMessageContent, + MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities, + ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest, + OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig, + OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo, + RequestMessage as LlmRequestMessage, StopReason as LlmStopReason, + ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage, + ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition, + ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult, + ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse, + ToolUseJsonParseError as LlmToolUseJsonParseError, + delete_credential as llm_delete_credential, get_credential as llm_get_credential, + get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser, + oauth_start_web_auth as llm_oauth_start_web_auth, + request_credential as llm_request_credential, + send_oauth_http_request as llm_oauth_http_request, + store_credential as llm_store_credential, + }, zed::extension::nodejs::{ node_binary_path, npm_install_package, npm_package_installed_version, npm_package_latest_version, @@ -259,6 +280,101 @@ pub trait Extension: Send + Sync { ) -> Result { Err("`run_dap_locator` not implemented".to_string()) } + + /// Returns information about language model providers offered by this extension. + fn llm_providers(&self) -> Vec { + Vec::new() + } + + /// Returns the models available for a provider. + fn llm_provider_models(&self, _provider_id: &str) -> Result, String> { + Ok(Vec::new()) + } + + /// Returns markdown content to display in the provider's settings UI. + /// This can include setup instructions, links to documentation, etc. + fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { + None + } + + /// Check if the provider is authenticated. + fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { + false + } + + /// Attempt to authenticate the provider. + /// This is called for background credential checks - it should check for + /// existing credentials and return Ok if found, or an error if not. + fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> { + Err("`llm_provider_authenticate` not implemented".to_string()) + } + + /// Start an OAuth device flow sign-in. + /// This is called when the user explicitly clicks "Sign in with GitHub" or similar. + /// Opens the browser to the verification URL and returns the user code that should + /// be displayed to the user. + fn llm_provider_start_device_flow_sign_in( + &mut self, + _provider_id: &str, + ) -> Result { + Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string()) + } + + /// Poll for device flow sign-in completion. + /// This is called after llm_provider_start_device_flow_sign_in returns the user code. + /// The extension should poll the OAuth provider until the user authorizes or the flow times out. + fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> { + Err("`llm_provider_poll_device_flow_sign_in` not implemented".to_string()) + } + + /// Reset credentials for the provider. + fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { + Err("`llm_provider_reset_credentials` not implemented".to_string()) + } + + /// Count tokens for a request. + fn llm_count_tokens( + &self, + _provider_id: &str, + _model_id: &str, + _request: &LlmCompletionRequest, + ) -> Result { + Err("`llm_count_tokens` not implemented".to_string()) + } + + /// Start streaming a completion from the model. + /// Returns a stream ID that can be used with `llm_stream_completion_next` and `llm_stream_completion_close`. + fn llm_stream_completion_start( + &mut self, + _provider_id: &str, + _model_id: &str, + _request: &LlmCompletionRequest, + ) -> Result { + Err("`llm_stream_completion_start` not implemented".to_string()) + } + + /// Get the next event from a completion stream. + /// Returns `Ok(None)` when the stream is complete. + fn llm_stream_completion_next( + &mut self, + _stream_id: &str, + ) -> Result, String> { + Err("`llm_stream_completion_next` not implemented".to_string()) + } + + /// Close a completion stream and release its resources. + fn llm_stream_completion_close(&mut self, _stream_id: &str) { + // Default implementation does nothing + } + + /// Get cache configuration for a model (if prompt caching is supported). + fn llm_cache_configuration( + &self, + _provider_id: &str, + _model_id: &str, + ) -> Option { + None + } } /// Registers the provided type as a Zed extension. @@ -518,6 +634,69 @@ impl wit::Guest for Component { ) -> Result { extension().run_dap_locator(locator_name, build_task) } + + fn llm_providers() -> Vec { + extension().llm_providers() + } + + fn llm_provider_models(provider_id: String) -> Result, String> { + extension().llm_provider_models(&provider_id) + } + + fn llm_provider_settings_markdown(provider_id: String) -> Option { + extension().llm_provider_settings_markdown(&provider_id) + } + + fn llm_provider_is_authenticated(provider_id: String) -> bool { + extension().llm_provider_is_authenticated(&provider_id) + } + + fn llm_provider_authenticate(provider_id: String) -> Result<(), String> { + extension().llm_provider_authenticate(&provider_id) + } + + fn llm_provider_start_device_flow_sign_in(provider_id: String) -> Result { + extension().llm_provider_start_device_flow_sign_in(&provider_id) + } + + fn llm_provider_poll_device_flow_sign_in(provider_id: String) -> Result<(), String> { + extension().llm_provider_poll_device_flow_sign_in(&provider_id) + } + + fn llm_provider_reset_credentials(provider_id: String) -> Result<(), String> { + extension().llm_provider_reset_credentials(&provider_id) + } + + fn llm_count_tokens( + provider_id: String, + model_id: String, + request: LlmCompletionRequest, + ) -> Result { + extension().llm_count_tokens(&provider_id, &model_id, &request) + } + + fn llm_stream_completion_start( + provider_id: String, + model_id: String, + request: LlmCompletionRequest, + ) -> Result { + extension().llm_stream_completion_start(&provider_id, &model_id, &request) + } + + fn llm_stream_completion_next(stream_id: String) -> Result, String> { + extension().llm_stream_completion_next(&stream_id) + } + + fn llm_stream_completion_close(stream_id: String) { + extension().llm_stream_completion_close(&stream_id) + } + + fn llm_cache_configuration( + provider_id: String, + model_id: String, + ) -> Option { + extension().llm_cache_configuration(&provider_id, &model_id) + } } /// The ID of a language server. diff --git a/crates/extension_api/wit/since_v0.8.0/extension.wit b/crates/extension_api/wit/since_v0.8.0/extension.wit index 8195162b89a420..ef9f464d29d802 100644 --- a/crates/extension_api/wit/since_v0.8.0/extension.wit +++ b/crates/extension_api/wit/since_v0.8.0/extension.wit @@ -8,6 +8,7 @@ world extension { import platform; import process; import nodejs; + import llm-provider; use common.{env-vars, range}; use context-server.{context-server-configuration}; @@ -15,6 +16,10 @@ world extension { use lsp.{completion, symbol}; use process.{command}; use slash-command.{slash-command, slash-command-argument-completion, slash-command-output}; + use llm-provider.{ + provider-info, model-info, completion-request, + credential-type, cache-configuration, completion-event, token-usage + }; /// Initializes the extension. export init-extension: func(); @@ -164,4 +169,80 @@ world extension { export dap-config-to-scenario: func(config: debug-config) -> result; export dap-locator-create-scenario: func(locator-name: string, build-config-template: build-task-template, resolved-label: string, debug-adapter-name: string) -> option; export run-dap-locator: func(locator-name: string, config: resolved-task) -> result; + + /// Returns information about language model providers offered by this extension. + export llm-providers: func() -> list; + + /// Returns the models available for a provider. + export llm-provider-models: func(provider-id: string) -> result, string>; + + /// Returns markdown content to display in the provider's settings UI. + /// This can include setup instructions, links to documentation, etc. + export llm-provider-settings-markdown: func(provider-id: string) -> option; + + /// Check if the provider is authenticated. + export llm-provider-is-authenticated: func(provider-id: string) -> bool; + + /// Attempt to authenticate the provider. + /// This is called for background credential checks - it should check for + /// existing credentials and return Ok if found, or an error if not. + /// For interactive OAuth flows, use the device flow functions instead. + export llm-provider-authenticate: func(provider-id: string) -> result<_, string>; + + /// Start an OAuth device flow sign-in. + /// This is called when the user explicitly clicks "Sign in with GitHub" or similar. + /// + /// The device flow works as follows: + /// 1. Extension requests a device code from the OAuth provider + /// 2. Extension opens the verification URL in the browser + /// 3. Extension returns the user code to display to the user + /// 4. Host displays the user code and calls llm-provider-poll-device-flow-sign-in + /// 5. Extension polls for the access token while user authorizes in browser + /// 6. Once authorized, extension stores the credential and returns success + /// + /// Returns the user code that should be displayed to the user while they + /// complete authorization in the browser. + export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result; + + /// Poll for device flow sign-in completion. + /// This is called after llm-provider-start-device-flow-sign-in returns the user code. + /// The extension should poll the OAuth provider until the user authorizes or the flow times out. + /// Returns Ok(()) on successful authentication, or an error message on failure. + export llm-provider-poll-device-flow-sign-in: func(provider-id: string) -> result<_, string>; + + /// Reset credentials for the provider. + export llm-provider-reset-credentials: func(provider-id: string) -> result<_, string>; + + /// Count tokens for a request. + export llm-count-tokens: func( + provider-id: string, + model-id: string, + request: completion-request + ) -> result; + + /// Start streaming a completion from the model. + /// Returns a stream ID that can be used with llm-stream-next and llm-stream-close. + export llm-stream-completion-start: func( + provider-id: string, + model-id: string, + request: completion-request + ) -> result; + + /// Get the next event from a completion stream. + /// Returns None when the stream is complete. + export llm-stream-completion-next: func( + stream-id: string + ) -> result, string>; + + /// Close a completion stream and release its resources. + export llm-stream-completion-close: func( + stream-id: string + ); + + /// Get cache configuration for a model (if prompt caching is supported). + export llm-cache-configuration: func( + provider-id: string, + model-id: string + ) -> option; + } diff --git a/crates/extension_api/wit/since_v0.8.0/llm-provider.wit b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit new file mode 100644 index 00000000000000..a3f1258fc78850 --- /dev/null +++ b/crates/extension_api/wit/since_v0.8.0/llm-provider.wit @@ -0,0 +1,348 @@ +interface llm-provider { + /// Information about a language model provider. + record provider-info { + /// Unique identifier for the provider (e.g., "my-extension.my-provider"). + id: string, + /// Display name for the provider. + name: string, + /// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg"). + icon: option, + } + + /// Capabilities of a language model. + record model-capabilities { + /// Whether the model supports image inputs. + supports-images: bool, + /// Whether the model supports tool/function calling. + supports-tools: bool, + /// Whether the model supports the "auto" tool choice. + supports-tool-choice-auto: bool, + /// Whether the model supports the "any" tool choice. + supports-tool-choice-any: bool, + /// Whether the model supports the "none" tool choice. + supports-tool-choice-none: bool, + /// Whether the model supports extended thinking/reasoning. + supports-thinking: bool, + /// The format for tool input schemas. + tool-input-format: tool-input-format, + } + + /// Format for tool input schemas. + enum tool-input-format { + /// Standard JSON Schema format. + json-schema, + /// Simplified schema format for certain providers. + simplified, + } + + /// Information about a specific model. + record model-info { + /// Unique identifier for the model. + id: string, + /// Display name for the model. + name: string, + /// Maximum input token count. + max-token-count: u64, + /// Maximum output tokens (optional). + max-output-tokens: option, + /// Model capabilities. + capabilities: model-capabilities, + /// Whether this is the default model for the provider. + is-default: bool, + /// Whether this is the default fast model. + is-default-fast: bool, + } + + /// The role of a message participant. + enum message-role { + /// User message. + user, + /// Assistant message. + assistant, + /// System message. + system, + } + + /// A message in a completion request. + record request-message { + /// The role of the message sender. + role: message-role, + /// The content of the message. + content: list, + /// Whether to cache this message for prompt caching. + cache: bool, + } + + /// Content within a message. + variant message-content { + /// Plain text content. + text(string), + /// Image content. + image(image-data), + /// A tool use request from the assistant. + tool-use(tool-use), + /// A tool result from the user. + tool-result(tool-result), + /// Thinking/reasoning content. + thinking(thinking-content), + /// Redacted/encrypted thinking content. + redacted-thinking(string), + } + + /// Image data for vision models. + record image-data { + /// Base64-encoded image data. + source: string, + /// Image width in pixels (optional). + width: option, + /// Image height in pixels (optional). + height: option, + } + + /// A tool use request from the model. + record tool-use { + /// Unique identifier for this tool use. + id: string, + /// The name of the tool being used. + name: string, + /// JSON string of the tool input arguments. + input: string, + /// Thought signature for providers that support it (e.g., Anthropic). + thought-signature: option, + } + + /// A tool result to send back to the model. + record tool-result { + /// The ID of the tool use this is a result for. + tool-use-id: string, + /// The name of the tool. + tool-name: string, + /// Whether this result represents an error. + is-error: bool, + /// The content of the result. + content: tool-result-content, + } + + /// Content of a tool result. + variant tool-result-content { + /// Text result. + text(string), + /// Image result. + image(image-data), + } + + /// Thinking/reasoning content from models that support extended thinking. + record thinking-content { + /// The thinking text. + text: string, + /// Signature for the thinking block (provider-specific). + signature: option, + } + + /// A tool definition for function calling. + record tool-definition { + /// The name of the tool. + name: string, + /// Description of what the tool does. + description: string, + /// JSON Schema for input parameters. + input-schema: string, + } + + /// Tool choice preference for the model. + enum tool-choice { + /// Let the model decide whether to use tools. + auto, + /// Force the model to use at least one tool. + any, + /// Prevent the model from using tools. + none, + } + + /// A completion request to send to the model. + record completion-request { + /// The messages in the conversation. + messages: list, + /// Available tools for the model to use. + tools: list, + /// Tool choice preference. + tool-choice: option, + /// Stop sequences to end generation. + stop-sequences: list, + /// Temperature for sampling (0.0-1.0). + temperature: option, + /// Whether thinking/reasoning is allowed. + thinking-allowed: bool, + /// Maximum tokens to generate. + max-tokens: option, + } + + /// Events emitted during completion streaming. + variant completion-event { + /// Completion has started. + started, + /// Text content chunk. + text(string), + /// Thinking/reasoning content chunk. + thinking(thinking-content), + /// Redacted thinking (encrypted) chunk. + redacted-thinking(string), + /// Tool use request from the model. + tool-use(tool-use), + /// JSON parse error when parsing tool input. + tool-use-json-parse-error(tool-use-json-parse-error), + /// Completion stopped. + stop(stop-reason), + /// Token usage update. + usage(token-usage), + /// Reasoning details (provider-specific JSON). + reasoning-details(string), + } + + /// Error information when tool use JSON parsing fails. + record tool-use-json-parse-error { + /// The tool use ID. + id: string, + /// The tool name. + tool-name: string, + /// The raw input that failed to parse. + raw-input: string, + /// The parse error message. + error: string, + } + + /// Reason the completion stopped. + enum stop-reason { + /// The model finished generating. + end-turn, + /// Maximum tokens reached. + max-tokens, + /// The model wants to use a tool. + tool-use, + /// The model refused to respond. + refusal, + } + + /// Token usage statistics. + record token-usage { + /// Number of input tokens used. + input-tokens: u64, + /// Number of output tokens generated. + output-tokens: u64, + /// Tokens used for cache creation (if supported). + cache-creation-input-tokens: option, + /// Tokens read from cache (if supported). + cache-read-input-tokens: option, + } + + /// Credential types that can be requested. + enum credential-type { + /// An API key. + api-key, + /// An OAuth token. + oauth-token, + } + + /// Cache configuration for prompt caching. + record cache-configuration { + /// Maximum number of cache anchors. + max-cache-anchors: u32, + /// Whether caching should be applied to tool definitions. + should-cache-tool-definitions: bool, + /// Minimum token count for a message to be cached. + min-total-token-count: u64, + } + + /// Configuration for starting an OAuth web authentication flow. + record oauth-web-auth-config { + /// The URL to open in the user's browser to start authentication. + /// This should include client_id, redirect_uri, scope, state, etc. + /// Use `{port}` as a placeholder in the URL - it will be replaced with + /// the actual localhost port before opening the browser. + /// Example: "https://example.com/oauth?redirect_uri=http://127.0.0.1:{port}/callback" + auth-url: string, + /// The path to listen on for the OAuth callback (e.g., "/callback"). + /// A localhost server will be started to receive the redirect. + callback-path: string, + /// Timeout in seconds to wait for the callback (default: 300 = 5 minutes). + timeout-secs: option, + } + + /// Result of an OAuth web authentication flow. + record oauth-web-auth-result { + /// The full callback URL that was received, including query parameters. + /// The extension is responsible for parsing the code, state, etc. + callback-url: string, + /// The port that was used for the localhost callback server. + port: u32, + } + + /// A generic HTTP request for OAuth token exchange. + record oauth-http-request { + /// The URL to request. + url: string, + /// HTTP method (e.g., "POST", "GET"). + method: string, + /// Request headers as key-value pairs. + headers: list>, + /// Request body as a string (for form-encoded or JSON bodies). + body: string, + } + + /// Response from an OAuth HTTP request. + record oauth-http-response { + /// HTTP status code. + status: u16, + /// Response headers as key-value pairs. + headers: list>, + /// Response body as a string. + body: string, + } + + /// Request a credential from the user. + /// Returns true if the credential was provided, false if the user cancelled. + request-credential: func( + provider-id: string, + credential-type: credential-type, + label: string, + placeholder: string + ) -> result; + + /// Get a stored credential for this provider. + get-credential: func(provider-id: string) -> option; + + /// Store a credential for this provider. + store-credential: func(provider-id: string, value: string) -> result<_, string>; + + /// Delete a stored credential for this provider. + delete-credential: func(provider-id: string) -> result<_, string>; + + /// Read an environment variable. + get-env-var: func(name: string) -> option; + + /// Start an OAuth web authentication flow. + /// + /// This will: + /// 1. Start a localhost server to receive the OAuth callback + /// 2. Open the auth URL in the user's default browser + /// 3. Wait for the callback (up to the timeout) + /// 4. Return the callback URL with query parameters + /// + /// The extension is responsible for: + /// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc. + /// - Parsing the callback URL to extract the authorization code + /// - Exchanging the code for tokens using oauth-http-request + oauth-start-web-auth: func(config: oauth-web-auth-config) -> result; + + /// Make an HTTP request for OAuth token exchange. + /// + /// This is a simple HTTP client for OAuth flows, allowing the extension + /// to handle token exchange with full control over serialization. + send-oauth-http-request: func(request: oauth-http-request) -> result; + + /// Open a URL in the user's default browser. + /// + /// Useful for OAuth flows that need to open a browser but handle the + /// callback differently (e.g., polling-based flows). + oauth-open-browser: func(url: string) -> result<_, string>; +} diff --git a/crates/extension_cli/src/main.rs b/crates/extension_cli/src/main.rs index 524e14b0cedceb..24eb696b1dcba8 100644 --- a/crates/extension_cli/src/main.rs +++ b/crates/extension_cli/src/main.rs @@ -254,6 +254,21 @@ async fn copy_extension_resources( } } + for (_, provider_entry) in &manifest.language_model_providers { + if let Some(icon_path) = &provider_entry.icon { + let source_icon = extension_path.join(icon_path); + let dest_icon = output_dir.join(icon_path); + + // Create parent directory if needed + if let Some(parent) = dest_icon.parent() { + fs::create_dir_all(parent)?; + } + + fs::copy(&source_icon, &dest_icon) + .with_context(|| format!("failed to copy LLM provider icon '{}'", icon_path))?; + } + } + if !manifest.languages.is_empty() { let output_languages_dir = output_dir.join("languages"); fs::create_dir_all(&output_languages_dir)?; diff --git a/crates/extension_host/Cargo.toml b/crates/extension_host/Cargo.toml index 328b808b1310e3..0f3d1eefee9e04 100644 --- a/crates/extension_host/Cargo.toml +++ b/crates/extension_host/Cargo.toml @@ -22,7 +22,10 @@ async-tar.workspace = true async-trait.workspace = true client.workspace = true collections.workspace = true +credentials_provider.workspace = true dap.workspace = true +dirs.workspace = true +editor.workspace = true extension.workspace = true fs.workspace = true futures.workspace = true @@ -30,8 +33,11 @@ gpui.workspace = true gpui_tokio.workspace = true http_client.workspace = true language.workspace = true +language_model.workspace = true log.workspace = true +markdown.workspace = true lsp.workspace = true +menu.workspace = true moka.workspace = true node_runtime.workspace = true paths.workspace = true @@ -43,10 +49,13 @@ serde.workspace = true serde_json.workspace = true serde_json_lenient.workspace = true settings.workspace = true +smol.workspace = true task.workspace = true telemetry.workspace = true tempfile.workspace = true +theme.workspace = true toml.workspace = true +ui.workspace = true url.workspace = true util.workspace = true wasmparser.workspace = true diff --git a/crates/extension_host/benches/extension_compilation_benchmark.rs b/crates/extension_host/benches/extension_compilation_benchmark.rs index c3459cf116b551..2a77be06a88d5d 100644 --- a/crates/extension_host/benches/extension_compilation_benchmark.rs +++ b/crates/extension_host/benches/extension_compilation_benchmark.rs @@ -143,6 +143,7 @@ fn manifest() -> ExtensionManifest { )], debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), } } diff --git a/crates/extension_host/src/capability_granter.rs b/crates/extension_host/src/capability_granter.rs index 9f27b5e480bc3c..6278deef0a7d41 100644 --- a/crates/extension_host/src/capability_granter.rs +++ b/crates/extension_host/src/capability_granter.rs @@ -113,6 +113,7 @@ mod tests { capabilities: vec![], debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), } } diff --git a/crates/extension_host/src/copilot_migration.rs b/crates/extension_host/src/copilot_migration.rs new file mode 100644 index 00000000000000..27d78f6db3ed0c --- /dev/null +++ b/crates/extension_host/src/copilot_migration.rs @@ -0,0 +1,161 @@ +use credentials_provider::CredentialsProvider; +use gpui::App; +use std::path::PathBuf; + +const COPILOT_CHAT_EXTENSION_ID: &str = "copilot-chat"; +const COPILOT_CHAT_PROVIDER_ID: &str = "copilot-chat"; + +pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) { + if extension_id != COPILOT_CHAT_EXTENSION_ID { + return; + } + + let credential_key = format!( + "extension-llm-{}:{}", + COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID + ); + + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + let existing_credential = credentials_provider + .read_credentials(&credential_key, &cx) + .await + .ok() + .flatten(); + + if existing_credential.is_some() { + log::debug!("Copilot Chat extension already has credentials, skipping migration"); + return; + } + + let oauth_token = match read_copilot_oauth_token().await { + Some(token) => token, + None => { + log::debug!("No existing Copilot OAuth token found to migrate"); + return; + } + }; + + log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension"); + + match credentials_provider + .write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &cx) + .await + { + Ok(()) => { + log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension"); + } + Err(err) => { + log::error!("Failed to migrate Copilot OAuth token: {}", err); + } + } + }) + .detach(); +} + +async fn read_copilot_oauth_token() -> Option { + let config_paths = copilot_config_paths(); + + for path in config_paths { + if let Some(token) = read_oauth_token_from_file(&path).await { + return Some(token); + } + } + + None +} + +fn copilot_config_paths() -> Vec { + let config_dir = if cfg!(target_os = "windows") { + dirs::data_local_dir() + } else { + std::env::var("XDG_CONFIG_HOME") + .map(PathBuf::from) + .ok() + .or_else(|| dirs::home_dir().map(|h| h.join(".config"))) + }; + + let Some(config_dir) = config_dir else { + return Vec::new(); + }; + + let copilot_dir = config_dir.join("github-copilot"); + + vec![ + copilot_dir.join("hosts.json"), + copilot_dir.join("apps.json"), + ] +} + +async fn read_oauth_token_from_file(path: &PathBuf) -> Option { + let contents = match smol::fs::read_to_string(path).await { + Ok(contents) => contents, + Err(_) => return None, + }; + + extract_oauth_token(&contents, "github.com") +} + +fn extract_oauth_token(contents: &str, domain: &str) -> Option { + let value: serde_json::Value = serde_json::from_str(contents).ok()?; + let obj = value.as_object()?; + + for (key, value) in obj.iter() { + if key.starts_with(domain) { + if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) { + return Some(token.to_string()); + } + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_oauth_token() { + let contents = r#"{ + "github.com": { + "oauth_token": "ghu_test_token_12345" + } + }"#; + + let token = extract_oauth_token(contents, "github.com"); + assert_eq!(token, Some("ghu_test_token_12345".to_string())); + } + + #[test] + fn test_extract_oauth_token_with_prefix() { + let contents = r#"{ + "github.com:user": { + "oauth_token": "ghu_another_token" + } + }"#; + + let token = extract_oauth_token(contents, "github.com"); + assert_eq!(token, Some("ghu_another_token".to_string())); + } + + #[test] + fn test_extract_oauth_token_missing() { + let contents = r#"{ + "gitlab.com": { + "oauth_token": "some_token" + } + }"#; + + let token = extract_oauth_token(contents, "github.com"); + assert_eq!(token, None); + } + + #[test] + fn test_extract_oauth_token_invalid_json() { + let contents = "not valid json"; + let token = extract_oauth_token(contents, "github.com"); + assert_eq!(token, None); + } +} diff --git a/crates/extension_host/src/extension_host.rs b/crates/extension_host/src/extension_host.rs index c1c598f1895f68..ea6d52418fe693 100644 --- a/crates/extension_host/src/extension_host.rs +++ b/crates/extension_host/src/extension_host.rs @@ -1,4 +1,5 @@ mod capability_granter; +mod copilot_migration; pub mod extension_settings; pub mod headless_host; pub mod wasm_host; @@ -16,9 +17,9 @@ pub use extension::ExtensionManifest; use extension::extension_builder::{CompileExtensionOptions, ExtensionBuilder}; use extension::{ ExtensionContextServerProxy, ExtensionDebugAdapterProviderProxy, ExtensionEvents, - ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageProxy, - ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, ExtensionSnippetProxy, - ExtensionThemeProxy, + ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageModelProviderProxy, + ExtensionLanguageProxy, ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, + ExtensionSnippetProxy, ExtensionThemeProxy, }; use fs::{Fs, RemoveOptions}; use futures::future::join_all; @@ -32,8 +33,8 @@ use futures::{ select_biased, }; use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task, WeakEntity, - actions, + App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, SharedString, Task, + WeakEntity, actions, }; use http_client::{AsyncBody, HttpClient, HttpClientWithUrl}; use language::{ @@ -57,11 +58,20 @@ use std::{ }; use url::Url; use util::{ResultExt, paths::RemotePathBuf}; +use wasm_host::llm_provider::ExtensionLanguageModelProvider; use wasm_host::{ WasmExtension, WasmHost, - wit::{is_supported_wasm_api_version, wasm_api_version_range}, + wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range}, }; +struct LlmProviderWithModels { + provider_info: LlmProviderInfo, + models: Vec, + is_authenticated: bool, + icon_path: Option, + auth_config: Option, +} + pub use extension::{ ExtensionLibraryKind, GrammarManifestEntry, OldExtensionManifest, SchemaVersion, }; @@ -70,6 +80,98 @@ pub use extension_settings::ExtensionSettings; pub const RELOAD_DEBOUNCE_DURATION: Duration = Duration::from_millis(200); const FS_WATCH_LATENCY: Duration = Duration::from_millis(100); +/// Extension IDs that are being migrated from hardcoded LLM providers. +/// For backwards compatibility, if the user has the corresponding env var set, +/// we automatically enable env var reading for these extensions on first install. +const LEGACY_LLM_EXTENSION_IDS: &[&str] = &[ + "anthropic", + "copilot_chat", + "google-ai", + "open_router", + "openai", +]; + +/// Migrates legacy LLM provider extensions by auto-enabling env var reading +/// if the env var is currently present in the environment. +/// +/// This migration only runs once per provider - we track which providers have been +/// migrated in `migrated_llm_providers` to avoid overriding user preferences. +fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut App) { + // Only apply migration to known legacy LLM extensions + if !LEGACY_LLM_EXTENSION_IDS.contains(&manifest.id.as_ref()) { + return; + } + + // Check each provider in the manifest + for (provider_id, provider_entry) in &manifest.language_model_providers { + let Some(auth_config) = &provider_entry.auth else { + continue; + }; + let Some(env_var_name) = &auth_config.env_var else { + continue; + }; + + let full_provider_id: Arc = format!("{}:{}", manifest.id, provider_id).into(); + + // Check if we've already run migration for this provider (regardless of outcome) + let already_migrated = ExtensionSettings::get_global(cx) + .migrated_llm_providers + .contains(full_provider_id.as_ref()); + + if already_migrated { + continue; + } + + // Check if the env var is present and non-empty + let env_var_is_set = std::env::var(env_var_name) + .map(|v| !v.is_empty()) + .unwrap_or(false); + + // Mark as migrated regardless of whether we enable env var reading + settings::update_settings_file(::global(cx), cx, { + let full_provider_id = full_provider_id.clone(); + let env_var_is_set = env_var_is_set; + move |settings, _| { + // Always mark as migrated + let migrated = settings + .extension + .migrated_llm_providers + .get_or_insert_with(Vec::new); + + if !migrated + .iter() + .any(|id| id.as_ref() == full_provider_id.as_ref()) + { + migrated.push(full_provider_id.clone()); + } + + // Only enable env var reading if the env var is set + if env_var_is_set { + let providers = settings + .extension + .allowed_env_var_providers + .get_or_insert_with(Vec::new); + + if !providers + .iter() + .any(|id| id.as_ref() == full_provider_id.as_ref()) + { + providers.push(full_provider_id); + } + } + } + }); + + if env_var_is_set { + log::info!( + "Migrating legacy LLM provider {}: auto-enabling {} env var reading", + full_provider_id, + env_var_name + ); + } + } +} + /// The current extension [`SchemaVersion`] supported by Zed. const CURRENT_SCHEMA_VERSION: SchemaVersion = SchemaVersion(1); @@ -771,6 +873,11 @@ impl ExtensionStore { if let ExtensionOperation::Install = operation { this.update(cx, |this, cx| { + // Check for legacy LLM provider migration + if let Some(manifest) = this.extension_manifest_for_id(&extension_id) { + migrate_legacy_llm_provider_env_var(&manifest, cx); + } + cx.emit(Event::ExtensionInstalled(extension_id.clone())); if let Some(events) = ExtensionEvents::try_global(cx) && let Some(manifest) = this.extension_manifest_for_id(&extension_id) @@ -779,6 +886,9 @@ impl ExtensionStore { this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx) }); } + + // Run extension-specific migrations + copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx); }) .ok(); } @@ -1217,6 +1327,11 @@ impl ExtensionStore { for command_name in extension.manifest.slash_commands.keys() { self.proxy.unregister_slash_command(command_name.clone()); } + for provider_id in extension.manifest.language_model_providers.keys() { + let full_provider_id: Arc = format!("{}:{}", extension_id, provider_id).into(); + self.proxy + .unregister_language_model_provider(full_provider_id, cx); + } } self.wasm_extensions @@ -1355,7 +1470,11 @@ impl ExtensionStore { }) .await; - let mut wasm_extensions = Vec::new(); + let mut wasm_extensions: Vec<( + Arc, + WasmExtension, + Vec, + )> = Vec::new(); for extension in extension_entries { if extension.manifest.lib.kind.is_none() { continue; @@ -1373,7 +1492,122 @@ impl ExtensionStore { match wasm_extension { Ok(wasm_extension) => { - wasm_extensions.push((extension.manifest.clone(), wasm_extension)) + // Query for LLM providers if the manifest declares any + let mut llm_providers_with_models = Vec::new(); + if !extension.manifest.language_model_providers.is_empty() { + let providers_result = wasm_extension + .call(|ext, store| { + async move { ext.call_llm_providers(store).await }.boxed() + }) + .await; + + if let Ok(Ok(providers)) = providers_result { + for provider_info in providers { + let models_result = wasm_extension + .call({ + let provider_id = provider_info.id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_models(store, &provider_id) + .await + } + .boxed() + } + }) + .await; + + let models: Vec = match models_result { + Ok(Ok(Ok(models))) => models, + Ok(Ok(Err(e))) => { + log::error!( + "Failed to get models for LLM provider {} in extension {}: {}", + provider_info.id, + extension.manifest.id, + e + ); + Vec::new() + } + Ok(Err(e)) => { + log::error!( + "Wasm error calling llm_provider_models for {} in extension {}: {:?}", + provider_info.id, + extension.manifest.id, + e + ); + Vec::new() + } + Err(e) => { + log::error!( + "Extension call failed for llm_provider_models {} in extension {}: {:?}", + provider_info.id, + extension.manifest.id, + e + ); + Vec::new() + } + }; + + // Query initial authentication state + let is_authenticated = wasm_extension + .call({ + let provider_id = provider_info.id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_is_authenticated( + store, + &provider_id, + ) + .await + } + .boxed() + } + }) + .await + .unwrap_or(Ok(false)) + .unwrap_or(false); + + // Resolve icon path if provided + let icon_path = provider_info.icon.as_ref().map(|icon| { + let icon_file_path = extension_path.join(icon); + // Canonicalize to resolve symlinks (dev extensions are symlinked) + let absolute_icon_path = icon_file_path + .canonicalize() + .unwrap_or(icon_file_path) + .to_string_lossy() + .to_string(); + SharedString::from(absolute_icon_path) + }); + + let provider_id_arc: Arc = + provider_info.id.as_str().into(); + let auth_config = extension + .manifest + .language_model_providers + .get(&provider_id_arc) + .and_then(|entry| entry.auth.clone()); + + llm_providers_with_models.push(LlmProviderWithModels { + provider_info, + models, + is_authenticated, + icon_path, + auth_config, + }); + } + } else { + log::error!( + "Failed to get LLM providers from extension {}: {:?}", + extension.manifest.id, + providers_result + ); + } + } + + wasm_extensions.push(( + extension.manifest.clone(), + wasm_extension, + llm_providers_with_models, + )) } Err(e) => { log::error!( @@ -1392,7 +1626,7 @@ impl ExtensionStore { this.update(cx, |this, cx| { this.reload_complete_senders.clear(); - for (manifest, wasm_extension) in &wasm_extensions { + for (manifest, wasm_extension, llm_providers_with_models) in &wasm_extensions { let extension = Arc::new(wasm_extension.clone()); for (language_server_id, language_server_config) in &manifest.language_servers { @@ -1446,9 +1680,41 @@ impl ExtensionStore { this.proxy .register_debug_locator(extension.clone(), debug_adapter.clone()); } + + // Register LLM providers + for llm_provider in llm_providers_with_models { + let provider_id: Arc = + format!("{}:{}", manifest.id, llm_provider.provider_info.id).into(); + let wasm_ext = extension.as_ref().clone(); + let pinfo = llm_provider.provider_info.clone(); + let mods = llm_provider.models.clone(); + let auth = llm_provider.is_authenticated; + let icon = llm_provider.icon_path.clone(); + let auth_config = llm_provider.auth_config.clone(); + + this.proxy.register_language_model_provider( + provider_id.clone(), + Box::new(move |cx: &mut App| { + let provider = Arc::new(ExtensionLanguageModelProvider::new( + wasm_ext, pinfo, mods, auth, icon, auth_config, cx, + )); + language_model::LanguageModelRegistry::global(cx).update( + cx, + |registry, cx| { + registry.register_provider(provider, cx); + }, + ); + }), + cx, + ); + } } - this.wasm_extensions.extend(wasm_extensions); + let wasm_extensions_without_llm: Vec<_> = wasm_extensions + .into_iter() + .map(|(manifest, ext, _)| (manifest, ext)) + .collect(); + this.wasm_extensions.extend(wasm_extensions_without_llm); this.proxy.set_extensions_loaded(); this.proxy.reload_current_theme(cx); this.proxy.reload_current_icon_theme(cx); diff --git a/crates/extension_host/src/extension_settings.rs b/crates/extension_host/src/extension_settings.rs index 736dd6b87ae53a..36777cb1727c95 100644 --- a/crates/extension_host/src/extension_settings.rs +++ b/crates/extension_host/src/extension_settings.rs @@ -1,4 +1,4 @@ -use collections::HashMap; +use collections::{HashMap, HashSet}; use extension::{ DownloadFileCapability, ExtensionCapability, NpmInstallPackageCapability, ProcessExecCapability, }; @@ -16,6 +16,13 @@ pub struct ExtensionSettings { pub auto_install_extensions: HashMap, bool>, pub auto_update_extensions: HashMap, bool>, pub granted_capabilities: Vec, + /// The extension language model providers that are allowed to read API keys + /// from environment variables. Each entry is a provider ID in the format + /// "extension_id:provider_id". + pub allowed_env_var_providers: HashSet>, + /// Tracks which legacy LLM providers have been migrated. + /// This prevents the migration from running multiple times and overriding user preferences. + pub migrated_llm_providers: HashSet>, } impl ExtensionSettings { @@ -60,6 +67,20 @@ impl Settings for ExtensionSettings { } }) .collect(), + allowed_env_var_providers: content + .extension + .allowed_env_var_providers + .clone() + .unwrap_or_default() + .into_iter() + .collect(), + migrated_llm_providers: content + .extension + .migrated_llm_providers + .clone() + .unwrap_or_default() + .into_iter() + .collect(), } } } diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index 85a3a720ce8c62..b3275ff52ff7fe 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -165,6 +165,7 @@ async fn test_extension_store(cx: &mut TestAppContext) { capabilities: Vec::new(), debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), }), dev: false, }, @@ -196,6 +197,7 @@ async fn test_extension_store(cx: &mut TestAppContext) { capabilities: Vec::new(), debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), }), dev: false, }, @@ -376,6 +378,7 @@ async fn test_extension_store(cx: &mut TestAppContext) { capabilities: Vec::new(), debug_adapters: Default::default(), debug_locators: Default::default(), + language_model_providers: BTreeMap::default(), }), dev: false, }, diff --git a/crates/extension_host/src/wasm_host.rs b/crates/extension_host/src/wasm_host.rs index cecaf2039bc6dc..5194cafec2601d 100644 --- a/crates/extension_host/src/wasm_host.rs +++ b/crates/extension_host/src/wasm_host.rs @@ -1,9 +1,11 @@ +pub mod llm_provider; pub mod wit; use crate::capability_granter::CapabilityGranter; use crate::{ExtensionManifest, ExtensionSettings}; use anyhow::{Context as _, Result, anyhow, bail}; use async_trait::async_trait; + use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest}; use extension::{ CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary, @@ -64,7 +66,7 @@ pub struct WasmHost { #[derive(Clone, Debug)] pub struct WasmExtension { - tx: UnboundedSender, + tx: Arc>, pub manifest: Arc, pub work_dir: Arc, #[allow(unused)] @@ -74,7 +76,10 @@ pub struct WasmExtension { impl Drop for WasmExtension { fn drop(&mut self) { - self.tx.close_channel(); + // Only close the channel when this is the last clone holding the sender + if Arc::strong_count(&self.tx) == 1 { + self.tx.close_channel(); + } } } @@ -671,7 +676,7 @@ impl WasmHost { Ok(WasmExtension { manifest, work_dir, - tx, + tx: Arc::new(tx), zed_api_version, _task: task, }) diff --git a/crates/extension_host/src/wasm_host/llm_provider.rs b/crates/extension_host/src/wasm_host/llm_provider.rs new file mode 100644 index 00000000000000..acec25b8258c16 --- /dev/null +++ b/crates/extension_host/src/wasm_host/llm_provider.rs @@ -0,0 +1,1425 @@ +use crate::ExtensionSettings; +use crate::wasm_host::WasmExtension; + +use crate::wasm_host::wit::{ + LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole, + LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent, + LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent, + LlmToolUse, +}; +use anyhow::{Result, anyhow}; +use credentials_provider::CredentialsProvider; +use editor::Editor; +use extension::{LanguageModelAuthConfig, OAuthConfig}; +use futures::future::BoxFuture; +use futures::stream::BoxStream; +use futures::{FutureExt, StreamExt}; +use gpui::Focusable; +use gpui::{ + AnyView, App, AppContext as _, AsyncApp, ClipboardItem, Context, Entity, EventEmitter, + MouseButton, Subscription, Task, TextStyleRefinement, UnderlineStyle, Window, px, +}; +use language_model::tool_schema::LanguageModelToolSchemaFormat; +use language_model::{ + AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, + LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, + LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, + LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, StopReason, TokenUsage, +}; +use markdown::{Markdown, MarkdownElement, MarkdownStyle}; +use settings::Settings; +use std::sync::Arc; +use theme::ThemeSettings; +use ui::{Label, LabelSize, prelude::*}; +use util::ResultExt as _; + +/// An extension-based language model provider. +pub struct ExtensionLanguageModelProvider { + pub extension: WasmExtension, + pub provider_info: LlmProviderInfo, + icon_path: Option, + auth_config: Option, + state: Entity, +} + +pub struct ExtensionLlmProviderState { + is_authenticated: bool, + available_models: Vec, + env_var_allowed: bool, + api_key_from_env: bool, +} + +impl EventEmitter<()> for ExtensionLlmProviderState {} + +impl ExtensionLanguageModelProvider { + pub fn new( + extension: WasmExtension, + provider_info: LlmProviderInfo, + models: Vec, + is_authenticated: bool, + icon_path: Option, + auth_config: Option, + cx: &mut App, + ) -> Self { + let provider_id_string = format!("{}:{}", extension.manifest.id, provider_info.id); + let env_var_allowed = ExtensionSettings::get_global(cx) + .allowed_env_var_providers + .contains(provider_id_string.as_str()); + + let (is_authenticated, api_key_from_env) = + if env_var_allowed && auth_config.as_ref().is_some_and(|c| c.env_var.is_some()) { + let env_var_name = auth_config.as_ref().unwrap().env_var.as_ref().unwrap(); + if let Ok(value) = std::env::var(env_var_name) { + if !value.is_empty() { + (true, true) + } else { + (is_authenticated, false) + } + } else { + (is_authenticated, false) + } + } else { + (is_authenticated, false) + }; + + let state = cx.new(|_| ExtensionLlmProviderState { + is_authenticated, + available_models: models, + env_var_allowed, + api_key_from_env, + }); + + Self { + extension, + provider_info, + icon_path, + auth_config, + state, + } + } + + fn provider_id_string(&self) -> String { + format!("{}:{}", self.extension.manifest.id, self.provider_info.id) + } + + /// The credential key used for storing the API key in the system keychain. + fn credential_key(&self) -> String { + format!("extension-llm-{}", self.provider_id_string()) + } +} + +impl LanguageModelProvider for ExtensionLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId::from(self.provider_id_string()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName::from(self.provider_info.name.clone()) + } + + fn icon(&self) -> ui::IconName { + ui::IconName::ZedAssistant + } + + fn icon_path(&self) -> Option { + self.icon_path.clone() + } + + fn default_model(&self, cx: &App) -> Option> { + let state = self.state.read(cx); + state + .available_models + .iter() + .find(|m| m.is_default) + .or_else(|| state.available_models.first()) + .map(|model_info| { + Arc::new(ExtensionLanguageModel { + extension: self.extension.clone(), + model_info: model_info.clone(), + provider_id: self.id(), + provider_name: self.name(), + provider_info: self.provider_info.clone(), + }) as Arc + }) + } + + fn default_fast_model(&self, cx: &App) -> Option> { + let state = self.state.read(cx); + state + .available_models + .iter() + .find(|m| m.is_default_fast) + .map(|model_info| { + Arc::new(ExtensionLanguageModel { + extension: self.extension.clone(), + model_info: model_info.clone(), + provider_id: self.id(), + provider_name: self.name(), + provider_info: self.provider_info.clone(), + }) as Arc + }) + } + + fn provided_models(&self, cx: &App) -> Vec> { + let state = self.state.read(cx); + state + .available_models + .iter() + .map(|model_info| { + Arc::new(ExtensionLanguageModel { + extension: self.extension.clone(), + model_info: model_info.clone(), + provider_id: self.id(), + provider_name: self.name(), + provider_info: self.provider_info.clone(), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &App) -> bool { + self.state.read(cx).is_authenticated + } + + fn authenticate(&self, cx: &mut App) -> Task> { + let extension = self.extension.clone(); + let provider_id = self.provider_info.id.clone(); + let state = self.state.clone(); + + cx.spawn(async move |cx| { + let result = extension + .call(|extension, store| { + async move { + extension + .call_llm_provider_authenticate(store, &provider_id) + .await + } + .boxed() + }) + .await; + + match result { + Ok(Ok(Ok(()))) => { + cx.update(|cx| { + state.update(cx, |state, _| { + state.is_authenticated = true; + }); + })?; + Ok(()) + } + Ok(Ok(Err(e))) => Err(AuthenticateError::Other(anyhow!("{}", e))), + Ok(Err(e)) => Err(AuthenticateError::Other(e)), + Err(e) => Err(AuthenticateError::Other(e)), + } + }) + } + + fn configuration_view( + &self, + _target_agent: ConfigurationViewTargetAgent, + window: &mut Window, + cx: &mut App, + ) -> AnyView { + let credential_key = self.credential_key(); + let extension = self.extension.clone(); + let extension_provider_id = self.provider_info.id.clone(); + let full_provider_id = self.provider_id_string(); + let state = self.state.clone(); + let auth_config = self.auth_config.clone(); + + cx.new(|cx| { + ExtensionProviderConfigurationView::new( + credential_key, + extension, + extension_provider_id, + full_provider_id, + auth_config, + state, + window, + cx, + ) + }) + .into() + } + + fn reset_credentials(&self, cx: &mut App) -> Task> { + let extension = self.extension.clone(); + let provider_id = self.provider_info.id.clone(); + let state = self.state.clone(); + let credential_key = self.credential_key(); + + let credentials_provider = ::global(cx); + + cx.spawn(async move |cx| { + // Delete from system keychain + credentials_provider + .delete_credentials(&credential_key, cx) + .await + .log_err(); + + // Call extension's reset_credentials + let result = extension + .call(|extension, store| { + async move { + extension + .call_llm_provider_reset_credentials(store, &provider_id) + .await + } + .boxed() + }) + .await; + + // Update state + cx.update(|cx| { + state.update(cx, |state, _| { + state.is_authenticated = false; + }); + })?; + + match result { + Ok(Ok(Ok(()))) => Ok(()), + Ok(Ok(Err(e))) => Err(anyhow!("{}", e)), + Ok(Err(e)) => Err(e), + Err(e) => Err(e), + } + }) + } +} + +impl LanguageModelProviderState for ExtensionLanguageModelProvider { + type ObservableEntity = ExtensionLlmProviderState; + + fn observable_entity(&self) -> Option> { + Some(self.state.clone()) + } + + fn subscribe( + &self, + cx: &mut Context, + callback: impl Fn(&mut T, &mut Context) + 'static, + ) -> Option { + Some(cx.subscribe(&self.state, move |this, _, _, cx| callback(this, cx))) + } +} + +/// Configuration view for extension-based LLM providers. +struct ExtensionProviderConfigurationView { + credential_key: String, + extension: WasmExtension, + extension_provider_id: String, + full_provider_id: String, + auth_config: Option, + state: Entity, + settings_markdown: Option>, + api_key_editor: Entity, + loading_settings: bool, + loading_credentials: bool, + oauth_in_progress: bool, + oauth_error: Option, + device_user_code: Option, + _subscriptions: Vec, +} + +impl ExtensionProviderConfigurationView { + fn new( + credential_key: String, + extension: WasmExtension, + extension_provider_id: String, + full_provider_id: String, + auth_config: Option, + state: Entity, + window: &mut Window, + cx: &mut Context, + ) -> Self { + // Subscribe to state changes + let state_subscription = cx.subscribe(&state, |_, _, _, cx| { + cx.notify(); + }); + + // Create API key editor + let api_key_editor = cx.new(|cx| { + let mut editor = Editor::single_line(window, cx); + editor.set_placeholder_text("Enter API key...", window, cx); + editor + }); + + let mut this = Self { + credential_key, + extension, + extension_provider_id, + full_provider_id, + auth_config, + state, + settings_markdown: None, + api_key_editor, + loading_settings: true, + loading_credentials: true, + oauth_in_progress: false, + oauth_error: None, + device_user_code: None, + _subscriptions: vec![state_subscription], + }; + + // Load settings text from extension + this.load_settings_text(cx); + + // Load existing credentials + this.load_credentials(cx); + + this + } + + fn load_settings_text(&mut self, cx: &mut Context) { + let extension = self.extension.clone(); + let provider_id = self.extension_provider_id.clone(); + + cx.spawn(async move |this, cx| { + let result = extension + .call({ + let provider_id = provider_id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_settings_markdown(store, &provider_id) + .await + } + .boxed() + } + }) + .await; + + let settings_text = result.ok().and_then(|inner| inner.ok()).flatten(); + + this.update(cx, |this, cx| { + this.loading_settings = false; + if let Some(text) = settings_text { + let markdown = cx.new(|cx| Markdown::new(text.into(), None, None, cx)); + this.settings_markdown = Some(markdown); + } + cx.notify(); + }) + .log_err(); + }) + .detach(); + } + + fn load_credentials(&mut self, cx: &mut Context) { + let credential_key = self.credential_key.clone(); + let credentials_provider = ::global(cx); + let state = self.state.clone(); + + // Check if we should use env var (already set in state during provider construction) + let api_key_from_env = self.state.read(cx).api_key_from_env; + + cx.spawn(async move |this, cx| { + // If using env var, we're already authenticated + if api_key_from_env { + this.update(cx, |this, cx| { + this.loading_credentials = false; + cx.notify(); + }) + .log_err(); + return; + } + + let credentials = credentials_provider + .read_credentials(&credential_key, cx) + .await + .log_err() + .flatten(); + + let has_credentials = credentials.is_some(); + + // Update authentication state based on stored credentials + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = has_credentials; + cx.notify(); + }); + }); + + this.update(cx, |this, cx| { + this.loading_credentials = false; + cx.notify(); + }) + .log_err(); + }) + .detach(); + } + + fn toggle_env_var_permission(&mut self, cx: &mut Context) { + let full_provider_id: Arc = self.full_provider_id.clone().into(); + let env_var_name = match &self.auth_config { + Some(config) => config.env_var.clone(), + None => return, + }; + + let state = self.state.clone(); + let currently_allowed = self.state.read(cx).env_var_allowed; + + // Update settings file + settings::update_settings_file(::global(cx), cx, move |settings, _| { + let providers = settings + .extension + .allowed_env_var_providers + .get_or_insert_with(Vec::new); + + if currently_allowed { + providers.retain(|id| id.as_ref() != full_provider_id.as_ref()); + } else { + if !providers + .iter() + .any(|id| id.as_ref() == full_provider_id.as_ref()) + { + providers.push(full_provider_id.clone()); + } + } + }); + + // Update local state + let new_allowed = !currently_allowed; + let new_from_env = if new_allowed { + if let Some(var_name) = &env_var_name { + if let Ok(value) = std::env::var(var_name) { + !value.is_empty() + } else { + false + } + } else { + false + } + } else { + false + }; + + state.update(cx, |state, cx| { + state.env_var_allowed = new_allowed; + state.api_key_from_env = new_from_env; + if new_from_env { + state.is_authenticated = true; + } + cx.notify(); + }); + + // If env var is being enabled, clear any stored keychain credentials + // so there's only one source of truth for the API key + if new_allowed { + let credential_key = self.credential_key.clone(); + let credentials_provider = ::global(cx); + cx.spawn(async move |_this, cx| { + credentials_provider + .delete_credentials(&credential_key, cx) + .await + .log_err(); + }) + .detach(); + } + + // If env var is being disabled, reload credentials from keychain + if !new_allowed { + self.reload_keychain_credentials(cx); + } + + cx.notify(); + } + + fn reload_keychain_credentials(&mut self, cx: &mut Context) { + let credential_key = self.credential_key.clone(); + let credentials_provider = ::global(cx); + let state = self.state.clone(); + + cx.spawn(async move |_this, cx| { + let credentials = credentials_provider + .read_credentials(&credential_key, cx) + .await + .log_err() + .flatten(); + + let has_credentials = credentials.is_some(); + + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = has_credentials; + cx.notify(); + }); + }); + }) + .detach(); + } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self.api_key_editor.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + // Clear the editor + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let credential_key = self.credential_key.clone(); + let credentials_provider = ::global(cx); + let state = self.state.clone(); + + cx.spawn(async move |_this, cx| { + // Store in system keychain + credentials_provider + .write_credentials(&credential_key, "Bearer", api_key.as_bytes(), cx) + .await + .log_err(); + + // Update state to authenticated + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = true; + cx.notify(); + }); + }); + }) + .detach(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + // Clear the editor + self.api_key_editor + .update(cx, |editor, cx| editor.set_text("", window, cx)); + + let credential_key = self.credential_key.clone(); + let credentials_provider = ::global(cx); + let state = self.state.clone(); + + cx.spawn(async move |_this, cx| { + // Delete from system keychain + credentials_provider + .delete_credentials(&credential_key, cx) + .await + .log_err(); + + // Update state to unauthenticated + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = false; + cx.notify(); + }); + }); + }) + .detach(); + } + + fn start_oauth_sign_in(&mut self, cx: &mut Context) { + if self.oauth_in_progress { + return; + } + + self.oauth_in_progress = true; + self.oauth_error = None; + self.device_user_code = None; + cx.notify(); + + let extension = self.extension.clone(); + let provider_id = self.extension_provider_id.clone(); + let state = self.state.clone(); + + cx.spawn(async move |this, cx| { + // Step 1: Start device flow - opens browser and returns user code + let start_result = extension + .call({ + let provider_id = provider_id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_start_device_flow_sign_in(store, &provider_id) + .await + } + .boxed() + } + }) + .await; + + let user_code = match start_result { + Ok(Ok(Ok(code))) => code, + Ok(Ok(Err(e))) => { + log::error!("Device flow start failed: {}", e); + this.update(cx, |this, cx| { + this.oauth_in_progress = false; + this.oauth_error = Some(e); + cx.notify(); + }) + .log_err(); + return; + } + Ok(Err(e)) | Err(e) => { + log::error!("Device flow start error: {}", e); + this.update(cx, |this, cx| { + this.oauth_in_progress = false; + this.oauth_error = Some(e.to_string()); + cx.notify(); + }) + .log_err(); + return; + } + }; + + // Update UI to show the user code before polling + this.update(cx, |this, cx| { + this.device_user_code = Some(user_code); + cx.notify(); + }) + .log_err(); + + // Step 2: Poll for authentication completion + let poll_result = extension + .call({ + let provider_id = provider_id.clone(); + |ext, store| { + async move { + ext.call_llm_provider_poll_device_flow_sign_in(store, &provider_id) + .await + } + .boxed() + } + }) + .await; + + let error_message = match poll_result { + Ok(Ok(Ok(()))) => { + let _ = cx.update(|cx| { + state.update(cx, |state, cx| { + state.is_authenticated = true; + cx.notify(); + }); + }); + None + } + Ok(Ok(Err(e))) => { + log::error!("Device flow poll failed: {}", e); + Some(e) + } + Ok(Err(e)) | Err(e) => { + log::error!("Device flow poll error: {}", e); + Some(e.to_string()) + } + }; + + this.update(cx, |this, cx| { + this.oauth_in_progress = false; + this.oauth_error = error_message; + this.device_user_code = None; + cx.notify(); + }) + .log_err(); + }) + .detach(); + } + + fn is_authenticated(&self, cx: &Context) -> bool { + self.state.read(cx).is_authenticated + } + + fn has_oauth_config(&self) -> bool { + self.auth_config.as_ref().is_some_and(|c| c.oauth.is_some()) + } + + fn oauth_config(&self) -> Option<&OAuthConfig> { + self.auth_config.as_ref().and_then(|c| c.oauth.as_ref()) + } + + fn has_api_key_config(&self) -> bool { + // API key is available if there's a credential_label or no oauth-only config + self.auth_config + .as_ref() + .map(|c| c.credential_label.is_some() || c.oauth.is_none()) + .unwrap_or(true) + } +} + +impl gpui::Render for ExtensionProviderConfigurationView { + fn render(&mut self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { + let is_loading = self.loading_settings || self.loading_credentials; + let is_authenticated = self.is_authenticated(cx); + let env_var_allowed = self.state.read(cx).env_var_allowed; + let api_key_from_env = self.state.read(cx).api_key_from_env; + let has_oauth = self.has_oauth_config(); + let has_api_key = self.has_api_key_config(); + + if is_loading { + return v_flex() + .gap_2() + .child(Label::new("Loading...").color(Color::Muted)) + .into_any_element(); + } + + let mut content = v_flex().gap_4().size_full(); + + // Render settings markdown if available + if let Some(markdown) = &self.settings_markdown { + let style = settings_markdown_style(_window, cx); + content = content.child( + div() + .p_2() + .rounded_md() + .bg(cx.theme().colors().surface_background) + .child(MarkdownElement::new(markdown.clone(), style)), + ); + } + + // Render env var checkbox if the extension specifies an env var + if let Some(auth_config) = &self.auth_config { + if let Some(env_var_name) = &auth_config.env_var { + let env_var_name = env_var_name.clone(); + let checkbox_label = + format!("Read API key from {} environment variable", env_var_name); + + content = content.child( + h_flex() + .gap_2() + .child( + ui::Checkbox::new("env-var-permission", env_var_allowed.into()) + .on_click(cx.listener(|this, _, _window, cx| { + this.toggle_env_var_permission(cx); + })), + ) + .child(Label::new(checkbox_label).size(LabelSize::Small)), + ); + + // Show status if env var is allowed + if env_var_allowed { + if api_key_from_env { + content = content.child( + h_flex() + .gap_2() + .child( + ui::Icon::new(ui::IconName::Check) + .color(Color::Success) + .size(ui::IconSize::Small), + ) + .child( + Label::new(format!("API key loaded from {}", env_var_name)) + .color(Color::Success), + ), + ); + return content.into_any_element(); + } else { + content = content.child( + h_flex() + .gap_2() + .child( + ui::Icon::new(ui::IconName::Warning) + .color(Color::Warning) + .size(ui::IconSize::Small), + ) + .child( + Label::new(format!( + "{} is not set or empty. You can set it and restart Zed, or use another authentication method below.", + env_var_name + )) + .color(Color::Warning) + .size(LabelSize::Small), + ), + ); + } + } + } + } + + // If authenticated, show success state with sign out option + if is_authenticated && !api_key_from_env { + let reset_label = if has_oauth && !has_api_key { + "Sign Out" + } else { + "Reset Credentials" + }; + + let status_label = if has_oauth && !has_api_key { + "Signed in" + } else { + "Authenticated" + }; + + content = content.child( + v_flex() + .gap_2() + .child( + h_flex() + .gap_2() + .child( + ui::Icon::new(ui::IconName::Check) + .color(Color::Success) + .size(ui::IconSize::Small), + ) + .child(Label::new(status_label).color(Color::Success)), + ) + .child( + ui::Button::new("reset-credentials", reset_label) + .style(ui::ButtonStyle::Subtle) + .on_click(cx.listener(|this, _, window, cx| { + this.reset_api_key(window, cx); + })), + ), + ); + + return content.into_any_element(); + } + + // Not authenticated - show available auth options + if !api_key_from_env { + // Render OAuth sign-in button if configured + if has_oauth { + let oauth_config = self.oauth_config(); + let button_label = oauth_config + .and_then(|c| c.sign_in_button_label.clone()) + .unwrap_or_else(|| "Sign In".to_string()); + + let oauth_in_progress = self.oauth_in_progress; + + let oauth_error = self.oauth_error.clone(); + + content = content.child( + v_flex() + .gap_2() + .child( + ui::Button::new("oauth-sign-in", button_label) + .style(ui::ButtonStyle::Filled) + .disabled(oauth_in_progress) + .on_click(cx.listener(|this, _, _window, cx| { + this.start_oauth_sign_in(cx); + })), + ) + .when(oauth_in_progress, |this| { + let user_code = self.device_user_code.clone(); + this.child( + v_flex() + .gap_1() + .when_some(user_code, |this, code| { + let copied = cx + .read_from_clipboard() + .map(|item| item.text().as_ref() == Some(&code)) + .unwrap_or(false); + let code_for_click = code.clone(); + this.child( + h_flex() + .gap_1() + .child( + Label::new("Enter code:") + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child( + h_flex() + .gap_1() + .px_1() + .border_1() + .border_color(cx.theme().colors().border) + .rounded_sm() + .cursor_pointer() + .on_mouse_down( + MouseButton::Left, + move |_, window, cx| { + cx.write_to_clipboard( + ClipboardItem::new_string( + code_for_click.clone(), + ), + ); + window.refresh(); + }, + ) + .child( + Label::new(code) + .size(LabelSize::Small) + .color(Color::Accent), + ) + .child( + ui::Icon::new(if copied { + ui::IconName::Check + } else { + ui::IconName::Copy + }) + .size(ui::IconSize::Small) + .color(if copied { + Color::Success + } else { + Color::Muted + }), + ), + ), + ) + }) + .child( + Label::new("Waiting for authorization in browser...") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ) + }) + .when_some(oauth_error, |this, error| { + this.child( + v_flex() + .gap_1() + .child( + h_flex() + .gap_2() + .child( + ui::Icon::new(ui::IconName::Warning) + .color(Color::Error) + .size(ui::IconSize::Small), + ) + .child( + Label::new("Authentication failed") + .color(Color::Error) + .size(LabelSize::Small), + ), + ) + .child( + div().pl_6().child( + Label::new(error) + .color(Color::Error) + .size(LabelSize::Small), + ), + ), + ) + }), + ); + } + + // Render API key input if configured (and we have both options, show a separator) + if has_api_key { + if has_oauth { + content = content.child( + h_flex() + .gap_2() + .items_center() + .child(div().h_px().flex_1().bg(cx.theme().colors().border)) + .child(Label::new("or").size(LabelSize::Small).color(Color::Muted)) + .child(div().h_px().flex_1().bg(cx.theme().colors().border)), + ); + } + + let credential_label = self + .auth_config + .as_ref() + .and_then(|c| c.credential_label.clone()) + .unwrap_or_else(|| "API Key".to_string()); + + content = content.child( + v_flex() + .gap_2() + .on_action(cx.listener(Self::save_api_key)) + .child( + Label::new(credential_label) + .size(LabelSize::Small) + .color(Color::Muted), + ) + .child(self.api_key_editor.clone()) + .child( + Label::new("Enter your API key and press Enter to save") + .size(LabelSize::Small) + .color(Color::Muted), + ), + ); + } + } + + content.into_any_element() + } +} + +impl Focusable for ExtensionProviderConfigurationView { + fn focus_handle(&self, cx: &App) -> gpui::FocusHandle { + self.api_key_editor.focus_handle(cx) + } +} + +fn settings_markdown_style(window: &Window, cx: &App) -> MarkdownStyle { + let theme_settings = ThemeSettings::get_global(cx); + let colors = cx.theme().colors(); + let mut text_style = window.text_style(); + text_style.refine(&TextStyleRefinement { + font_family: Some(theme_settings.ui_font.family.clone()), + font_fallbacks: theme_settings.ui_font.fallbacks.clone(), + font_features: Some(theme_settings.ui_font.features.clone()), + color: Some(colors.text), + ..Default::default() + }); + + MarkdownStyle { + base_text_style: text_style, + selection_background_color: colors.element_selection_background, + inline_code: TextStyleRefinement { + background_color: Some(colors.editor_background), + ..Default::default() + }, + link: TextStyleRefinement { + color: Some(colors.text_accent), + underline: Some(UnderlineStyle { + color: Some(colors.text_accent.opacity(0.5)), + thickness: px(1.), + ..Default::default() + }), + ..Default::default() + }, + syntax: cx.theme().syntax().clone(), + ..Default::default() + } +} + +/// An extension-based language model. +pub struct ExtensionLanguageModel { + extension: WasmExtension, + model_info: LlmModelInfo, + provider_id: LanguageModelProviderId, + provider_name: LanguageModelProviderName, + provider_info: LlmProviderInfo, +} + +impl LanguageModel for ExtensionLanguageModel { + fn id(&self) -> LanguageModelId { + LanguageModelId::from(self.model_info.id.clone()) + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model_info.name.clone()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + self.provider_id.clone() + } + + fn provider_name(&self) -> LanguageModelProviderName { + self.provider_name.clone() + } + + fn telemetry_id(&self) -> String { + format!("extension-{}", self.model_info.id) + } + + fn supports_images(&self) -> bool { + self.model_info.capabilities.supports_images + } + + fn supports_tools(&self) -> bool { + self.model_info.capabilities.supports_tools + } + + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + match choice { + LanguageModelToolChoice::Auto => self.model_info.capabilities.supports_tool_choice_auto, + LanguageModelToolChoice::Any => self.model_info.capabilities.supports_tool_choice_any, + LanguageModelToolChoice::None => self.model_info.capabilities.supports_tool_choice_none, + } + } + + fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { + match self.model_info.capabilities.tool_input_format { + LlmToolInputFormat::JsonSchema => LanguageModelToolSchemaFormat::JsonSchema, + LlmToolInputFormat::Simplified => LanguageModelToolSchemaFormat::JsonSchema, + } + } + + fn max_token_count(&self) -> u64 { + self.model_info.max_token_count + } + + fn max_output_tokens(&self) -> Option { + self.model_info.max_output_tokens + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + let extension = self.extension.clone(); + let provider_id = self.provider_info.id.clone(); + let model_id = self.model_info.id.clone(); + + let wit_request = convert_request_to_wit(request); + + cx.background_spawn(async move { + extension + .call({ + let provider_id = provider_id.clone(); + let model_id = model_id.clone(); + let wit_request = wit_request.clone(); + |ext, store| { + async move { + let count = ext + .call_llm_count_tokens(store, &provider_id, &model_id, &wit_request) + .await? + .map_err(|e| anyhow!("{}", e))?; + Ok(count) + } + .boxed() + } + }) + .await? + }) + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + _cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { + let extension = self.extension.clone(); + let provider_id = self.provider_info.id.clone(); + let model_id = self.model_info.id.clone(); + + let wit_request = convert_request_to_wit(request); + + async move { + // Start the stream + let stream_id_result = extension + .call({ + let provider_id = provider_id.clone(); + let model_id = model_id.clone(); + let wit_request = wit_request.clone(); + |ext, store| { + async move { + let id = ext + .call_llm_stream_completion_start( + store, + &provider_id, + &model_id, + &wit_request, + ) + .await? + .map_err(|e| anyhow!("{}", e))?; + Ok(id) + } + .boxed() + } + }) + .await; + + let stream_id = stream_id_result + .map_err(LanguageModelCompletionError::Other)? + .map_err(LanguageModelCompletionError::Other)?; + + // Create a stream that polls for events + let stream = futures::stream::unfold( + (extension.clone(), stream_id, false), + move |(extension, stream_id, done)| async move { + if done { + return None; + } + + let result = extension + .call({ + let stream_id = stream_id.clone(); + |ext, store| { + async move { + let event = ext + .call_llm_stream_completion_next(store, &stream_id) + .await? + .map_err(|e| anyhow!("{}", e))?; + Ok(event) + } + .boxed() + } + }) + .await + .and_then(|inner| inner); + + match result { + Ok(Some(event)) => { + let converted = convert_completion_event(event); + let is_done = + matches!(&converted, Ok(LanguageModelCompletionEvent::Stop(_))); + Some((converted, (extension, stream_id, is_done))) + } + Ok(None) => { + // Stream complete, close it + let _ = extension + .call({ + let stream_id = stream_id.clone(); + |ext, store| { + async move { + ext.call_llm_stream_completion_close(store, &stream_id) + .await?; + Ok::<(), anyhow::Error>(()) + } + .boxed() + } + }) + .await; + None + } + Err(e) => Some(( + Err(LanguageModelCompletionError::Other(e)), + (extension, stream_id, true), + )), + } + }, + ); + + Ok(stream.boxed()) + } + .boxed() + } + + fn cache_configuration(&self) -> Option { + // Extensions can implement this via llm_cache_configuration + None + } +} + +fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest { + use language_model::{MessageContent, Role}; + + let messages: Vec = request + .messages + .into_iter() + .map(|msg| { + let role = match msg.role { + Role::User => LlmMessageRole::User, + Role::Assistant => LlmMessageRole::Assistant, + Role::System => LlmMessageRole::System, + }; + + let content: Vec = msg + .content + .into_iter() + .map(|c| match c { + MessageContent::Text(text) => LlmMessageContent::Text(text), + MessageContent::Image(image) => LlmMessageContent::Image(LlmImageData { + source: image.source.to_string(), + width: Some(image.size.width.0 as u32), + height: Some(image.size.height.0 as u32), + }), + MessageContent::ToolUse(tool_use) => LlmMessageContent::ToolUse(LlmToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + input: serde_json::to_string(&tool_use.input).unwrap_or_default(), + thought_signature: tool_use.thought_signature, + }), + MessageContent::ToolResult(tool_result) => { + let content = match tool_result.content { + language_model::LanguageModelToolResultContent::Text(text) => { + LlmToolResultContent::Text(text.to_string()) + } + language_model::LanguageModelToolResultContent::Image(image) => { + LlmToolResultContent::Image(LlmImageData { + source: image.source.to_string(), + width: Some(image.size.width.0 as u32), + height: Some(image.size.height.0 as u32), + }) + } + }; + LlmMessageContent::ToolResult(LlmToolResult { + tool_use_id: tool_result.tool_use_id.to_string(), + tool_name: tool_result.tool_name.to_string(), + is_error: tool_result.is_error, + content, + }) + } + MessageContent::Thinking { text, signature } => { + LlmMessageContent::Thinking(LlmThinkingContent { text, signature }) + } + MessageContent::RedactedThinking(data) => { + LlmMessageContent::RedactedThinking(data) + } + }) + .collect(); + + LlmRequestMessage { + role, + content, + cache: msg.cache, + } + }) + .collect(); + + let tools: Vec = request + .tools + .into_iter() + .map(|tool| LlmToolDefinition { + name: tool.name, + description: tool.description, + input_schema: serde_json::to_string(&tool.input_schema).unwrap_or_default(), + }) + .collect(); + + let tool_choice = request.tool_choice.map(|tc| match tc { + LanguageModelToolChoice::Auto => LlmToolChoice::Auto, + LanguageModelToolChoice::Any => LlmToolChoice::Any, + LanguageModelToolChoice::None => LlmToolChoice::None, + }); + + LlmCompletionRequest { + messages, + tools, + tool_choice, + stop_sequences: request.stop, + temperature: request.temperature, + thinking_allowed: false, + max_tokens: None, + } +} + +fn convert_completion_event( + event: LlmCompletionEvent, +) -> Result { + match event { + LlmCompletionEvent::Started => Ok(LanguageModelCompletionEvent::StartMessage { + message_id: String::new(), + }), + LlmCompletionEvent::Text(text) => Ok(LanguageModelCompletionEvent::Text(text)), + LlmCompletionEvent::Thinking(thinking) => Ok(LanguageModelCompletionEvent::Thinking { + text: thinking.text, + signature: thinking.signature, + }), + LlmCompletionEvent::RedactedThinking(data) => { + Ok(LanguageModelCompletionEvent::RedactedThinking { data }) + } + LlmCompletionEvent::ToolUse(tool_use) => { + let raw_input = tool_use.input.clone(); + let input = serde_json::from_str(&tool_use.input).unwrap_or(serde_json::Value::Null); + Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(tool_use.id), + name: tool_use.name.into(), + raw_input, + input, + is_input_complete: true, + thought_signature: tool_use.thought_signature, + }, + )) + } + LlmCompletionEvent::ToolUseJsonParseError(error) => { + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: LanguageModelToolUseId::from(error.id), + tool_name: error.tool_name.into(), + raw_input: error.raw_input.into(), + json_parse_error: error.error, + }) + } + LlmCompletionEvent::Stop(reason) => { + let stop_reason = match reason { + LlmStopReason::EndTurn => StopReason::EndTurn, + LlmStopReason::MaxTokens => StopReason::MaxTokens, + LlmStopReason::ToolUse => StopReason::ToolUse, + LlmStopReason::Refusal => StopReason::Refusal, + }; + Ok(LanguageModelCompletionEvent::Stop(stop_reason)) + } + LlmCompletionEvent::Usage(usage) => { + Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), + })) + } + LlmCompletionEvent::ReasoningDetails(json) => { + Ok(LanguageModelCompletionEvent::ReasoningDetails( + serde_json::from_str(&json).unwrap_or(serde_json::Value::Null), + )) + } + } +} diff --git a/crates/extension_host/src/wasm_host/wit.rs b/crates/extension_host/src/wasm_host/wit.rs index 5058c63365021a..c2b22d2ad02278 100644 --- a/crates/extension_host/src/wasm_host/wit.rs +++ b/crates/extension_host/src/wasm_host/wit.rs @@ -16,7 +16,7 @@ use lsp::LanguageServerName; use release_channel::ReleaseChannel; use task::{DebugScenario, SpawnInTerminal, TaskTemplate, ZedDebugConfig}; -use crate::wasm_host::wit::since_v0_6_0::dap::StartDebuggingRequestArgumentsRequest; +use crate::wasm_host::wit::since_v0_8_0::dap::StartDebuggingRequestArgumentsRequest; use super::{WasmState, wasm_engine}; use anyhow::{Context as _, Result, anyhow}; @@ -33,6 +33,19 @@ pub use latest::CodeLabelSpanLiteral; pub use latest::{ CodeLabel, CodeLabelSpan, Command, DebugAdapterBinary, ExtensionProject, Range, SlashCommand, zed::extension::context_server::ContextServerConfiguration, + zed::extension::llm_provider::{ + CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent, + CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType, + ImageData as LlmImageData, MessageContent as LlmMessageContent, + MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities, + ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo, + RequestMessage as LlmRequestMessage, StopReason as LlmStopReason, + ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage, + ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition, + ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult, + ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse, + ToolUseJsonParseError as LlmToolUseJsonParseError, + }, zed::extension::lsp::{ Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind, }, @@ -1007,6 +1020,20 @@ impl Extension { resource: Resource>, ) -> Result> { match self { + Extension::V0_8_0(ext) => { + let dap_binary = ext + .call_get_dap_binary( + store, + &adapter_name, + &task.try_into()?, + user_installed_path.as_ref().and_then(|p| p.to_str()), + resource, + ) + .await? + .map_err(|e| anyhow!("{e:?}"))?; + + Ok(Ok(dap_binary)) + } Extension::V0_6_0(ext) => { let dap_binary = ext .call_get_dap_binary( @@ -1032,6 +1059,16 @@ impl Extension { config: serde_json::Value, ) -> Result> { match self { + Extension::V0_8_0(ext) => { + let config = + serde_json::to_string(&config).context("Adapter config is not a valid JSON")?; + let result = ext + .call_dap_request_kind(store, &adapter_name, &config) + .await? + .map_err(|e| anyhow!("{e:?}"))?; + + Ok(Ok(result)) + } Extension::V0_6_0(ext) => { let config = serde_json::to_string(&config).context("Adapter config is not a valid JSON")?; @@ -1052,6 +1089,15 @@ impl Extension { config: ZedDebugConfig, ) -> Result> { match self { + Extension::V0_8_0(ext) => { + let config = config.into(); + let result = ext + .call_dap_config_to_scenario(store, &config) + .await? + .map_err(|e| anyhow!("{e:?}"))?; + + Ok(Ok(result.try_into()?)) + } Extension::V0_6_0(ext) => { let config = config.into(); let dap_binary = ext @@ -1074,6 +1120,20 @@ impl Extension { debug_adapter_name: String, ) -> Result> { match self { + Extension::V0_8_0(ext) => { + let build_config_template = build_config_template.into(); + let result = ext + .call_dap_locator_create_scenario( + store, + &locator_name, + &build_config_template, + &resolved_label, + &debug_adapter_name, + ) + .await?; + + Ok(result.map(TryInto::try_into).transpose()?) + } Extension::V0_6_0(ext) => { let build_config_template = build_config_template.into(); let dap_binary = ext @@ -1099,6 +1159,15 @@ impl Extension { resolved_build_task: SpawnInTerminal, ) -> Result> { match self { + Extension::V0_8_0(ext) => { + let build_config_template = resolved_build_task.try_into()?; + let dap_request = ext + .call_run_dap_locator(store, &locator_name, &build_config_template) + .await? + .map_err(|e| anyhow!("{e:?}"))?; + + Ok(Ok(dap_request.into())) + } Extension::V0_6_0(ext) => { let build_config_template = resolved_build_task.try_into()?; let dap_request = ext @@ -1111,6 +1180,185 @@ impl Extension { _ => anyhow::bail!("`dap_locator_create_scenario` not available prior to v0.6.0"), } } + + pub async fn call_llm_providers( + &self, + store: &mut Store, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => ext.call_llm_providers(store).await, + _ => Ok(Vec::new()), + } + } + + pub async fn call_llm_provider_models( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result, String>> { + match self { + Extension::V0_8_0(ext) => ext.call_llm_provider_models(store, provider_id).await, + _ => anyhow::bail!("`llm_provider_models` not available prior to v0.8.0"), + } + } + + pub async fn call_llm_provider_settings_markdown( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_provider_settings_markdown(store, provider_id) + .await + } + _ => Ok(None), + } + } + + pub async fn call_llm_provider_is_authenticated( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_provider_is_authenticated(store, provider_id) + .await + } + _ => Ok(false), + } + } + + pub async fn call_llm_provider_authenticate( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => ext.call_llm_provider_authenticate(store, provider_id).await, + _ => anyhow::bail!("`llm_provider_authenticate` not available prior to v0.8.0"), + } + } + + pub async fn call_llm_provider_start_device_flow_sign_in( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_provider_start_device_flow_sign_in(store, provider_id) + .await + } + _ => { + anyhow::bail!( + "`llm_provider_start_device_flow_sign_in` not available prior to v0.8.0" + ) + } + } + } + + pub async fn call_llm_provider_poll_device_flow_sign_in( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_provider_poll_device_flow_sign_in(store, provider_id) + .await + } + _ => { + anyhow::bail!( + "`llm_provider_poll_device_flow_sign_in` not available prior to v0.8.0" + ) + } + } + } + + pub async fn call_llm_provider_reset_credentials( + &self, + store: &mut Store, + provider_id: &str, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_provider_reset_credentials(store, provider_id) + .await + } + _ => anyhow::bail!("`llm_provider_reset_credentials` not available prior to v0.8.0"), + } + } + + pub async fn call_llm_count_tokens( + &self, + store: &mut Store, + provider_id: &str, + model_id: &str, + request: &latest::llm_provider::CompletionRequest, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_count_tokens(store, provider_id, model_id, request) + .await + } + _ => anyhow::bail!("`llm_count_tokens` not available prior to v0.8.0"), + } + } + + pub async fn call_llm_stream_completion_start( + &self, + store: &mut Store, + provider_id: &str, + model_id: &str, + request: &latest::llm_provider::CompletionRequest, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_stream_completion_start(store, provider_id, model_id, request) + .await + } + _ => anyhow::bail!("`llm_stream_completion_start` not available prior to v0.8.0"), + } + } + + pub async fn call_llm_stream_completion_next( + &self, + store: &mut Store, + stream_id: &str, + ) -> Result, String>> { + match self { + Extension::V0_8_0(ext) => ext.call_llm_stream_completion_next(store, stream_id).await, + _ => anyhow::bail!("`llm_stream_completion_next` not available prior to v0.8.0"), + } + } + + pub async fn call_llm_stream_completion_close( + &self, + store: &mut Store, + stream_id: &str, + ) -> Result<()> { + match self { + Extension::V0_8_0(ext) => ext.call_llm_stream_completion_close(store, stream_id).await, + _ => anyhow::bail!("`llm_stream_completion_close` not available prior to v0.8.0"), + } + } + + pub async fn call_llm_cache_configuration( + &self, + store: &mut Store, + provider_id: &str, + model_id: &str, + ) -> Result> { + match self { + Extension::V0_8_0(ext) => { + ext.call_llm_cache_configuration(store, provider_id, model_id) + .await + } + _ => Ok(None), + } + } } trait ToWasmtimeResult { diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs index 8595c278b95a43..45bec57ee376aa 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_6_0.rs @@ -32,8 +32,6 @@ wasmtime::component::bindgen!({ }, }); -pub use self::zed::extension::*; - mod settings { #![allow(dead_code)] include!(concat!(env!("OUT_DIR"), "/since_v0.6.0/settings.rs")); diff --git a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs index a2776f9f3b5b05..a7fc76ffb6d489 100644 --- a/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs +++ b/crates/extension_host/src/wasm_host/wit/since_v0_8_0.rs @@ -1,11 +1,11 @@ -use crate::wasm_host::wit::since_v0_6_0::{ +use crate::wasm_host::wit::since_v0_8_0::{ dap::{ AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest, StartDebuggingRequestArguments, TcpArguments, TcpArgumentsTemplate, }, + lsp::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind}, slash_command::SlashCommandOutputSection, }; -use crate::wasm_host::wit::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind}; use crate::wasm_host::{WasmState, wit::ToWasmtimeResult}; use ::http_client::{AsyncBody, HttpRequestExt}; use ::settings::{Settings, WorktreeId}; @@ -13,6 +13,7 @@ use anyhow::{Context as _, Result, bail}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; use async_trait::async_trait; +use credentials_provider::CredentialsProvider; use extension::{ ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate, }; @@ -22,12 +23,14 @@ use gpui::{BackgroundExecutor, SharedString}; use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings}; use project::project_settings::ProjectSettings; use semver::Version; +use smol::net::TcpListener; use std::{ env, net::Ipv4Addr, path::{Path, PathBuf}, str::FromStr, sync::{Arc, OnceLock}, + time::Duration, }; use task::{SpawnInTerminal, ZedDebugConfig}; use url::Url; @@ -1107,3 +1110,361 @@ impl ExtensionImports for WasmState { .to_wasmtime_result() } } + +impl llm_provider::Host for WasmState { + async fn request_credential( + &mut self, + _provider_id: String, + _credential_type: llm_provider::CredentialType, + _label: String, + _placeholder: String, + ) -> wasmtime::Result> { + // For now, credential requests return false (not provided) + // Extensions should use get_env_var to check for env vars first, + // then store_credential/get_credential for manual storage + // Full UI credential prompting will be added in a future phase + Ok(Ok(false)) + } + + async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result> { + let extension_id = self.manifest.id.clone(); + + // Check if this provider has an env var configured and if the user has allowed it + let env_var_name = self + .manifest + .language_model_providers + .get(&Arc::::from(provider_id.as_str())) + .and_then(|entry| entry.auth.as_ref()) + .and_then(|auth| auth.env_var.clone()); + + if let Some(env_var_name) = env_var_name { + let full_provider_id: Arc = format!("{}:{}", extension_id, provider_id).into(); + // Read settings dynamically to get current allowed_env_var_providers + let is_allowed = self + .on_main_thread({ + let full_provider_id = full_provider_id.clone(); + move |cx| { + async move { + cx.update(|cx| { + crate::extension_settings::ExtensionSettings::get_global(cx) + .allowed_env_var_providers + .contains(&full_provider_id) + }) + } + .boxed_local() + } + }) + .await + .unwrap_or(false); + + if is_allowed { + if let Ok(value) = env::var(&env_var_name) { + if !value.is_empty() { + return Ok(Some(value)); + } + } + } + } + + // Fall back to credential store + let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id); + + self.on_main_thread(move |cx| { + async move { + let credentials_provider = cx.update(|cx| ::global(cx))?; + let result = credentials_provider + .read_credentials(&credential_key, cx) + .await + .ok() + .flatten(); + Ok(result.map(|(_, password)| String::from_utf8_lossy(&password).to_string())) + } + .boxed_local() + }) + .await + } + + async fn store_credential( + &mut self, + provider_id: String, + value: String, + ) -> wasmtime::Result> { + let extension_id = self.manifest.id.clone(); + let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id); + + self.on_main_thread(move |cx| { + async move { + let credentials_provider = cx.update(|cx| ::global(cx))?; + credentials_provider + .write_credentials(&credential_key, "api_key", value.as_bytes(), cx) + .await + .map_err(|e| anyhow::anyhow!("{}", e)) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } + + async fn delete_credential( + &mut self, + provider_id: String, + ) -> wasmtime::Result> { + let extension_id = self.manifest.id.clone(); + let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id); + + self.on_main_thread(move |cx| { + async move { + let credentials_provider = cx.update(|cx| ::global(cx))?; + credentials_provider + .delete_credentials(&credential_key, cx) + .await + .map_err(|e| anyhow::anyhow!("{}", e)) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } + + async fn get_env_var(&mut self, name: String) -> wasmtime::Result> { + let extension_id = self.manifest.id.clone(); + + // Find which provider (if any) declares this env var in its auth config + let mut allowed_provider_id: Option> = None; + for (provider_id, provider_entry) in &self.manifest.language_model_providers { + if let Some(auth_config) = &provider_entry.auth { + if auth_config.env_var.as_deref() == Some(&name) { + allowed_provider_id = Some(provider_id.clone()); + break; + } + } + } + + // If no provider declares this env var, deny access + let Some(provider_id) = allowed_provider_id else { + log::warn!( + "Extension {} attempted to read env var {} which is not declared in any provider auth config", + extension_id, + name + ); + return Ok(None); + }; + + // Check if the user has allowed this provider to read env vars + // Read settings dynamically to get current allowed_env_var_providers + let full_provider_id: Arc = format!("{}:{}", extension_id, provider_id).into(); + let is_allowed = self + .on_main_thread({ + let full_provider_id = full_provider_id.clone(); + move |cx| { + async move { + cx.update(|cx| { + crate::extension_settings::ExtensionSettings::get_global(cx) + .allowed_env_var_providers + .contains(&full_provider_id) + }) + } + .boxed_local() + } + }) + .await + .unwrap_or(false); + + if !is_allowed { + log::debug!( + "Extension {} provider {} is not allowed to read env var {}", + extension_id, + provider_id, + name + ); + return Ok(None); + } + + Ok(env::var(&name).ok()) + } + + async fn oauth_start_web_auth( + &mut self, + config: llm_provider::OauthWebAuthConfig, + ) -> wasmtime::Result> { + let auth_url = config.auth_url; + let callback_path = config.callback_path; + let timeout_secs = config.timeout_secs.unwrap_or(300); + + self.on_main_thread(move |cx| { + async move { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?; + let port = listener + .local_addr() + .map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))? + .port(); + + let auth_url_with_port = auth_url.replace("{port}", &port.to_string()); + cx.update(|cx| { + cx.open_url(&auth_url_with_port); + })?; + + let accept_future = async { + let (mut stream, _) = listener + .accept() + .await + .map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?; + + let mut request_line = String::new(); + { + let mut reader = smol::io::BufReader::new(&mut stream); + smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line) + .await + .map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?; + } + + let callback_url = if let Some(path_start) = request_line.find(' ') { + if let Some(path_end) = request_line[path_start + 1..].find(' ') { + let path = &request_line[path_start + 1..path_start + 1 + path_end]; + if path.starts_with(&callback_path) || path.starts_with(&format!("/{}", callback_path.trim_start_matches('/'))) { + format!("http://localhost:{}{}", port, path) + } else { + return Err(anyhow::anyhow!( + "Unexpected callback path: {}", + path + )); + } + } else { + return Err(anyhow::anyhow!("Malformed HTTP request")); + } + } else { + return Err(anyhow::anyhow!("Malformed HTTP request")); + }; + + let response = "HTTP/1.1 200 OK\r\n\ + Content-Type: text/html\r\n\ + Connection: close\r\n\ + \r\n\ + \ + Authentication Complete\ + \ +
\ +

Authentication Complete

\ +

You can close this window and return to Zed.

\ +
"; + + smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes()) + .await + .ok(); + smol::io::AsyncWriteExt::flush(&mut stream).await.ok(); + + Ok(callback_url) + }; + + let timeout_duration = Duration::from_secs(timeout_secs as u64); + let callback_url = smol::future::or( + accept_future, + async { + smol::Timer::after(timeout_duration).await; + Err(anyhow::anyhow!( + "OAuth callback timed out after {} seconds", + timeout_secs + )) + }, + ) + .await?; + + Ok(llm_provider::OauthWebAuthResult { + callback_url, + port: port as u32, + }) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } + + async fn send_oauth_http_request( + &mut self, + request: llm_provider::OauthHttpRequest, + ) -> wasmtime::Result> { + let http_client = self.host.http_client.clone(); + + self.on_main_thread(move |_cx| { + async move { + let method = match request.method.to_uppercase().as_str() { + "GET" => ::http_client::Method::GET, + "POST" => ::http_client::Method::POST, + "PUT" => ::http_client::Method::PUT, + "DELETE" => ::http_client::Method::DELETE, + "PATCH" => ::http_client::Method::PATCH, + _ => { + return Err(anyhow::anyhow!( + "Unsupported HTTP method: {}", + request.method + )); + } + }; + + let mut builder = ::http_client::Request::builder() + .method(method) + .uri(&request.url); + + for (key, value) in &request.headers { + builder = builder.header(key.as_str(), value.as_str()); + } + + let body = if request.body.is_empty() { + AsyncBody::empty() + } else { + AsyncBody::from(request.body.into_bytes()) + }; + + let http_request = builder + .body(body) + .map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?; + + let mut response = http_client + .send(http_request) + .await + .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?; + + let status = response.status().as_u16(); + let headers: Vec<(String, String)> = response + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + let mut body_bytes = Vec::new(); + futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes) + .await + .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?; + + let body = String::from_utf8_lossy(&body_bytes).to_string(); + + Ok(llm_provider::OauthHttpResponse { + status, + headers, + body, + }) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } + + async fn oauth_open_browser(&mut self, url: String) -> wasmtime::Result> { + self.on_main_thread(move |cx| { + async move { + cx.update(|cx| { + cx.open_url(&url); + })?; + Ok(()) + } + .boxed_local() + }) + .await + .to_wasmtime_result() + } +} diff --git a/crates/icons/src/icons.rs b/crates/icons/src/icons.rs index d28e2c1030c3c2..ce4ba4d3fa2aa3 100644 --- a/crates/icons/src/icons.rs +++ b/crates/icons/src/icons.rs @@ -9,7 +9,6 @@ use strum::{EnumIter, EnumString, IntoStaticStr}; #[strum(serialize_all = "snake_case")] pub enum IconName { Ai, - AiAnthropic, AiBedrock, AiClaude, AiDeepSeek, diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index c9b6391136da1a..60fe4226ac3b45 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -746,6 +746,11 @@ pub trait LanguageModelProvider: 'static { fn icon(&self) -> IconName { IconName::ZedAssistant } + /// Returns the path to an external SVG icon for this provider, if any. + /// When present, this takes precedence over `icon()`. + fn icon_path(&self) -> Option { + None + } fn default_model(&self, cx: &App) -> Option>; fn default_fast_model(&self, cx: &App) -> Option>; fn provided_models(&self, cx: &App) -> Vec>; diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index 6c5704312d94e2..decb32c5aa4000 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -28,6 +28,7 @@ convert_case.workspace = true copilot.workspace = true credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } +extension.workspace = true fs.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } diff --git a/crates/language_models/src/extension.rs b/crates/language_models/src/extension.rs new file mode 100644 index 00000000000000..9af6f41bd59955 --- /dev/null +++ b/crates/language_models/src/extension.rs @@ -0,0 +1,34 @@ +use extension::{ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration}; +use gpui::{App, Entity}; +use language_model::{LanguageModelProviderId, LanguageModelRegistry}; +use std::sync::Arc; + +/// Proxy implementation that registers extension-based language model providers +/// with the LanguageModelRegistry. +pub struct ExtensionLanguageModelProxy { + registry: Entity, +} + +impl ExtensionLanguageModelProxy { + pub fn new(registry: Entity) -> Self { + Self { registry } + } +} + +impl ExtensionLanguageModelProviderProxy for ExtensionLanguageModelProxy { + fn register_language_model_provider( + &self, + provider_id: Arc, + register_fn: LanguageModelProviderRegistration, + cx: &mut App, + ) { + let _ = provider_id; + register_fn(cx); + } + + fn unregister_language_model_provider(&self, provider_id: Arc, cx: &mut App) { + self.registry.update(cx, |registry, cx| { + registry.unregister_provider(LanguageModelProviderId::from(provider_id), cx); + }); + } +} diff --git a/crates/language_models/src/language_models.rs b/crates/language_models/src/language_models.rs index d771dba3733540..8b8ca1e2912e2a 100644 --- a/crates/language_models/src/language_models.rs +++ b/crates/language_models/src/language_models.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use ::extension::ExtensionHostProxy; use ::settings::{Settings, SettingsStore}; use client::{Client, UserStore}; use collections::HashSet; @@ -8,11 +9,11 @@ use language_model::{LanguageModelProviderId, LanguageModelRegistry}; use provider::deepseek::DeepSeekLanguageModelProvider; mod api_key; +mod extension; pub mod provider; mod settings; pub mod ui; -use crate::provider::anthropic::AnthropicLanguageModelProvider; use crate::provider::bedrock::BedrockLanguageModelProvider; use crate::provider::cloud::CloudLanguageModelProvider; use crate::provider::copilot_chat::CopilotChatLanguageModelProvider; @@ -33,6 +34,12 @@ pub fn init(user_store: Entity, client: Arc, cx: &mut App) { register_language_model_providers(registry, user_store, client.clone(), cx); }); + // Register the extension language model provider proxy + let extension_proxy = ExtensionHostProxy::default_global(cx); + extension_proxy.register_language_model_provider_proxy( + extension::ExtensionLanguageModelProxy::new(registry.clone()), + ); + let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx) .openai_compatible .keys() @@ -111,13 +118,6 @@ fn register_language_model_providers( )), cx, ); - registry.register_provider( - Arc::new(AnthropicLanguageModelProvider::new( - client.http_client(), - cx, - )), - cx, - ); registry.register_provider( Arc::new(OpenAiLanguageModelProvider::new(client.http_client(), cx)), cx, diff --git a/crates/language_models/src/provider.rs b/crates/language_models/src/provider.rs index d780195c66ec0d..e585fc06f6b523 100644 --- a/crates/language_models/src/provider.rs +++ b/crates/language_models/src/provider.rs @@ -1,4 +1,3 @@ -pub mod anthropic; pub mod bedrock; pub mod cloud; pub mod copilot_chat; diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs deleted file mode 100644 index 1affe38a08d22e..00000000000000 --- a/crates/language_models/src/provider/anthropic.rs +++ /dev/null @@ -1,1045 +0,0 @@ -use anthropic::{ - ANTHROPIC_API_URL, AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, - ToolResultContent, ToolResultPart, Usage, -}; -use anyhow::{Result, anyhow}; -use collections::{BTreeMap, HashMap}; -use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, App, AsyncApp, Context, Entity, Task}; -use http_client::HttpClient; -use language_model::{ - AuthenticateError, ConfigurationViewTargetAgent, LanguageModel, - LanguageModelCacheConfiguration, LanguageModelCompletionError, LanguageModelId, - LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolResultContent, MessageContent, RateLimiter, Role, -}; -use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason}; -use settings::{Settings, SettingsStore}; -use std::pin::Pin; -use std::str::FromStr; -use std::sync::{Arc, LazyLock}; -use strum::IntoEnumIterator; -use ui::{List, prelude::*}; -use ui_input::InputField; -use util::ResultExt; -use zed_env_vars::{EnvVar, env_var}; - -use crate::api_key::ApiKeyState; -use crate::ui::{ConfiguredApiCard, InstructionListItem}; - -pub use settings::AnthropicAvailableModel as AvailableModel; - -const PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID; -const PROVIDER_NAME: LanguageModelProviderName = language_model::ANTHROPIC_PROVIDER_NAME; - -#[derive(Default, Clone, Debug, PartialEq)] -pub struct AnthropicSettings { - pub api_url: String, - /// Extend Zed's list of Anthropic models. - pub available_models: Vec, -} - -pub struct AnthropicLanguageModelProvider { - http_client: Arc, - state: Entity, -} - -const API_KEY_ENV_VAR_NAME: &str = "ANTHROPIC_API_KEY"; -static API_KEY_ENV_VAR: LazyLock = env_var!(API_KEY_ENV_VAR_NAME); - -pub struct State { - api_key_state: ApiKeyState, -} - -impl State { - fn is_authenticated(&self) -> bool { - self.api_key_state.has_key() - } - - fn set_api_key(&mut self, api_key: Option, cx: &mut Context) -> Task> { - let api_url = AnthropicLanguageModelProvider::api_url(cx); - self.api_key_state - .store(api_url, api_key, |this| &mut this.api_key_state, cx) - } - - fn authenticate(&mut self, cx: &mut Context) -> Task> { - let api_url = AnthropicLanguageModelProvider::api_url(cx); - self.api_key_state.load_if_needed( - api_url, - &API_KEY_ENV_VAR, - |this| &mut this.api_key_state, - cx, - ) - } -} - -impl AnthropicLanguageModelProvider { - pub fn new(http_client: Arc, cx: &mut App) -> Self { - let state = cx.new(|cx| { - cx.observe_global::(|this: &mut State, cx| { - let api_url = Self::api_url(cx); - this.api_key_state.handle_url_change( - api_url, - &API_KEY_ENV_VAR, - |this| &mut this.api_key_state, - cx, - ); - cx.notify(); - }) - .detach(); - State { - api_key_state: ApiKeyState::new(Self::api_url(cx)), - } - }); - - Self { http_client, state } - } - - fn create_language_model(&self, model: anthropic::Model) -> Arc { - Arc::new(AnthropicModel { - id: LanguageModelId::from(model.id().to_string()), - model, - state: self.state.clone(), - http_client: self.http_client.clone(), - request_limiter: RateLimiter::new(4), - }) - } - - fn settings(cx: &App) -> &AnthropicSettings { - &crate::AllLanguageModelSettings::get_global(cx).anthropic - } - - fn api_url(cx: &App) -> SharedString { - let api_url = &Self::settings(cx).api_url; - if api_url.is_empty() { - ANTHROPIC_API_URL.into() - } else { - SharedString::new(api_url.as_str()) - } - } -} - -impl LanguageModelProviderState for AnthropicLanguageModelProvider { - type ObservableEntity = State; - - fn observable_entity(&self) -> Option> { - Some(self.state.clone()) - } -} - -impl LanguageModelProvider for AnthropicLanguageModelProvider { - fn id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn icon(&self) -> IconName { - IconName::AiAnthropic - } - - fn default_model(&self, _cx: &App) -> Option> { - Some(self.create_language_model(anthropic::Model::default())) - } - - fn default_fast_model(&self, _cx: &App) -> Option> { - Some(self.create_language_model(anthropic::Model::default_fast())) - } - - fn recommended_models(&self, _cx: &App) -> Vec> { - [ - anthropic::Model::ClaudeSonnet4_5, - anthropic::Model::ClaudeSonnet4_5Thinking, - ] - .into_iter() - .map(|model| self.create_language_model(model)) - .collect() - } - - fn provided_models(&self, cx: &App) -> Vec> { - let mut models = BTreeMap::default(); - - // Add base models from anthropic::Model::iter() - for model in anthropic::Model::iter() { - if !matches!(model, anthropic::Model::Custom { .. }) { - models.insert(model.id().to_string(), model); - } - } - - // Override with available models from settings - for model in &AnthropicLanguageModelProvider::settings(cx).available_models { - models.insert( - model.name.clone(), - anthropic::Model::Custom { - name: model.name.clone(), - display_name: model.display_name.clone(), - max_tokens: model.max_tokens, - tool_override: model.tool_override.clone(), - cache_configuration: model.cache_configuration.as_ref().map(|config| { - anthropic::AnthropicModelCacheConfiguration { - max_cache_anchors: config.max_cache_anchors, - should_speculate: config.should_speculate, - min_total_token: config.min_total_token, - } - }), - max_output_tokens: model.max_output_tokens, - default_temperature: model.default_temperature, - extra_beta_headers: model.extra_beta_headers.clone(), - mode: model.mode.unwrap_or_default().into(), - }, - ); - } - - models - .into_values() - .map(|model| self.create_language_model(model)) - .collect() - } - - fn is_authenticated(&self, cx: &App) -> bool { - self.state.read(cx).is_authenticated() - } - - fn authenticate(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.authenticate(cx)) - } - - fn configuration_view( - &self, - target_agent: ConfigurationViewTargetAgent, - window: &mut Window, - cx: &mut App, - ) -> AnyView { - cx.new(|cx| ConfigurationView::new(self.state.clone(), target_agent, window, cx)) - .into() - } - - fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state - .update(cx, |state, cx| state.set_api_key(None, cx)) - } -} - -pub struct AnthropicModel { - id: LanguageModelId, - model: anthropic::Model, - state: Entity, - http_client: Arc, - request_limiter: RateLimiter, -} - -pub fn count_anthropic_tokens( - request: LanguageModelRequest, - cx: &App, -) -> BoxFuture<'static, Result> { - cx.background_spawn(async move { - let messages = request.messages; - let mut tokens_from_images = 0; - let mut string_messages = Vec::with_capacity(messages.len()); - - for message in messages { - use language_model::MessageContent; - - let mut string_contents = String::new(); - - for content in message.content { - match content { - MessageContent::Text(text) => { - string_contents.push_str(&text); - } - MessageContent::Thinking { .. } => { - // Thinking blocks are not included in the input token count. - } - MessageContent::RedactedThinking(_) => { - // Thinking blocks are not included in the input token count. - } - MessageContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - MessageContent::ToolUse(_tool_use) => { - // TODO: Estimate token usage from tool uses. - } - MessageContent::ToolResult(tool_result) => match &tool_result.content { - LanguageModelToolResultContent::Text(text) => { - string_contents.push_str(text); - } - LanguageModelToolResultContent::Image(image) => { - tokens_from_images += image.estimate_tokens(); - } - }, - } - } - - if !string_contents.is_empty() { - string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some(string_contents), - name: None, - function_call: None, - }); - } - } - - // Tiktoken doesn't yet support these models, so we manually use the - // same tokenizer as GPT-4. - tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) - .map(|tokens| (tokens + tokens_from_images) as u64) - }) - .boxed() -} - -impl AnthropicModel { - fn stream_completion( - &self, - request: anthropic::Request, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream<'static, Result>, - LanguageModelCompletionError, - >, - > { - let http_client = self.http_client.clone(); - - let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { - let api_url = AnthropicLanguageModelProvider::api_url(cx); - (state.api_key_state.key(&api_url), api_url) - }) else { - return future::ready(Err(anyhow!("App state dropped").into())).boxed(); - }; - - let beta_headers = self.model.beta_headers(); - - async move { - let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); - }; - let request = anthropic::stream_completion( - http_client.as_ref(), - &api_url, - &api_key, - request, - beta_headers, - ); - request.await.map_err(Into::into) - } - .boxed() - } -} - -impl LanguageModel for AnthropicModel { - fn id(&self) -> LanguageModelId { - self.id.clone() - } - - fn name(&self) -> LanguageModelName { - LanguageModelName::from(self.model.display_name().to_string()) - } - - fn provider_id(&self) -> LanguageModelProviderId { - PROVIDER_ID - } - - fn provider_name(&self) -> LanguageModelProviderName { - PROVIDER_NAME - } - - fn supports_tools(&self) -> bool { - true - } - - fn supports_images(&self) -> bool { - true - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - match choice { - LanguageModelToolChoice::Auto - | LanguageModelToolChoice::Any - | LanguageModelToolChoice::None => true, - } - } - - fn telemetry_id(&self) -> String { - format!("anthropic/{}", self.model.id()) - } - - fn api_key(&self, cx: &App) -> Option { - self.state.read_with(cx, |state, cx| { - let api_url = AnthropicLanguageModelProvider::api_url(cx); - state.api_key_state.key(&api_url).map(|key| key.to_string()) - }) - } - - fn max_token_count(&self) -> u64 { - self.model.max_token_count() - } - - fn max_output_tokens(&self) -> Option { - Some(self.model.max_output_tokens()) - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - count_anthropic_tokens(request, cx) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream<'static, Result>, - LanguageModelCompletionError, - >, - > { - let request = into_anthropic( - request, - self.model.request_id().into(), - self.model.default_temperature(), - self.model.max_output_tokens(), - self.model.mode(), - ); - let request = self.stream_completion(request, cx); - let future = self.request_limiter.stream(async move { - let response = request.await?; - Ok(AnthropicEventMapper::new().map_stream(response)) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } - - fn cache_configuration(&self) -> Option { - self.model - .cache_configuration() - .map(|config| LanguageModelCacheConfiguration { - max_cache_anchors: config.max_cache_anchors, - should_speculate: config.should_speculate, - min_total_token: config.min_total_token, - }) - } -} - -pub fn into_anthropic( - request: LanguageModelRequest, - model: String, - default_temperature: f32, - max_output_tokens: u64, - mode: AnthropicModelMode, -) -> anthropic::Request { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages { - if message.contents_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - let mut anthropic_message_content: Vec = message - .content - .into_iter() - .filter_map(|content| match content { - MessageContent::Text(text) => { - let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { - text.trim_end().to_string() - } else { - text - }; - if !text.is_empty() { - Some(anthropic::RequestContent::Text { - text, - cache_control: None, - }) - } else { - None - } - } - MessageContent::Thinking { - text: thinking, - signature, - } => { - if !thinking.is_empty() { - Some(anthropic::RequestContent::Thinking { - thinking, - signature: signature.unwrap_or_default(), - cache_control: None, - }) - } else { - None - } - } - MessageContent::RedactedThinking(data) => { - if !data.is_empty() { - Some(anthropic::RequestContent::RedactedThinking { data }) - } else { - None - } - } - MessageContent::Image(image) => Some(anthropic::RequestContent::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - cache_control: None, - }), - MessageContent::ToolUse(tool_use) => { - Some(anthropic::RequestContent::ToolUse { - id: tool_use.id.to_string(), - name: tool_use.name.to_string(), - input: tool_use.input, - cache_control: None, - }) - } - MessageContent::ToolResult(tool_result) => { - Some(anthropic::RequestContent::ToolResult { - tool_use_id: tool_result.tool_use_id.to_string(), - is_error: tool_result.is_error, - content: match tool_result.content { - LanguageModelToolResultContent::Text(text) => { - ToolResultContent::Plain(text.to_string()) - } - LanguageModelToolResultContent::Image(image) => { - ToolResultContent::Multipart(vec![ToolResultPart::Image { - source: anthropic::ImageSource { - source_type: "base64".to_string(), - media_type: "image/png".to_string(), - data: image.source.to_string(), - }, - }]) - } - }, - cache_control: None, - }) - } - }) - .collect(); - let anthropic_role = match message.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("System role should never occur here"), - }; - if let Some(last_message) = new_messages.last_mut() - && last_message.role == anthropic_role - { - last_message.content.extend(anthropic_message_content); - continue; - } - - // Mark the last segment of the message as cached - if message.cache { - let cache_control_value = Some(anthropic::CacheControl { - cache_type: anthropic::CacheControlType::Ephemeral, - }); - for message_content in anthropic_message_content.iter_mut().rev() { - match message_content { - anthropic::RequestContent::RedactedThinking { .. } => { - // Caching is not possible, fallback to next message - } - anthropic::RequestContent::Text { cache_control, .. } - | anthropic::RequestContent::Thinking { cache_control, .. } - | anthropic::RequestContent::Image { cache_control, .. } - | anthropic::RequestContent::ToolUse { cache_control, .. } - | anthropic::RequestContent::ToolResult { cache_control, .. } => { - *cache_control = cache_control_value; - break; - } - } - } - } - - new_messages.push(anthropic::Message { - role: anthropic_role, - content: anthropic_message_content, - }); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.string_contents()); - } - } - } - - anthropic::Request { - model, - messages: new_messages, - max_tokens: max_output_tokens, - system: if system_message.is_empty() { - None - } else { - Some(anthropic::StringOrContents::String(system_message)) - }, - thinking: if request.thinking_allowed - && let AnthropicModelMode::Thinking { budget_tokens } = mode - { - Some(anthropic::Thinking::Enabled { budget_tokens }) - } else { - None - }, - tools: request - .tools - .into_iter() - .map(|tool| anthropic::Tool { - name: tool.name, - description: tool.description, - input_schema: tool.input_schema, - }) - .collect(), - tool_choice: request.tool_choice.map(|choice| match choice { - LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, - LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, - LanguageModelToolChoice::None => anthropic::ToolChoice::None, - }), - metadata: None, - stop_sequences: Vec::new(), - temperature: request.temperature.or(Some(default_temperature)), - top_k: None, - top_p: None, - } -} - -pub struct AnthropicEventMapper { - tool_uses_by_index: HashMap, - usage: Usage, - stop_reason: StopReason, -} - -impl AnthropicEventMapper { - pub fn new() -> Self { - Self { - tool_uses_by_index: HashMap::default(), - usage: Usage::default(), - stop_reason: StopReason::EndTurn, - } - } - - pub fn map_stream( - mut self, - events: Pin>>>, - ) -> impl Stream> - { - events.flat_map(move |event| { - futures::stream::iter(match event { - Ok(event) => self.map_event(event), - Err(error) => vec![Err(error.into())], - }) - }) - } - - pub fn map_event( - &mut self, - event: Event, - ) -> Vec> { - match event { - Event::ContentBlockStart { - index, - content_block, - } => match content_block { - ResponseContent::Text { text } => { - vec![Ok(LanguageModelCompletionEvent::Text(text))] - } - ResponseContent::Thinking { thinking } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })] - } - ResponseContent::RedactedThinking { data } => { - vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] - } - ResponseContent::ToolUse { id, name, .. } => { - self.tool_uses_by_index.insert( - index, - RawToolUse { - id, - name, - input_json: String::new(), - }, - ); - Vec::new() - } - }, - Event::ContentBlockDelta { index, delta } => match delta { - ContentDelta::TextDelta { text } => { - vec![Ok(LanguageModelCompletionEvent::Text(text))] - } - ContentDelta::ThinkingDelta { thinking } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: thinking, - signature: None, - })] - } - ContentDelta::SignatureDelta { signature } => { - vec![Ok(LanguageModelCompletionEvent::Thinking { - text: "".to_string(), - signature: Some(signature), - })] - } - ContentDelta::InputJsonDelta { partial_json } => { - if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { - tool_use.input_json.push_str(&partial_json); - - // Try to convert invalid (incomplete) JSON into - // valid JSON that serde can accept, e.g. by closing - // unclosed delimiters. This way, we can update the - // UI with whatever has been streamed back so far. - if let Ok(input) = serde_json::Value::from_str( - &partial_json_fixer::fix_json(&tool_use.input_json), - ) { - return vec![Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.clone().into(), - name: tool_use.name.clone().into(), - is_input_complete: false, - raw_input: tool_use.input_json.clone(), - input, - thought_signature: None, - }, - ))]; - } - } - vec![] - } - }, - Event::ContentBlockStop { index } => { - if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { - let input_json = tool_use.input_json.trim(); - let input_value = if input_json.is_empty() { - Ok(serde_json::Value::Object(serde_json::Map::default())) - } else { - serde_json::Value::from_str(input_json) - }; - let event_result = match input_value { - Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( - LanguageModelToolUse { - id: tool_use.id.into(), - name: tool_use.name.into(), - is_input_complete: true, - input, - raw_input: tool_use.input_json.clone(), - thought_signature: None, - }, - )), - Err(json_parse_err) => { - Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { - id: tool_use.id.into(), - tool_name: tool_use.name.into(), - raw_input: input_json.into(), - json_parse_error: json_parse_err.to_string(), - }) - } - }; - - vec![event_result] - } else { - Vec::new() - } - } - Event::MessageStart { message } => { - update_usage(&mut self.usage, &message.usage); - vec![ - Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( - &self.usage, - ))), - Ok(LanguageModelCompletionEvent::StartMessage { - message_id: message.id, - }), - ] - } - Event::MessageDelta { delta, usage } => { - update_usage(&mut self.usage, &usage); - if let Some(stop_reason) = delta.stop_reason.as_deref() { - self.stop_reason = match stop_reason { - "end_turn" => StopReason::EndTurn, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - "refusal" => StopReason::Refusal, - _ => { - log::error!("Unexpected anthropic stop_reason: {stop_reason}"); - StopReason::EndTurn - } - }; - } - vec![Ok(LanguageModelCompletionEvent::UsageUpdate( - convert_usage(&self.usage), - ))] - } - Event::MessageStop => { - vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] - } - Event::Error { error } => { - vec![Err(error.into())] - } - _ => Vec::new(), - } - } -} - -struct RawToolUse { - id: String, - name: String, - input_json: String, -} - -/// Updates usage data by preferring counts from `new`. -fn update_usage(usage: &mut Usage, new: &Usage) { - if let Some(input_tokens) = new.input_tokens { - usage.input_tokens = Some(input_tokens); - } - if let Some(output_tokens) = new.output_tokens { - usage.output_tokens = Some(output_tokens); - } - if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { - usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); - } - if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { - usage.cache_read_input_tokens = Some(cache_read_input_tokens); - } -} - -fn convert_usage(usage: &Usage) -> language_model::TokenUsage { - language_model::TokenUsage { - input_tokens: usage.input_tokens.unwrap_or(0), - output_tokens: usage.output_tokens.unwrap_or(0), - cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), - cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), - } -} - -struct ConfigurationView { - api_key_editor: Entity, - state: Entity, - load_credentials_task: Option>, - target_agent: ConfigurationViewTargetAgent, -} - -impl ConfigurationView { - const PLACEHOLDER_TEXT: &'static str = "sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; - - fn new( - state: Entity, - target_agent: ConfigurationViewTargetAgent, - window: &mut Window, - cx: &mut Context, - ) -> Self { - cx.observe(&state, |_, _, cx| { - cx.notify(); - }) - .detach(); - - let load_credentials_task = Some(cx.spawn({ - let state = state.clone(); - async move |this, cx| { - if let Some(task) = state - .update(cx, |state, cx| state.authenticate(cx)) - .log_err() - { - // We don't log an error, because "not signed in" is also an error. - let _ = task.await; - } - this.update(cx, |this, cx| { - this.load_credentials_task = None; - cx.notify(); - }) - .log_err(); - } - })); - - Self { - api_key_editor: cx.new(|cx| InputField::new(window, cx, Self::PLACEHOLDER_TEXT)), - state, - load_credentials_task, - target_agent, - } - } - - fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { - let api_key = self.api_key_editor.read(cx).text(cx); - if api_key.is_empty() { - return; - } - - // url changes can cause the editor to be displayed again - self.api_key_editor - .update(cx, |editor, cx| editor.set_text("", window, cx)); - - let state = self.state.clone(); - cx.spawn_in(window, async move |_, cx| { - state - .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))? - .await - }) - .detach_and_log_err(cx); - } - - fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { - self.api_key_editor - .update(cx, |editor, cx| editor.set_text("", window, cx)); - - let state = self.state.clone(); - cx.spawn_in(window, async move |_, cx| { - state - .update(cx, |state, cx| state.set_api_key(None, cx))? - .await - }) - .detach_and_log_err(cx); - } - - fn should_render_editor(&self, cx: &mut Context) -> bool { - !self.state.read(cx).is_authenticated() - } -} - -impl Render for ConfigurationView { - fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { - let env_var_set = self.state.read(cx).api_key_state.is_from_env_var(); - let configured_card_label = if env_var_set { - format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable") - } else { - let api_url = AnthropicLanguageModelProvider::api_url(cx); - if api_url == ANTHROPIC_API_URL { - "API key configured".to_string() - } else { - format!("API key configured for {}", api_url) - } - }; - - if self.load_credentials_task.is_some() { - div() - .child(Label::new("Loading credentials...")) - .into_any_element() - } else if self.should_render_editor(cx) { - v_flex() - .size_full() - .on_action(cx.listener(Self::save_api_key)) - .child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent { - ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(), - ConfigurationViewTargetAgent::Other(agent) => agent.clone(), - }))) - .child( - List::new() - .child( - InstructionListItem::new( - "Create one by visiting", - Some("Anthropic's settings"), - Some("https://console.anthropic.com/settings/keys") - ) - ) - .child( - InstructionListItem::text_only("Paste your API key below and hit enter to start using the agent") - ) - ) - .child(self.api_key_editor.clone()) - .child( - Label::new( - format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."), - ) - .size(LabelSize::Small) - .color(Color::Muted), - ) - .into_any_element() - } else { - ConfiguredApiCard::new(configured_card_label) - .disabled(env_var_set) - .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))) - .when(env_var_set, |this| { - this.tooltip_label(format!( - "To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable." - )) - }) - .into_any_element() - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use anthropic::AnthropicModelMode; - use language_model::{LanguageModelRequestMessage, MessageContent}; - - #[test] - fn test_cache_control_only_on_last_segment() { - let request = LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![ - MessageContent::Text("Some prompt".to_string()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - MessageContent::Image(language_model::LanguageModelImage::empty()), - ], - cache: true, - reasoning_details: None, - }], - thread_id: None, - prompt_id: None, - intent: None, - mode: None, - stop: vec![], - temperature: None, - tools: vec![], - tool_choice: None, - thinking_allowed: true, - }; - - let anthropic_request = into_anthropic( - request, - "claude-3-5-sonnet".to_string(), - 0.7, - 4096, - AnthropicModelMode::Default, - ); - - assert_eq!(anthropic_request.messages.len(), 1); - - let message = &anthropic_request.messages[0]; - assert_eq!(message.content.len(), 5); - - assert!(matches!( - message.content[0], - anthropic::RequestContent::Text { - cache_control: None, - .. - } - )); - for i in 1..3 { - assert!(matches!( - message.content[i], - anthropic::RequestContent::Image { - cache_control: None, - .. - } - )); - } - - assert!(matches!( - message.content[4], - anthropic::RequestContent::Image { - cache_control: Some(anthropic::CacheControl { - cache_type: anthropic::CacheControlType::Ephemeral, - }), - .. - } - )); - } -} diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index a19a427dbacb32..3730db2b42654b 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,5 +1,8 @@ use ai_onboarding::YoungAccountBanner; -use anthropic::AnthropicModelMode; +use anthropic::{ + AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent, ToolResultPart, + Usage, +}; use anyhow::{Context as _, Result, anyhow}; use chrono::{DateTime, Utc}; use client::{Client, ModelRequestUsage, UserStore, zed_urls}; @@ -23,8 +26,9 @@ use language_model::{ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolSchemaFormat, LlmApiToken, ModelRequestLimitReachedError, - PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, + LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelToolUse, + LanguageModelToolUseId, LlmApiToken, MessageContent, ModelRequestLimitReachedError, + PaymentRequiredError, RateLimiter, RefreshLlmTokenListener, Role, StopReason, }; use release_channel::AppVersion; use schemars::JsonSchema; @@ -42,7 +46,6 @@ use thiserror::Error; use ui::{TintColor, prelude::*}; use util::{ResultExt as _, maybe}; -use crate::provider::anthropic::{AnthropicEventMapper, count_anthropic_tokens, into_anthropic}; use crate::provider::google::{GoogleEventMapper, into_google}; use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_open_ai}; use crate::provider::x_ai::count_xai_tokens; @@ -1394,3 +1397,434 @@ mod tests { } } } + +fn count_anthropic_tokens( + request: LanguageModelRequest, + cx: &App, +) -> BoxFuture<'static, Result> { + use gpui::AppContext as _; + cx.background_spawn(async move { + let messages = request.messages; + let mut tokens_from_images = 0; + let mut string_messages = Vec::with_capacity(messages.len()); + + for message in messages { + let mut string_contents = String::new(); + + for content in message.content { + match content { + MessageContent::Text(text) => { + string_contents.push_str(&text); + } + MessageContent::Thinking { .. } => {} + MessageContent::RedactedThinking(_) => {} + MessageContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + MessageContent::ToolUse(_tool_use) => {} + MessageContent::ToolResult(tool_result) => match &tool_result.content { + LanguageModelToolResultContent::Text(text) => { + string_contents.push_str(text); + } + LanguageModelToolResultContent::Image(image) => { + tokens_from_images += image.estimate_tokens(); + } + }, + } + } + + if !string_contents.is_empty() { + string_messages.push(tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(string_contents), + name: None, + function_call: None, + }); + } + } + + tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages) + .map(|tokens| (tokens + tokens_from_images) as u64) + }) + .boxed() +} + +fn into_anthropic( + request: LanguageModelRequest, + model: String, + default_temperature: f32, + max_output_tokens: u64, + mode: AnthropicModelMode, +) -> anthropic::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in request.messages { + if message.contents_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + let mut anthropic_message_content: Vec = message + .content + .into_iter() + .filter_map(|content| match content { + MessageContent::Text(text) => { + let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) { + text.trim_end().to_string() + } else { + text + }; + if !text.is_empty() { + Some(anthropic::RequestContent::Text { + text, + cache_control: None, + }) + } else { + None + } + } + MessageContent::Thinking { + text: thinking, + signature, + } => { + if !thinking.is_empty() { + Some(anthropic::RequestContent::Thinking { + thinking, + signature: signature.unwrap_or_default(), + cache_control: None, + }) + } else { + None + } + } + MessageContent::RedactedThinking(data) => { + if !data.is_empty() { + Some(anthropic::RequestContent::RedactedThinking { data }) + } else { + None + } + } + MessageContent::Image(image) => Some(anthropic::RequestContent::Image { + source: anthropic::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + cache_control: None, + }), + MessageContent::ToolUse(tool_use) => { + Some(anthropic::RequestContent::ToolUse { + id: tool_use.id.to_string(), + name: tool_use.name.to_string(), + input: tool_use.input, + cache_control: None, + }) + } + MessageContent::ToolResult(tool_result) => { + Some(anthropic::RequestContent::ToolResult { + tool_use_id: tool_result.tool_use_id.to_string(), + is_error: tool_result.is_error, + content: match tool_result.content { + LanguageModelToolResultContent::Text(text) => { + ToolResultContent::Plain(text.to_string()) + } + LanguageModelToolResultContent::Image(image) => { + ToolResultContent::Multipart(vec![ToolResultPart::Image { + source: anthropic::ImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: image.source.to_string(), + }, + }]) + } + }, + cache_control: None, + }) + } + }) + .collect(); + let anthropic_role = match message.role { + Role::User => anthropic::Role::User, + Role::Assistant => anthropic::Role::Assistant, + Role::System => unreachable!("System role should never occur here"), + }; + if let Some(last_message) = new_messages.last_mut() + && last_message.role == anthropic_role + { + last_message.content.extend(anthropic_message_content); + continue; + } + + if message.cache { + let cache_control_value = Some(anthropic::CacheControl { + cache_type: anthropic::CacheControlType::Ephemeral, + }); + for message_content in anthropic_message_content.iter_mut().rev() { + match message_content { + anthropic::RequestContent::RedactedThinking { .. } => {} + anthropic::RequestContent::Text { cache_control, .. } + | anthropic::RequestContent::Thinking { cache_control, .. } + | anthropic::RequestContent::Image { cache_control, .. } + | anthropic::RequestContent::ToolUse { cache_control, .. } + | anthropic::RequestContent::ToolResult { cache_control, .. } => { + *cache_control = cache_control_value; + break; + } + } + } + } + + new_messages.push(anthropic::Message { + role: anthropic_role, + content: anthropic_message_content, + }); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.string_contents()); + } + } + } + + anthropic::Request { + model, + messages: new_messages, + max_tokens: max_output_tokens, + system: if system_message.is_empty() { + None + } else { + Some(anthropic::StringOrContents::String(system_message)) + }, + thinking: if request.thinking_allowed + && let AnthropicModelMode::Thinking { budget_tokens } = mode + { + Some(anthropic::Thinking::Enabled { budget_tokens }) + } else { + None + }, + tools: request + .tools + .into_iter() + .map(|tool| anthropic::Tool { + name: tool.name, + description: tool.description, + input_schema: tool.input_schema, + }) + .collect(), + tool_choice: request.tool_choice.map(|choice| match choice { + LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto, + LanguageModelToolChoice::Any => anthropic::ToolChoice::Any, + LanguageModelToolChoice::None => anthropic::ToolChoice::None, + }), + metadata: None, + stop_sequences: Vec::new(), + temperature: request.temperature.or(Some(default_temperature)), + top_k: None, + top_p: None, + } +} + +struct AnthropicEventMapper { + tool_uses_by_index: collections::HashMap, + usage: Usage, + stop_reason: StopReason, +} + +impl AnthropicEventMapper { + fn new() -> Self { + Self { + tool_uses_by_index: collections::HashMap::default(), + usage: Usage::default(), + stop_reason: StopReason::EndTurn, + } + } + + fn map_event( + &mut self, + event: Event, + ) -> Vec> { + match event { + Event::ContentBlockStart { + index, + content_block, + } => match content_block { + ResponseContent::Text { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ResponseContent::Thinking { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ResponseContent::RedactedThinking { data } => { + vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })] + } + ResponseContent::ToolUse { id, name, .. } => { + self.tool_uses_by_index.insert( + index, + RawToolUse { + id, + name, + input_json: String::new(), + }, + ); + Vec::new() + } + }, + Event::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => { + vec![Ok(LanguageModelCompletionEvent::Text(text))] + } + ContentDelta::ThinkingDelta { thinking } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: thinking, + signature: None, + })] + } + ContentDelta::SignatureDelta { signature } => { + vec![Ok(LanguageModelCompletionEvent::Thinking { + text: "".to_string(), + signature: Some(signature), + })] + } + ContentDelta::InputJsonDelta { partial_json } => { + if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) { + tool_use.input_json.push_str(&partial_json); + + let event = serde_json::from_str::(&tool_use.input_json) + .ok() + .and_then(|input| { + let input_json_roundtripped = serde_json::to_string(&input).ok()?; + + if !tool_use.input_json.starts_with(&input_json_roundtripped) { + return None; + } + + Some(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(tool_use.id.clone()), + name: tool_use.name.clone().into(), + raw_input: tool_use.input_json.clone(), + input, + is_input_complete: false, + thought_signature: None, + }, + )) + }); + + if let Some(event) = event { + vec![Ok(event)] + } else { + Vec::new() + } + } else { + Vec::new() + } + } + }, + Event::ContentBlockStop { index } => { + if let Some(tool_use) = self.tool_uses_by_index.remove(&index) { + let event_result = match serde_json::from_str(&tool_use.input_json) { + Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(tool_use.id), + name: tool_use.name.into(), + raw_input: tool_use.input_json, + input, + is_input_complete: true, + thought_signature: None, + }, + )), + Err(json_parse_err) => { + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + id: LanguageModelToolUseId::from(tool_use.id), + tool_name: tool_use.name.into(), + raw_input: tool_use.input_json.into(), + json_parse_error: json_parse_err.to_string(), + }) + } + }; + + vec![event_result] + } else { + Vec::new() + } + } + Event::MessageStart { message } => { + update_anthropic_usage(&mut self.usage, &message.usage); + vec![ + Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_anthropic_usage(&self.usage), + )), + Ok(LanguageModelCompletionEvent::StartMessage { + message_id: message.id, + }), + ] + } + Event::MessageDelta { delta, usage } => { + update_anthropic_usage(&mut self.usage, &usage); + if let Some(stop_reason) = delta.stop_reason.as_deref() { + self.stop_reason = match stop_reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + "refusal" => StopReason::Refusal, + _ => { + log::error!("Unexpected anthropic stop_reason: {stop_reason}"); + StopReason::EndTurn + } + }; + } + vec![Ok(LanguageModelCompletionEvent::UsageUpdate( + convert_anthropic_usage(&self.usage), + ))] + } + Event::MessageStop => { + vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] + } + Event::Error { error } => { + vec![Err(error.into())] + } + _ => Vec::new(), + } + } +} + +struct RawToolUse { + id: String, + name: String, + input_json: String, +} + +fn update_anthropic_usage(usage: &mut Usage, new: &Usage) { + if let Some(input_tokens) = new.input_tokens { + usage.input_tokens = Some(input_tokens); + } + if let Some(output_tokens) = new.output_tokens { + usage.output_tokens = Some(output_tokens); + } + if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens { + usage.cache_creation_input_tokens = Some(cache_creation_input_tokens); + } + if let Some(cache_read_input_tokens) = new.cache_read_input_tokens { + usage.cache_read_input_tokens = Some(cache_read_input_tokens); + } +} + +fn convert_anthropic_usage(usage: &Usage) -> language_model::TokenUsage { + language_model::TokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: usage.output_tokens.unwrap_or(0), + cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0), + } +} diff --git a/crates/language_models/src/settings.rs b/crates/language_models/src/settings.rs index 43a8e7334a744c..15a3c936705194 100644 --- a/crates/language_models/src/settings.rs +++ b/crates/language_models/src/settings.rs @@ -4,16 +4,14 @@ use collections::HashMap; use settings::RegisterSetting; use crate::provider::{ - anthropic::AnthropicSettings, bedrock::AmazonBedrockSettings, cloud::ZedDotDevSettings, - deepseek::DeepSeekSettings, google::GoogleSettings, lmstudio::LmStudioSettings, - mistral::MistralSettings, ollama::OllamaSettings, open_ai::OpenAiSettings, - open_ai_compatible::OpenAiCompatibleSettings, open_router::OpenRouterSettings, - vercel::VercelSettings, x_ai::XAiSettings, + bedrock::AmazonBedrockSettings, cloud::ZedDotDevSettings, deepseek::DeepSeekSettings, + google::GoogleSettings, lmstudio::LmStudioSettings, mistral::MistralSettings, + ollama::OllamaSettings, open_ai::OpenAiSettings, open_ai_compatible::OpenAiCompatibleSettings, + open_router::OpenRouterSettings, vercel::VercelSettings, x_ai::XAiSettings, }; #[derive(Debug, RegisterSetting)] pub struct AllLanguageModelSettings { - pub anthropic: AnthropicSettings, pub bedrock: AmazonBedrockSettings, pub deepseek: DeepSeekSettings, pub google: GoogleSettings, @@ -33,7 +31,6 @@ impl settings::Settings for AllLanguageModelSettings { fn from_settings(content: &settings::SettingsContent) -> Self { let language_models = content.language_models.clone().unwrap(); - let anthropic = language_models.anthropic.unwrap(); let bedrock = language_models.bedrock.unwrap(); let deepseek = language_models.deepseek.unwrap(); let google = language_models.google.unwrap(); @@ -47,10 +44,6 @@ impl settings::Settings for AllLanguageModelSettings { let x_ai = language_models.x_ai.unwrap(); let zed_dot_dev = language_models.zed_dot_dev.unwrap(); Self { - anthropic: AnthropicSettings { - api_url: anthropic.api_url.unwrap(), - available_models: anthropic.available_models.unwrap_or_default(), - }, bedrock: AmazonBedrockSettings { available_models: bedrock.available_models.unwrap_or_default(), region: bedrock.region, diff --git a/crates/settings/src/settings_content/agent.rs b/crates/settings/src/settings_content/agent.rs index 2ea9f0cd5788f3..e875c95f1c89b6 100644 --- a/crates/settings/src/settings_content/agent.rs +++ b/crates/settings/src/settings_content/agent.rs @@ -255,7 +255,6 @@ impl JsonSchema for LanguageModelProviderSetting { "type": "string", "enum": [ "amazon-bedrock", - "anthropic", "copilot_chat", "deepseek", "google", diff --git a/crates/settings/src/settings_content/extension.rs b/crates/settings/src/settings_content/extension.rs index 2fefd4ef38aeb9..b405103e8c311d 100644 --- a/crates/settings/src/settings_content/extension.rs +++ b/crates/settings/src/settings_content/extension.rs @@ -20,6 +20,15 @@ pub struct ExtensionSettingsContent { pub auto_update_extensions: HashMap, bool>, /// The capabilities granted to extensions. pub granted_extension_capabilities: Option>, + /// Extension language model providers that are allowed to read API keys from + /// environment variables. Each entry is a provider ID in the format + /// "extension_id:provider_id" (e.g., "openai:openai"). + /// + /// Default: [] + pub allowed_env_var_providers: Option>>, + /// Tracks which legacy LLM providers have been migrated. This is an internal + /// setting used to prevent the migration from running multiple times. + pub migrated_llm_providers: Option>>, } /// A capability for an extension. diff --git a/crates/settings/src/settings_content/language_model.rs b/crates/settings/src/settings_content/language_model.rs index 48f5a463a4b8d8..f99e1687130d80 100644 --- a/crates/settings/src/settings_content/language_model.rs +++ b/crates/settings/src/settings_content/language_model.rs @@ -8,7 +8,6 @@ use std::sync::Arc; #[with_fallible_options] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)] pub struct AllLanguageModelSettingsContent { - pub anthropic: Option, pub bedrock: Option, pub deepseek: Option, pub google: Option, @@ -24,35 +23,6 @@ pub struct AllLanguageModelSettingsContent { pub zed_dot_dev: Option, } -#[with_fallible_options] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)] -pub struct AnthropicSettingsContent { - pub api_url: Option, - pub available_models: Option>, -} - -#[with_fallible_options] -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, MergeFrom)] -pub struct AnthropicAvailableModel { - /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-latest, claude-3-opus-20240229, etc - pub name: String, - /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel. - pub display_name: Option, - /// The model's context window size. - pub max_tokens: u64, - /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling. - pub tool_override: Option, - /// Configuration of Anthropic's caching API. - pub cache_configuration: Option, - pub max_output_tokens: Option, - #[serde(serialize_with = "crate::serialize_optional_f32_with_two_decimal_places")] - pub default_temperature: Option, - #[serde(default)] - pub extra_beta_headers: Vec, - /// The model's mode (e.g. thinking) - pub mode: Option, -} - #[with_fallible_options] #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema, MergeFrom)] pub struct AmazonBedrockSettingsContent { diff --git a/docs/src/configuring-zed.md b/docs/src/configuring-zed.md index 477885a4537580..63ad054bfae1a2 100644 --- a/docs/src/configuring-zed.md +++ b/docs/src/configuring-zed.md @@ -2626,9 +2626,6 @@ These values take in the same options as the root-level settings with the same n ```json [settings] { "language_models": { - "anthropic": { - "api_url": "https://api.anthropic.com" - }, "google": { "api_url": "https://generativelanguage.googleapis.com" }, diff --git a/extensions/anthropic/Cargo.lock b/extensions/anthropic/Cargo.lock new file mode 100644 index 00000000000000..8bd00ffdbe3292 --- /dev/null +++ b/extensions/anthropic/Cargo.lock @@ -0,0 +1,823 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "anthropic" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", + "zed_extension_api", +] + +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + +[[package]] +name = "auditable-serde" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5" +dependencies = [ + "semver", + "serde", + "serde_json", + "topological-sort", +] + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +dependencies = [ + "serde", + "serde_core", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.145" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", + "serde_core", +] + +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "spdx" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3" +dependencies = [ + "smallvec", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "syn" +version = "2.0.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "topological-sort" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "url" +version = "2.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "wasm-encoder" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d" +dependencies = [ + "anyhow", + "auditable-serde", + "flate2", + "indexmap", + "serde", + "serde_derive", + "serde_json", + "spdx", + "url", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "wit-bindgen" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de" +dependencies = [ + "wit-bindgen-rt", + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rt" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621" +dependencies = [ + "bitflags", + "futures", + "once_cell", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zed_extension_api" +version = "0.8.0" +dependencies = [ + "serde", + "serde_json", + "wit-bindgen", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/extensions/anthropic/Cargo.toml b/extensions/anthropic/Cargo.toml new file mode 100644 index 00000000000000..25dfe72b0e92ca --- /dev/null +++ b/extensions/anthropic/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "anthropic" +version = "0.1.0" +edition = "2021" +publish = false +license = "Apache-2.0" + +[workspace] + +[lib] +path = "src/anthropic.rs" +crate-type = ["cdylib"] + +[dependencies] +zed_extension_api = { path = "../../crates/extension_api" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/extensions/anthropic/extension.toml b/extensions/anthropic/extension.toml new file mode 100644 index 00000000000000..c37b8aca34f6cf --- /dev/null +++ b/extensions/anthropic/extension.toml @@ -0,0 +1,13 @@ +id = "anthropic" +name = "Anthropic" +description = "Anthropic Claude LLM provider for Zed." +version = "0.1.0" +schema_version = 1 +authors = ["Zed Team"] +repository = "https://github.com/zed-industries/zed" + +[language_model_providers.anthropic] +name = "Anthropic" + +[language_model_providers.anthropic.auth] +env_var = "ANTHROPIC_API_KEY" \ No newline at end of file diff --git a/assets/icons/ai_anthropic.svg b/extensions/anthropic/icons/anthropic.svg similarity index 88% rename from assets/icons/ai_anthropic.svg rename to extensions/anthropic/icons/anthropic.svg index 12d731fb0b4438..75c1a7e0014e7d 100644 --- a/assets/icons/ai_anthropic.svg +++ b/extensions/anthropic/icons/anthropic.svg @@ -1,7 +1,7 @@ - - + + diff --git a/extensions/anthropic/src/anthropic.rs b/extensions/anthropic/src/anthropic.rs new file mode 100644 index 00000000000000..26d364cf90acbc --- /dev/null +++ b/extensions/anthropic/src/anthropic.rs @@ -0,0 +1,803 @@ +use std::collections::HashMap; +use std::sync::Mutex; + +use serde::{Deserialize, Serialize}; +use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy}; +use zed_extension_api::{self as zed, *}; + +struct AnthropicProvider { + streams: Mutex>, + next_stream_id: Mutex, +} + +struct StreamState { + response_stream: Option, + buffer: String, + started: bool, + current_tool_use: Option, + stop_reason: Option, + pending_signature: Option, +} + +struct ToolUseState { + id: String, + name: String, + input_json: String, +} + +struct ModelDefinition { + real_id: &'static str, + display_name: &'static str, + max_tokens: u64, + max_output_tokens: u64, + supports_images: bool, + supports_thinking: bool, + is_default: bool, + is_default_fast: bool, +} + +const MODELS: &[ModelDefinition] = &[ + ModelDefinition { + real_id: "claude-opus-4-5-20251101", + display_name: "Claude Opus 4.5", + max_tokens: 200_000, + max_output_tokens: 8_192, + supports_images: true, + supports_thinking: false, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "claude-opus-4-5-20251101", + display_name: "Claude Opus 4.5 Thinking", + max_tokens: 200_000, + max_output_tokens: 8_192, + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "claude-sonnet-4-5-20250929", + display_name: "Claude Sonnet 4.5", + max_tokens: 200_000, + max_output_tokens: 8_192, + supports_images: true, + supports_thinking: false, + is_default: true, + is_default_fast: false, + }, + ModelDefinition { + real_id: "claude-sonnet-4-5-20250929", + display_name: "Claude Sonnet 4.5 Thinking", + max_tokens: 200_000, + max_output_tokens: 8_192, + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "claude-sonnet-4-20250514", + display_name: "Claude Sonnet 4", + max_tokens: 200_000, + max_output_tokens: 8_192, + supports_images: true, + supports_thinking: false, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "claude-sonnet-4-20250514", + display_name: "Claude Sonnet 4 Thinking", + max_tokens: 200_000, + max_output_tokens: 8_192, + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "claude-haiku-4-5-20251001", + display_name: "Claude Haiku 4.5", + max_tokens: 200_000, + max_output_tokens: 64_000, + supports_images: true, + supports_thinking: false, + is_default: false, + is_default_fast: true, + }, + ModelDefinition { + real_id: "claude-haiku-4-5-20251001", + display_name: "Claude Haiku 4.5 Thinking", + max_tokens: 200_000, + max_output_tokens: 64_000, + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "claude-3-5-sonnet-latest", + display_name: "Claude 3.5 Sonnet", + max_tokens: 200_000, + max_output_tokens: 8_192, + supports_images: true, + supports_thinking: false, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "claude-3-5-haiku-latest", + display_name: "Claude 3.5 Haiku", + max_tokens: 200_000, + max_output_tokens: 8_192, + supports_images: true, + supports_thinking: false, + is_default: false, + is_default_fast: false, + }, +]; + +fn get_model_definition(display_name: &str) -> Option<&'static ModelDefinition> { + MODELS.iter().find(|m| m.display_name == display_name) +} + +// Anthropic API Request Types + +#[derive(Serialize)] +struct AnthropicRequest { + model: String, + max_tokens: u64, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + thinking: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + stop_sequences: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + stream: bool, +} + +#[derive(Serialize)] +struct AnthropicThinking { + #[serde(rename = "type")] + thinking_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + budget_tokens: Option, +} + +#[derive(Serialize)] +struct AnthropicMessage { + role: String, + content: Vec, +} + +#[derive(Serialize, Clone)] +#[serde(tag = "type")] +enum AnthropicContent { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "thinking")] + Thinking { thinking: String, signature: String }, + #[serde(rename = "redacted_thinking")] + RedactedThinking { data: String }, + #[serde(rename = "image")] + Image { source: AnthropicImageSource }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + is_error: bool, + content: String, + }, +} + +#[derive(Serialize, Clone)] +struct AnthropicImageSource { + #[serde(rename = "type")] + source_type: String, + media_type: String, + data: String, +} + +#[derive(Serialize)] +struct AnthropicTool { + name: String, + description: String, + input_schema: serde_json::Value, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "lowercase")] +enum AnthropicToolChoice { + Auto, + Any, + None, +} + +// Anthropic API Response Types + +#[derive(Deserialize, Debug)] +#[serde(tag = "type")] +#[allow(dead_code)] +enum AnthropicEvent { + #[serde(rename = "message_start")] + MessageStart { message: AnthropicMessageResponse }, + #[serde(rename = "content_block_start")] + ContentBlockStart { + index: usize, + content_block: AnthropicContentBlock, + }, + #[serde(rename = "content_block_delta")] + ContentBlockDelta { index: usize, delta: AnthropicDelta }, + #[serde(rename = "content_block_stop")] + ContentBlockStop { index: usize }, + #[serde(rename = "message_delta")] + MessageDelta { + delta: AnthropicMessageDelta, + usage: AnthropicUsage, + }, + #[serde(rename = "message_stop")] + MessageStop, + #[serde(rename = "ping")] + Ping, + #[serde(rename = "error")] + Error { error: AnthropicApiError }, +} + +#[derive(Deserialize, Debug)] +struct AnthropicMessageResponse { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + role: String, + #[serde(default)] + usage: AnthropicUsage, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type")] +enum AnthropicContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "thinking")] + Thinking { thinking: String }, + #[serde(rename = "redacted_thinking")] + RedactedThinking { data: String }, + #[serde(rename = "tool_use")] + ToolUse { id: String, name: String }, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type")] +enum AnthropicDelta { + #[serde(rename = "text_delta")] + TextDelta { text: String }, + #[serde(rename = "thinking_delta")] + ThinkingDelta { thinking: String }, + #[serde(rename = "signature_delta")] + SignatureDelta { signature: String }, + #[serde(rename = "input_json_delta")] + InputJsonDelta { partial_json: String }, +} + +#[derive(Deserialize, Debug)] +struct AnthropicMessageDelta { + stop_reason: Option, +} + +#[derive(Deserialize, Debug, Default)] +struct AnthropicUsage { + #[serde(default)] + input_tokens: Option, + #[serde(default)] + output_tokens: Option, + #[serde(default)] + cache_creation_input_tokens: Option, + #[serde(default)] + cache_read_input_tokens: Option, +} + +#[derive(Deserialize, Debug)] +struct AnthropicApiError { + #[serde(rename = "type")] + #[allow(dead_code)] + error_type: String, + message: String, +} + +fn convert_request( + model_id: &str, + request: &LlmCompletionRequest, +) -> Result { + let model_def = + get_model_definition(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?; + + let mut messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for msg in &request.messages { + match msg.role { + LlmMessageRole::System => { + for content in &msg.content { + if let LlmMessageContent::Text(text) = content { + if !system_message.is_empty() { + system_message.push('\n'); + } + system_message.push_str(text); + } + } + } + LlmMessageRole::User => { + let mut contents: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + contents.push(AnthropicContent::Text { text: text.clone() }); + } + } + LlmMessageContent::Image(img) => { + contents.push(AnthropicContent::Image { + source: AnthropicImageSource { + source_type: "base64".to_string(), + media_type: "image/png".to_string(), + data: img.source.clone(), + }, + }); + } + LlmMessageContent::ToolResult(result) => { + let content_text = match &result.content { + LlmToolResultContent::Text(t) => t.clone(), + LlmToolResultContent::Image(_) => "[Image]".to_string(), + }; + contents.push(AnthropicContent::ToolResult { + tool_use_id: result.tool_use_id.clone(), + is_error: result.is_error, + content: content_text, + }); + } + _ => {} + } + } + + if !contents.is_empty() { + messages.push(AnthropicMessage { + role: "user".to_string(), + content: contents, + }); + } + } + LlmMessageRole::Assistant => { + let mut contents: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + contents.push(AnthropicContent::Text { text: text.clone() }); + } + } + LlmMessageContent::ToolUse(tool_use) => { + let input: serde_json::Value = + serde_json::from_str(&tool_use.input).unwrap_or_default(); + contents.push(AnthropicContent::ToolUse { + id: tool_use.id.clone(), + name: tool_use.name.clone(), + input, + }); + } + LlmMessageContent::Thinking(thinking) => { + if !thinking.text.is_empty() { + contents.push(AnthropicContent::Thinking { + thinking: thinking.text.clone(), + signature: thinking.signature.clone().unwrap_or_default(), + }); + } + } + LlmMessageContent::RedactedThinking(data) => { + if !data.is_empty() { + contents.push(AnthropicContent::RedactedThinking { + data: data.clone(), + }); + } + } + _ => {} + } + } + + if !contents.is_empty() { + messages.push(AnthropicMessage { + role: "assistant".to_string(), + content: contents, + }); + } + } + } + } + + let tools: Vec = request + .tools + .iter() + .map(|t| AnthropicTool { + name: t.name.clone(), + description: t.description.clone(), + input_schema: serde_json::from_str(&t.input_schema) + .unwrap_or(serde_json::Value::Object(Default::default())), + }) + .collect(); + + let tool_choice = request.tool_choice.as_ref().map(|tc| match tc { + LlmToolChoice::Auto => AnthropicToolChoice::Auto, + LlmToolChoice::Any => AnthropicToolChoice::Any, + LlmToolChoice::None => AnthropicToolChoice::None, + }); + + let thinking = if model_def.supports_thinking && request.thinking_allowed { + Some(AnthropicThinking { + thinking_type: "enabled".to_string(), + budget_tokens: Some(4096), + }) + } else { + None + }; + + Ok(AnthropicRequest { + model: model_def.real_id.to_string(), + max_tokens: model_def.max_output_tokens, + messages, + system: if system_message.is_empty() { + None + } else { + Some(system_message) + }, + thinking, + tools, + tool_choice, + stop_sequences: request.stop_sequences.clone(), + temperature: request.temperature, + stream: true, + }) +} + +fn parse_sse_line(line: &str) -> Option { + let data = line.strip_prefix("data: ")?; + serde_json::from_str(data).ok() +} + +impl zed::Extension for AnthropicProvider { + fn new() -> Self { + Self { + streams: Mutex::new(HashMap::new()), + next_stream_id: Mutex::new(0), + } + } + + fn llm_providers(&self) -> Vec { + vec![LlmProviderInfo { + id: "anthropic".into(), + name: "Anthropic".into(), + icon: Some("icons/anthropic.svg".into()), + }] + } + + fn llm_provider_models(&self, _provider_id: &str) -> Result, String> { + Ok(MODELS + .iter() + .map(|m| LlmModelInfo { + id: m.display_name.to_string(), + name: m.display_name.to_string(), + max_token_count: m.max_tokens, + max_output_tokens: Some(m.max_output_tokens), + capabilities: LlmModelCapabilities { + supports_images: m.supports_images, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: m.supports_thinking, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: m.is_default, + is_default_fast: m.is_default_fast, + }) + .collect()) + } + + fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { + llm_get_credential("anthropic").is_some() + } + + fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { + Some( + r#"# Anthropic Setup + +Welcome to **Anthropic**! This extension provides access to Claude models. + +## Configuration + +Enter your Anthropic API key below. You can get your API key at [console.anthropic.com](https://console.anthropic.com/). + +## Available Models + +| Display Name | Real Model | Context | Output | +|--------------|------------|---------|--------| +| Claude Opus 4.5 | claude-opus-4-5 | 200K | 8K | +| Claude Opus 4.5 Thinking | claude-opus-4-5 | 200K | 8K | +| Claude Sonnet 4.5 | claude-sonnet-4-5 | 200K | 8K | +| Claude Sonnet 4.5 Thinking | claude-sonnet-4-5 | 200K | 8K | +| Claude Sonnet 4 | claude-sonnet-4 | 200K | 8K | +| Claude Sonnet 4 Thinking | claude-sonnet-4 | 200K | 8K | +| Claude Haiku 4.5 | claude-haiku-4-5 | 200K | 64K | +| Claude Haiku 4.5 Thinking | claude-haiku-4-5 | 200K | 64K | +| Claude 3.5 Sonnet | claude-3-5-sonnet | 200K | 8K | +| Claude 3.5 Haiku | claude-3-5-haiku | 200K | 8K | + +## Features + +- ✅ Full streaming support +- ✅ Tool/function calling +- ✅ Vision (image inputs) +- ✅ Extended thinking support +- ✅ All Claude models + +## Pricing + +Uses your Anthropic API credits. See [Anthropic pricing](https://www.anthropic.com/pricing) for details. +"# + .to_string(), + ) + } + + fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> { + let provided = llm_request_credential( + "anthropic", + LlmCredentialType::ApiKey, + "Anthropic API Key", + "sk-ant-...", + )?; + if provided { + Ok(()) + } else { + Err("Authentication cancelled".to_string()) + } + } + + fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { + llm_delete_credential("anthropic") + } + + fn llm_stream_completion_start( + &mut self, + _provider_id: &str, + model_id: &str, + request: &LlmCompletionRequest, + ) -> Result { + let api_key = llm_get_credential("anthropic").ok_or_else(|| { + "No API key configured. Please add your Anthropic API key in settings.".to_string() + })?; + + let anthropic_request = convert_request(model_id, request)?; + + let body = serde_json::to_vec(&anthropic_request) + .map_err(|e| format!("Failed to serialize request: {}", e))?; + + let http_request = HttpRequest { + method: HttpMethod::Post, + url: "https://api.anthropic.com/v1/messages".to_string(), + headers: vec![ + ("Content-Type".to_string(), "application/json".to_string()), + ("x-api-key".to_string(), api_key), + ("anthropic-version".to_string(), "2023-06-01".to_string()), + ], + body: Some(body), + redirect_policy: RedirectPolicy::FollowAll, + }; + + let response_stream = http_request + .fetch_stream() + .map_err(|e| format!("HTTP request failed: {}", e))?; + + let stream_id = { + let mut id_counter = self.next_stream_id.lock().unwrap(); + let id = format!("anthropic-stream-{}", *id_counter); + *id_counter += 1; + id + }; + + self.streams.lock().unwrap().insert( + stream_id.clone(), + StreamState { + response_stream: Some(response_stream), + buffer: String::new(), + started: false, + current_tool_use: None, + stop_reason: None, + pending_signature: None, + }, + ); + + Ok(stream_id) + } + + fn llm_stream_completion_next( + &mut self, + stream_id: &str, + ) -> Result, String> { + let mut streams = self.streams.lock().unwrap(); + let state = streams + .get_mut(stream_id) + .ok_or_else(|| format!("Unknown stream: {}", stream_id))?; + + if !state.started { + state.started = true; + return Ok(Some(LlmCompletionEvent::Started)); + } + + let response_stream = state + .response_stream + .as_mut() + .ok_or_else(|| "Stream already closed".to_string())?; + + loop { + if let Some(newline_pos) = state.buffer.find('\n') { + let line = state.buffer[..newline_pos].to_string(); + state.buffer = state.buffer[newline_pos + 1..].to_string(); + + if line.trim().is_empty() || line.starts_with("event:") { + continue; + } + + if let Some(event) = parse_sse_line(&line) { + match event { + AnthropicEvent::MessageStart { message } => { + if let (Some(input), Some(output)) = + (message.usage.input_tokens, message.usage.output_tokens) + { + return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens: input, + output_tokens: output, + cache_creation_input_tokens: message + .usage + .cache_creation_input_tokens, + cache_read_input_tokens: message.usage.cache_read_input_tokens, + }))); + } + } + AnthropicEvent::ContentBlockStart { content_block, .. } => { + match content_block { + AnthropicContentBlock::Text { text } => { + if !text.is_empty() { + return Ok(Some(LlmCompletionEvent::Text(text))); + } + } + AnthropicContentBlock::Thinking { thinking } => { + return Ok(Some(LlmCompletionEvent::Thinking( + LlmThinkingContent { + text: thinking, + signature: None, + }, + ))); + } + AnthropicContentBlock::RedactedThinking { data } => { + return Ok(Some(LlmCompletionEvent::RedactedThinking(data))); + } + AnthropicContentBlock::ToolUse { id, name } => { + state.current_tool_use = Some(ToolUseState { + id, + name, + input_json: String::new(), + }); + } + } + } + AnthropicEvent::ContentBlockDelta { delta, .. } => match delta { + AnthropicDelta::TextDelta { text } => { + if !text.is_empty() { + return Ok(Some(LlmCompletionEvent::Text(text))); + } + } + AnthropicDelta::ThinkingDelta { thinking } => { + return Ok(Some(LlmCompletionEvent::Thinking( + LlmThinkingContent { + text: thinking, + signature: None, + }, + ))); + } + AnthropicDelta::SignatureDelta { signature } => { + state.pending_signature = Some(signature.clone()); + return Ok(Some(LlmCompletionEvent::Thinking( + LlmThinkingContent { + text: String::new(), + signature: Some(signature), + }, + ))); + } + AnthropicDelta::InputJsonDelta { partial_json } => { + if let Some(ref mut tool_use) = state.current_tool_use { + tool_use.input_json.push_str(&partial_json); + } + } + }, + AnthropicEvent::ContentBlockStop { .. } => { + if let Some(tool_use) = state.current_tool_use.take() { + return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { + id: tool_use.id, + name: tool_use.name, + input: tool_use.input_json, + thought_signature: state.pending_signature.take(), + }))); + } + } + AnthropicEvent::MessageDelta { delta, usage } => { + if let Some(reason) = delta.stop_reason { + state.stop_reason = Some(match reason.as_str() { + "end_turn" => LlmStopReason::EndTurn, + "max_tokens" => LlmStopReason::MaxTokens, + "tool_use" => LlmStopReason::ToolUse, + _ => LlmStopReason::EndTurn, + }); + } + if let Some(output) = usage.output_tokens { + return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens: usage.input_tokens.unwrap_or(0), + output_tokens: output, + cache_creation_input_tokens: usage.cache_creation_input_tokens, + cache_read_input_tokens: usage.cache_read_input_tokens, + }))); + } + } + AnthropicEvent::MessageStop => { + if let Some(stop_reason) = state.stop_reason.take() { + return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); + } + return Ok(Some(LlmCompletionEvent::Stop(LlmStopReason::EndTurn))); + } + AnthropicEvent::Ping => {} + AnthropicEvent::Error { error } => { + return Err(format!("API error: {}", error.message)); + } + } + } + + continue; + } + + match response_stream.next_chunk() { + Ok(Some(chunk)) => { + let text = String::from_utf8_lossy(&chunk); + state.buffer.push_str(&text); + } + Ok(None) => { + if let Some(stop_reason) = state.stop_reason.take() { + return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); + } + return Ok(None); + } + Err(e) => { + return Err(format!("Stream error: {}", e)); + } + } + } + } + + fn llm_stream_completion_close(&mut self, stream_id: &str) { + self.streams.lock().unwrap().remove(stream_id); + } +} + +zed::register_extension!(AnthropicProvider); diff --git a/extensions/copilot-chat/Cargo.lock b/extensions/copilot-chat/Cargo.lock new file mode 100644 index 00000000000000..4b78fda143f8ec --- /dev/null +++ b/extensions/copilot-chat/Cargo.lock @@ -0,0 +1,823 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + +[[package]] +name = "auditable-serde" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5" +dependencies = [ + "semver", + "serde", + "serde_json", + "topological-sort", +] + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "copilot-chat" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", + "zed_extension_api", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +dependencies = [ + "serde", + "serde_core", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.145" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", + "serde_core", +] + +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "spdx" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3" +dependencies = [ + "smallvec", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "syn" +version = "2.0.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "topological-sort" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "url" +version = "2.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "wasm-encoder" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d" +dependencies = [ + "anyhow", + "auditable-serde", + "flate2", + "indexmap", + "serde", + "serde_derive", + "serde_json", + "spdx", + "url", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "wit-bindgen" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de" +dependencies = [ + "wit-bindgen-rt", + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rt" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621" +dependencies = [ + "bitflags", + "futures", + "once_cell", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zed_extension_api" +version = "0.8.0" +dependencies = [ + "serde", + "serde_json", + "wit-bindgen", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/extensions/copilot-chat/Cargo.toml b/extensions/copilot-chat/Cargo.toml new file mode 100644 index 00000000000000..189c1db9fc37b9 --- /dev/null +++ b/extensions/copilot-chat/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "copilot-chat" +version = "0.1.0" +edition = "2021" +publish = false +license = "Apache-2.0" + +[workspace] + +[lib] +path = "src/copilot_chat.rs" +crate-type = ["cdylib"] + +[dependencies] +zed_extension_api = { path = "../../crates/extension_api" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/extensions/copilot-chat/extension.toml b/extensions/copilot-chat/extension.toml new file mode 100644 index 00000000000000..5e77c6dda4144f --- /dev/null +++ b/extensions/copilot-chat/extension.toml @@ -0,0 +1,16 @@ +id = "copilot-chat" +name = "Copilot Chat" +description = "GitHub Copilot Chat LLM provider for Zed." +version = "0.1.0" +schema_version = 1 +authors = ["Zed Team"] +repository = "https://github.com/zed-industries/zed" + +[language_model_providers.copilot-chat] +name = "Copilot Chat" + +[language_model_providers.copilot-chat.auth] +env_var = "GH_COPILOT_TOKEN" + +[language_model_providers.copilot-chat.auth.oauth] +sign_in_button_label = "Sign in with GitHub" \ No newline at end of file diff --git a/extensions/copilot-chat/icons/copilot.svg b/extensions/copilot-chat/icons/copilot.svg new file mode 100644 index 00000000000000..2584cd631006c1 --- /dev/null +++ b/extensions/copilot-chat/icons/copilot.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/extensions/copilot-chat/src/copilot_chat.rs b/extensions/copilot-chat/src/copilot_chat.rs new file mode 100644 index 00000000000000..9d5730e85055a2 --- /dev/null +++ b/extensions/copilot-chat/src/copilot_chat.rs @@ -0,0 +1,1004 @@ +use std::collections::HashMap; +use std::sync::Mutex; +use std::thread; +use std::time::Duration; + +use serde::{Deserialize, Serialize}; +use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy}; +use zed_extension_api::{self as zed, *}; + +const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code"; +const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; +const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token"; +const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98"; + +struct DeviceFlowState { + device_code: String, + interval: u64, + expires_in: u64, +} + +#[derive(Clone)] +struct ApiToken { + api_key: String, + api_endpoint: String, +} + +#[derive(Clone, Deserialize)] +struct CopilotModel { + id: String, + name: String, + #[serde(default)] + is_chat_default: bool, + #[serde(default)] + is_chat_fallback: bool, + #[serde(default)] + model_picker_enabled: bool, + #[serde(default)] + capabilities: ModelCapabilities, + #[serde(default)] + policy: Option, +} + +#[derive(Clone, Default, Deserialize)] +struct ModelCapabilities { + #[serde(default)] + family: String, + #[serde(default)] + limits: ModelLimits, + #[serde(default)] + supports: ModelSupportedFeatures, + #[serde(rename = "type", default)] + model_type: String, +} + +#[derive(Clone, Default, Deserialize)] +struct ModelLimits { + #[serde(default)] + max_context_window_tokens: u64, + #[serde(default)] + max_output_tokens: u64, +} + +#[derive(Clone, Default, Deserialize)] +struct ModelSupportedFeatures { + #[serde(default)] + streaming: bool, + #[serde(default)] + tool_calls: bool, + #[serde(default)] + vision: bool, +} + +#[derive(Clone, Deserialize)] +struct ModelPolicy { + state: String, +} + +struct CopilotChatProvider { + streams: Mutex>, + next_stream_id: Mutex, + device_flow_state: Mutex>, + api_token: Mutex>, + cached_models: Mutex>>, +} + +struct StreamState { + response_stream: Option, + buffer: String, + started: bool, + tool_calls: HashMap, + tool_calls_emitted: bool, +} + +#[derive(Clone, Default)] +struct AccumulatedToolCall { + id: String, + name: String, + arguments: String, +} + +#[derive(Serialize)] +struct OpenAiRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + stop: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + stream_options: Option, +} + +#[derive(Serialize)] +struct StreamOptions { + include_usage: bool, +} + +#[derive(Serialize)] +struct OpenAiMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, +} + +#[derive(Serialize, Clone)] +#[serde(untagged)] +enum OpenAiContent { + Text(String), + Parts(Vec), +} + +#[derive(Serialize, Clone)] +#[serde(tag = "type")] +enum OpenAiContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Serialize, Clone)] +struct ImageUrl { + url: String, +} + +#[derive(Serialize, Clone)] +struct OpenAiToolCall { + id: String, + #[serde(rename = "type")] + call_type: String, + function: OpenAiFunctionCall, +} + +#[derive(Serialize, Clone)] +struct OpenAiFunctionCall { + name: String, + arguments: String, +} + +#[derive(Serialize)] +struct OpenAiTool { + #[serde(rename = "type")] + tool_type: String, + function: OpenAiFunctionDef, +} + +#[derive(Serialize)] +struct OpenAiFunctionDef { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Deserialize, Debug)] +struct OpenAiStreamResponse { + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Deserialize, Debug)] +struct OpenAiStreamChoice { + delta: OpenAiDelta, + finish_reason: Option, +} + +#[derive(Deserialize, Debug, Default)] +struct OpenAiDelta { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[derive(Deserialize, Debug)] +struct OpenAiToolCallDelta { + index: usize, + #[serde(default)] + id: Option, + #[serde(default)] + function: Option, +} + +#[derive(Deserialize, Debug, Default)] +struct OpenAiFunctionDelta { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, +} + +#[derive(Deserialize, Debug)] +struct OpenAiUsage { + prompt_tokens: u64, + completion_tokens: u64, +} + +fn convert_request( + model_id: &str, + request: &LlmCompletionRequest, +) -> Result { + let mut messages: Vec = Vec::new(); + + for msg in &request.messages { + match msg.role { + LlmMessageRole::System => { + let mut text_content = String::new(); + for content in &msg.content { + if let LlmMessageContent::Text(text) = content { + if !text_content.is_empty() { + text_content.push('\n'); + } + text_content.push_str(text); + } + } + if !text_content.is_empty() { + messages.push(OpenAiMessage { + role: "system".to_string(), + content: Some(OpenAiContent::Text(text_content)), + tool_calls: None, + tool_call_id: None, + }); + } + } + LlmMessageRole::User => { + let mut parts: Vec = Vec::new(); + let mut tool_result_messages: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + parts.push(OpenAiContentPart::Text { text: text.clone() }); + } + } + LlmMessageContent::Image(img) => { + let data_url = format!("data:image/png;base64,{}", img.source); + parts.push(OpenAiContentPart::ImageUrl { + image_url: ImageUrl { url: data_url }, + }); + } + LlmMessageContent::ToolResult(result) => { + let content_text = match &result.content { + LlmToolResultContent::Text(t) => t.clone(), + LlmToolResultContent::Image(_) => "[Image]".to_string(), + }; + tool_result_messages.push(OpenAiMessage { + role: "tool".to_string(), + content: Some(OpenAiContent::Text(content_text)), + tool_calls: None, + tool_call_id: Some(result.tool_use_id.clone()), + }); + } + _ => {} + } + } + + if !parts.is_empty() { + let content = if parts.len() == 1 { + if let OpenAiContentPart::Text { text } = &parts[0] { + OpenAiContent::Text(text.clone()) + } else { + OpenAiContent::Parts(parts) + } + } else { + OpenAiContent::Parts(parts) + }; + + messages.push(OpenAiMessage { + role: "user".to_string(), + content: Some(content), + tool_calls: None, + tool_call_id: None, + }); + } + + messages.extend(tool_result_messages); + } + LlmMessageRole::Assistant => { + let mut text_content = String::new(); + let mut tool_calls: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + if !text_content.is_empty() { + text_content.push('\n'); + } + text_content.push_str(text); + } + } + LlmMessageContent::ToolUse(tool_use) => { + tool_calls.push(OpenAiToolCall { + id: tool_use.id.clone(), + call_type: "function".to_string(), + function: OpenAiFunctionCall { + name: tool_use.name.clone(), + arguments: tool_use.input.clone(), + }, + }); + } + _ => {} + } + } + + messages.push(OpenAiMessage { + role: "assistant".to_string(), + content: if text_content.is_empty() { + None + } else { + Some(OpenAiContent::Text(text_content)) + }, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + tool_call_id: None, + }); + } + } + } + + let tools: Vec = request + .tools + .iter() + .map(|t| OpenAiTool { + tool_type: "function".to_string(), + function: OpenAiFunctionDef { + name: t.name.clone(), + description: t.description.clone(), + parameters: serde_json::from_str(&t.input_schema) + .unwrap_or(serde_json::Value::Object(Default::default())), + }, + }) + .collect(); + + let tool_choice = request.tool_choice.as_ref().map(|tc| match tc { + LlmToolChoice::Auto => "auto".to_string(), + LlmToolChoice::Any => "required".to_string(), + LlmToolChoice::None => "none".to_string(), + }); + + let max_tokens = request.max_tokens; + + Ok(OpenAiRequest { + model: model_id.to_string(), + messages, + max_tokens, + tools, + tool_choice, + stop: request.stop_sequences.clone(), + temperature: request.temperature, + stream: true, + stream_options: Some(StreamOptions { + include_usage: true, + }), + }) +} + +fn parse_sse_line(line: &str) -> Option { + let data = line.strip_prefix("data: ")?; + if data.trim() == "[DONE]" { + return None; + } + serde_json::from_str(data).ok() +} + +impl zed::Extension for CopilotChatProvider { + fn new() -> Self { + Self { + streams: Mutex::new(HashMap::new()), + next_stream_id: Mutex::new(0), + device_flow_state: Mutex::new(None), + api_token: Mutex::new(None), + cached_models: Mutex::new(None), + } + } + + fn llm_providers(&self) -> Vec { + vec![LlmProviderInfo { + id: "copilot-chat".into(), + name: "Copilot Chat".into(), + icon: Some("icons/copilot.svg".into()), + }] + } + + fn llm_provider_models(&self, _provider_id: &str) -> Result, String> { + // Try to get models from cache first + if let Some(models) = self.cached_models.lock().unwrap().as_ref() { + return Ok(convert_models_to_llm_info(models)); + } + + // Need to fetch models - requires authentication + let oauth_token = match llm_get_credential("copilot-chat") { + Some(token) => token, + None => return Ok(Vec::new()), // Not authenticated, return empty + }; + + // Get API token + let api_token = self.get_api_token(&oauth_token)?; + + // Fetch models from API + let models = self.fetch_models(&api_token)?; + + // Cache the models + *self.cached_models.lock().unwrap() = Some(models.clone()); + + Ok(convert_models_to_llm_info(&models)) + } + + fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { + llm_get_credential("copilot-chat").is_some() + } + + fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { + Some( + r#"# Copilot Chat Setup + +Welcome to **Copilot Chat**! This extension provides access to GitHub Copilot's chat models. + +## Authentication + +Click **Sign in with GitHub** to authenticate with your GitHub account. You'll be redirected to GitHub to authorize access. This requires an active GitHub Copilot subscription. + +Alternatively, you can set the `GH_COPILOT_TOKEN` environment variable with your token. + +## Available Models + +| Model | Context | Output | +|-------|---------|--------| +| GPT-4o | 128K | 16K | +| GPT-4o Mini | 128K | 16K | +| GPT-4.1 | 1M | 32K | +| o1 | 200K | 100K | +| o3-mini | 200K | 100K | +| Claude 3.5 Sonnet | 200K | 8K | +| Claude 3.7 Sonnet | 200K | 8K | +| Gemini 2.0 Flash | 1M | 8K | + +## Features + +- ✅ Full streaming support +- ✅ Tool/function calling +- ✅ Vision (image inputs) +- ✅ Multiple model providers via Copilot + +## Note + +This extension requires an active GitHub Copilot subscription. +"# + .to_string(), + ) + } + + fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> { + // Check if we have existing credentials + if llm_get_credential("copilot-chat").is_some() { + return Ok(()); + } + + // No credentials found - return error for background auth checks. + // The device flow will be triggered by the host when the user clicks + // the "Sign in with GitHub" button, which calls llm_provider_start_device_flow_sign_in. + Err("CredentialsNotFound".to_string()) + } + + fn llm_provider_start_device_flow_sign_in( + &mut self, + _provider_id: &str, + ) -> Result { + // Step 1: Request device and user verification codes + let device_code_response = llm_oauth_http_request(&LlmOauthHttpRequest { + url: GITHUB_DEVICE_CODE_URL.to_string(), + method: "POST".to_string(), + headers: vec![ + ("Accept".to_string(), "application/json".to_string()), + ( + "Content-Type".to_string(), + "application/x-www-form-urlencoded".to_string(), + ), + ], + body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID), + })?; + + if device_code_response.status != 200 { + return Err(format!( + "Failed to get device code: HTTP {}", + device_code_response.status + )); + } + + #[derive(Deserialize)] + struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + #[serde(default)] + verification_uri_complete: Option, + expires_in: u64, + interval: u64, + } + + let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body) + .map_err(|e| format!("Failed to parse device code response: {}", e))?; + + // Store device flow state for polling + *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState { + device_code: device_info.device_code, + interval: device_info.interval, + expires_in: device_info.expires_in, + }); + + // Step 2: Open browser to verification URL + // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL + let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| { + format!( + "{}?user_code={}", + device_info.verification_uri, &device_info.user_code + ) + }); + llm_oauth_open_browser(&verification_url)?; + + // Return the user code for the host to display + Ok(device_info.user_code) + } + + fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> { + let state = self + .device_flow_state + .lock() + .unwrap() + .take() + .ok_or("No device flow in progress")?; + + let poll_interval = Duration::from_secs(state.interval.max(5)); + let max_attempts = (state.expires_in / state.interval.max(5)) as usize; + + for _ in 0..max_attempts { + thread::sleep(poll_interval); + + let token_response = llm_oauth_http_request(&LlmOauthHttpRequest { + url: GITHUB_ACCESS_TOKEN_URL.to_string(), + method: "POST".to_string(), + headers: vec![ + ("Accept".to_string(), "application/json".to_string()), + ( + "Content-Type".to_string(), + "application/x-www-form-urlencoded".to_string(), + ), + ], + body: format!( + "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code", + GITHUB_COPILOT_CLIENT_ID, state.device_code + ), + })?; + + #[derive(Deserialize)] + struct TokenResponse { + access_token: Option, + error: Option, + error_description: Option, + } + + let token_json: TokenResponse = serde_json::from_str(&token_response.body) + .map_err(|e| format!("Failed to parse token response: {}", e))?; + + if let Some(access_token) = token_json.access_token { + llm_store_credential("copilot-chat", &access_token)?; + return Ok(()); + } + + if let Some(error) = &token_json.error { + match error.as_str() { + "authorization_pending" => { + // User hasn't authorized yet, keep polling + continue; + } + "slow_down" => { + // Need to slow down polling + thread::sleep(Duration::from_secs(5)); + continue; + } + "expired_token" => { + return Err("Device code expired. Please try again.".to_string()); + } + "access_denied" => { + return Err("Authorization was denied.".to_string()); + } + _ => { + let description = token_json.error_description.unwrap_or_default(); + return Err(format!("OAuth error: {} - {}", error, description)); + } + } + } + } + + Err("Authorization timed out. Please try again.".to_string()) + } + + fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { + // Clear cached API token and models + *self.api_token.lock().unwrap() = None; + *self.cached_models.lock().unwrap() = None; + llm_delete_credential("copilot-chat") + } + + fn llm_stream_completion_start( + &mut self, + _provider_id: &str, + model_id: &str, + request: &LlmCompletionRequest, + ) -> Result { + let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| { + "No token configured. Please add your GitHub Copilot token in settings.".to_string() + })?; + + // Get or refresh API token + let api_token = self.get_api_token(&oauth_token)?; + + let openai_request = convert_request(model_id, request)?; + + let body = serde_json::to_vec(&openai_request) + .map_err(|e| format!("Failed to serialize request: {}", e))?; + + let completions_url = format!("{}/chat/completions", api_token.api_endpoint); + + let http_request = HttpRequest { + method: HttpMethod::Post, + url: completions_url, + headers: vec![ + ("Content-Type".to_string(), "application/json".to_string()), + ( + "Authorization".to_string(), + format!("Bearer {}", api_token.api_key), + ), + ( + "Copilot-Integration-Id".to_string(), + "vscode-chat".to_string(), + ), + ("Editor-Version".to_string(), "Zed/1.0.0".to_string()), + ], + body: Some(body), + redirect_policy: RedirectPolicy::FollowAll, + }; + + let response_stream = http_request + .fetch_stream() + .map_err(|e| format!("HTTP request failed: {}", e))?; + + let stream_id = { + let mut id_counter = self.next_stream_id.lock().unwrap(); + let id = format!("copilot-stream-{}", *id_counter); + *id_counter += 1; + id + }; + + self.streams.lock().unwrap().insert( + stream_id.clone(), + StreamState { + response_stream: Some(response_stream), + buffer: String::new(), + started: false, + tool_calls: HashMap::new(), + tool_calls_emitted: false, + }, + ); + + Ok(stream_id) + } + + fn llm_stream_completion_next( + &mut self, + stream_id: &str, + ) -> Result, String> { + let mut streams = self.streams.lock().unwrap(); + let state = streams + .get_mut(stream_id) + .ok_or_else(|| format!("Unknown stream: {}", stream_id))?; + + if !state.started { + state.started = true; + return Ok(Some(LlmCompletionEvent::Started)); + } + + let response_stream = state + .response_stream + .as_mut() + .ok_or_else(|| "Stream already closed".to_string())?; + + loop { + if let Some(newline_pos) = state.buffer.find('\n') { + let line = state.buffer[..newline_pos].to_string(); + state.buffer = state.buffer[newline_pos + 1..].to_string(); + + if line.trim().is_empty() { + continue; + } + + if let Some(response) = parse_sse_line(&line) { + if let Some(choice) = response.choices.first() { + if let Some(content) = &choice.delta.content { + if !content.is_empty() { + return Ok(Some(LlmCompletionEvent::Text(content.clone()))); + } + } + + if let Some(tool_calls) = &choice.delta.tool_calls { + for tc in tool_calls { + let entry = state + .tool_calls + .entry(tc.index) + .or_insert_with(AccumulatedToolCall::default); + + if let Some(id) = &tc.id { + entry.id = id.clone(); + } + if let Some(func) = &tc.function { + if let Some(name) = &func.name { + entry.name = name.clone(); + } + if let Some(args) = &func.arguments { + entry.arguments.push_str(args); + } + } + } + } + + if let Some(finish_reason) = &choice.finish_reason { + if !state.tool_calls.is_empty() && !state.tool_calls_emitted { + state.tool_calls_emitted = true; + let mut tool_calls: Vec<_> = state.tool_calls.drain().collect(); + tool_calls.sort_by_key(|(idx, _)| *idx); + + if let Some((_, tc)) = tool_calls.into_iter().next() { + return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { + id: tc.id, + name: tc.name, + input: tc.arguments, + thought_signature: None, + }))); + } + } + + let stop_reason = match finish_reason.as_str() { + "stop" => LlmStopReason::EndTurn, + "length" => LlmStopReason::MaxTokens, + "tool_calls" => LlmStopReason::ToolUse, + "content_filter" => LlmStopReason::Refusal, + _ => LlmStopReason::EndTurn, + }; + return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); + } + } + + if let Some(usage) = response.usage { + return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }))); + } + } + + continue; + } + + match response_stream.next_chunk() { + Ok(Some(chunk)) => { + let text = String::from_utf8_lossy(&chunk); + state.buffer.push_str(&text); + } + Ok(None) => { + return Ok(None); + } + Err(e) => { + return Err(format!("Stream error: {}", e)); + } + } + } + } + + fn llm_stream_completion_close(&mut self, stream_id: &str) { + self.streams.lock().unwrap().remove(stream_id); + } +} + +impl CopilotChatProvider { + fn get_api_token(&self, oauth_token: &str) -> Result { + // Check if we have a cached token + if let Some(token) = self.api_token.lock().unwrap().clone() { + return Ok(token); + } + + // Request a new API token + let http_request = HttpRequest { + method: HttpMethod::Get, + url: GITHUB_COPILOT_TOKEN_URL.to_string(), + headers: vec![ + ( + "Authorization".to_string(), + format!("token {}", oauth_token), + ), + ("Accept".to_string(), "application/json".to_string()), + ], + body: None, + redirect_policy: RedirectPolicy::FollowAll, + }; + + let response = http_request + .fetch() + .map_err(|e| format!("Failed to request API token: {}", e))?; + + #[derive(Deserialize)] + struct ApiTokenResponse { + token: String, + endpoints: ApiEndpoints, + } + + #[derive(Deserialize)] + struct ApiEndpoints { + api: String, + } + + let token_response: ApiTokenResponse = + serde_json::from_slice(&response.body).map_err(|e| { + format!( + "Failed to parse API token response: {} - body: {}", + e, + String::from_utf8_lossy(&response.body) + ) + })?; + + let api_token = ApiToken { + api_key: token_response.token, + api_endpoint: token_response.endpoints.api, + }; + + // Cache the token + *self.api_token.lock().unwrap() = Some(api_token.clone()); + + Ok(api_token) + } + + fn fetch_models(&self, api_token: &ApiToken) -> Result, String> { + let models_url = format!("{}/models", api_token.api_endpoint); + + let http_request = HttpRequest { + method: HttpMethod::Get, + url: models_url, + headers: vec![ + ( + "Authorization".to_string(), + format!("Bearer {}", api_token.api_key), + ), + ("Content-Type".to_string(), "application/json".to_string()), + ( + "Copilot-Integration-Id".to_string(), + "vscode-chat".to_string(), + ), + ("Editor-Version".to_string(), "Zed/1.0.0".to_string()), + ("x-github-api-version".to_string(), "2025-05-01".to_string()), + ], + body: None, + redirect_policy: RedirectPolicy::FollowAll, + }; + + let response = http_request + .fetch() + .map_err(|e| format!("Failed to fetch models: {}", e))?; + + #[derive(Deserialize)] + struct ModelsResponse { + data: Vec, + } + + let models_response: ModelsResponse = + serde_json::from_slice(&response.body).map_err(|e| { + format!( + "Failed to parse models response: {} - body: {}", + e, + String::from_utf8_lossy(&response.body) + ) + })?; + + // Filter models like the built-in Copilot Chat does + let mut models: Vec = models_response + .data + .into_iter() + .filter(|model| { + model.model_picker_enabled + && model.capabilities.model_type == "chat" + && model + .policy + .as_ref() + .map(|p| p.state == "enabled") + .unwrap_or(true) + }) + .collect(); + + // Sort so default model is first + if let Some(pos) = models.iter().position(|m| m.is_chat_default) { + let default_model = models.remove(pos); + models.insert(0, default_model); + } + + Ok(models) + } +} + +fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec { + models + .iter() + .map(|m| { + let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 { + m.capabilities.limits.max_context_window_tokens + } else { + 128_000 // Default fallback + }; + let max_output = if m.capabilities.limits.max_output_tokens > 0 { + Some(m.capabilities.limits.max_output_tokens) + } else { + None + }; + + LlmModelInfo { + id: m.id.clone(), + name: m.name.clone(), + max_token_count: max_tokens, + max_output_tokens: max_output, + capabilities: LlmModelCapabilities { + supports_images: m.capabilities.supports.vision, + supports_tools: m.capabilities.supports.tool_calls, + supports_tool_choice_auto: m.capabilities.supports.tool_calls, + supports_tool_choice_any: m.capabilities.supports.tool_calls, + supports_tool_choice_none: m.capabilities.supports.tool_calls, + supports_thinking: false, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: m.is_chat_default, + is_default_fast: m.is_chat_fallback, + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_device_flow_request_body() { + let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID); + assert!(body.contains("client_id=Iv1.b507a08c87ecfe98")); + assert!(body.contains("scope=read:user")); + } + + #[test] + fn test_token_poll_request_body() { + let device_code = "test_device_code_123"; + let body = format!( + "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code", + GITHUB_COPILOT_CLIENT_ID, device_code + ); + assert!(body.contains("client_id=Iv1.b507a08c87ecfe98")); + assert!(body.contains("device_code=test_device_code_123")); + assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code")); + } +} + +zed::register_extension!(CopilotChatProvider); diff --git a/extensions/google-ai/Cargo.lock b/extensions/google-ai/Cargo.lock new file mode 100644 index 00000000000000..2389ff51da0c24 --- /dev/null +++ b/extensions/google-ai/Cargo.lock @@ -0,0 +1,823 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + +[[package]] +name = "auditable-serde" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5" +dependencies = [ + "semver", + "serde", + "serde_json", + "topological-sort", +] + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foogle" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", + "zed_extension_api", +] + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +dependencies = [ + "serde", + "serde_core", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.145" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", + "serde_core", +] + +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "spdx" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3" +dependencies = [ + "smallvec", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "syn" +version = "2.0.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "topological-sort" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "url" +version = "2.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "wasm-encoder" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d" +dependencies = [ + "anyhow", + "auditable-serde", + "flate2", + "indexmap", + "serde", + "serde_derive", + "serde_json", + "spdx", + "url", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "wit-bindgen" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de" +dependencies = [ + "wit-bindgen-rt", + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rt" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621" +dependencies = [ + "bitflags", + "futures", + "once_cell", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zed_extension_api" +version = "0.7.0" +dependencies = [ + "serde", + "serde_json", + "wit-bindgen", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/extensions/google-ai/Cargo.toml b/extensions/google-ai/Cargo.toml new file mode 100644 index 00000000000000..f6de35d4066938 --- /dev/null +++ b/extensions/google-ai/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "google-ai" +version = "0.1.0" +edition = "2021" +publish = false +license = "Apache-2.0" + +[workspace] + +[lib] +path = "src/google_ai.rs" +crate-type = ["cdylib"] + +[dependencies] +zed_extension_api = { path = "../../crates/extension_api" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/extensions/google-ai/extension.toml b/extensions/google-ai/extension.toml new file mode 100644 index 00000000000000..1b1cb382a7835d --- /dev/null +++ b/extensions/google-ai/extension.toml @@ -0,0 +1,13 @@ +id = "google-ai" +name = "Google AI" +description = "Google Gemini LLM provider for Zed." +version = "0.1.0" +schema_version = 1 +authors = ["Zed Team"] +repository = "https://github.com/zed-industries/zed" + +[language_model_providers.google-ai] +name = "Google AI" + +[language_model_providers.google-ai.auth] +env_var = "GEMINI_API_KEY" \ No newline at end of file diff --git a/extensions/google-ai/icons/google-ai.svg b/extensions/google-ai/icons/google-ai.svg new file mode 100644 index 00000000000000..bdde44ed247531 --- /dev/null +++ b/extensions/google-ai/icons/google-ai.svg @@ -0,0 +1,3 @@ + + + diff --git a/extensions/google-ai/src/google_ai.rs b/extensions/google-ai/src/google_ai.rs new file mode 100644 index 00000000000000..61baca80b19d9d --- /dev/null +++ b/extensions/google-ai/src/google_ai.rs @@ -0,0 +1,840 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Mutex; + +use serde::{Deserialize, Serialize}; +use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy}; +use zed_extension_api::{self as zed, *}; + +static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); + +struct GoogleAiProvider { + streams: Mutex>, + next_stream_id: Mutex, +} + +struct StreamState { + response_stream: Option, + buffer: String, + started: bool, + stop_reason: Option, + wants_tool_use: bool, +} + +struct ModelDefinition { + real_id: &'static str, + display_name: &'static str, + max_tokens: u64, + max_output_tokens: Option, + supports_images: bool, + supports_thinking: bool, + is_default: bool, + is_default_fast: bool, +} + +const MODELS: &[ModelDefinition] = &[ + ModelDefinition { + real_id: "gemini-2.5-flash-lite", + display_name: "Gemini 2.5 Flash-Lite", + max_tokens: 1_048_576, + max_output_tokens: Some(65_536), + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: true, + }, + ModelDefinition { + real_id: "gemini-2.5-flash", + display_name: "Gemini 2.5 Flash", + max_tokens: 1_048_576, + max_output_tokens: Some(65_536), + supports_images: true, + supports_thinking: true, + is_default: true, + is_default_fast: false, + }, + ModelDefinition { + real_id: "gemini-2.5-pro", + display_name: "Gemini 2.5 Pro", + max_tokens: 1_048_576, + max_output_tokens: Some(65_536), + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "gemini-3-pro-preview", + display_name: "Gemini 3 Pro", + max_tokens: 1_048_576, + max_output_tokens: Some(65_536), + supports_images: true, + supports_thinking: true, + is_default: false, + is_default_fast: false, + }, +]; + +fn get_real_model_id(display_name: &str) -> Option<&'static str> { + MODELS + .iter() + .find(|m| m.display_name == display_name) + .map(|m| m.real_id) +} + +fn get_model_supports_thinking(display_name: &str) -> bool { + MODELS + .iter() + .find(|m| m.display_name == display_name) + .map(|m| m.supports_thinking) + .unwrap_or(false) +} + +/// Adapts a JSON schema to be compatible with Google's API subset. +/// Google only supports a specific subset of JSON Schema fields. +/// See: https://ai.google.dev/api/caching#Schema +fn adapt_schema_for_google(json: &mut serde_json::Value) { + adapt_schema_for_google_impl(json, true); +} + +fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) { + if let serde_json::Value::Object(obj) = json { + // Google's Schema only supports these fields: + // type, format, title, description, nullable, enum, maxItems, minItems, + // properties, required, minProperties, maxProperties, minLength, maxLength, + // pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum + const ALLOWED_KEYS: &[&str] = &[ + "type", + "format", + "title", + "description", + "nullable", + "enum", + "maxItems", + "minItems", + "properties", + "required", + "minProperties", + "maxProperties", + "minLength", + "maxLength", + "pattern", + "example", + "anyOf", + "propertyOrdering", + "default", + "items", + "minimum", + "maximum", + ]; + + // Convert oneOf to anyOf before filtering keys + if let Some(one_of) = obj.remove("oneOf") { + obj.insert("anyOf".to_string(), one_of); + } + + // If type is an array (e.g., ["string", "null"]), take just the first type + if let Some(type_field) = obj.get_mut("type") { + if let serde_json::Value::Array(types) = type_field { + if let Some(first_type) = types.first().cloned() { + *type_field = first_type; + } + } + } + + // Only filter keys if this is a schema object, not a properties map + if is_schema { + obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str())); + } + + // Recursively process nested values + // "properties" contains a map of property names -> schemas + // "items" and "anyOf" contain schemas directly + for (key, value) in obj.iter_mut() { + if key == "properties" { + // properties is a map of property_name -> schema + if let serde_json::Value::Object(props) = value { + for (_, prop_schema) in props.iter_mut() { + adapt_schema_for_google_impl(prop_schema, true); + } + } + } else if key == "items" { + // items is a schema + adapt_schema_for_google_impl(value, true); + } else if key == "anyOf" { + // anyOf is an array of schemas + if let serde_json::Value::Array(arr) = value { + for item in arr.iter_mut() { + adapt_schema_for_google_impl(item, true); + } + } + } + } + } else if let serde_json::Value::Array(arr) = json { + for item in arr.iter_mut() { + adapt_schema_for_google_impl(item, true); + } + } +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GoogleRequest { + contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + system_instruction: Option, + #[serde(skip_serializing_if = "Option::is_none")] + generation_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_config: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GoogleSystemInstruction { + parts: Vec, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct GoogleContent { + parts: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +enum GooglePart { + Text(GoogleTextPart), + InlineData(GoogleInlineDataPart), + FunctionCall(GoogleFunctionCallPart), + FunctionResponse(GoogleFunctionResponsePart), + Thought(GoogleThoughtPart), +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct GoogleTextPart { + text: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct GoogleInlineDataPart { + inline_data: GoogleBlob, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct GoogleBlob { + mime_type: String, + data: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct GoogleFunctionCallPart { + function_call: GoogleFunctionCall, + #[serde(skip_serializing_if = "Option::is_none")] + thought_signature: Option, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct GoogleFunctionCall { + name: String, + args: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct GoogleFunctionResponsePart { + function_response: GoogleFunctionResponse, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct GoogleFunctionResponse { + name: String, + response: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +struct GoogleThoughtPart { + thought: bool, + thought_signature: String, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GoogleGenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] + candidate_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stop_sequences: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + thinking_config: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GoogleThinkingConfig { + thinking_budget: u32, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GoogleTool { + function_declarations: Vec, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GoogleFunctionDeclaration { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GoogleToolConfig { + function_calling_config: GoogleFunctionCallingConfig, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct GoogleFunctionCallingConfig { + mode: String, + #[serde(skip_serializing_if = "Option::is_none")] + allowed_function_names: Option>, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +struct GoogleStreamResponse { + #[serde(default)] + candidates: Vec, + #[serde(default)] + usage_metadata: Option, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +struct GoogleCandidate { + #[serde(default)] + content: Option, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +struct GoogleUsageMetadata { + #[serde(default)] + prompt_token_count: u64, + #[serde(default)] + candidates_token_count: u64, +} + +fn convert_request( + model_id: &str, + request: &LlmCompletionRequest, +) -> Result<(GoogleRequest, String), String> { + let real_model_id = + get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?; + + let supports_thinking = get_model_supports_thinking(model_id); + + let mut contents: Vec = Vec::new(); + let mut system_parts: Vec = Vec::new(); + + for msg in &request.messages { + match msg.role { + LlmMessageRole::System => { + for content in &msg.content { + if let LlmMessageContent::Text(text) = content { + if !text.is_empty() { + system_parts + .push(GooglePart::Text(GoogleTextPart { text: text.clone() })); + } + } + } + } + LlmMessageRole::User => { + let mut parts: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() })); + } + } + LlmMessageContent::Image(img) => { + parts.push(GooglePart::InlineData(GoogleInlineDataPart { + inline_data: GoogleBlob { + mime_type: "image/png".to_string(), + data: img.source.clone(), + }, + })); + } + LlmMessageContent::ToolResult(result) => { + let response_value = match &result.content { + LlmToolResultContent::Text(t) => { + serde_json::json!({ "output": t }) + } + LlmToolResultContent::Image(_) => { + serde_json::json!({ "output": "Tool responded with an image" }) + } + }; + parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart { + function_response: GoogleFunctionResponse { + name: result.tool_name.clone(), + response: response_value, + }, + })); + } + _ => {} + } + } + + if !parts.is_empty() { + contents.push(GoogleContent { + parts, + role: Some("user".to_string()), + }); + } + } + LlmMessageRole::Assistant => { + let mut parts: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() })); + } + } + LlmMessageContent::ToolUse(tool_use) => { + let thought_signature = + tool_use.thought_signature.clone().filter(|s| !s.is_empty()); + + let args: serde_json::Value = + serde_json::from_str(&tool_use.input).unwrap_or_default(); + + parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart { + function_call: GoogleFunctionCall { + name: tool_use.name.clone(), + args, + }, + thought_signature, + })); + } + LlmMessageContent::Thinking(thinking) => { + if let Some(ref signature) = thinking.signature { + if !signature.is_empty() { + parts.push(GooglePart::Thought(GoogleThoughtPart { + thought: true, + thought_signature: signature.clone(), + })); + } + } + } + _ => {} + } + } + + if !parts.is_empty() { + contents.push(GoogleContent { + parts, + role: Some("model".to_string()), + }); + } + } + } + } + + let system_instruction = if system_parts.is_empty() { + None + } else { + Some(GoogleSystemInstruction { + parts: system_parts, + }) + }; + + let tools: Option> = if request.tools.is_empty() { + None + } else { + let declarations: Vec = request + .tools + .iter() + .map(|t| { + let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema) + .unwrap_or(serde_json::Value::Object(Default::default())); + adapt_schema_for_google(&mut parameters); + GoogleFunctionDeclaration { + name: t.name.clone(), + description: t.description.clone(), + parameters, + } + }) + .collect(); + Some(vec![GoogleTool { + function_declarations: declarations, + }]) + }; + + let tool_config = request.tool_choice.as_ref().map(|tc| { + let mode = match tc { + LlmToolChoice::Auto => "AUTO", + LlmToolChoice::Any => "ANY", + LlmToolChoice::None => "NONE", + }; + GoogleToolConfig { + function_calling_config: GoogleFunctionCallingConfig { + mode: mode.to_string(), + allowed_function_names: None, + }, + } + }); + + let thinking_config = if supports_thinking && request.thinking_allowed { + Some(GoogleThinkingConfig { + thinking_budget: 8192, + }) + } else { + None + }; + + let generation_config = Some(GoogleGenerationConfig { + candidate_count: Some(1), + stop_sequences: if request.stop_sequences.is_empty() { + None + } else { + Some(request.stop_sequences.clone()) + }, + max_output_tokens: None, + temperature: request.temperature.map(|t| t as f64).or(Some(1.0)), + thinking_config, + }); + + Ok(( + GoogleRequest { + contents, + system_instruction, + generation_config, + tools, + tool_config, + }, + real_model_id.to_string(), + )) +} + +fn parse_stream_line(line: &str) -> Option { + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," { + return None; + } + + let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed); + let json_str = json_str.trim_start_matches(',').trim(); + + if json_str.is_empty() { + return None; + } + + serde_json::from_str(json_str).ok() +} + +impl zed::Extension for GoogleAiProvider { + fn new() -> Self { + Self { + streams: Mutex::new(HashMap::new()), + next_stream_id: Mutex::new(0), + } + } + + fn llm_providers(&self) -> Vec { + vec![LlmProviderInfo { + id: "google-ai".into(), + name: "Google AI".into(), + icon: Some("icons/google-ai.svg".into()), + }] + } + + fn llm_provider_models(&self, _provider_id: &str) -> Result, String> { + Ok(MODELS + .iter() + .map(|m| LlmModelInfo { + id: m.display_name.to_string(), + name: m.display_name.to_string(), + max_token_count: m.max_tokens, + max_output_tokens: m.max_output_tokens, + capabilities: LlmModelCapabilities { + supports_images: m.supports_images, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: m.supports_thinking, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: m.is_default, + is_default_fast: m.is_default_fast, + }) + .collect()) + } + + fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { + llm_get_credential("google-ai").is_some() + } + + fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { + Some( + r#"# Google AI Setup + +Welcome to **Google AI**! This extension provides access to Google Gemini models. + +## Configuration + +Enter your Google AI API key below. You can get your API key at [aistudio.google.com/apikey](https://aistudio.google.com/apikey). + +## Available Models + +| Display Name | Real Model | Context | Output | +|--------------|------------|---------|--------| +| Gemini 2.5 Flash-Lite | gemini-2.5-flash-lite | 1M | 65K | +| Gemini 2.5 Flash | gemini-2.5-flash | 1M | 65K | +| Gemini 2.5 Pro | gemini-2.5-pro | 1M | 65K | +| Gemini 3 Pro | gemini-3-pro-preview | 1M | 65K | + +## Features + +- ✅ Full streaming support +- ✅ Tool/function calling with thought signatures +- ✅ Vision (image inputs) +- ✅ Extended thinking support +- ✅ All Gemini models + +## Pricing + +Uses your Google AI API credits. See [Google AI pricing](https://ai.google.dev/pricing) for details. +"# + .to_string(), + ) + } + + fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> { + let provided = llm_request_credential( + "google-ai", + LlmCredentialType::ApiKey, + "Google AI API Key", + "AIza...", + )?; + if provided { + Ok(()) + } else { + Err("Authentication cancelled".to_string()) + } + } + + fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { + llm_delete_credential("google-ai") + } + + fn llm_stream_completion_start( + &mut self, + _provider_id: &str, + model_id: &str, + request: &LlmCompletionRequest, + ) -> Result { + let api_key = llm_get_credential("google-ai").ok_or_else(|| { + "No API key configured. Please add your Google AI API key in settings.".to_string() + })?; + + let (google_request, real_model_id) = convert_request(model_id, request)?; + + let body = serde_json::to_vec(&google_request) + .map_err(|e| format!("Failed to serialize request: {}", e))?; + + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}", + real_model_id, api_key + ); + + let http_request = HttpRequest { + method: HttpMethod::Post, + url, + headers: vec![("Content-Type".to_string(), "application/json".to_string())], + body: Some(body), + redirect_policy: RedirectPolicy::FollowAll, + }; + + let response_stream = http_request + .fetch_stream() + .map_err(|e| format!("HTTP request failed: {}", e))?; + + let stream_id = { + let mut id_counter = self.next_stream_id.lock().unwrap(); + let id = format!("google-ai-stream-{}", *id_counter); + *id_counter += 1; + id + }; + + self.streams.lock().unwrap().insert( + stream_id.clone(), + StreamState { + response_stream: Some(response_stream), + buffer: String::new(), + started: false, + stop_reason: None, + wants_tool_use: false, + }, + ); + + Ok(stream_id) + } + + fn llm_stream_completion_next( + &mut self, + stream_id: &str, + ) -> Result, String> { + let mut streams = self.streams.lock().unwrap(); + let state = streams + .get_mut(stream_id) + .ok_or_else(|| format!("Unknown stream: {}", stream_id))?; + + if !state.started { + state.started = true; + return Ok(Some(LlmCompletionEvent::Started)); + } + + let response_stream = state + .response_stream + .as_mut() + .ok_or_else(|| "Stream already closed".to_string())?; + + loop { + if let Some(newline_pos) = state.buffer.find('\n') { + let line = state.buffer[..newline_pos].to_string(); + state.buffer = state.buffer[newline_pos + 1..].to_string(); + + if let Some(response) = parse_stream_line(&line) { + for candidate in response.candidates { + if let Some(finish_reason) = &candidate.finish_reason { + state.stop_reason = Some(match finish_reason.as_str() { + "STOP" => { + if state.wants_tool_use { + LlmStopReason::ToolUse + } else { + LlmStopReason::EndTurn + } + } + "MAX_TOKENS" => LlmStopReason::MaxTokens, + "SAFETY" => LlmStopReason::Refusal, + _ => LlmStopReason::EndTurn, + }); + } + + if let Some(content) = candidate.content { + for part in content.parts { + match part { + GooglePart::Text(text_part) => { + if !text_part.text.is_empty() { + return Ok(Some(LlmCompletionEvent::Text( + text_part.text, + ))); + } + } + GooglePart::FunctionCall(fc_part) => { + state.wants_tool_use = true; + let next_tool_id = + TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst); + let id = format!( + "{}-{}", + fc_part.function_call.name, next_tool_id + ); + + let thought_signature = + fc_part.thought_signature.filter(|s| !s.is_empty()); + + return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { + id, + name: fc_part.function_call.name, + input: fc_part.function_call.args.to_string(), + thought_signature, + }))); + } + GooglePart::Thought(thought_part) => { + return Ok(Some(LlmCompletionEvent::Thinking( + LlmThinkingContent { + text: "(Encrypted thought)".to_string(), + signature: Some(thought_part.thought_signature), + }, + ))); + } + _ => {} + } + } + } + } + + if let Some(usage) = response.usage_metadata { + return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens: usage.prompt_token_count, + output_tokens: usage.candidates_token_count, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }))); + } + } + + continue; + } + + match response_stream.next_chunk() { + Ok(Some(chunk)) => { + let text = String::from_utf8_lossy(&chunk); + state.buffer.push_str(&text); + } + Ok(None) => { + // Stream ended - check if we have a stop reason + if let Some(stop_reason) = state.stop_reason.take() { + return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); + } + + // No stop reason - this is unexpected. Check if buffer contains error info + let mut error_msg = String::from("Stream ended unexpectedly."); + + // Try to parse remaining buffer as potential error response + if !state.buffer.is_empty() { + error_msg.push_str(&format!( + "\nRemaining buffer: {}", + &state.buffer[..state.buffer.len().min(1000)] + )); + } + + return Err(error_msg); + } + Err(e) => { + return Err(format!("Stream error: {}", e)); + } + } + } + } + + fn llm_stream_completion_close(&mut self, stream_id: &str) { + self.streams.lock().unwrap().remove(stream_id); + } +} + +zed::register_extension!(GoogleAiProvider); diff --git a/extensions/open-router/Cargo.lock b/extensions/open-router/Cargo.lock new file mode 100644 index 00000000000000..4dea7c7a8a9cd8 --- /dev/null +++ b/extensions/open-router/Cargo.lock @@ -0,0 +1,823 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + +[[package]] +name = "auditable-serde" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5" +dependencies = [ + "semver", + "serde", + "serde_json", + "topological-sort", +] + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "open_router" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", + "zed_extension_api", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +dependencies = [ + "serde", + "serde_core", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.145" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", + "serde_core", +] + +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "spdx" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3" +dependencies = [ + "smallvec", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "syn" +version = "2.0.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "topological-sort" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "url" +version = "2.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "wasm-encoder" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d" +dependencies = [ + "anyhow", + "auditable-serde", + "flate2", + "indexmap", + "serde", + "serde_derive", + "serde_json", + "spdx", + "url", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "wit-bindgen" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de" +dependencies = [ + "wit-bindgen-rt", + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rt" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621" +dependencies = [ + "bitflags", + "futures", + "once_cell", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zed_extension_api" +version = "0.8.0" +dependencies = [ + "serde", + "serde_json", + "wit-bindgen", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/extensions/open-router/Cargo.toml b/extensions/open-router/Cargo.toml new file mode 100644 index 00000000000000..5c5af5ad7ff9e7 --- /dev/null +++ b/extensions/open-router/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "open-router" +version = "0.1.0" +edition = "2021" +publish = false +license = "Apache-2.0" + +[workspace] + +[lib] +path = "src/open_router.rs" +crate-type = ["cdylib"] + +[dependencies] +zed_extension_api = { path = "../../crates/extension_api" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/extensions/open-router/extension.toml b/extensions/open-router/extension.toml new file mode 100644 index 00000000000000..6c1b8c087d016e --- /dev/null +++ b/extensions/open-router/extension.toml @@ -0,0 +1,13 @@ +id = "open-router" +name = "OpenRouter" +description = "OpenRouter LLM provider - access multiple AI models through a unified API." +version = "0.1.0" +schema_version = 1 +authors = ["Zed Team"] +repository = "https://github.com/zed-industries/zed" + +[language_model_providers.open-router] +name = "OpenRouter" + +[language_model_providers.open-router.auth] +env_var = "OPENROUTER_API_KEY" \ No newline at end of file diff --git a/extensions/open-router/icons/open-router.svg b/extensions/open-router/icons/open-router.svg new file mode 100644 index 00000000000000..b6f5164e0b385f --- /dev/null +++ b/extensions/open-router/icons/open-router.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/extensions/open-router/src/open_router.rs b/extensions/open-router/src/open_router.rs new file mode 100644 index 00000000000000..8d8b143cd70a3c --- /dev/null +++ b/extensions/open-router/src/open_router.rs @@ -0,0 +1,830 @@ +use std::collections::HashMap; +use std::sync::Mutex; + +use serde::{Deserialize, Serialize}; +use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy}; +use zed_extension_api::{self as zed, *}; + +struct OpenRouterProvider { + streams: Mutex>, + next_stream_id: Mutex, +} + +struct StreamState { + response_stream: Option, + buffer: String, + started: bool, + tool_calls: HashMap, + tool_calls_emitted: bool, +} + +#[derive(Clone, Default)] +struct AccumulatedToolCall { + id: String, + name: String, + arguments: String, +} + +struct ModelDefinition { + id: &'static str, + display_name: &'static str, + max_tokens: u64, + max_output_tokens: Option, + supports_images: bool, + supports_tools: bool, + is_default: bool, + is_default_fast: bool, +} + +const MODELS: &[ModelDefinition] = &[ + // Anthropic Models + ModelDefinition { + id: "anthropic/claude-sonnet-4", + display_name: "Claude Sonnet 4", + max_tokens: 200_000, + max_output_tokens: Some(8_192), + supports_images: true, + supports_tools: true, + is_default: true, + is_default_fast: false, + }, + ModelDefinition { + id: "anthropic/claude-opus-4", + display_name: "Claude Opus 4", + max_tokens: 200_000, + max_output_tokens: Some(8_192), + supports_images: true, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + id: "anthropic/claude-haiku-4", + display_name: "Claude Haiku 4", + max_tokens: 200_000, + max_output_tokens: Some(8_192), + supports_images: true, + supports_tools: true, + is_default: false, + is_default_fast: true, + }, + ModelDefinition { + id: "anthropic/claude-3.5-sonnet", + display_name: "Claude 3.5 Sonnet", + max_tokens: 200_000, + max_output_tokens: Some(8_192), + supports_images: true, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + // OpenAI Models + ModelDefinition { + id: "openai/gpt-4o", + display_name: "GPT-4o", + max_tokens: 128_000, + max_output_tokens: Some(16_384), + supports_images: true, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + id: "openai/gpt-4o-mini", + display_name: "GPT-4o Mini", + max_tokens: 128_000, + max_output_tokens: Some(16_384), + supports_images: true, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + id: "openai/o1", + display_name: "o1", + max_tokens: 200_000, + max_output_tokens: Some(100_000), + supports_images: true, + supports_tools: false, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + id: "openai/o3-mini", + display_name: "o3-mini", + max_tokens: 200_000, + max_output_tokens: Some(100_000), + supports_images: false, + supports_tools: false, + is_default: false, + is_default_fast: false, + }, + // Google Models + ModelDefinition { + id: "google/gemini-2.0-flash-001", + display_name: "Gemini 2.0 Flash", + max_tokens: 1_000_000, + max_output_tokens: Some(8_192), + supports_images: true, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + id: "google/gemini-2.5-pro-preview", + display_name: "Gemini 2.5 Pro", + max_tokens: 1_000_000, + max_output_tokens: Some(8_192), + supports_images: true, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + // Meta Models + ModelDefinition { + id: "meta-llama/llama-3.3-70b-instruct", + display_name: "Llama 3.3 70B", + max_tokens: 128_000, + max_output_tokens: Some(4_096), + supports_images: false, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + id: "meta-llama/llama-4-maverick", + display_name: "Llama 4 Maverick", + max_tokens: 128_000, + max_output_tokens: Some(4_096), + supports_images: true, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + // Mistral Models + ModelDefinition { + id: "mistralai/mistral-large-2411", + display_name: "Mistral Large", + max_tokens: 128_000, + max_output_tokens: Some(4_096), + supports_images: false, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + id: "mistralai/codestral-latest", + display_name: "Codestral", + max_tokens: 32_000, + max_output_tokens: Some(4_096), + supports_images: false, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + // DeepSeek Models + ModelDefinition { + id: "deepseek/deepseek-chat-v3-0324", + display_name: "DeepSeek V3", + max_tokens: 64_000, + max_output_tokens: Some(8_192), + supports_images: false, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + id: "deepseek/deepseek-r1", + display_name: "DeepSeek R1", + max_tokens: 64_000, + max_output_tokens: Some(8_192), + supports_images: false, + supports_tools: false, + is_default: false, + is_default_fast: false, + }, + // Qwen Models + ModelDefinition { + id: "qwen/qwen3-235b-a22b", + display_name: "Qwen 3 235B", + max_tokens: 40_000, + max_output_tokens: Some(8_192), + supports_images: false, + supports_tools: true, + is_default: false, + is_default_fast: false, + }, +]; + +fn get_model_definition(model_id: &str) -> Option<&'static ModelDefinition> { + MODELS.iter().find(|m| m.id == model_id) +} + +#[derive(Serialize)] +struct OpenRouterRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + stop: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + stream: bool, +} + +#[derive(Serialize)] +struct OpenRouterMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, +} + +#[derive(Serialize, Clone)] +#[serde(untagged)] +enum OpenRouterContent { + Text(String), + Parts(Vec), +} + +#[derive(Serialize, Clone)] +#[serde(tag = "type")] +enum OpenRouterContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Serialize, Clone)] +struct ImageUrl { + url: String, +} + +#[derive(Serialize, Clone)] +struct OpenRouterToolCall { + id: String, + #[serde(rename = "type")] + call_type: String, + function: OpenRouterFunctionCall, +} + +#[derive(Serialize, Clone)] +struct OpenRouterFunctionCall { + name: String, + arguments: String, +} + +#[derive(Serialize)] +struct OpenRouterTool { + #[serde(rename = "type")] + tool_type: String, + function: OpenRouterFunctionDef, +} + +#[derive(Serialize)] +struct OpenRouterFunctionDef { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Deserialize, Debug)] +struct OpenRouterStreamResponse { + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Deserialize, Debug)] +struct OpenRouterStreamChoice { + delta: OpenRouterDelta, + finish_reason: Option, +} + +#[derive(Deserialize, Debug, Default)] +struct OpenRouterDelta { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[derive(Deserialize, Debug)] +struct OpenRouterToolCallDelta { + index: usize, + #[serde(default)] + id: Option, + #[serde(default)] + function: Option, +} + +#[derive(Deserialize, Debug, Default)] +struct OpenRouterFunctionDelta { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, +} + +#[derive(Deserialize, Debug)] +struct OpenRouterUsage { + prompt_tokens: u64, + completion_tokens: u64, +} + +fn convert_request( + model_id: &str, + request: &LlmCompletionRequest, +) -> Result { + let mut messages: Vec = Vec::new(); + + for msg in &request.messages { + match msg.role { + LlmMessageRole::System => { + let mut text_content = String::new(); + for content in &msg.content { + if let LlmMessageContent::Text(text) = content { + if !text_content.is_empty() { + text_content.push('\n'); + } + text_content.push_str(text); + } + } + if !text_content.is_empty() { + messages.push(OpenRouterMessage { + role: "system".to_string(), + content: Some(OpenRouterContent::Text(text_content)), + tool_calls: None, + tool_call_id: None, + }); + } + } + LlmMessageRole::User => { + let mut parts: Vec = Vec::new(); + let mut tool_result_messages: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + parts.push(OpenRouterContentPart::Text { text: text.clone() }); + } + } + LlmMessageContent::Image(img) => { + let data_url = format!("data:image/png;base64,{}", img.source); + parts.push(OpenRouterContentPart::ImageUrl { + image_url: ImageUrl { url: data_url }, + }); + } + LlmMessageContent::ToolResult(result) => { + let content_text = match &result.content { + LlmToolResultContent::Text(t) => t.clone(), + LlmToolResultContent::Image(_) => "[Image]".to_string(), + }; + tool_result_messages.push(OpenRouterMessage { + role: "tool".to_string(), + content: Some(OpenRouterContent::Text(content_text)), + tool_calls: None, + tool_call_id: Some(result.tool_use_id.clone()), + }); + } + _ => {} + } + } + + if !parts.is_empty() { + let content = if parts.len() == 1 { + if let OpenRouterContentPart::Text { text } = &parts[0] { + OpenRouterContent::Text(text.clone()) + } else { + OpenRouterContent::Parts(parts) + } + } else { + OpenRouterContent::Parts(parts) + }; + + messages.push(OpenRouterMessage { + role: "user".to_string(), + content: Some(content), + tool_calls: None, + tool_call_id: None, + }); + } + + messages.extend(tool_result_messages); + } + LlmMessageRole::Assistant => { + let mut text_content = String::new(); + let mut tool_calls: Vec = Vec::new(); + + for content in &msg.content { + match content { + LlmMessageContent::Text(text) => { + if !text.is_empty() { + if !text_content.is_empty() { + text_content.push('\n'); + } + text_content.push_str(text); + } + } + LlmMessageContent::ToolUse(tool_use) => { + tool_calls.push(OpenRouterToolCall { + id: tool_use.id.clone(), + call_type: "function".to_string(), + function: OpenRouterFunctionCall { + name: tool_use.name.clone(), + arguments: tool_use.input.clone(), + }, + }); + } + _ => {} + } + } + + messages.push(OpenRouterMessage { + role: "assistant".to_string(), + content: if text_content.is_empty() { + None + } else { + Some(OpenRouterContent::Text(text_content)) + }, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + tool_call_id: None, + }); + } + } + } + + let model_def = get_model_definition(model_id); + let supports_tools = model_def.map(|m| m.supports_tools).unwrap_or(true); + + let tools: Vec = if supports_tools { + request + .tools + .iter() + .map(|t| OpenRouterTool { + tool_type: "function".to_string(), + function: OpenRouterFunctionDef { + name: t.name.clone(), + description: t.description.clone(), + parameters: serde_json::from_str(&t.input_schema) + .unwrap_or(serde_json::Value::Object(Default::default())), + }, + }) + .collect() + } else { + Vec::new() + }; + + let tool_choice = if supports_tools { + request.tool_choice.as_ref().map(|tc| match tc { + LlmToolChoice::Auto => "auto".to_string(), + LlmToolChoice::Any => "required".to_string(), + LlmToolChoice::None => "none".to_string(), + }) + } else { + None + }; + + let max_tokens = request + .max_tokens + .or(model_def.and_then(|m| m.max_output_tokens)); + + Ok(OpenRouterRequest { + model: model_id.to_string(), + messages, + max_tokens, + tools, + tool_choice, + stop: request.stop_sequences.clone(), + temperature: request.temperature, + stream: true, + }) +} + +fn parse_sse_line(line: &str) -> Option { + let data = line.strip_prefix("data: ")?; + if data.trim() == "[DONE]" { + return None; + } + serde_json::from_str(data).ok() +} + +impl zed::Extension for OpenRouterProvider { + fn new() -> Self { + Self { + streams: Mutex::new(HashMap::new()), + next_stream_id: Mutex::new(0), + } + } + + fn llm_providers(&self) -> Vec { + vec![LlmProviderInfo { + id: "open_router".into(), + name: "OpenRouter".into(), + icon: Some("icons/open-router.svg".into()), + }] + } + + fn llm_provider_models(&self, _provider_id: &str) -> Result, String> { + Ok(MODELS + .iter() + .map(|m| LlmModelInfo { + id: m.id.to_string(), + name: m.display_name.to_string(), + max_token_count: m.max_tokens, + max_output_tokens: m.max_output_tokens, + capabilities: LlmModelCapabilities { + supports_images: m.supports_images, + supports_tools: m.supports_tools, + supports_tool_choice_auto: m.supports_tools, + supports_tool_choice_any: m.supports_tools, + supports_tool_choice_none: m.supports_tools, + supports_thinking: false, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: m.is_default, + is_default_fast: m.is_default_fast, + }) + .collect()) + } + + fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { + llm_get_credential("open_router").is_some() + } + + fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { + Some( + r#"# OpenRouter Setup + +Welcome to **OpenRouter**! Access multiple AI models through a single API. + +## Configuration + +Enter your OpenRouter API key below. Get your API key at [openrouter.ai/keys](https://openrouter.ai/keys). + +## Available Models + +### Anthropic +| Model | Context | Output | +|-------|---------|--------| +| Claude Sonnet 4 | 200K | 8K | +| Claude Opus 4 | 200K | 8K | +| Claude Haiku 4 | 200K | 8K | +| Claude 3.5 Sonnet | 200K | 8K | + +### OpenAI +| Model | Context | Output | +|-------|---------|--------| +| GPT-4o | 128K | 16K | +| GPT-4o Mini | 128K | 16K | +| o1 | 200K | 100K | +| o3-mini | 200K | 100K | + +### Google +| Model | Context | Output | +|-------|---------|--------| +| Gemini 2.0 Flash | 1M | 8K | +| Gemini 2.5 Pro | 1M | 8K | + +### Meta +| Model | Context | Output | +|-------|---------|--------| +| Llama 3.3 70B | 128K | 4K | +| Llama 4 Maverick | 128K | 4K | + +### Mistral +| Model | Context | Output | +|-------|---------|--------| +| Mistral Large | 128K | 4K | +| Codestral | 32K | 4K | + +### DeepSeek +| Model | Context | Output | +|-------|---------|--------| +| DeepSeek V3 | 64K | 8K | +| DeepSeek R1 | 64K | 8K | + +### Qwen +| Model | Context | Output | +|-------|---------|--------| +| Qwen 3 235B | 40K | 8K | + +## Features + +- ✅ Full streaming support +- ✅ Tool/function calling (model dependent) +- ✅ Vision (model dependent) +- ✅ Access to 200+ models +- ✅ Unified billing + +## Pricing + +Pay-per-use based on model. See [openrouter.ai/models](https://openrouter.ai/models) for pricing. +"# + .to_string(), + ) + } + + fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> { + let provided = llm_request_credential( + "open_router", + LlmCredentialType::ApiKey, + "OpenRouter API Key", + "sk-or-v1-...", + )?; + if provided { + Ok(()) + } else { + Err("Authentication cancelled".to_string()) + } + } + + fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { + llm_delete_credential("open_router") + } + + fn llm_stream_completion_start( + &mut self, + _provider_id: &str, + model_id: &str, + request: &LlmCompletionRequest, + ) -> Result { + let api_key = llm_get_credential("open_router").ok_or_else(|| { + "No API key configured. Please add your OpenRouter API key in settings.".to_string() + })?; + + let openrouter_request = convert_request(model_id, request)?; + + let body = serde_json::to_vec(&openrouter_request) + .map_err(|e| format!("Failed to serialize request: {}", e))?; + + let http_request = HttpRequest { + method: HttpMethod::Post, + url: "https://openrouter.ai/api/v1/chat/completions".to_string(), + headers: vec![ + ("Content-Type".to_string(), "application/json".to_string()), + ("Authorization".to_string(), format!("Bearer {}", api_key)), + ("HTTP-Referer".to_string(), "https://zed.dev".to_string()), + ("X-Title".to_string(), "Zed Editor".to_string()), + ], + body: Some(body), + redirect_policy: RedirectPolicy::FollowAll, + }; + + let response_stream = http_request + .fetch_stream() + .map_err(|e| format!("HTTP request failed: {}", e))?; + + let stream_id = { + let mut id_counter = self.next_stream_id.lock().unwrap(); + let id = format!("openrouter-stream-{}", *id_counter); + *id_counter += 1; + id + }; + + self.streams.lock().unwrap().insert( + stream_id.clone(), + StreamState { + response_stream: Some(response_stream), + buffer: String::new(), + started: false, + tool_calls: HashMap::new(), + tool_calls_emitted: false, + }, + ); + + Ok(stream_id) + } + + fn llm_stream_completion_next( + &mut self, + stream_id: &str, + ) -> Result, String> { + let mut streams = self.streams.lock().unwrap(); + let state = streams + .get_mut(stream_id) + .ok_or_else(|| format!("Unknown stream: {}", stream_id))?; + + if !state.started { + state.started = true; + return Ok(Some(LlmCompletionEvent::Started)); + } + + let response_stream = state + .response_stream + .as_mut() + .ok_or_else(|| "Stream already closed".to_string())?; + + loop { + if let Some(newline_pos) = state.buffer.find('\n') { + let line = state.buffer[..newline_pos].to_string(); + state.buffer = state.buffer[newline_pos + 1..].to_string(); + + if line.trim().is_empty() { + continue; + } + + if let Some(response) = parse_sse_line(&line) { + if let Some(choice) = response.choices.first() { + if let Some(content) = &choice.delta.content { + if !content.is_empty() { + return Ok(Some(LlmCompletionEvent::Text(content.clone()))); + } + } + + if let Some(tool_calls) = &choice.delta.tool_calls { + for tc in tool_calls { + let entry = state + .tool_calls + .entry(tc.index) + .or_insert_with(AccumulatedToolCall::default); + + if let Some(id) = &tc.id { + entry.id = id.clone(); + } + if let Some(func) = &tc.function { + if let Some(name) = &func.name { + entry.name = name.clone(); + } + if let Some(args) = &func.arguments { + entry.arguments.push_str(args); + } + } + } + } + + if let Some(finish_reason) = &choice.finish_reason { + if !state.tool_calls.is_empty() && !state.tool_calls_emitted { + state.tool_calls_emitted = true; + let mut tool_calls: Vec<_> = state.tool_calls.drain().collect(); + tool_calls.sort_by_key(|(idx, _)| *idx); + + if let Some((_, tc)) = tool_calls.into_iter().next() { + return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { + id: tc.id, + name: tc.name, + input: tc.arguments, + thought_signature: None, + }))); + } + } + + let stop_reason = match finish_reason.as_str() { + "stop" => LlmStopReason::EndTurn, + "length" => LlmStopReason::MaxTokens, + "tool_calls" => LlmStopReason::ToolUse, + "content_filter" => LlmStopReason::Refusal, + _ => LlmStopReason::EndTurn, + }; + return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); + } + } + + if let Some(usage) = response.usage { + return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }))); + } + } + + continue; + } + + match response_stream.next_chunk() { + Ok(Some(chunk)) => { + let text = String::from_utf8_lossy(&chunk); + state.buffer.push_str(&text); + } + Ok(None) => { + return Ok(None); + } + Err(e) => { + return Err(format!("Stream error: {}", e)); + } + } + } + } + + fn llm_stream_completion_close(&mut self, stream_id: &str) { + self.streams.lock().unwrap().remove(stream_id); + } +} + +zed::register_extension!(OpenRouterProvider); diff --git a/extensions/openai/Cargo.lock b/extensions/openai/Cargo.lock new file mode 100644 index 00000000000000..2ef354a2892b23 --- /dev/null +++ b/extensions/openai/Cargo.lock @@ -0,0 +1,823 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + +[[package]] +name = "auditable-serde" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5" +dependencies = [ + "semver", + "serde", + "serde_json", + "topological-sort", +] + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "fopenai" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", + "zed_extension_api", +] + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +dependencies = [ + "serde", + "serde_core", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.145" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", + "serde_core", +] + +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "spdx" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3" +dependencies = [ + "smallvec", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "syn" +version = "2.0.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "topological-sort" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "url" +version = "2.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "wasm-encoder" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d" +dependencies = [ + "anyhow", + "auditable-serde", + "flate2", + "indexmap", + "serde", + "serde_derive", + "serde_json", + "spdx", + "url", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "wit-bindgen" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de" +dependencies = [ + "wit-bindgen-rt", + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rt" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621" +dependencies = [ + "bitflags", + "futures", + "once_cell", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.41.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.227.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zed_extension_api" +version = "0.7.0" +dependencies = [ + "serde", + "serde_json", + "wit-bindgen", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/extensions/openai/Cargo.toml b/extensions/openai/Cargo.toml new file mode 100644 index 00000000000000..f81809e426ef69 --- /dev/null +++ b/extensions/openai/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "openai" +version = "0.1.0" +edition = "2021" +publish = false +license = "Apache-2.0" + +[workspace] + +[lib] +path = "src/openai.rs" +crate-type = ["cdylib"] + +[dependencies] +zed_extension_api = { path = "../../crates/extension_api" } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/extensions/openai/extension.toml b/extensions/openai/extension.toml new file mode 100644 index 00000000000000..94788688716f1d --- /dev/null +++ b/extensions/openai/extension.toml @@ -0,0 +1,13 @@ +id = "openai" +name = "OpenAI" +description = "OpenAI GPT LLM provider for Zed." +version = "0.1.0" +schema_version = 1 +authors = ["Zed Team"] +repository = "https://github.com/zed-industries/zed" + +[language_model_providers.openai] +name = "OpenAI" + +[language_model_providers.openai.auth] +env_var = "OPENAI_API_KEY" \ No newline at end of file diff --git a/extensions/openai/icons/openai.svg b/extensions/openai/icons/openai.svg new file mode 100644 index 00000000000000..e45ac315a01185 --- /dev/null +++ b/extensions/openai/icons/openai.svg @@ -0,0 +1,3 @@ + + + diff --git a/extensions/openai/src/openai.rs b/extensions/openai/src/openai.rs new file mode 100644 index 00000000000000..40a99352abd5da --- /dev/null +++ b/extensions/openai/src/openai.rs @@ -0,0 +1,727 @@ +use std::collections::HashMap; +use std::sync::Mutex; + +use serde::{Deserialize, Serialize}; +use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy}; +use zed_extension_api::{self as zed, *}; + +struct OpenAiProvider { + streams: Mutex>, + next_stream_id: Mutex, +} + +struct StreamState { + response_stream: Option, + buffer: String, + started: bool, + tool_calls: HashMap, + tool_calls_emitted: bool, +} + +#[derive(Clone, Default)] +struct AccumulatedToolCall { + id: String, + name: String, + arguments: String, +} + +struct ModelDefinition { + real_id: &'static str, + display_name: &'static str, + max_tokens: u64, + max_output_tokens: Option, + supports_images: bool, + is_default: bool, + is_default_fast: bool, +} + +const MODELS: &[ModelDefinition] = &[ + ModelDefinition { + real_id: "gpt-4o", + display_name: "GPT-4o", + max_tokens: 128_000, + max_output_tokens: Some(16_384), + supports_images: true, + is_default: true, + is_default_fast: false, + }, + ModelDefinition { + real_id: "gpt-4o-mini", + display_name: "GPT-4o-mini", + max_tokens: 128_000, + max_output_tokens: Some(16_384), + supports_images: true, + is_default: false, + is_default_fast: true, + }, + ModelDefinition { + real_id: "gpt-4.1", + display_name: "GPT-4.1", + max_tokens: 1_047_576, + max_output_tokens: Some(32_768), + supports_images: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "gpt-4.1-mini", + display_name: "GPT-4.1-mini", + max_tokens: 1_047_576, + max_output_tokens: Some(32_768), + supports_images: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "gpt-4.1-nano", + display_name: "GPT-4.1-nano", + max_tokens: 1_047_576, + max_output_tokens: Some(32_768), + supports_images: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "gpt-5", + display_name: "GPT-5", + max_tokens: 272_000, + max_output_tokens: Some(32_768), + supports_images: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "gpt-5-mini", + display_name: "GPT-5-mini", + max_tokens: 272_000, + max_output_tokens: Some(32_768), + supports_images: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "o1", + display_name: "o1", + max_tokens: 200_000, + max_output_tokens: Some(100_000), + supports_images: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "o3", + display_name: "o3", + max_tokens: 200_000, + max_output_tokens: Some(100_000), + supports_images: true, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "o3-mini", + display_name: "o3-mini", + max_tokens: 200_000, + max_output_tokens: Some(100_000), + supports_images: false, + is_default: false, + is_default_fast: false, + }, + ModelDefinition { + real_id: "o4-mini", + display_name: "o4-mini", + max_tokens: 200_000, + max_output_tokens: Some(100_000), + supports_images: true, + is_default: false, + is_default_fast: false, + }, +]; + +fn get_real_model_id(display_name: &str) -> Option<&'static str> { + MODELS + .iter() + .find(|m| m.display_name == display_name) + .map(|m| m.real_id) +} + +#[derive(Serialize)] +struct OpenAiRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + stop: Vec, + stream: bool, + stream_options: Option, +} + +#[derive(Serialize)] +struct StreamOptions { + include_usage: bool, +} + +#[derive(Serialize)] +#[serde(tag = "role")] +enum OpenAiMessage { + #[serde(rename = "system")] + System { content: String }, + #[serde(rename = "user")] + User { content: Vec }, + #[serde(rename = "assistant")] + Assistant { + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + }, + #[serde(rename = "tool")] + Tool { + tool_call_id: String, + content: String, + }, +} + +#[derive(Serialize)] +#[serde(tag = "type")] +enum OpenAiContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Serialize)] +struct ImageUrl { + url: String, +} + +#[derive(Serialize, Deserialize, Clone)] +struct OpenAiToolCall { + id: String, + #[serde(rename = "type")] + call_type: String, + function: OpenAiFunctionCall, +} + +#[derive(Serialize, Deserialize, Clone)] +struct OpenAiFunctionCall { + name: String, + arguments: String, +} + +#[derive(Serialize)] +struct OpenAiTool { + #[serde(rename = "type")] + tool_type: String, + function: OpenAiFunctionDef, +} + +#[derive(Serialize)] +struct OpenAiFunctionDef { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Deserialize, Debug)] +struct OpenAiStreamEvent { + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Deserialize, Debug)] +struct OpenAiChoice { + delta: OpenAiDelta, + finish_reason: Option, +} + +#[derive(Deserialize, Debug, Default)] +struct OpenAiDelta { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +#[derive(Deserialize, Debug)] +struct OpenAiToolCallDelta { + index: usize, + #[serde(default)] + id: Option, + #[serde(default)] + function: Option, +} + +#[derive(Deserialize, Debug)] +struct OpenAiFunctionDelta { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, +} + +#[derive(Deserialize, Debug)] +struct OpenAiUsage { + prompt_tokens: u64, + completion_tokens: u64, +} + +#[allow(dead_code)] +#[derive(Deserialize, Debug)] +struct OpenAiError { + error: OpenAiErrorDetail, +} + +#[allow(dead_code)] +#[derive(Deserialize, Debug)] +struct OpenAiErrorDetail { + message: String, +} + +fn convert_request( + model_id: &str, + request: &LlmCompletionRequest, +) -> Result { + let real_model_id = + get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?; + + let mut messages = Vec::new(); + + for msg in &request.messages { + match msg.role { + LlmMessageRole::System => { + let text: String = msg + .content + .iter() + .filter_map(|c| match c { + LlmMessageContent::Text(t) => Some(t.as_str()), + _ => None, + }) + .collect::>() + .join("\n"); + if !text.is_empty() { + messages.push(OpenAiMessage::System { content: text }); + } + } + LlmMessageRole::User => { + let parts: Vec = msg + .content + .iter() + .filter_map(|c| match c { + LlmMessageContent::Text(t) => { + Some(OpenAiContentPart::Text { text: t.clone() }) + } + LlmMessageContent::Image(img) => Some(OpenAiContentPart::ImageUrl { + image_url: ImageUrl { + url: format!("data:image/png;base64,{}", img.source), + }, + }), + LlmMessageContent::ToolResult(_) => None, + _ => None, + }) + .collect(); + + for content in &msg.content { + if let LlmMessageContent::ToolResult(result) = content { + let content_text = match &result.content { + LlmToolResultContent::Text(t) => t.clone(), + LlmToolResultContent::Image(_) => "[Image]".to_string(), + }; + messages.push(OpenAiMessage::Tool { + tool_call_id: result.tool_use_id.clone(), + content: content_text, + }); + } + } + + if !parts.is_empty() { + messages.push(OpenAiMessage::User { content: parts }); + } + } + LlmMessageRole::Assistant => { + let mut content_text: Option = None; + let mut tool_calls: Vec = Vec::new(); + + for c in &msg.content { + match c { + LlmMessageContent::Text(t) => { + content_text = Some(t.clone()); + } + LlmMessageContent::ToolUse(tool_use) => { + tool_calls.push(OpenAiToolCall { + id: tool_use.id.clone(), + call_type: "function".to_string(), + function: OpenAiFunctionCall { + name: tool_use.name.clone(), + arguments: tool_use.input.clone(), + }, + }); + } + _ => {} + } + } + + messages.push(OpenAiMessage::Assistant { + content: content_text, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + }); + } + } + } + + let tools: Option> = if request.tools.is_empty() { + None + } else { + Some( + request + .tools + .iter() + .map(|t| OpenAiTool { + tool_type: "function".to_string(), + function: OpenAiFunctionDef { + name: t.name.clone(), + description: t.description.clone(), + parameters: serde_json::from_str(&t.input_schema) + .unwrap_or(serde_json::Value::Object(Default::default())), + }, + }) + .collect(), + ) + }; + + let tool_choice = request.tool_choice.as_ref().map(|tc| match tc { + LlmToolChoice::Auto => "auto".to_string(), + LlmToolChoice::Any => "required".to_string(), + LlmToolChoice::None => "none".to_string(), + }); + + Ok(OpenAiRequest { + model: real_model_id.to_string(), + messages, + tools, + tool_choice, + temperature: request.temperature, + max_tokens: request.max_tokens, + stop: request.stop_sequences.clone(), + stream: true, + stream_options: Some(StreamOptions { + include_usage: true, + }), + }) +} + +fn parse_sse_line(line: &str) -> Option { + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + return None; + } + serde_json::from_str(data).ok() + } else { + None + } +} + +impl zed::Extension for OpenAiProvider { + fn new() -> Self { + Self { + streams: Mutex::new(HashMap::new()), + next_stream_id: Mutex::new(0), + } + } + + fn llm_providers(&self) -> Vec { + vec![LlmProviderInfo { + id: "openai".into(), + name: "OpenAI".into(), + icon: Some("icons/openai.svg".into()), + }] + } + + fn llm_provider_models(&self, _provider_id: &str) -> Result, String> { + Ok(MODELS + .iter() + .map(|m| LlmModelInfo { + id: m.display_name.to_string(), + name: m.display_name.to_string(), + max_token_count: m.max_tokens, + max_output_tokens: m.max_output_tokens, + capabilities: LlmModelCapabilities { + supports_images: m.supports_images, + supports_tools: true, + supports_tool_choice_auto: true, + supports_tool_choice_any: true, + supports_tool_choice_none: true, + supports_thinking: false, + tool_input_format: LlmToolInputFormat::JsonSchema, + }, + is_default: m.is_default, + is_default_fast: m.is_default_fast, + }) + .collect()) + } + + fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool { + llm_get_credential("openai").is_some() + } + + fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option { + Some( + r#"# OpenAI Setup + +Welcome to **OpenAI**! This extension provides access to OpenAI GPT models. + +## Configuration + +Enter your OpenAI API key below. You can find your API key at [platform.openai.com/api-keys](https://platform.openai.com/api-keys). + +## Available Models + +| Display Name | Real Model | Context | Output | +|--------------|------------|---------|--------| +| GPT-4o | gpt-4o | 128K | 16K | +| GPT-4o-mini | gpt-4o-mini | 128K | 16K | +| GPT-4.1 | gpt-4.1 | 1M | 32K | +| GPT-4.1-mini | gpt-4.1-mini | 1M | 32K | +| GPT-5 | gpt-5 | 272K | 32K | +| GPT-5-mini | gpt-5-mini | 272K | 32K | +| o1 | o1 | 200K | 100K | +| o3 | o3 | 200K | 100K | +| o3-mini | o3-mini | 200K | 100K | + +## Features + +- ✅ Full streaming support +- ✅ Tool/function calling +- ✅ Vision (image inputs) +- ✅ All OpenAI models + +## Pricing + +Uses your OpenAI API credits. See [OpenAI pricing](https://openai.com/pricing) for details. +"# + .to_string(), + ) + } + + fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> { + let provided = llm_request_credential( + "openai", + LlmCredentialType::ApiKey, + "OpenAI API Key", + "sk-...", + )?; + if provided { + Ok(()) + } else { + Err("Authentication cancelled".to_string()) + } + } + + fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> { + llm_delete_credential("openai") + } + + fn llm_stream_completion_start( + &mut self, + _provider_id: &str, + model_id: &str, + request: &LlmCompletionRequest, + ) -> Result { + let api_key = llm_get_credential("openai").ok_or_else(|| { + "No API key configured. Please add your OpenAI API key in settings.".to_string() + })?; + + let openai_request = convert_request(model_id, request)?; + + let body = serde_json::to_vec(&openai_request) + .map_err(|e| format!("Failed to serialize request: {}", e))?; + + let http_request = HttpRequest { + method: HttpMethod::Post, + url: "https://api.openai.com/v1/chat/completions".to_string(), + headers: vec![ + ("Content-Type".to_string(), "application/json".to_string()), + ("Authorization".to_string(), format!("Bearer {}", api_key)), + ], + body: Some(body), + redirect_policy: RedirectPolicy::FollowAll, + }; + + let response_stream = http_request + .fetch_stream() + .map_err(|e| format!("HTTP request failed: {}", e))?; + + let stream_id = { + let mut id_counter = self.next_stream_id.lock().unwrap(); + let id = format!("openai-stream-{}", *id_counter); + *id_counter += 1; + id + }; + + self.streams.lock().unwrap().insert( + stream_id.clone(), + StreamState { + response_stream: Some(response_stream), + buffer: String::new(), + started: false, + tool_calls: HashMap::new(), + tool_calls_emitted: false, + }, + ); + + Ok(stream_id) + } + + fn llm_stream_completion_next( + &mut self, + stream_id: &str, + ) -> Result, String> { + let mut streams = self.streams.lock().unwrap(); + let state = streams + .get_mut(stream_id) + .ok_or_else(|| format!("Unknown stream: {}", stream_id))?; + + if !state.started { + state.started = true; + return Ok(Some(LlmCompletionEvent::Started)); + } + + let response_stream = state + .response_stream + .as_mut() + .ok_or_else(|| "Stream already closed".to_string())?; + + loop { + if let Some(newline_pos) = state.buffer.find('\n') { + let line = state.buffer[..newline_pos].trim().to_string(); + state.buffer = state.buffer[newline_pos + 1..].to_string(); + + if line.is_empty() { + continue; + } + + if let Some(event) = parse_sse_line(&line) { + if let Some(choice) = event.choices.first() { + if let Some(tool_calls) = &choice.delta.tool_calls { + for tc in tool_calls { + let entry = state.tool_calls.entry(tc.index).or_default(); + + if let Some(id) = &tc.id { + entry.id = id.clone(); + } + + if let Some(func) = &tc.function { + if let Some(name) = &func.name { + entry.name = name.clone(); + } + if let Some(args) = &func.arguments { + entry.arguments.push_str(args); + } + } + } + } + + if let Some(reason) = &choice.finish_reason { + if reason == "tool_calls" && !state.tool_calls_emitted { + state.tool_calls_emitted = true; + if let Some((&index, _)) = state.tool_calls.iter().next() { + if let Some(tool_call) = state.tool_calls.remove(&index) { + return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { + id: tool_call.id, + name: tool_call.name, + input: tool_call.arguments, + thought_signature: None, + }))); + } + } + } + + let stop_reason = match reason.as_str() { + "stop" => LlmStopReason::EndTurn, + "length" => LlmStopReason::MaxTokens, + "tool_calls" => LlmStopReason::ToolUse, + "content_filter" => LlmStopReason::Refusal, + _ => LlmStopReason::EndTurn, + }; + + if let Some(usage) = event.usage { + return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }))); + } + + return Ok(Some(LlmCompletionEvent::Stop(stop_reason))); + } + + if let Some(content) = &choice.delta.content { + if !content.is_empty() { + return Ok(Some(LlmCompletionEvent::Text(content.clone()))); + } + } + } + + if event.choices.is_empty() { + if let Some(usage) = event.usage { + return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage { + input_tokens: usage.prompt_tokens, + output_tokens: usage.completion_tokens, + cache_creation_input_tokens: None, + cache_read_input_tokens: None, + }))); + } + } + } + + continue; + } + + match response_stream.next_chunk() { + Ok(Some(chunk)) => { + let text = String::from_utf8_lossy(&chunk); + state.buffer.push_str(&text); + } + Ok(None) => { + if !state.tool_calls.is_empty() && !state.tool_calls_emitted { + state.tool_calls_emitted = true; + let keys: Vec = state.tool_calls.keys().copied().collect(); + if let Some(&key) = keys.first() { + if let Some(tool_call) = state.tool_calls.remove(&key) { + return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse { + id: tool_call.id, + name: tool_call.name, + input: tool_call.arguments, + thought_signature: None, + }))); + } + } + } + return Ok(None); + } + Err(e) => { + return Err(format!("Stream error: {}", e)); + } + } + } + } + + fn llm_stream_completion_close(&mut self, stream_id: &str) { + self.streams.lock().unwrap().remove(stream_id); + } +} + +zed::register_extension!(OpenAiProvider);