From defc726bf724368e2f6445c77effab687a042d66 Mon Sep 17 00:00:00 2001 From: Finesssee <90105158+Finesssee@users.noreply.github.com> Date: Wed, 11 Mar 2026 03:14:12 +0700 Subject: [PATCH 1/2] Fix Ollama endpoint validation and model loading --- src-tauri/src/commands/ai/mod.rs | 2 + src-tauri/src/commands/ai/ollama.rs | 173 ++++++++++++++++++ src-tauri/src/main.rs | 1 + .../selectors/ai-model-selector.tsx | 37 ++-- .../settings/components/tabs/ai-settings.tsx | 26 ++- src/utils/ollama.test.ts | 50 +++++ src/utils/ollama.ts | 98 ++++++++++ src/utils/providers/ollama-provider.ts | 15 +- 8 files changed, 367 insertions(+), 35 deletions(-) create mode 100644 src-tauri/src/commands/ai/ollama.rs create mode 100644 src/utils/ollama.test.ts create mode 100644 src/utils/ollama.ts diff --git a/src-tauri/src/commands/ai/mod.rs b/src-tauri/src/commands/ai/mod.rs index ee0cac96..87b34fed 100644 --- a/src-tauri/src/commands/ai/mod.rs +++ b/src-tauri/src/commands/ai/mod.rs @@ -2,10 +2,12 @@ pub mod acp; pub mod auth; pub mod chat_history; pub mod claude; +pub mod ollama; pub mod tokens; pub use acp::*; pub use auth::*; pub use chat_history::*; pub use claude::*; +pub use ollama::*; pub use tokens::*; diff --git a/src-tauri/src/commands/ai/ollama.rs b/src-tauri/src/commands/ai/ollama.rs new file mode 100644 index 00000000..5c5ff644 --- /dev/null +++ b/src-tauri/src/commands/ai/ollama.rs @@ -0,0 +1,173 @@ +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tauri::command; +use url::Url; + +const DEFAULT_OLLAMA_BASE_URL: &str = "http://localhost:11434"; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct OllamaProbeResponse { + normalized_url: String, + models: Vec, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct OllamaModel { + id: String, + name: String, + max_tokens: usize, +} + +#[derive(Debug, Deserialize)] +struct OllamaTagsResponse { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct OllamaTagModel { + name: String, +} + +fn normalize_ollama_base_url(input: &str) -> String { + let trimmed = input.trim().trim_end_matches('/'); + if trimmed.is_empty() { + DEFAULT_OLLAMA_BASE_URL.to_string() + } else { + trimmed.to_string() + } +} + +fn validate_ollama_base_url(input: &str) -> Result<(), String> { + let parsed = Url::parse(input).map_err(|_| "Invalid Ollama URL".to_string())?; + + match parsed.scheme() { + "http" | "https" => {} + _ => return Err("Invalid Ollama URL".to_string()), + } + + if parsed.host_str().is_none() { + return Err("Invalid Ollama URL".to_string()); + } + + Ok(()) +} + +#[command] +pub async fn probe_ollama_endpoint(base_url: String) -> Result { + let normalized_url = normalize_ollama_base_url(&base_url); + validate_ollama_base_url(&normalized_url)?; + + let client = Client::builder() + .timeout(Duration::from_secs(3)) + .build() + .map_err(|error| format!("Failed to create Ollama client: {}", error))?; + + let tags_url = format!("{}/api/tags", normalized_url); + let response = client + .get(&tags_url) + .send() + .await + .map_err(|error| format!("Failed to connect to Ollama: {}", error))?; + + if !response.status().is_success() { + return Err(format!( + "Ollama endpoint returned HTTP {}", + response.status() + )); + } + + let payload = response + .json::() + .await + .map_err(|error| format!("Failed to read Ollama response: {}", error))?; + + let models = payload + .models + .into_iter() + .map(|model| OllamaModel { + id: model.name.clone(), + name: model.name, + max_tokens: 4096, + }) + .collect(); + + Ok(OllamaProbeResponse { + normalized_url, + models, + }) +} + +#[cfg(test)] +mod tests { + use std::{ + io::{Read, Write}, + net::TcpListener, + thread, + }; + + use super::{DEFAULT_OLLAMA_BASE_URL, normalize_ollama_base_url, validate_ollama_base_url}; + use crate::commands::ai::ollama::probe_ollama_endpoint; + + #[test] + fn normalizes_empty_base_url_to_default() { + assert_eq!(normalize_ollama_base_url(""), DEFAULT_OLLAMA_BASE_URL); + assert_eq!(normalize_ollama_base_url(" "), DEFAULT_OLLAMA_BASE_URL); + } + + #[test] + fn trims_trailing_slashes() { + assert_eq!( + normalize_ollama_base_url("http://localhost:11434///"), + DEFAULT_OLLAMA_BASE_URL + ); + } + + #[test] + fn accepts_http_and_https_urls() { + assert!(validate_ollama_base_url("http://localhost:11434").is_ok()); + assert!(validate_ollama_base_url("https://ollama.example.com/base").is_ok()); + } + + #[test] + fn rejects_invalid_urls() { + assert!(validate_ollama_base_url("localhost:11434").is_err()); + assert!(validate_ollama_base_url("ftp://localhost:11434").is_err()); + assert!(validate_ollama_base_url("http://").is_err()); + } + + #[tokio::test] + async fn probes_custom_port_and_returns_models() { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server"); + let address = listener.local_addr().expect("read local addr"); + + thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept request"); + let mut buffer = [0_u8; 1024]; + let _ = stream.read(&mut buffer); + + let body = r#"{"models":[{"name":"llama3.2"}]}"#; + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}", + body.len(), + body + ); + + stream + .write_all(response.as_bytes()) + .expect("write response"); + }); + + let result = probe_ollama_endpoint(format!("http://{}", address)) + .await + .expect("probe endpoint"); + + assert_eq!(result.normalized_url, format!("http://{}", address)); + assert_eq!(result.models.len(), 1); + assert_eq!(result.models[0].id, "llama3.2"); + assert_eq!(result.models[0].name, "llama3.2"); + } +} diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 6ef4d67c..be308f17 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -420,6 +420,7 @@ fn main() { store_ai_provider_token, get_ai_provider_token, remove_ai_provider_token, + probe_ollama_endpoint, // Auth token commands store_auth_token, get_auth_token, diff --git a/src/features/ai/components/selectors/ai-model-selector.tsx b/src/features/ai/components/selectors/ai-model-selector.tsx index 7114f10d..ba22f3a9 100644 --- a/src/features/ai/components/selectors/ai-model-selector.tsx +++ b/src/features/ai/components/selectors/ai-model-selector.tsx @@ -24,6 +24,11 @@ import { } from "@/features/ai/types/providers"; import { useSettingsStore } from "@/features/settings/store"; import { cn } from "@/utils/cn"; +import { + DEFAULT_OLLAMA_BASE_URL, + getOllamaProbeErrorMessage, + probeOllamaEndpoint, +} from "@/utils/ollama"; import { getProvider, setOllamaBaseUrl } from "@/utils/providers"; interface AIModelSelectorProps { @@ -82,11 +87,12 @@ export function AIModelSelector({ const { settings, updateSetting } = useSettingsStore(); const [ollamaUrlInput, setOllamaUrlInput] = useState( - settings.ollamaBaseUrl || "http://localhost:11434", + settings.ollamaBaseUrl || DEFAULT_OLLAMA_BASE_URL, ); const [ollamaUrlStatus, setOllamaUrlStatus] = useState<"idle" | "checking" | "ok" | "error">( "idle", ); + const [ollamaUrlMessage, setOllamaUrlMessage] = useState(null); const triggerRef = useRef(null); const dropdownRef = useRef(null); @@ -403,22 +409,19 @@ export function AIModelSelector({ }; const handleSaveOllamaUrl = async (url: string) => { - const trimmed = url.replace(/\/+$/, "") || "http://localhost:11434"; setOllamaUrlStatus("checking"); + setOllamaUrlMessage(null); try { - const response = await fetch(`${trimmed}/api/tags`, { signal: AbortSignal.timeout(3000) }); - if (response.ok) { - setOllamaUrlStatus("ok"); - updateSetting("ollamaBaseUrl", trimmed); - setOllamaBaseUrl(trimmed); - setOllamaUrlInput(trimmed); - fetchDynamicModels(); - setTimeout(() => cancelEditing(), 1000); - } else { - setOllamaUrlStatus("error"); - } - } catch { + const { normalizedUrl } = await probeOllamaEndpoint(url); + setOllamaUrlStatus("ok"); + updateSetting("ollamaBaseUrl", normalizedUrl); + setOllamaBaseUrl(normalizedUrl); + setOllamaUrlInput(normalizedUrl); + await fetchDynamicModels(); + setTimeout(() => cancelEditing(), 1000); + } catch (error) { setOllamaUrlStatus("error"); + setOllamaUrlMessage(getOllamaProbeErrorMessage(error)); } }; @@ -611,6 +614,7 @@ export function AIModelSelector({ onChange={(e) => { setOllamaUrlInput(e.target.value); setOllamaUrlStatus("idle"); + setOllamaUrlMessage(null); }} onKeyDown={(e) => { if (e.key === "Enter") { @@ -657,7 +661,10 @@ export function AIModelSelector({ {ollamaUrlStatus === "error" && (
- Could not connect to Ollama at this URL + + {ollamaUrlMessage || + "Could not connect to Ollama at this URL"} +
)} diff --git a/src/features/settings/components/tabs/ai-settings.tsx b/src/features/settings/components/tabs/ai-settings.tsx index 2add3b8d..eef9d13a 100644 --- a/src/features/settings/components/tabs/ai-settings.tsx +++ b/src/features/settings/components/tabs/ai-settings.tsx @@ -14,9 +14,13 @@ import Section, { SettingRow } from "@/ui/section"; import Switch from "@/ui/switch"; import { fetchAutocompleteModels } from "@/utils/autocomplete"; import { cn } from "@/utils/cn"; +import { + DEFAULT_OLLAMA_BASE_URL, + getOllamaProbeErrorMessage, + probeOllamaEndpoint, +} from "@/utils/ollama"; import { setOllamaBaseUrl } from "@/utils/providers"; -const DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434"; const DEFAULT_AUTOCOMPLETE_MODEL_ID = "mistralai/devstral-small"; const DEFAULT_AUTOCOMPLETE_MODELS = [ @@ -51,6 +55,7 @@ export const AISettings = () => { // Ollama URL state const [ollamaUrl, setOllamaUrl] = useState(settings.ollamaBaseUrl || DEFAULT_OLLAMA_BASE_URL); const [ollamaStatus, setOllamaStatus] = useState<"idle" | "checking" | "ok" | "error">("idle"); + const [ollamaStatusMessage, setOllamaStatusMessage] = useState(null); const ollamaDebounceRef = useRef>(undefined); useEffect(() => { @@ -81,19 +86,26 @@ export const AISettings = () => { const checkOllamaConnection = useCallback(async (url: string) => { setOllamaStatus("checking"); + setOllamaStatusMessage(null); try { - const response = await fetch(`${url}/api/tags`, { - signal: AbortSignal.timeout(3000), - }); - setOllamaStatus(response.ok ? "ok" : "error"); - } catch { + await probeOllamaEndpoint(url); + setOllamaStatus("ok"); + } catch (error) { setOllamaStatus("error"); + setOllamaStatusMessage(getOllamaProbeErrorMessage(error)); } }, []); + useEffect(() => { + return () => { + if (ollamaDebounceRef.current) clearTimeout(ollamaDebounceRef.current); + }; + }, []); + const handleOllamaUrlChange = (value: string) => { setOllamaUrl(value); setOllamaStatus("idle"); + setOllamaStatusMessage(null); if (ollamaDebounceRef.current) clearTimeout(ollamaDebounceRef.current); ollamaDebounceRef.current = setTimeout(() => { @@ -208,7 +220,7 @@ export const AISettings = () => { {ollamaStatus === "error" && (
- Could not connect. Check that Ollama is running at this address. + {ollamaStatusMessage || "Could not connect. Check the Ollama endpoint."}
)} diff --git a/src/utils/ollama.test.ts b/src/utils/ollama.test.ts new file mode 100644 index 00000000..5f7b4c21 --- /dev/null +++ b/src/utils/ollama.test.ts @@ -0,0 +1,50 @@ +import { describe, expect, it } from "bun:test"; +import { __test__, DEFAULT_OLLAMA_BASE_URL, getOllamaProbeErrorMessage } from "./ollama"; + +describe("ollama helpers", () => { + it("normalizes empty and trailing-slash URLs", () => { + expect(__test__.normalizeOllamaBaseUrl("")).toBe(DEFAULT_OLLAMA_BASE_URL); + expect(__test__.normalizeOllamaBaseUrl(" http://localhost:11434/// ")).toBe( + DEFAULT_OLLAMA_BASE_URL, + ); + expect(__test__.normalizeOllamaBaseUrl("https://ollama.example.com/base/")).toBe( + "https://ollama.example.com/base", + ); + }); + + it("parses a valid probe response", () => { + expect( + __test__.parseOllamaProbeResponse({ + normalizedUrl: "http://localhost:11434", + models: [{ id: "llama3.2", name: "llama3.2", maxTokens: 8192 }], + }), + ).toEqual({ + normalizedUrl: "http://localhost:11434", + models: [{ id: "llama3.2", name: "llama3.2", maxTokens: 8192 }], + }); + }); + + it("filters malformed models out of the probe response", () => { + expect( + __test__.parseOllamaProbeResponse({ + normalizedUrl: "http://localhost:11434", + models: [{ id: "ok", name: "ok" }, { id: 42 }, null], + }), + ).toEqual({ + normalizedUrl: "http://localhost:11434", + models: [{ id: "ok", name: "ok", maxTokens: 4096 }], + }); + }); + + it("maps low-level probe failures to UI-friendly messages", () => { + expect(getOllamaProbeErrorMessage(new Error("Invalid Ollama URL"))).toBe( + "Enter a valid http:// or https:// Ollama URL.", + ); + expect(getOllamaProbeErrorMessage("Ollama endpoint returned HTTP 404")).toBe( + "Ollama endpoint returned HTTP 404", + ); + expect(getOllamaProbeErrorMessage(new Error("socket hang up"))).toBe( + "Could not connect to Ollama at this URL", + ); + }); +}); diff --git a/src/utils/ollama.ts b/src/utils/ollama.ts new file mode 100644 index 00000000..9a9a77b3 --- /dev/null +++ b/src/utils/ollama.ts @@ -0,0 +1,98 @@ +import { invoke } from "@tauri-apps/api/core"; +import type { ProviderModel } from "@/utils/providers/provider-interface"; + +export const DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434"; + +interface OllamaProbeResponse { + normalizedUrl: string; + models: ProviderModel[]; +} + +function normalizeOllamaBaseUrl(baseUrl: string): string { + const normalized = baseUrl.trim().replace(/\/+$/, ""); + return normalized || DEFAULT_OLLAMA_BASE_URL; +} + +function parseProviderModel(value: unknown): ProviderModel | null { + if (!value || typeof value !== "object") return null; + + const candidate = value as { + id?: unknown; + name?: unknown; + maxTokens?: unknown; + }; + + if (typeof candidate.id !== "string" || typeof candidate.name !== "string") { + return null; + } + + return { + id: candidate.id, + name: candidate.name, + maxTokens: typeof candidate.maxTokens === "number" ? candidate.maxTokens : 4096, + }; +} + +function parseOllamaProbeResponse(value: unknown): OllamaProbeResponse | null { + if (!value || typeof value !== "object") return null; + + const candidate = value as { + normalizedUrl?: unknown; + models?: unknown; + }; + + if (typeof candidate.normalizedUrl !== "string" || !Array.isArray(candidate.models)) { + return null; + } + + return { + normalizedUrl: candidate.normalizedUrl, + models: candidate.models + .map((model) => parseProviderModel(model)) + .filter((model): model is ProviderModel => Boolean(model)), + }; +} + +export async function probeOllamaEndpoint(baseUrl: string): Promise { + const normalizedUrl = normalizeOllamaBaseUrl(baseUrl); + const response = await invoke("probe_ollama_endpoint", { + baseUrl: normalizedUrl, + }); + + const parsed = parseOllamaProbeResponse(response); + if (!parsed) { + throw new Error("Invalid Ollama probe response"); + } + + return parsed; +} + +export async function listOllamaModels(baseUrl: string): Promise { + const response = await probeOllamaEndpoint(baseUrl); + return response.models; +} + +export function getOllamaProbeErrorMessage(error: unknown): string { + const message = + typeof error === "string" + ? error + : error instanceof Error + ? error.message + : "Could not connect to Ollama at this URL"; + + if (message === "Invalid Ollama URL") { + return "Enter a valid http:// or https:// Ollama URL."; + } + + if (message.startsWith("Ollama endpoint returned HTTP ")) { + return message; + } + + return "Could not connect to Ollama at this URL"; +} + +export const __test__ = { + normalizeOllamaBaseUrl, + parseOllamaProbeResponse, + parseProviderModel, +}; diff --git a/src/utils/providers/ollama-provider.ts b/src/utils/providers/ollama-provider.ts index e6484ee9..10f26e09 100644 --- a/src/utils/providers/ollama-provider.ts +++ b/src/utils/providers/ollama-provider.ts @@ -1,8 +1,7 @@ +import { DEFAULT_OLLAMA_BASE_URL, listOllamaModels } from "@/utils/ollama"; import type { ProviderModel } from "./provider-interface"; import { AIProvider, type ProviderHeaders, type StreamRequest } from "./provider-interface"; -const DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434"; - export class OllamaProvider extends AIProvider { private baseUrl: string = DEFAULT_OLLAMA_BASE_URL; @@ -40,17 +39,7 @@ export class OllamaProvider extends AIProvider { async getModels(): Promise { try { - const response = await fetch(`${this.baseUrl}/api/tags`, { - signal: AbortSignal.timeout(3000), - }); - if (!response.ok) return []; - - const data = await response.json(); - return data.models.map((model: { name: string }) => ({ - id: model.name, - name: model.name, - maxTokens: 4096, - })); + return await listOllamaModels(this.baseUrl); } catch { return []; } From c11f4b11bc85b39dfbd77f6d05cf07e435c61b95 Mon Sep 17 00:00:00 2001 From: Finesssee <90105158+Finesssee@users.noreply.github.com> Date: Wed, 11 Mar 2026 03:24:48 +0700 Subject: [PATCH 2/2] Fix Rust formatting for Ollama probe tests --- src-tauri/src/commands/ai/ollama.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src-tauri/src/commands/ai/ollama.rs b/src-tauri/src/commands/ai/ollama.rs index 5c5ff644..504dcbf3 100644 --- a/src-tauri/src/commands/ai/ollama.rs +++ b/src-tauri/src/commands/ai/ollama.rs @@ -103,15 +103,14 @@ pub async fn probe_ollama_endpoint(base_url: String) -> Result