Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src-tauri/src/commands/ai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ pub mod acp;
pub mod auth;
pub mod chat_history;
pub mod claude;
pub mod ollama;
pub mod tokens;

pub use acp::*;
pub use auth::*;
pub use chat_history::*;
pub use claude::*;
pub use ollama::*;
pub use tokens::*;
173 changes: 173 additions & 0 deletions src-tauri/src/commands/ai/ollama.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tauri::command;
use url::Url;

const DEFAULT_OLLAMA_BASE_URL: &str = "http://localhost:11434";

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct OllamaProbeResponse {
normalized_url: String,
models: Vec<OllamaModel>,
}

#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct OllamaModel {
id: String,
name: String,
max_tokens: usize,
}

#[derive(Debug, Deserialize)]
struct OllamaTagsResponse {
#[serde(default)]
models: Vec<OllamaTagModel>,
}

#[derive(Debug, Deserialize)]
struct OllamaTagModel {
name: String,
}

fn normalize_ollama_base_url(input: &str) -> String {
let trimmed = input.trim().trim_end_matches('/');
if trimmed.is_empty() {
DEFAULT_OLLAMA_BASE_URL.to_string()
} else {
trimmed.to_string()
}
}

fn validate_ollama_base_url(input: &str) -> Result<(), String> {
let parsed = Url::parse(input).map_err(|_| "Invalid Ollama URL".to_string())?;

match parsed.scheme() {
"http" | "https" => {}
_ => return Err("Invalid Ollama URL".to_string()),
}

if parsed.host_str().is_none() {
return Err("Invalid Ollama URL".to_string());
}

Ok(())
}

#[command]
pub async fn probe_ollama_endpoint(base_url: String) -> Result<OllamaProbeResponse, String> {
let normalized_url = normalize_ollama_base_url(&base_url);
validate_ollama_base_url(&normalized_url)?;

let client = Client::builder()
.timeout(Duration::from_secs(3))
.build()
.map_err(|error| format!("Failed to create Ollama client: {}", error))?;

let tags_url = format!("{}/api/tags", normalized_url);
let response = client
.get(&tags_url)
.send()
.await
.map_err(|error| format!("Failed to connect to Ollama: {}", error))?;

if !response.status().is_success() {
return Err(format!(
"Ollama endpoint returned HTTP {}",
response.status()
));
}

let payload = response
.json::<OllamaTagsResponse>()
.await
.map_err(|error| format!("Failed to read Ollama response: {}", error))?;

let models = payload
.models
.into_iter()
.map(|model| OllamaModel {
id: model.name.clone(),
name: model.name,
max_tokens: 4096,
})
.collect();

Ok(OllamaProbeResponse {
normalized_url,
models,
})
}

#[cfg(test)]
mod tests {
use super::{DEFAULT_OLLAMA_BASE_URL, normalize_ollama_base_url, validate_ollama_base_url};
use crate::commands::ai::ollama::probe_ollama_endpoint;
use std::{
io::{Read, Write},
net::TcpListener,
thread,
};

#[test]
fn normalizes_empty_base_url_to_default() {
assert_eq!(normalize_ollama_base_url(""), DEFAULT_OLLAMA_BASE_URL);
assert_eq!(normalize_ollama_base_url(" "), DEFAULT_OLLAMA_BASE_URL);
}

#[test]
fn trims_trailing_slashes() {
assert_eq!(
normalize_ollama_base_url("http://localhost:11434///"),
DEFAULT_OLLAMA_BASE_URL
);
}

#[test]
fn accepts_http_and_https_urls() {
assert!(validate_ollama_base_url("http://localhost:11434").is_ok());
assert!(validate_ollama_base_url("https://ollama.example.com/base").is_ok());
}

#[test]
fn rejects_invalid_urls() {
assert!(validate_ollama_base_url("localhost:11434").is_err());
assert!(validate_ollama_base_url("ftp://localhost:11434").is_err());
assert!(validate_ollama_base_url("http://").is_err());
}

#[tokio::test]
async fn probes_custom_port_and_returns_models() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
let address = listener.local_addr().expect("read local addr");

thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept request");
let mut buffer = [0_u8; 1024];
let _ = stream.read(&mut buffer);

let body = r#"{"models":[{"name":"llama3.2"}]}"#;
let response = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: \
{}\r\nconnection: close\r\n\r\n{}",
body.len(),
body
);

