diff --git a/src-tauri/src/services/chat_service.rs b/src-tauri/src/services/chat_service.rs index 09b34b6..9c427d1 100644 --- a/src-tauri/src/services/chat_service.rs +++ b/src-tauri/src/services/chat_service.rs @@ -33,6 +33,42 @@ fn needs_visual_guide(content: &str) -> bool { VISUAL_KEYWORDS.iter().any(|kw| lower.contains(kw)) } +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "camelCase")] +pub struct TokenUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "camelCase")] +struct ChatUsageEvent { + session_id: String, + usage: TokenUsage, +} + +#[derive(Debug, Default)] +struct UsageAccumulator { + prompt_tokens: u32, + completion_tokens: u32, +} + +impl UsageAccumulator { + fn finish(self) -> Option { + let total = self.prompt_tokens + self.completion_tokens; + if total == 0 { + None + } else { + Some(TokenUsage { + prompt_tokens: self.prompt_tokens, + completion_tokens: self.completion_tokens, + total_tokens: total, + }) + } + } +} + pub async fn get_messages(db: &SqlitePool, session_id: &str) -> AppResult> { let messages = sqlx::query_as::<_, Message>( "SELECT id, session_id, role, content, created_at FROM messages \ @@ -280,7 +316,7 @@ async fn send_message_inner( if provider.uses_anthropic_format() { "anthropic" } else { "openai" }, ); - let assistant_output = if provider.uses_anthropic_format() { + let (assistant_output, token_usage) = if provider.uses_anthropic_format() { send_anthropic( history, model, @@ -292,10 +328,12 @@ async fn send_message_inner( ) .await? } else { + let supports_stream_usage = provider.provider_type == "openai"; send_openai_compatible( &provider.base_url, model, provider.api_key.as_deref(), + supports_stream_usage, history, &on_token, &cancel_token, @@ -303,6 +341,16 @@ async fn send_message_inner( .await? }; + if let Some(usage) = token_usage { + let _ = app_handle.emit( + "chat-usage", + ChatUsageEvent { + session_id: session_id.to_string(), + usage, + }, + ); + } + let assistant_message = Message { id: Uuid::new_v4().to_string(), session_id: session_id.to_string(), @@ -330,10 +378,11 @@ async fn send_openai_compatible( base_url: &str, model: &str, api_key: Option<&str>, + include_usage: bool, history: Vec, on_token: &Channel, cancel_token: &CancellationToken, -) -> AppResult { +) -> AppResult<(String, Option)> { let client = http_client::streaming_client()?; let endpoint = format!("{}/chat/completions", base_url.trim_end_matches('/')); @@ -342,7 +391,7 @@ async fn send_openai_compatible( .map(|m| serde_json::json!({ "role": m.role, "content": m.content })) .collect(); - let payload = serde_json::json!({ + let mut payload = serde_json::json!({ "model": model, "messages": messages, "temperature": 0.2, @@ -351,6 +400,10 @@ async fn send_openai_compatible( "stream": true, }); + if include_usage { + payload["stream_options"] = serde_json::json!({ "include_usage": true }); + } + // Lazy system prompt: only inject full preview guide when user asks for visuals let last_user_content = history.iter().rev().find(|m| m.role == "user").map(|m| m.content.as_str()).unwrap_or(""); let system_instructions = if needs_visual_guide(last_user_content) { @@ -412,7 +465,7 @@ async fn send_anthropic( base_url: &str, on_token: &Channel, cancel_token: &CancellationToken, -) -> AppResult { +) -> AppResult<(String, Option)> { let client = http_client::streaming_client()?; let (system_msgs, chat_msgs): (Vec<_>, Vec<_>) = @@ -531,7 +584,7 @@ async fn send_anthropic( return Err(AppError::Http(format!("Anthropic {status}: {body}"))); } - let output = stream_anthropic_sse(response, on_token, cancel_token).await?; + let (output, usage) = stream_anthropic_sse(response, on_token, cancel_token).await?; // Fallback: some gateways return message_start โ†’ message_stop without any // content_block events for certain models. Retry non-streaming. @@ -569,23 +622,24 @@ async fn send_anthropic( .and_then(Value::as_str) { let _ = on_token.send(text.to_string()); - return Ok(text.to_string()); + return Ok((text.to_string(), usage)); } - return Ok(String::new()); + return Ok((String::new(), usage)); } - Ok(output) + Ok((output, usage)) } async fn stream_openai_sse( response: reqwest::Response, on_token: &Channel, cancel_token: &CancellationToken, -) -> AppResult { +) -> AppResult<(String, Option)> { let mut stream = response.bytes_stream(); let mut line_buffer = String::new(); let mut output = String::new(); + let mut usage = UsageAccumulator::default(); loop { tokio::select! { @@ -604,8 +658,8 @@ async fn stream_openai_sse( line.pop(); } - if parse_openai_sse_line(&line, on_token, &mut output)? { - return Ok(output); + if parse_openai_sse_line(&line, on_token, &mut output, &mut usage)? { + return Ok((output, usage.finish())); } } } @@ -617,16 +671,17 @@ async fn stream_openai_sse( } if !line_buffer.is_empty() { - parse_openai_sse_line(&line_buffer, on_token, &mut output)?; + parse_openai_sse_line(&line_buffer, on_token, &mut output, &mut usage)?; } - Ok(output) + Ok((output, usage.finish())) } fn parse_openai_sse_line( line: &str, on_token: &Channel, output: &mut String, + usage: &mut UsageAccumulator, ) -> AppResult { let trimmed = line.trim(); if trimmed.is_empty() { @@ -642,6 +697,16 @@ fn parse_openai_sse_line( } let value: Value = serde_json::from_str(payload)?; + + if let Some(u) = value.get("usage") { + if let Some(pt) = u.get("prompt_tokens").and_then(Value::as_u64) { + usage.prompt_tokens = pt as u32; + } + if let Some(ct) = u.get("completion_tokens").and_then(Value::as_u64) { + usage.completion_tokens = ct as u32; + } + } + if let Some(token) = value .get("choices") .and_then(Value::as_array) @@ -661,7 +726,7 @@ async fn stream_anthropic_sse( response: reqwest::Response, on_token: &Channel, cancel_token: &CancellationToken, -) -> AppResult { +) -> AppResult<(String, Option)> { let mut stream = response.bytes_stream(); let mut line_buffer = String::new(); let mut output = String::new(); @@ -669,6 +734,7 @@ async fn stream_anthropic_sse( // subsequent `data:` line. Some gateways omit the `"type"` field from // the JSON payload, so we fall back to the SSE event name. let mut current_event = String::new(); + let mut usage = UsageAccumulator::default(); let mut message_stop_received = false; 'outer: loop { @@ -688,7 +754,13 @@ async fn stream_anthropic_sse( line.pop(); } - if parse_anthropic_sse_line(&line, &mut current_event, on_token, &mut output)? { + if parse_anthropic_sse_line( + &line, + &mut current_event, + on_token, + &mut output, + &mut usage, + )? { message_stop_received = true; break 'outer; } @@ -707,7 +779,7 @@ async fn stream_anthropic_sse( )); } - Ok(output) + Ok((output, usage.finish())) } /// Parse a single SSE line from an Anthropic-format stream. @@ -720,6 +792,7 @@ fn parse_anthropic_sse_line( current_event: &mut String, on_token: &Channel, output: &mut String, + usage: &mut UsageAccumulator, ) -> AppResult { let trimmed = line.trim(); if trimmed.is_empty() { @@ -752,6 +825,25 @@ fn parse_anthropic_sse_line( .unwrap_or(current_event.as_str()); match event_type { + "message_start" => { + if let Some(pt) = value + .get("message") + .and_then(|m| m.get("usage")) + .and_then(|u| u.get("input_tokens")) + .and_then(Value::as_u64) + { + usage.prompt_tokens = pt as u32; + } + } + "message_delta" => { + if let Some(ct) = value + .get("usage") + .and_then(|u| u.get("output_tokens")) + .and_then(Value::as_u64) + { + usage.completion_tokens = ct as u32; + } + } "content_block_delta" => { if let Some(token) = value .get("delta") diff --git a/src/components/layout/AppShell.tsx b/src/components/layout/AppShell.tsx index eedd79f..d5feb78 100644 --- a/src/components/layout/AppShell.tsx +++ b/src/components/layout/AppShell.tsx @@ -16,11 +16,11 @@ import { useUIStore } from '@/stores/useUIStore'; import { useAgentStore } from '@/stores/useAgentStore'; import { SettingsModal } from '@/components/settings/SettingsModal'; import { ExcalidrawCanvas } from '@/components/canvas/ExcalidrawCanvas'; -import { AgentConfig, AgentRunWithTools, AgentType, Message, PermissionRequest, Project, Provider, ProviderModelConfig, Session, ToolCall } from '@/types'; +import { AgentConfig, AgentRunWithTools, AgentType, ChatUsageEvent, Message, PermissionRequest, Project, Provider, ProviderModelConfig, Session, ToolCall } from '@/types'; import { cn } from '@/lib/utils'; export const AppShell: React.FC = () => { - const { addMessage, appendStreamToken, setStreaming, clearStreaming, setMessages } = useChatStore(); + const { addMessage, appendStreamToken, setStreaming, clearStreaming, setMessages, addTokenUsage } = useChatStore(); const setProjects = useProjectStore((s) => s.setProjects); const addProject = useProjectStore((s) => s.addProject); const setActiveProjectId = useProjectStore((s) => s.setActiveProjectId); @@ -182,6 +182,10 @@ export const AppShell: React.FC = () => { clearStreaming(); }); + const unlistenChatUsage = await listen('chat-usage', (event) => { + addTokenUsage(event.payload.sessionId, event.payload.usage); + }); + const unlistenAgentStarted = await listen<{ agentRunId: string; agentType: string; @@ -326,6 +330,7 @@ export const AppShell: React.FC = () => { localUnlisten.push( unlistenChatDone, unlistenChatError, + unlistenChatUsage, unlistenAgentStarted, unlistenAgentToken, unlistenAgentToolCall, @@ -348,6 +353,7 @@ export const AppShell: React.FC = () => { }; }, [ clearStreaming, + addTokenUsage, addAgentRun, appendAgentToken, flushThinkingBlock, diff --git a/src/components/layout/ChatHeader.tsx b/src/components/layout/ChatHeader.tsx index df8094f..5d4cbf1 100644 --- a/src/components/layout/ChatHeader.tsx +++ b/src/components/layout/ChatHeader.tsx @@ -25,9 +25,14 @@ export const ChatHeader: React.FC = ({ onToggleLeftSidebar }) = const mainView = useUIStore((s) => s.mainView); const setMainView = useUIStore((s) => s.setMainView); const { activeProjectId } = useProjectStore(); + const sessionUsage = useChatStore((s) => s.sessionUsage); const { theme, toggleTheme } = useUIStore(); const activeSession = sessions.find(s => s.id === activeSessionId); + const currentUsage = activeSessionId ? sessionUsage[activeSessionId] : undefined; + + const formatTokens = (n: number) => + n >= 1000 ? `${(n / 1000).toFixed(1)}k` : String(n); const [isRenaming, setIsRenaming] = useState(false); const [renameValue, setRenameValue] = useState(''); @@ -165,6 +170,14 @@ export const ChatHeader: React.FC = ({ onToggleLeftSidebar }) = )} + {currentUsage && ( + + {formatTokens(currentUsage.totalTokens)} tokens + + )} )} diff --git a/src/stores/useChatStore.ts b/src/stores/useChatStore.ts index 2256568..0d29500 100644 --- a/src/stores/useChatStore.ts +++ b/src/stores/useChatStore.ts @@ -1,24 +1,47 @@ import { create } from 'zustand'; -import { Message } from '@/types'; +import { Message, TokenUsage } from '@/types'; interface ChatState { messages: Message[]; streamingText: string; isStreaming: boolean; + sessionUsage: Record; setMessages: (messages: Message[]) => void; addMessage: (message: Message) => void; appendStreamToken: (token: string) => void; setStreaming: (isStreaming: boolean) => void; clearStreaming: () => void; + addTokenUsage: (sessionId: string, usage: TokenUsage) => void; + clearSessionUsage: (sessionId: string) => void; } export const useChatStore = create((set) => ({ messages: [], streamingText: '', isStreaming: false, + sessionUsage: {}, setMessages: (messages) => set({ messages }), addMessage: (message) => set((state) => ({ messages: [...state.messages, message] })), appendStreamToken: (token) => set((state) => ({ streamingText: state.streamingText + token })), setStreaming: (isStreaming) => set({ isStreaming }), clearStreaming: () => set({ streamingText: '', isStreaming: false }), + addTokenUsage: (sessionId, usage) => + set((state) => { + const prev = state.sessionUsage[sessionId]; + return { + sessionUsage: { + ...state.sessionUsage, + [sessionId]: { + promptTokens: (prev?.promptTokens ?? 0) + usage.promptTokens, + completionTokens: (prev?.completionTokens ?? 0) + usage.completionTokens, + totalTokens: (prev?.totalTokens ?? 0) + usage.totalTokens, + }, + }, + }; + }), + clearSessionUsage: (sessionId) => + set((state) => { + const { [sessionId]: _removed, ...rest } = state.sessionUsage; + return { sessionUsage: rest }; + }), })); diff --git a/src/types/index.ts b/src/types/index.ts index 27fcad2..498c51f 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -126,6 +126,17 @@ export interface AgentRunWithTools extends AgentRun { projectPath: string | null; } +export interface TokenUsage { + promptTokens: number; + completionTokens: number; + totalTokens: number; +} + +export interface ChatUsageEvent { + sessionId: string; + usage: TokenUsage; +} + export interface PermissionRequest { type: 'sensitive_file' | 'outside_sandbox' | 'shell_command'; path: string;