diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index ec0cd1a38a..199383e62a 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -12,7 +12,7 @@ use convert_case::{Case, Casing}; use forge_api::{ API, AgentId, AnyProvider, ApiKeyRequest, AuthContextRequest, AuthContextResponse, ChatRequest, ChatResponse, CodeRequest, Conversation, ConversationId, DeviceCodeRequest, Event, - InterruptionReason, Model, ModelId, Provider, ProviderId, TextMessage, UserPrompt, + InterruptionReason, ModelId, Provider, ProviderId, TextMessage, UserPrompt, }; use forge_app::utils::{format_display_path, truncate_key}; use forge_app::{CommitResult, ToolResolver}; @@ -127,14 +127,6 @@ impl A + Send + Sync> UI { self.spinner.ewrite_ln(title) } - /// Retrieve available models - async fn get_models(&mut self) -> Result> { - self.spinner.start(Some("Loading"))?; - let models = self.api.get_models().await?; - self.spinner.stop(None)?; - Ok(models) - } - /// Helper to get provider for an optional agent, defaulting to the current /// active agent's provider async fn get_provider(&self, agent_id: Option) -> Result> { @@ -2027,7 +2019,9 @@ impl A + Send + Sync> UI { ) -> Result> { // Check if provider is set otherwise first ask to select a provider if self.api.get_default_provider().await.is_err() { - self.on_provider_selection().await?; + if !self.on_provider_selection().await? { + return Ok(None); + } // Check if a model was already selected during provider activation // Return None to signal the model selection is complete and message was already @@ -2686,15 +2680,16 @@ impl A + Send + Sync> UI { Ok(Some(model)) } - async fn on_provider_selection(&mut self) -> Result<()> { + async fn on_provider_selection(&mut self) -> Result { // Select a provider // If no provider was selected (user canceled), return early let any_provider = match self.select_provider().await? { Some(provider) => provider, - None => return Ok(()), + None => return Ok(false), }; - self.activate_provider(any_provider).await + self.activate_provider(any_provider).await?; + Ok(true) } /// Activates a provider by configuring it if needed, setting it as default, @@ -2741,20 +2736,22 @@ impl A + Send + Sync> UI { provider: Provider, model: Option, ) -> Result<()> { - // Set the provider via API - self.api.set_default_provider(provider.id.clone()).await?; - - self.writeln_title( - TitleFormat::action(format!("{}", provider.id)) - .sub_title("is now the default provider"), - )?; - // If a model was pre-selected (e.g. from :model), validate and set it // directly without prompting if let Some(model) = model { let model_id = self .validate_model(model.as_str(), Some(&provider.id)) .await?; + + //set provider + self.api.set_default_provider(provider.id.clone()).await?; + + self.writeln_title( + TitleFormat::action(format!("{}", provider.id)) + .sub_title("is now the default provider"), + )?; + + //set model self.api.set_default_model(model_id.clone()).await?; self.writeln_title( TitleFormat::action(model_id.as_str()).sub_title("is now the default model"), @@ -2763,20 +2760,35 @@ impl A + Send + Sync> UI { } // Check if the current model is available for the new provider + let current_model = self.api.get_default_model().await; - if let Some(current_model) = current_model { - let models = self.get_models().await?; - let model_available = models.iter().any(|m| m.id == current_model); - if !model_available { - // Prompt user to select a new model, scoped to the activated provider - self.writeln_title(TitleFormat::info("Please select a new model"))?; - self.on_model_selection(Some(provider.id.clone())).await?; + let needs_model_selection = match current_model { + None => true, + Some(current_model) => { + let provider_models = self.api.get_all_provider_models().await?; + !provider_models + .iter() + .find(|pm| pm.provider_id == provider.id) + .map(|pm| pm.models.iter().any(|m| m.id == current_model)) + .unwrap_or(false) + } + }; + + if needs_model_selection { + self.writeln_title(TitleFormat::info("Please select a new model"))?; + let selected = self.on_model_selection(Some(provider.id.clone())).await?; + if selected.is_none() { + // User cancelled — preserve existing config untouched + return Ok(()); } - } else { - // No model set, select one now scoped to the activated provider - self.on_model_selection(Some(provider.id.clone())).await?; } + // Only reaches here if model is confirmed — safe to write provider now + self.api.set_default_provider(provider.id.clone()).await?; + self.writeln_title( + TitleFormat::action(format!("{}", provider.id)) + .sub_title("is now the default provider"), + )?; Ok(()) } @@ -2883,19 +2895,19 @@ impl A + Send + Sync> UI { // Ensure we have a model selected before proceeding with initialization let active_agent = self.api.get_active_agent().await; + // Validate provider is configured before loading agents + // If provider is set in config but not configured (no credentials), prompt user + // to login + if self.api.get_default_provider().await.is_err() && !self.on_provider_selection().await? { + return Ok(()); + } + let mut operating_model = self.get_agent_model(active_agent.clone()).await; if operating_model.is_none() { // Use the model returned from selection instead of re-fetching operating_model = self.on_model_selection(None).await?; } - // Validate provider is configured before loading agents - // If provider is set in config but not configured (no credentials), prompt user - // to login - if self.api.get_default_provider().await.is_err() { - self.on_provider_selection().await?; - } - if first { // For chat, we are trying to get active agent or setting it to default. // So for default values, `/info` doesn't show active provider, model, etc.