stream
.write_all(response.as_bytes())
.expect("write response");
});

let result = probe_ollama_endpoint(format!("http://{}", address))
.await
.expect("probe endpoint");

assert_eq!(result.normalized_url, format!("http://{}", address));
assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].id, "llama3.2");
assert_eq!(result.models[0].name, "llama3.2");
}
}
1 change: 1 addition & 0 deletions src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ fn main() {
store_ai_provider_token,
get_ai_provider_token,
remove_ai_provider_token,
probe_ollama_endpoint,
// Auth token commands
store_auth_token,
get_auth_token,
Expand Down
37 changes: 22 additions & 15 deletions src/features/ai/components/selectors/ai-model-selector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ import {
} from "@/features/ai/types/providers";
import { useSettingsStore } from "@/features/settings/store";
import { cn } from "@/utils/cn";
import {
DEFAULT_OLLAMA_BASE_URL,
getOllamaProbeErrorMessage,
probeOllamaEndpoint,
} from "@/utils/ollama";
import { getProvider, setOllamaBaseUrl } from "@/utils/providers";

interface AIModelSelectorProps {
Expand Down Expand Up @@ -82,11 +87,12 @@ export function AIModelSelector({
const { settings, updateSetting } = useSettingsStore();

const [ollamaUrlInput, setOllamaUrlInput] = useState(
settings.ollamaBaseUrl || "http://localhost:11434",
settings.ollamaBaseUrl || DEFAULT_OLLAMA_BASE_URL,
);
const [ollamaUrlStatus, setOllamaUrlStatus] = useState<"idle" | "checking" | "ok" | "error">(
"idle",
);
const [ollamaUrlMessage, setOllamaUrlMessage] = useState<string | null>(null);

const triggerRef = useRef<HTMLButtonElement>(null);
const dropdownRef = useRef<HTMLDivElement>(null);
Expand Down Expand Up @@ -403,22 +409,19 @@ export function AIModelSelector({
};

const handleSaveOllamaUrl = async (url: string) => {
const trimmed = url.replace(/\/+$/, "") || "http://localhost:11434";
setOllamaUrlStatus("checking");
setOllamaUrlMessage(null);
try {
const response = await fetch(`${trimmed}/api/tags`, { signal: AbortSignal.timeout(3000) });
if (response.ok) {
setOllamaUrlStatus("ok");
updateSetting("ollamaBaseUrl", trimmed);
setOllamaBaseUrl(trimmed);
setOllamaUrlInput(trimmed);
fetchDynamicModels();
setTimeout(() => cancelEditing(), 1000);
} else {
setOllamaUrlStatus("error");
}
} catch {
const { normalizedUrl } = await probeOllamaEndpoint(url);
setOllamaUrlStatus("ok");
updateSetting("ollamaBaseUrl", normalizedUrl);
setOllamaBaseUrl(normalizedUrl);
setOllamaUrlInput(normalizedUrl);
await fetchDynamicModels();
setTimeout(() => cancelEditing(), 1000);
} catch (error) {
setOllamaUrlStatus("error");
setOllamaUrlMessage(getOllamaProbeErrorMessage(error));
}
};

Expand Down Expand Up @@ -611,6 +614,7 @@ export function AIModelSelector({
onChange={(e) => {
setOllamaUrlInput(e.target.value);
setOllamaUrlStatus("idle");
setOllamaUrlMessage(null);
}}
onKeyDown={(e) => {
if (e.key === "Enter") {
Expand Down Expand Up @@ -657,7 +661,10 @@ export function AIModelSelector({
</div>
{ollamaUrlStatus === "error" && (
<div className="flex items-center gap-1 px-3 pb-2 text-[10px] text-red-400">
<span>Could not connect to Ollama at this URL</span>
<span>
{ollamaUrlMessage ||
"Could not connect to Ollama at this URL"}
</span>
</div>
)}
</>
Expand Down
26 changes: 19 additions & 7 deletions src/features/settings/components/tabs/ai-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ import Section, { SettingRow } from "@/ui/section";
import Switch from "@/ui/switch";
import { fetchAutocompleteModels } from "@/utils/autocomplete";
import { cn } from "@/utils/cn";
import {
DEFAULT_OLLAMA_BASE_URL,
getOllamaProbeErrorMessage,
probeOllamaEndpoint,
} from "@/utils/ollama";
import { setOllamaBaseUrl } from "@/utils/providers";

const DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434";
const DEFAULT_AUTOCOMPLETE_MODEL_ID = "mistralai/devstral-small";

const DEFAULT_AUTOCOMPLETE_MODELS = [
Expand Down Expand Up @@ -51,6 +55,7 @@ export const AISettings = () => {
// Ollama URL state
const [ollamaUrl, setOllamaUrl] = useState(settings.ollamaBaseUrl || DEFAULT_OLLAMA_BASE_URL);
const [ollamaStatus, setOllamaStatus] = useState<"idle" | "checking" | "ok" | "error">("idle");
const [ollamaStatusMessage, setOllamaStatusMessage] = useState<string | null>(null);
const ollamaDebounceRef = useRef<ReturnType<typeof setTimeout>>(undefined);

useEffect(() => {
Expand Down Expand Up @@ -81,19 +86,26 @@ export const AISettings = () => {

const checkOllamaConnection = useCallback(async (url: string) => {
setOllamaStatus("checking");
setOllamaStatusMessage(null);
try {
const response = await fetch(`${url}/api/tags`, {
signal: AbortSignal.timeout(3000),
});
setOllamaStatus(response.ok ? "ok" : "error");
} catch {
await probeOllamaEndpoint(url);
setOllamaStatus("ok");
} catch (error) {
setOllamaStatus("error");
setOllamaStatusMessage(getOllamaProbeErrorMessage(error));
}
}, []);

useEffect(() => {
return () => {
if (ollamaDebounceRef.current) clearTimeout(ollamaDebounceRef.current);
};
}, []);

const handleOllamaUrlChange = (value: string) => {
setOllamaUrl(value);
setOllamaStatus("idle");
setOllamaStatusMessage(null);

if (ollamaDebounceRef.current) clearTimeout(ollamaDebounceRef.current);
ollamaDebounceRef.current = setTimeout(() => {
Expand Down Expand Up @@ -208,7 +220,7 @@ export const AISettings = () => {
{ollamaStatus === "error" && (
<div className="flex items-center gap-1.5 px-1 text-red-400 text-xs">
<AlertCircle size={11} className="shrink-0" />
<span>Could not connect. Check that Ollama is running at this address.</span>
<span>{ollamaStatusMessage || "Could not connect. Check the Ollama endpoint."}</span>
</div>
)}
</Section>
Expand Down
50 changes: 50 additions & 0 deletions src/utils/ollama.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { describe, expect, it } from "bun:test";
import { __test__, DEFAULT_OLLAMA_BASE_URL, getOllamaProbeErrorMessage } from "./ollama";

describe("ollama helpers", () => {
it("normalizes empty and trailing-slash URLs", () => {
expect(__test__.normalizeOllamaBaseUrl("")).toBe(DEFAULT_OLLAMA_BASE_URL);
expect(__test__.normalizeOllamaBaseUrl(" http://localhost:11434/// ")).toBe(
DEFAULT_OLLAMA_BASE_URL,
);
expect(__test__.normalizeOllamaBaseUrl("https://ollama.example.com/base/")).toBe(
"https://ollama.example.com/base",
);
});

it("parses a valid probe response", () => {
expect(
__test__.parseOllamaProbeResponse({
normalizedUrl: "http://localhost:11434",
models: [{ id: "llama3.2", name: "llama3.2", maxTokens: 8192 }],
}),
).toEqual({
normalizedUrl: "http://localhost:11434",
models: [{ id: "llama3.2", name: "llama3.2", maxTokens: 8192 }],
});
});

it("filters malformed models out of the probe response", () => {
expect(
__test__.parseOllamaProbeResponse({
normalizedUrl: "http://localhost:11434",
models: [{ id: "ok", name: "ok" }, { id: 42 }, null],
}),
).toEqual({
normalizedUrl: "http://localhost:11434",
models: [{ id: "ok", name: "ok", maxTokens: 4096 }],
});
});

it("maps low-level probe failures to UI-friendly messages", () => {
expect(getOllamaProbeErrorMessage(new Error("Invalid Ollama URL"))).toBe(
"Enter a valid http:// or https:// Ollama URL.",
);
expect(getOllamaProbeErrorMessage("Ollama endpoint returned HTTP 404")).toBe(
"Ollama endpoint returned HTTP 404",
);
expect(getOllamaProbeErrorMessage(new Error("socket hang up"))).toBe(
"Could not connect to Ollama at this URL",
);
});
});
Loading