diff --git a/docs/README.md b/docs/README.md index 21e3b0c22..719cf70b3 100644 --- a/docs/README.md +++ b/docs/README.md @@ -178,34 +178,51 @@ paths, networking flags, etc.) — only the visual layout. ## Developer notes: Using a custom thv binary (dev only) -During development, you can test the UI with a custom `thv` binary by running it -manually: +The studio talks to its managed `thv` over a UNIX domain socket on macOS/Linux +and a Windows named pipe on Windows. To test the UI with a custom `thv` binary, +run it manually with the same `--socket` flag the studio uses internally and +point the studio at it via `THV_SOCKET`: 1. Start your custom `thv` binary with the serve command: + **macOS / Linux** + ```bash thv serve \ --openapi \ - --host=127.0.0.1 --port=50000 \ + --socket=/tmp/thv-dev.sock \ --experimental-mcp \ --experimental-mcp-host=127.0.0.1 \ --experimental-mcp-port=50001 ``` -2. Set the `THV_PORT` and `THV_MCP_PORT` environment variables and start the dev - server. + **Windows (PowerShell)** + + ```powershell + thv.exe serve ` + --openapi ` + --socket='\\.\pipe\thv-dev' ` + --experimental-mcp ` + --experimental-mcp-host=127.0.0.1 ` + --experimental-mcp-port=50001 + ``` + +2. Set `THV_SOCKET` (and `THV_MCP_PORT` if you also need the experimental MCP + backend) and start the dev server: ```bash - THV_PORT=50000 THV_MCP_PORT=50001 pnpm start + THV_SOCKET=/tmp/thv-dev.sock THV_MCP_PORT=50001 pnpm start ``` -The UI displays a banner with the HTTP address when using a custom port. This -works in development mode only; packaged builds use the embedded binary. + On Windows: + + ```powershell + $env:THV_SOCKET = '\\.\pipe\thv-dev'; $env:THV_MCP_PORT = '50001'; pnpm start + ``` -> Note on MCP Optimizer If you plan to use the MCP Optimizer with an external -> `thv`, ensure `THV_PORT` is within the range `50000-50100`. The app starts its -> embedded server in this range, and the optimizer expects the ToolHive API to -> be reachable there. +The UI displays a banner with the socket / pipe path when `THV_SOCKET` is set. +This works in development mode only; packaged builds use the embedded binary and +an auto-generated per-process socket path. ## Code signing diff --git a/e2e-tests/helpers/app-relaunch.ts b/e2e-tests/helpers/app-relaunch.ts index 23abe0426..142ef60eb 100644 --- a/e2e-tests/helpers/app-relaunch.ts +++ b/e2e-tests/helpers/app-relaunch.ts @@ -1,3 +1,4 @@ +import http from 'node:http' import path from 'path' import { _electron as electron, @@ -29,7 +30,7 @@ function getExecutablePath(): string { export interface LaunchedApp { app: ElectronApplication window: Page - baseUrl: string + socketPath: string /** * Terminate the app without waiting on the renderer's before-quit teardown. * @@ -74,18 +75,20 @@ export async function launchApp(userDataDir: string): Promise { await window.getByRole('link', { name: /mcp servers/i }).waitFor() - const port = await window.evaluate(async () => { + const socketPath = await window.evaluate(async () => { // eslint-disable-next-line @typescript-eslint/no-explicit-any - return (await (globalThis as any).electronAPI.getToolhivePort()) as number + return (await (globalThis as any).electronAPI.getToolhiveSocketPath()) as + | string + | undefined }) - if (!port) { - throw new Error('Failed to resolve ToolHive port from the launched app') + if (!socketPath) { + throw new Error( + 'Failed to resolve ToolHive socket path from the launched app' + ) } - const baseUrl = `http://127.0.0.1:${port}` - - await waitForThvReady(baseUrl) + await waitForThvReady(socketPath) const close = async () => { // Force an immediate exit via Electron's app.exit(), bypassing before-quit @@ -113,35 +116,71 @@ export async function launchApp(userDataDir: string): Promise { } } - return { app, window, baseUrl, close } + return { app, window, socketPath, close } +} + +/** + * Performs an HTTP request against the thv server over its UNIX socket. Used + * by e2e helpers to seed/inspect state out-of-band from the app UI, mirroring + * the transport the production renderer uses (via the IPC bridge). + */ +function socketRequest( + socketPath: string, + apiPath: string, + init?: { method?: string; headers?: Record; body?: string } +): Promise<{ status: number; text: string }> { + return new Promise((resolve, reject) => { + const req = http.request( + { + socketPath, + path: apiPath, + method: init?.method ?? 'GET', + headers: { + 'content-type': 'application/json', + ...(init?.headers ?? {}), + }, + }, + (res) => { + const chunks: Buffer[] = [] + res.on('data', (chunk: Buffer) => chunks.push(chunk)) + res.on('end', () => + resolve({ + status: res.statusCode ?? 500, + text: Buffer.concat(chunks).toString('utf-8'), + }) + ) + res.on('error', reject) + } + ) + req.on('error', reject) + if (init?.body) req.write(init.body) + req.end() + }) } /** - * Thin wrapper around `fetch` that raises on non-2xx/4xx responses the caller - * wants to treat as failures, optionally returning parsed JSON. + * Thin wrapper around the thv UNIX socket transport that raises on unexpected + * statuses and parses JSON when present. */ export async function thvFetch( - baseUrl: string, + socketPath: string, apiPath: string, - init?: RequestInit & { expectStatus?: number[] } + init?: { + method?: string + headers?: Record + body?: string + expectStatus?: number[] + } ): Promise<{ status: number; json: T | null }> { const { expectStatus, ...rest } = init ?? {} - const res = await fetch(`${baseUrl}${apiPath}`, { - ...rest, - headers: { - 'content-type': 'application/json', - ...(rest.headers ?? {}), - }, - }) + const { status, text } = await socketRequest(socketPath, apiPath, rest) - if (expectStatus && !expectStatus.includes(res.status)) { - const body = await res.text() + if (expectStatus && !expectStatus.includes(status)) { throw new Error( - `thvFetch ${apiPath} expected status in [${expectStatus.join(',')}], got ${res.status}: ${body}` + `thvFetch ${apiPath} expected status in [${expectStatus.join(',')}], got ${status}: ${text}` ) } - const text = await res.text() let json: T | null = null if (text) { try { @@ -150,24 +189,24 @@ export async function thvFetch( json = null } } - return { status: res.status, json } + return { status, json } } async function waitForThvReady( - baseUrl: string, + socketPath: string, { timeoutMs = 30_000 } = {} ): Promise { const deadline = Date.now() + timeoutMs while (Date.now() < deadline) { try { - const res = await fetch(`${baseUrl}/api/v1beta/groups`) - if (res.ok) return + const { status } = await socketRequest(socketPath, '/api/v1beta/groups') + if (status >= 200 && status < 300) return } catch { // keep polling } await new Promise((resolve) => setTimeout(resolve, 250)) } throw new Error( - `ToolHive API at ${baseUrl} did not become ready within ${timeoutMs}ms` + `ToolHive API at socket ${socketPath} did not become ready within ${timeoutMs}ms` ) } diff --git a/e2e-tests/mcp-optimizer-startup-cleanup.spec.ts b/e2e-tests/mcp-optimizer-startup-cleanup.spec.ts index 796336429..7bf7504ee 100644 --- a/e2e-tests/mcp-optimizer-startup-cleanup.spec.ts +++ b/e2e-tests/mcp-optimizer-startup-cleanup.spec.ts @@ -57,16 +57,16 @@ async function createGroupViaUi( } async function seedOptimizerState( - baseUrl: string, + socketPath: string, testServer: TestMcpServer ): Promise { - await thvFetch(baseUrl, '/api/v1beta/groups', { + await thvFetch(socketPath, '/api/v1beta/groups', { method: 'POST', body: JSON.stringify({ name: OPTIMIZER_GROUP }), expectStatus: [200, 201], }) - await thvFetch(baseUrl, '/api/v1beta/clients/register', { + await thvFetch(socketPath, '/api/v1beta/clients/register', { method: 'POST', body: JSON.stringify({ names: [TEST_CLIENT], @@ -78,7 +78,7 @@ async function seedOptimizerState( // Create a remote meta-mcp workload so GET /workloads/meta-mcp later returns // the ALLOWED_GROUPS env var that drives the restoration path. A remote // workload avoids any Docker image pull complications. - await thvFetch(baseUrl, '/api/v1beta/workloads', { + await thvFetch(socketPath, '/api/v1beta/workloads', { method: 'POST', body: JSON.stringify({ name: META_MCP_SERVER, @@ -93,13 +93,13 @@ async function seedOptimizerState( }) } -async function waitForOptimizerCleanup(baseUrl: string): Promise { +async function waitForOptimizerCleanup(socketPath: string): Promise { await expect .poll( async () => { const { json } = await thvFetch<{ groups?: Array<{ name?: string; registered_clients?: string[] }> - }>(baseUrl, '/api/v1beta/groups', { expectStatus: [200] }) + }>(socketPath, '/api/v1beta/groups', { expectStatus: [200] }) const groups = json?.groups ?? [] const optimizerGroup = groups.find((g) => g.name === OPTIMIZER_GROUP) const customGroup = groups.find((g) => g.name === CUSTOM_GROUP) @@ -147,12 +147,12 @@ test.describe('MCP Optimizer startup cleanup', () => { const firstLaunch = await launchApp(userDataDir) try { await createGroupViaUi(firstLaunch, CUSTOM_GROUP) - await seedOptimizerState(firstLaunch.baseUrl, testServer) + await seedOptimizerState(firstLaunch.socketPath, testServer) // Sanity: both groups exist and optimizer has the registered client. const { json: seeded } = await thvFetch<{ groups?: Array<{ name?: string; registered_clients?: string[] }> - }>(firstLaunch.baseUrl, '/api/v1beta/groups', { expectStatus: [200] }) + }>(firstLaunch.socketPath, '/api/v1beta/groups', { expectStatus: [200] }) const seededOptimizer = seeded?.groups?.find( (g) => g.name === OPTIMIZER_GROUP ) @@ -166,11 +166,11 @@ test.describe('MCP Optimizer startup cleanup', () => { // startup cleanup hook, which restores clients and deletes the group. const secondLaunch = await launchApp(userDataDir) try { - await waitForOptimizerCleanup(secondLaunch.baseUrl) + await waitForOptimizerCleanup(secondLaunch.socketPath) // The meta-mcp workload is deleted as part of ?with-workloads=true. const { status: workloadStatus } = await thvFetch( - secondLaunch.baseUrl, + secondLaunch.socketPath, `/api/v1beta/workloads/${META_MCP_SERVER}` ) expect(workloadStatus).toBe(404) @@ -178,7 +178,9 @@ test.describe('MCP Optimizer startup cleanup', () => { // The user's custom group is preserved. const { json: finalGroups } = await thvFetch<{ groups?: Array<{ name?: string }> - }>(secondLaunch.baseUrl, '/api/v1beta/groups', { expectStatus: [200] }) + }>(secondLaunch.socketPath, '/api/v1beta/groups', { + expectStatus: [200], + }) expect(finalGroups?.groups?.some((g) => g.name === CUSTOM_GROUP)).toBe( true ) diff --git a/main/src/app-events/block-quit.ts b/main/src/app-events/block-quit.ts index a046701cb..1f6beaff1 100644 --- a/main/src/app-events/block-quit.ts +++ b/main/src/app-events/block-quit.ts @@ -8,8 +8,9 @@ import { recreateMainWindowForShutdown, sendToMainWindowRenderer, } from '../main-window' -import { getToolhivePort, stopToolhive, binPath } from '../toolhive-manager' +import { stopToolhive, binPath } from '../toolhive-manager' import { stopAllServers } from '../graceful-exit' +import { createMainProcessFetch } from '../unix-socket-fetch' import { safeTrayDestroy } from '../system-tray' import { delay } from '../../../utils/delay' import log from '../logger' @@ -39,10 +40,7 @@ export async function blockQuit(source: string, event?: Electron.Event) { } try { - const port = getToolhivePort() - if (port) { - await stopAllServers(binPath, port) - } + await stopAllServers(binPath, { createFetch: createMainProcessFetch }) } catch (err) { log.error('Teardown failed: ', err) } finally { diff --git a/main/src/app-events/process-signals.ts b/main/src/app-events/process-signals.ts index 859208d58..b83e78713 100644 --- a/main/src/app-events/process-signals.ts +++ b/main/src/app-events/process-signals.ts @@ -3,8 +3,9 @@ import { setTearingDownState, setQuittingState, } from '../app-state' -import { getToolhivePort, stopToolhive, binPath } from '../toolhive-manager' +import { stopToolhive, binPath } from '../toolhive-manager' import { stopAllServers } from '../graceful-exit' +import { createMainProcessFetch } from '../unix-socket-fetch' import { safeTrayDestroy } from '../system-tray' import log from '../logger' @@ -17,10 +18,7 @@ export function register() { setQuittingState(true) log.info(`[${sig}] delaying exit for teardown...`) try { - const port = getToolhivePort() - if (port) { - await stopAllServers(binPath, port) - } + await stopAllServers(binPath, { createFetch: createMainProcessFetch }) } finally { stopToolhive() safeTrayDestroy() diff --git a/main/src/app-events/when-ready.ts b/main/src/app-events/when-ready.ts index cf1220462..88591fb53 100644 --- a/main/src/app-events/when-ready.ts +++ b/main/src/app-events/when-ready.ts @@ -10,10 +10,10 @@ import { initTray, safeTrayDestroy } from '../system-tray' import { createApplicationMenu } from '../menu' import { startToolhive, - getToolhivePort, isToolhiveRunning, stopToolhive, } from '../toolhive-manager' +import { registerApiFetchHandlers } from '../unix-socket-fetch' import { getMainWindow, createMainWindow, hideMainWindow } from '../main-window' import { extractDeepLinkFromArgs, handleDeepLink } from '../deep-links' import { getCspString } from '../csp' @@ -71,6 +71,9 @@ export function register() { // Start ToolHive with tray reference await startToolhive() + // Register IPC handlers for renderer -> main -> thv API bridge + registerApiFetchHandlers() + // Create main window try { const mainWindow = await createMainWindow() @@ -128,20 +131,15 @@ export function register() { } } - // Setup CSP headers session.defaultSession.webRequest.onHeadersReceived((details, callback) => { if (process.env.NODE_ENV === 'development') { return callback({ responseHeaders: details.responseHeaders }) } - const port = getToolhivePort() - if (port == null) { - throw new Error('[content-security-policy] ToolHive port is not set') - } return callback({ responseHeaders: { ...details.responseHeaders, 'Content-Security-Policy': [ - getCspString(port, import.meta.env.VITE_SENTRY_DSN), + getCspString(import.meta.env.VITE_SENTRY_DSN), ], }, }) diff --git a/main/src/auto-update.ts b/main/src/auto-update.ts index 56aac4042..ad7eddfca 100644 --- a/main/src/auto-update.ts +++ b/main/src/auto-update.ts @@ -2,12 +2,8 @@ import { app, autoUpdater, dialog, ipcMain, type BrowserWindow } from 'electron' import { updateElectronApp, UpdateSourceType } from 'update-electron-app' import * as Sentry from '@sentry/electron/main' import { stopAllServers } from './graceful-exit' -import { - stopToolhive, - getToolhivePort, - binPath, - isToolhiveRunning, -} from './toolhive-manager' +import { stopToolhive, binPath, isToolhiveRunning } from './toolhive-manager' +import { createMainProcessFetch } from './unix-socket-fetch' import { safeTrayDestroy } from './system-tray' import { getAppVersion, pollWindowReady } from './util' import { delay } from '../../utils/delay' @@ -36,14 +32,7 @@ let updateState: UpdateState = 'none' async function safeServerShutdown(): Promise { try { - const port = getToolhivePort() - if (!port) { - log.info('[update] No ToolHive port available, skipping server shutdown') - return true - } - - await stopAllServers(binPath, port) - + await stopAllServers(binPath, { createFetch: createMainProcessFetch }) log.info('[update] All servers stopped successfully') return true } catch (error) { diff --git a/main/src/chat/__tests__/mcp-tools.test.ts b/main/src/chat/__tests__/mcp-tools.test.ts index 3b08b30ae..938267bf6 100644 --- a/main/src/chat/__tests__/mcp-tools.test.ts +++ b/main/src/chat/__tests__/mcp-tools.test.ts @@ -4,10 +4,10 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' // Hoisted mock factories — must be defined before vi.mock() calls // --------------------------------------------------------------------------- -const mockGetToolhivePort = vi.hoisted(() => vi.fn().mockReturnValue(3000)) const mockGetToolhiveMcpPort = vi.hoisted(() => vi.fn().mockReturnValue(3001)) -const mockGetHeaders = vi.hoisted(() => vi.fn().mockReturnValue({})) -const mockCreateApiClient = vi.hoisted(() => vi.fn().mockReturnValue({})) +const mockCreateMainProcessApiClient = vi.hoisted(() => + vi.fn().mockReturnValue({}) +) const mockGetApiV1BetaWorkloads = vi.hoisted(() => vi.fn().mockResolvedValue({ data: { workloads: [] } }) ) @@ -49,16 +49,11 @@ const mockReplaceAllMcpAppUiMetadata = vi.hoisted(() => vi.fn()) // --------------------------------------------------------------------------- vi.mock('../../toolhive-manager', () => ({ - getToolhivePort: mockGetToolhivePort, getToolhiveMcpPort: mockGetToolhiveMcpPort, })) -vi.mock('../../headers', () => ({ - getHeaders: mockGetHeaders, -})) - -vi.mock('@common/api/generated/client', () => ({ - createClient: mockCreateApiClient, +vi.mock('../../unix-socket-fetch', () => ({ + createMainProcessApiClient: mockCreateMainProcessApiClient, })) vi.mock('@common/api/generated/sdk.gen', () => ({ @@ -168,13 +163,11 @@ const makeToolDef = (overrides: Record = {}) => ({ beforeEach(() => { vi.clearAllMocks() - // Ports - mockGetToolhivePort.mockReturnValue(3000) + // MCP backend port (still TCP) mockGetToolhiveMcpPort.mockReturnValue(3001) // API client chain (fetchWorkloads) - mockGetHeaders.mockReturnValue({}) - mockCreateApiClient.mockReturnValue({}) + mockCreateMainProcessApiClient.mockReturnValue({}) mockGetApiV1BetaWorkloads.mockResolvedValue({ data: { workloads: [] } }) // Settings diff --git a/main/src/chat/agents/builtin-agent-tools/__tests__/skills.test.ts b/main/src/chat/agents/builtin-agent-tools/__tests__/skills.test.ts index d476671cf..3f4b1e4ce 100644 --- a/main/src/chat/agents/builtin-agent-tools/__tests__/skills.test.ts +++ b/main/src/chat/agents/builtin-agent-tools/__tests__/skills.test.ts @@ -6,11 +6,10 @@ import { Buffer } from 'node:buffer' import { nanoid } from 'nanoid' const mockGetApiV1BetaSkills = vi.hoisted(() => vi.fn()) -const mockCreateClient = vi.hoisted(() => +const mockCreateMainProcessApiClient = vi.hoisted(() => vi.fn(() => ({}) as { __fake__: true }) ) -const mockGetToolhivePort = vi.hoisted(() => vi.fn()) -const mockGetHeaders = vi.hoisted(() => vi.fn(() => ({}))) +const mockHasToolhiveConnection = vi.hoisted(() => vi.fn(() => true)) const mockLog = vi.hoisted(() => ({ debug: vi.fn(), info: vi.fn(), @@ -31,14 +30,9 @@ vi.mock('@common/api/generated/sdk.gen', () => ({ postApiV1BetaSkillsBuild: mockPostApiV1BetaSkillsBuild, getApiV1BetaSkillsBuilds: mockGetApiV1BetaSkillsBuilds, })) -vi.mock('@common/api/generated/client', () => ({ - createClient: mockCreateClient, -})) -vi.mock('../../../../toolhive-manager', () => ({ - getToolhivePort: mockGetToolhivePort, -})) -vi.mock('../../../../headers', () => ({ - getHeaders: mockGetHeaders, +vi.mock('../../../../unix-socket-fetch', () => ({ + createMainProcessApiClient: mockCreateMainProcessApiClient, + hasToolhiveConnection: mockHasToolhiveConnection, })) vi.mock('../../../../logger', () => ({ default: mockLog })) vi.mock('../../../settings-storage', () => ({ @@ -73,7 +67,8 @@ async function ensureClientSkill( beforeEach(async () => { mockGetApiV1BetaSkills.mockReset() - mockGetToolhivePort.mockReset() + mockHasToolhiveConnection.mockReset() + mockHasToolhiveConnection.mockReturnValue(true) mockGetEnabledSkills.mockReset() mockGetEnabledSkills.mockImplementation(() => []) mockPruneEnabledSkillsTo.mockReset() @@ -133,7 +128,7 @@ async function buildHandle(options: BuildOptions = {}) { describe('skills bundle — list_skills', () => { it('reports a friendly error when ToolHive is not running', async () => { - mockGetToolhivePort.mockReturnValue(null) + mockHasToolhiveConnection.mockReturnValue(false) const handle = await createSkillsAgentTools({ buildClient: () => null, homeDir: homeRoot, diff --git a/main/src/chat/agents/builtin-agent-tools/skills.ts b/main/src/chat/agents/builtin-agent-tools/skills.ts index 687451086..c7be95e77 100644 --- a/main/src/chat/agents/builtin-agent-tools/skills.ts +++ b/main/src/chat/agents/builtin-agent-tools/skills.ts @@ -6,7 +6,7 @@ import { nanoid } from 'nanoid' import { z } from 'zod' import { tool, type ToolSet } from 'ai' import log from '../../../logger' -import { createClient, type Client } from '@common/api/generated/client' +import { type Client } from '@common/api/generated/client' import { getApiV1BetaSkills, getApiV1BetaSkillsBuilds, @@ -16,8 +16,10 @@ import type { GithubComStacklokToolhivePkgSkillsInstalledSkill as InstalledSkill, GithubComStacklokToolhivePkgSkillsLocalBuild as LocalBuild, } from '@common/api/generated/types.gen' -import { getToolhivePort } from '../../../toolhive-manager' -import { getHeaders } from '../../../headers' +import { + createMainProcessApiClient, + hasToolhiveConnection, +} from '../../../unix-socket-fetch' import { getEnabledSkills as defaultGetEnabledSkills, pruneEnabledSkillsTo as defaultPruneEnabledSkillsTo, @@ -395,12 +397,8 @@ function renderInstructionsSuffix( } function defaultBuildClient(): Client | null { - const port = getToolhivePort() - if (!port) return null - return createClient({ - baseUrl: `http://localhost:${port}`, - headers: getHeaders(), - }) + if (!hasToolhiveConnection()) return null + return createMainProcessApiClient() } export interface SkillsAgentToolsHandle { diff --git a/main/src/chat/mcp-tools.ts b/main/src/chat/mcp-tools.ts index 9944207b0..fdb3daf07 100644 --- a/main/src/chat/mcp-tools.ts +++ b/main/src/chat/mcp-tools.ts @@ -11,10 +11,9 @@ import type { McpUiResourceCsp, McpUiResourcePermissions, } from '@modelcontextprotocol/ext-apps/app-bridge' -import { createClient } from '@common/api/generated/client' import { getApiV1BetaWorkloads } from '@common/api/generated/sdk.gen' -import { getHeaders } from '../headers' -import { getToolhivePort, getToolhiveMcpPort } from '../toolhive-manager' +import { getToolhiveMcpPort } from '../toolhive-manager' +import { createMainProcessApiClient } from '../unix-socket-fetch' import log from '../logger' import type { AvailableServer } from './types' import { getEnabledMcpTools } from './settings-storage' @@ -86,11 +85,7 @@ function createToolhiveMcpTransport(): StreamableHTTPClientTransport { /** Fetches all workloads from the ToolHive API. */ async function fetchWorkloads(): Promise { - const port = getToolhivePort() - const client = createClient({ - baseUrl: `http://localhost:${port}`, - headers: getHeaders(), - }) + const client = createMainProcessApiClient() const { data } = await getApiV1BetaWorkloads({ client }) return data?.workloads ?? [] } diff --git a/main/src/chat/settings-storage.ts b/main/src/chat/settings-storage.ts index b0172a4a6..7dac1c3f8 100644 --- a/main/src/chat/settings-storage.ts +++ b/main/src/chat/settings-storage.ts @@ -1,10 +1,9 @@ import Store from 'electron-store' import log from '../logger' -import { getToolhivePort, isToolhiveRunning } from '../toolhive-manager' -import { createClient } from '@common/api/generated/client' +import { isToolhiveRunning } from '../toolhive-manager' import { getApiV1BetaWorkloads } from '@common/api/generated/sdk.gen' import type { GithubComStacklokToolhivePkgCoreWorkload as CoreWorkload } from '@common/api/generated/types.gen' -import { getHeaders } from '../headers' +import { createMainProcessApiClient } from '../unix-socket-fetch' import { getTearingDownState } from '../app-state' import { getToolhiveMcpInfo } from './mcp-tools' import { TOOLHIVE_MCP_SERVER_NAME } from '../utils/constants' @@ -252,12 +251,8 @@ export async function getEnabledMcpTools(): Promise< } // Get running servers to filter out tools from stopped servers - const port = getToolhivePort() try { - const client = createClient({ - baseUrl: `http://localhost:${port}`, - headers: getHeaders(), - }) + const client = createMainProcessApiClient() const { data } = await getApiV1BetaWorkloads({ client, diff --git a/main/src/csp.ts b/main/src/csp.ts index a3a8103e7..9f5a08936 100644 --- a/main/src/csp.ts +++ b/main/src/csp.ts @@ -1,29 +1,33 @@ -const getCspMap = (port: number, sentryDsn?: string) => { - // In production with Sentry enabled, allow blob workers for replay +const getCspMap = (sentryDsn?: string) => { const hasSentry = Boolean(sentryDsn) const workerSrc = hasSentry ? "'self' blob:" : "'self'" + // The renderer never makes direct HTTP requests to thv — they are forwarded + // over IPC to the main process, which dials the UNIX socket / named pipe — + // so no localhost entry is needed in connect-src. + const connectParts = ["'self'", 'https://api.hsforms.com'] + if (hasSentry) connectParts.push('https://*.sentry.io') + return { 'default-src': "'self'", 'script-src': "'self'", 'style-src': "'self' 'unsafe-inline'", 'img-src': "'self' data: blob:", 'font-src': "'self' data:", - 'connect-src': `'self' http://localhost:${port} https://api.hsforms.com${hasSentry ? ' https://*.sentry.io' : ''}`, - 'frame-src': "'self' blob:", + 'connect-src': connectParts.join(' '), + 'frame-src': "'none'", 'object-src': "'none'", 'base-uri': "'self'", 'form-action': "'self'", 'frame-ancestors': "'none'", 'manifest-src': "'self'", 'media-src': "'self' blob: data:", - // Allow blob: workers only when Sentry is configured 'worker-src': workerSrc, 'child-src': "'self' blob:", } } -export const getCspString = (port: number, sentryDsn?: string) => - Object.entries(getCspMap(port, sentryDsn)) +export const getCspString = (sentryDsn?: string) => + Object.entries(getCspMap(sentryDsn)) .map(([key, value]) => `${key} ${value}`) .join('; ') diff --git a/main/src/graceful-exit.ts b/main/src/graceful-exit.ts index 731adddb5..70f2cc795 100644 --- a/main/src/graceful-exit.ts +++ b/main/src/graceful-exit.ts @@ -22,11 +22,15 @@ export const shutdownStore = new Store({ }, }) -/** Create API client for the given port */ -function createApiClient(port: number) { +/** + * Create API client. When a custom fetch is provided (UNIX socket transport), + * the baseUrl is a dummy since the custom fetch handles routing. + */ +function createApiClient(opts: { port?: number; customFetch?: typeof fetch }) { return createClient({ - baseUrl: `http://localhost:${port}`, + baseUrl: opts.port ? `http://localhost:${opts.port}` : 'http://localhost', headers: getHeaders(), + ...(opts.customFetch ? { fetch: opts.customFetch } : {}), }) } @@ -114,10 +118,11 @@ async function pollUntilAllStopped( /** Stop every running server in parallel and wait until *all* are down. */ export async function stopAllServers( - _binPath: string, // Kept for backward compatibility - port: number + _binPath: string, + opts: { port?: number; createFetch?: () => typeof fetch } ): Promise { - const client = createApiClient(port) + const customFetch = opts.createFetch?.() + const client = createApiClient({ port: opts.port, customFetch }) const servers = await getRunningServers(client) log.info( `Found ${servers.length} running servers: `, diff --git a/main/src/ipc-handlers/__tests__/toolhive.test.ts b/main/src/ipc-handlers/__tests__/toolhive.test.ts new file mode 100644 index 000000000..71ef440f0 --- /dev/null +++ b/main/src/ipc-handlers/__tests__/toolhive.test.ts @@ -0,0 +1,216 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' + +vi.mock('electron', () => ({ + ipcMain: { + handle: vi.fn(), + }, +})) + +vi.mock('../../toolhive-manager', () => ({ + restartToolhive: vi.fn(), + getToolhiveSocketPath: vi.fn(), + isToolhiveRunning: vi.fn(), + getToolhiveStatus: vi.fn(), + getToolhiveMcpPort: vi.fn(), + isUsingCustomSocket: vi.fn(), +})) + +vi.mock('../../container-engine', () => ({ + checkContainerEngine: vi.fn(), +})) + +vi.mock('../../graceful-exit', () => ({ + getLastShutdownServers: vi.fn(), + clearShutdownHistory: vi.fn(), +})) + +vi.mock('../../unix-socket-fetch', () => ({ + registerApiFetchHandlers: vi.fn(), +})) + +vi.mock('../../logger', () => ({ + default: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, +})) + +import { ipcMain } from 'electron' +import { + restartToolhive, + getToolhiveSocketPath, + isToolhiveRunning, + getToolhiveStatus, + getToolhiveMcpPort, + isUsingCustomSocket, +} from '../../toolhive-manager' +import { checkContainerEngine } from '../../container-engine' +import { + getLastShutdownServers, + clearShutdownHistory, +} from '../../graceful-exit' +import { registerApiFetchHandlers } from '../../unix-socket-fetch' +import log from '../../logger' +import { register } from '../toolhive' + +const mockHandle = vi.mocked(ipcMain.handle) +const mockRegisterApiFetchHandlers = vi.mocked(registerApiFetchHandlers) +const mockRestartToolhive = vi.mocked(restartToolhive) +const mockGetSocketPath = vi.mocked(getToolhiveSocketPath) +const mockIsRunning = vi.mocked(isToolhiveRunning) +const mockGetStatus = vi.mocked(getToolhiveStatus) +const mockGetMcpPort = vi.mocked(getToolhiveMcpPort) +const mockIsUsingCustomSocket = vi.mocked(isUsingCustomSocket) +const mockCheckContainerEngine = vi.mocked(checkContainerEngine) +const mockGetLastShutdownServers = vi.mocked(getLastShutdownServers) +const mockClearShutdownHistory = vi.mocked(clearShutdownHistory) +const mockLogError = vi.mocked(log.error) + +type Handler = (event: unknown, ...args: unknown[]) => unknown + +function getHandler(channel: string): Handler { + const call = mockHandle.mock.calls.find(([c]) => c === channel) + if (!call) throw new Error(`handler for ${channel} not registered`) + return call[1] as Handler +} + +describe('ipc-handlers/toolhive register()', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('registers exactly the expected channel set', () => { + register() + + const channels = mockHandle.mock.calls.map(([c]) => c) + expect(channels.sort()).toEqual( + [ + 'get-toolhive-mcp-port', + 'get-toolhive-socket-path', + 'is-toolhive-running', + 'get-toolhive-status', + 'is-using-custom-socket', + 'check-container-engine', + 'restart-toolhive', + 'shutdown-store:get-last-servers', + 'shutdown-store:clear-history', + ].sort() + ) + }) + + it('wires registerApiFetchHandlers exactly once', () => { + register() + + expect(mockRegisterApiFetchHandlers).toHaveBeenCalledTimes(1) + }) + + it('get-toolhive-mcp-port returns the value from the manager', () => { + mockGetMcpPort.mockReturnValue(12345) + register() + + expect(getHandler('get-toolhive-mcp-port')({})).toBe(12345) + }) + + it('get-toolhive-socket-path returns the value from the manager', () => { + mockGetSocketPath.mockReturnValue('/tmp/foo.sock') + register() + + expect(getHandler('get-toolhive-socket-path')({})).toBe('/tmp/foo.sock') + }) + + it('is-toolhive-running returns the value from the manager', () => { + mockIsRunning.mockReturnValue(true) + register() + + expect(getHandler('is-toolhive-running')({})).toBe(true) + }) + + it('get-toolhive-status returns the value from the manager', () => { + const status = { isRunning: true } as unknown as ReturnType< + typeof getToolhiveStatus + > + mockGetStatus.mockReturnValue(status) + register() + + expect(getHandler('get-toolhive-status')({})).toBe(status) + }) + + it('is-using-custom-socket returns the value from the manager', () => { + mockIsUsingCustomSocket.mockReturnValue(true) + register() + + expect(getHandler('is-using-custom-socket')({})).toBe(true) + + mockIsUsingCustomSocket.mockReturnValue(false) + expect(getHandler('is-using-custom-socket')({})).toBe(false) + }) + + it('check-container-engine resolves with the engine result', async () => { + const result = { + available: true, + docker: true, + podman: false, + rancherDesktop: false, + } + mockCheckContainerEngine.mockResolvedValue(result) + register() + + await expect(getHandler('check-container-engine')({})).resolves.toEqual( + result + ) + }) + + it('restart-toolhive resolves { success: true } on success', async () => { + mockRestartToolhive.mockResolvedValue(undefined) + register() + + await expect(getHandler('restart-toolhive')({})).resolves.toEqual({ + success: true, + }) + }) + + it('restart-toolhive resolves { success: false, error } on rejection and logs', async () => { + mockRestartToolhive.mockRejectedValue(new Error('boom')) + register() + + await expect(getHandler('restart-toolhive')({})).resolves.toEqual({ + success: false, + error: 'boom', + }) + expect(mockLogError).toHaveBeenCalledWith( + 'Failed to restart ToolHive: ', + expect.any(Error) + ) + }) + + it('restart-toolhive returns "Unknown error" when rejection is not an Error', async () => { + mockRestartToolhive.mockRejectedValue('plain string') + register() + + await expect(getHandler('restart-toolhive')({})).resolves.toEqual({ + success: false, + error: 'Unknown error', + }) + }) + + it('shutdown-store:get-last-servers returns the workloads from graceful-exit', () => { + const workloads = [{ name: 'srv', status: 'running' as const }] + mockGetLastShutdownServers.mockReturnValue( + workloads as unknown as ReturnType + ) + register() + + expect(getHandler('shutdown-store:get-last-servers')({})).toEqual(workloads) + }) + + it('shutdown-store:clear-history calls clearShutdownHistory and returns success', () => { + register() + + const result = getHandler('shutdown-store:clear-history')({}) + + expect(mockClearShutdownHistory).toHaveBeenCalledTimes(1) + expect(result).toEqual({ success: true }) + }) +}) diff --git a/main/src/ipc-handlers/toolhive.ts b/main/src/ipc-handlers/toolhive.ts index 1b40f56a7..81c21c0f7 100644 --- a/main/src/ipc-handlers/toolhive.ts +++ b/main/src/ipc-handlers/toolhive.ts @@ -1,22 +1,23 @@ import { ipcMain } from 'electron' import { restartToolhive, - getToolhivePort, + getToolhiveSocketPath, isToolhiveRunning, getToolhiveStatus, getToolhiveMcpPort, - isUsingCustomPort, + isUsingCustomSocket, } from '../toolhive-manager' import { checkContainerEngine } from '../container-engine' import { getLastShutdownServers, clearShutdownHistory } from '../graceful-exit' +import { registerApiFetchHandlers } from '../unix-socket-fetch' import log from '../logger' export function register() { - ipcMain.handle('get-toolhive-port', () => getToolhivePort()) ipcMain.handle('get-toolhive-mcp-port', () => getToolhiveMcpPort()) + ipcMain.handle('get-toolhive-socket-path', () => getToolhiveSocketPath()) ipcMain.handle('is-toolhive-running', () => isToolhiveRunning()) ipcMain.handle('get-toolhive-status', () => getToolhiveStatus()) - ipcMain.handle('is-using-custom-port', () => isUsingCustomPort()) + ipcMain.handle('is-using-custom-socket', () => isUsingCustomSocket()) ipcMain.handle('check-container-engine', async () => { return await checkContainerEngine() @@ -43,4 +44,6 @@ export function register() { clearShutdownHistory() return { success: true } }) + + registerApiFetchHandlers() } diff --git a/main/src/tests/auto-update.test.ts b/main/src/tests/auto-update.test.ts index df7d8077c..ea5ec567f 100644 --- a/main/src/tests/auto-update.test.ts +++ b/main/src/tests/auto-update.test.ts @@ -119,11 +119,14 @@ vi.mock('../graceful-exit', () => ({ vi.mock('../toolhive-manager', () => ({ stopToolhive: vi.fn(), - getToolhivePort: vi.fn(() => 3000), isToolhiveRunning: vi.fn(() => true), binPath: '/mock/bin/path', })) +vi.mock('../unix-socket-fetch', () => ({ + createMainProcessFetch: vi.fn(() => vi.fn()), +})) + vi.mock('../system-tray', () => ({ safeTrayDestroy: vi.fn(), })) @@ -156,7 +159,7 @@ vi.mock('../app-state', () => ({ })) import { stopAllServers } from '../graceful-exit' -import { stopToolhive, getToolhivePort } from '../toolhive-manager' +import { stopToolhive } from '../toolhive-manager' import { safeTrayDestroy } from '../system-tray' import { pollWindowReady } from '../util' import { delay } from '../../../utils/delay' @@ -199,7 +202,6 @@ describe('auto-update', () => { // Setup default mocks vi.mocked(stopAllServers).mockResolvedValue(undefined) vi.mocked(stopToolhive).mockReturnValue(undefined) - vi.mocked(getToolhivePort).mockReturnValue(3000) vi.mocked(pollWindowReady).mockResolvedValue(undefined) vi.mocked(delay).mockResolvedValue(undefined) vi.mocked(dialog.showMessageBox).mockResolvedValue({ @@ -803,8 +805,7 @@ describe('auto-update', () => { expect(vi.mocked(autoUpdater).quitAndInstall).toHaveBeenCalled() }) - it('integrates with toolhive manager port detection', async () => { - vi.mocked(getToolhivePort).mockReturnValue(undefined) + it('always attempts server shutdown via IPC fetch bridge', async () => { vi.mocked(dialog.showMessageBox).mockResolvedValue({ response: 0, checkboxChecked: false, @@ -823,13 +824,14 @@ describe('auto-update', () => { await updatePromise - // Should skip server shutdown when no port is available - expect(vi.mocked(getToolhivePort)).toHaveBeenCalled() - expect(vi.mocked(stopAllServers)).not.toHaveBeenCalled() + // Always attempts server shutdown (connection errors handled internally) + expect(vi.mocked(stopAllServers)).toHaveBeenCalled() }) - it('handles missing toolhive port gracefully', async () => { - vi.mocked(getToolhivePort).mockReturnValue(undefined) + it('handles server shutdown failure gracefully', async () => { + vi.mocked(stopAllServers).mockRejectedValueOnce( + new Error('No ToolHive connection available') + ) vi.mocked(dialog.showMessageBox).mockResolvedValue({ response: 0, checkboxChecked: false, @@ -848,8 +850,9 @@ describe('auto-update', () => { await updatePromise - expect(vi.mocked(log).info).toHaveBeenCalledWith( - '[update] No ToolHive port available, skipping server shutdown' + expect(vi.mocked(log).error).toHaveBeenCalledWith( + expect.stringContaining('[update] Server shutdown failed'), + expect.anything() ) }) diff --git a/main/src/tests/csp.test.ts b/main/src/tests/csp.test.ts new file mode 100644 index 000000000..173d0c7ea --- /dev/null +++ b/main/src/tests/csp.test.ts @@ -0,0 +1,75 @@ +import { describe, it, expect } from 'vitest' +import { getCspString } from '../csp' + +function parseCsp(csp: string): Record { + return Object.fromEntries( + csp + .split(';') + .map((part) => part.trim()) + .filter(Boolean) + .map((part) => { + const idx = part.indexOf(' ') + if (idx === -1) return [part, ''] + return [part.slice(0, idx), part.slice(idx + 1)] as const + }) + ) +} + +describe('getCspString', () => { + it('omits localhost from connect-src (transport is now over IPC, not HTTP)', () => { + const csp = getCspString() + + expect(csp).not.toMatch(/localhost/) + expect(csp).not.toMatch(/127\.0\.0\.1/) + }) + + it('without a Sentry DSN, connect-src is just self + hsforms', () => { + const map = parseCsp(getCspString()) + + expect(map['connect-src']).toBe("'self' https://api.hsforms.com") + expect(map['connect-src']).not.toContain('sentry.io') + }) + + it('with a Sentry DSN, connect-src additionally allows *.sentry.io and worker-src allows blob:', () => { + const map = parseCsp(getCspString('https://sentry.example/dsn')) + + expect(map['connect-src']).toBe( + "'self' https://api.hsforms.com https://*.sentry.io" + ) + expect(map['worker-src']).toBe("'self' blob:") + }) + + it("without a Sentry DSN, worker-src is 'self' only", () => { + const map = parseCsp(getCspString()) + + expect(map['worker-src']).toBe("'self'") + }) + + it("always emits frame-src 'none', object-src 'none', and frame-ancestors 'none'", () => { + const map = parseCsp(getCspString()) + + expect(map['frame-src']).toBe("'none'") + expect(map['object-src']).toBe("'none'") + expect(map['frame-ancestors']).toBe("'none'") + }) + + it('emits a `; `-joined string that round-trips through the parser', () => { + const csp = getCspString() + expect(csp).toContain('; ') + + const map = parseCsp(csp) + + expect(map).toMatchObject({ + 'default-src': "'self'", + 'script-src': "'self'", + 'style-src': "'self' 'unsafe-inline'", + 'img-src': "'self' data: blob:", + 'font-src': "'self' data:", + 'base-uri': "'self'", + 'form-action': "'self'", + 'manifest-src': "'self'", + 'media-src': "'self' blob: data:", + 'child-src': "'self' blob:", + }) + }) +}) diff --git a/main/src/tests/graceful-exit.test.ts b/main/src/tests/graceful-exit.test.ts index 165aec7ca..6ec7d0126 100644 --- a/main/src/tests/graceful-exit.test.ts +++ b/main/src/tests/graceful-exit.test.ts @@ -117,7 +117,7 @@ describe('graceful-exit', () => { createMockWorkloadsResponse([]) ) - await stopAllServers('', 3000) + await stopAllServers('', { port: 3000 }) expect(mockLog.info).toHaveBeenCalledWith( 'No running servers – teardown complete' @@ -140,7 +140,7 @@ describe('graceful-exit', () => { mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse()) - await stopAllServers('', 3000) + await stopAllServers('', { port: 3000 }) expect(mockPostApiV1BetaWorkloadsStop).toHaveBeenCalledTimes(1) expect(mockPostApiV1BetaWorkloadsStop).toHaveBeenCalledWith({ @@ -165,7 +165,7 @@ describe('graceful-exit', () => { mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse()) - await stopAllServers('', 3000) + await stopAllServers('', { port: 3000 }) expect(mockLog.info).toHaveBeenCalledWith( 'All servers have reached final state' @@ -182,7 +182,9 @@ describe('graceful-exit', () => { new Error('Stop failed') ) - await expect(stopAllServers('', 3000)).rejects.toThrow('Stop failed') + await expect(stopAllServers('', { port: 3000 })).rejects.toThrow( + 'Stop failed' + ) }) it('handles timeout when servers do not stop', async () => { @@ -201,7 +203,7 @@ describe('graceful-exit', () => { mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse()) - await expect(stopAllServers('', 3000)).rejects.toThrow( + await expect(stopAllServers('', { port: 3000 })).rejects.toThrow( 'Some servers failed to stop within timeout' ) }) @@ -213,7 +215,7 @@ describe('graceful-exit', () => { mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse()) - await stopAllServers('', 3000) + await stopAllServers('', { port: 3000 }) expect(mockWriteShutdownServers).toHaveBeenCalledWith(mockRunningServers) }) @@ -234,7 +236,7 @@ describe('graceful-exit', () => { mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse()) - await stopAllServers('', 3000) + await stopAllServers('', { port: 3000 }) // Should only include the server with a name in the batch call expect(mockPostApiV1BetaWorkloadsStop).toHaveBeenCalledTimes(1) @@ -300,7 +302,7 @@ describe('graceful-exit', () => { mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse()) - await stopAllServers('', 3000) + await stopAllServers('', { port: 3000 }) expect(mockLog.info).toHaveBeenCalledWith( 'Still waiting for 1 servers to reach final state: server1(stopping)' @@ -326,10 +328,106 @@ describe('graceful-exit', () => { mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse()) - await stopAllServers('', 3000) + await stopAllServers('', { port: 3000 }) // Should call delay between polling attempts (not on first attempt) expect(mockDelay).toHaveBeenCalledWith(2000) }) }) + + // The new socket transport calls stopAllServers with a `createFetch` factory + // instead of a `port`. The client's baseUrl becomes the sentinel + // 'http://localhost' and the custom fetch handles routing to the UNIX socket. + describe('socket transport (createFetch branch)', () => { + const mockRunningServers: CoreWorkload[] = [ + { name: 'server1', status: 'running', port: 3001 }, + ] + const mockStoppedServers: CoreWorkload[] = [ + { name: 'server1', status: 'stopped', port: 3001 }, + ] + + it('passes a sentinel baseUrl (no port) and the produced custom fetch to createClient', async () => { + mockGetApiV1BetaWorkloads.mockResolvedValue( + createMockWorkloadsResponse([]) + ) + + const customFetch = vi.fn() as unknown as typeof fetch + const createFetch = vi.fn(() => customFetch) + + await stopAllServers('', { createFetch }) + + expect(createFetch).toHaveBeenCalledTimes(1) + expect(mockCreateClient).toHaveBeenCalledWith({ + baseUrl: 'http://localhost', + headers: mockHeaders, + fetch: customFetch, + }) + }) + + it('falls back to http://localhost (no port, no fetch) when neither is provided', async () => { + mockGetApiV1BetaWorkloads.mockResolvedValue( + createMockWorkloadsResponse([]) + ) + + await stopAllServers('', {}) + + const cfg = mockCreateClient.mock.calls.at(-1)?.[0] as { + baseUrl: string + fetch?: unknown + } + expect(cfg.baseUrl).toBe('http://localhost') + expect(cfg.fetch).toBeUndefined() + }) + + it('uses the port baseUrl when both port and createFetch are provided, but the custom fetch overrides transport', async () => { + mockGetApiV1BetaWorkloads.mockResolvedValue( + createMockWorkloadsResponse([]) + ) + + const customFetch = vi.fn() as unknown as typeof fetch + const createFetch = vi.fn(() => customFetch) + + await stopAllServers('', { port: 3000, createFetch }) + + const cfg = mockCreateClient.mock.calls.at(-1)?.[0] as { + baseUrl: string + fetch?: unknown + } + expect(cfg.baseUrl).toBe('http://localhost:3000') + expect(cfg.fetch).toBe(customFetch) + }) + + it('completes the no-running-servers fast path when using createFetch', async () => { + mockGetApiV1BetaWorkloads.mockResolvedValue( + createMockWorkloadsResponse([]) + ) + + const customFetch = vi.fn() as unknown as typeof fetch + await expect( + stopAllServers('', { createFetch: () => customFetch }) + ).resolves.toBeUndefined() + + expect(mockPostApiV1BetaWorkloadsStop).not.toHaveBeenCalled() + expect(mockLog.info).toHaveBeenCalledWith( + 'No running servers – teardown complete' + ) + }) + + it('completes the polling loop when using createFetch', async () => { + mockGetApiV1BetaWorkloads + .mockResolvedValueOnce(createMockWorkloadsResponse(mockRunningServers)) + .mockResolvedValueOnce(createMockWorkloadsResponse(mockStoppedServers)) + mockPostApiV1BetaWorkloadsStop.mockResolvedValue(createMockStopResponse()) + + const customFetch = vi.fn() as unknown as typeof fetch + await stopAllServers('', { createFetch: () => customFetch }) + + expect(mockPostApiV1BetaWorkloadsStop).toHaveBeenCalledTimes(1) + expect(mockPostApiV1BetaWorkloadsStop).toHaveBeenCalledWith({ + client: mockClient, + body: { names: ['server1'] }, + }) + expect(mockLog.info).toHaveBeenCalledWith('All servers stopped cleanly') + }) + }) }) diff --git a/main/src/tests/toolhive-manager.test.ts b/main/src/tests/toolhive-manager.test.ts index 550a43ac7..e9d0ca65c 100644 --- a/main/src/tests/toolhive-manager.test.ts +++ b/main/src/tests/toolhive-manager.test.ts @@ -7,8 +7,8 @@ import { platform } from 'node:os' import { app } from 'electron' import { startToolhive, - getToolhivePort, getToolhiveMcpPort, + getToolhiveSocketPath, getToolhiveStatus, isToolhiveRunning, stopToolhive, @@ -227,8 +227,7 @@ describe('toolhive-manager', () => { '--experimental-mcp', '--experimental-mcp-host=127.0.0.1', expect.stringMatching(/--experimental-mcp-port=\d+/), - '--host=127.0.0.1', - expect.stringMatching(/--port=\d+/), + expect.stringMatching(/--socket=.+/), ]), { stdio: ['ignore', 'ignore', 'pipe'], @@ -244,7 +243,7 @@ describe('toolhive-manager', () => { expect.stringContaining('Starting ToolHive from:') ) expect(isToolhiveRunning()).toBe(true) - expect(getToolhivePort()).toBeTypeOf('number') + expect(getToolhiveSocketPath()).toBeTypeOf('string') expect(getToolhiveMcpPort()).toBeTypeOf('number') }) @@ -373,39 +372,86 @@ describe('toolhive-manager', () => { expect(mockCaptureMessage).not.toHaveBeenCalled() }) - it('assigns different ports to main and MCP services', async () => { - const startPromise = startToolhive() + it('uses a UNIX socket for main service and a port for MCP on non-Windows', async () => { + const originalPlatform = process.platform + Object.defineProperty(process, 'platform', { + value: 'linux', + configurable: true, + }) - await vi.advanceTimersByTimeAsync(50) - await startPromise + try { + const startPromise = startToolhive() - const toolhivePort = getToolhivePort() - const mcpPort = getToolhiveMcpPort() + await vi.advanceTimersByTimeAsync(50) + await startPromise - expect(toolhivePort).toBeTypeOf('number') - expect(mcpPort).toBeTypeOf('number') - expect(toolhivePort).not.toBe(mcpPort) + const socketPath = getToolhiveSocketPath() + const mcpPort = getToolhiveMcpPort() + + expect(socketPath).toBeTypeOf('string') + expect(socketPath).toMatch(/toolhive-\d+\.sock$/) + expect(mcpPort).toBeTypeOf('number') + + const spawnArgs = mockSpawn.mock.calls[0]![1] as string[] + expect(spawnArgs).toEqual( + expect.arrayContaining([ + expect.stringMatching(/^--socket=.*toolhive-\d+\.sock$/), + ]) + ) + } finally { + Object.defineProperty(process, 'platform', { + value: originalPlatform, + configurable: true, + }) + } }) - it('uses range for main port but any port for MCP', async () => { + it('uses a Windows named pipe for main service when process.platform is win32', async () => { + const originalPlatform = process.platform + Object.defineProperty(process, 'platform', { + value: 'win32', + configurable: true, + }) + + try { + const startPromise = startToolhive() + + await vi.advanceTimersByTimeAsync(50) + await startPromise + + const socketPath = getToolhiveSocketPath() + expect(socketPath).toBeTypeOf('string') + // Named pipe shape: \\.\pipe\toolhive- + expect(socketPath).toMatch(/^\\\\\.\\pipe\\toolhive-\d+$/) + expect(getToolhiveMcpPort()).toBeTypeOf('number') + + const spawnArgs = mockSpawn.mock.calls[0]![1] as string[] + expect(spawnArgs).toEqual( + expect.arrayContaining([ + expect.stringMatching(/^--socket=\\\\\.\\pipe\\toolhive-\d+$/), + ]) + ) + } finally { + Object.defineProperty(process, 'platform', { + value: originalPlatform, + configurable: true, + }) + } + }) + + it('logs socket path and MCP port on startup', async () => { const startPromise = startToolhive() await vi.advanceTimersByTimeAsync(50) await startPromise - const toolhivePort = getToolhivePort() - const mcpPort = getToolhiveMcpPort() - - // Main port should be in preferred range (when available) or fallback - expect(toolhivePort).toBeTypeOf('number') - // MCP port can be any available port - expect(mcpPort).toBeTypeOf('number') - expect(toolhivePort).not.toBe(mcpPort) + expect(getToolhiveSocketPath()).toBeTypeOf('string') + expect(getToolhiveMcpPort()).toBeTypeOf('number') - // Verify the log message includes both ports + // Verify the log message includes the socket path and MCP port expect(mockLog.info).toHaveBeenCalledWith( expect.stringMatching( - /Starting ToolHive from: .+ on port \d+, MCP on port \d+/ + /Starting ToolHive from: .+ on socket .+, MCP on port \d+/ ) ) }) @@ -424,8 +470,7 @@ describe('toolhive-manager', () => { '--experimental-mcp', '--experimental-mcp-host=127.0.0.1', expect.stringMatching(/--experimental-mcp-port=\d+/), - '--host=127.0.0.1', - expect.stringMatching(/--port=\d+/), + expect.stringMatching(/--socket=.+/), ]), { stdio: ['ignore', 'ignore', 'pipe'], @@ -548,49 +593,21 @@ describe('toolhive-manager', () => { }) }) - describe('port finding with fallback', () => { - it('falls back to random port when preferred range is unavailable', async () => { - // Mock all ports in range to be unavailable, then allow random port - mockNet.createServer.mockImplementation(function createServer() { - const server = new MockServer() as unknown as net.Server - const originalListen = server.listen.bind(server) - - server.listen = vi.fn(function listen( - port: number, - callback?: () => void - ) { - if (port >= 50000 && port <= 50100) { - // Simulate all ports in range being unavailable - setTimeout(() => { - server.emit('error', { code: 'EADDRINUSE' }) - }, 5) - } else if (port === 0) { - // Allow OS assignment (fallback) - setTimeout(() => { - originalListen(port, callback) - }, 5) - } else { - // Any other specific port - originalListen(port, callback) - } - return server - }) as unknown as typeof server.listen - - return server - }) - + describe('port finding', () => { + it('uses an OS-assigned random port for the MCP service', async () => { const startPromise = startToolhive() - // Advance timers to complete all async operations including fallback attempts - await vi.advanceTimersByTimeAsync(1000) + await vi.advanceTimersByTimeAsync(50) await startPromise - expect(mockLog.warn).toHaveBeenCalledWith( - expect.stringContaining( - 'No free port found in range 50000-50100, falling back to random port' - ) + // findFreePort() is called without a range, so it uses OS assignment + // (server.listen(0, ...)) which always succeeds. No fallback warning is + // expected on the happy path. + expect(mockLog.warn).not.toHaveBeenCalledWith( + expect.stringContaining('falling back to random port') ) expect(isToolhiveRunning()).toBe(true) + expect(getToolhiveMcpPort()).toBeTypeOf('number') }) }) @@ -834,4 +851,61 @@ describe('toolhive-manager', () => { ) }) }) + + describe('external thv via THV_SOCKET', () => { + const originalSocket = process.env.THV_SOCKET + const originalMcpPort = process.env.THV_MCP_PORT + + afterEach(() => { + if (originalSocket === undefined) delete process.env.THV_SOCKET + else process.env.THV_SOCKET = originalSocket + if (originalMcpPort === undefined) delete process.env.THV_MCP_PORT + else process.env.THV_MCP_PORT = originalMcpPort + }) + + it('reports running and exposes the socket without spawning thv', async () => { + process.env.THV_SOCKET = '/tmp/external-thv.sock' + delete process.env.THV_MCP_PORT + + await startToolhive() + + expect(mockSpawn).not.toHaveBeenCalled() + expect(isToolhiveRunning()).toBe(true) + expect(getToolhiveSocketPath()).toBe('/tmp/external-thv.sock') + expect(getToolhiveMcpPort()).toBeUndefined() + }) + + it('parses a valid THV_MCP_PORT', async () => { + process.env.THV_SOCKET = '/tmp/external-thv.sock' + process.env.THV_MCP_PORT = '40000' + + await startToolhive() + + expect(getToolhiveMcpPort()).toBe(40000) + }) + + it('falls back to undefined and warns on a non-numeric THV_MCP_PORT', async () => { + process.env.THV_SOCKET = '/tmp/external-thv.sock' + process.env.THV_MCP_PORT = 'not-a-port' + + await startToolhive() + + expect(getToolhiveMcpPort()).toBeUndefined() + expect(mockLog.warn).toHaveBeenCalledWith( + expect.stringContaining('Ignoring invalid THV_MCP_PORT=not-a-port') + ) + }) + + it('rejects out-of-range THV_MCP_PORT values', async () => { + process.env.THV_SOCKET = '/tmp/external-thv.sock' + process.env.THV_MCP_PORT = '70000' + + await startToolhive() + + expect(getToolhiveMcpPort()).toBeUndefined() + expect(mockLog.warn).toHaveBeenCalledWith( + expect.stringContaining('Ignoring invalid THV_MCP_PORT=70000') + ) + }) + }) }) diff --git a/main/src/tests/unix-socket-fetch.test.ts b/main/src/tests/unix-socket-fetch.test.ts new file mode 100644 index 000000000..6531871d8 --- /dev/null +++ b/main/src/tests/unix-socket-fetch.test.ts @@ -0,0 +1,463 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { EventEmitter } from 'node:events' +import http from 'node:http' + +vi.mock('node:http', () => { + const requestFn = vi.fn() + return { + default: { request: requestFn }, + request: requestFn, + } +}) + +vi.mock('electron', () => ({ + ipcMain: { + handle: vi.fn(), + removeHandler: vi.fn(), + }, +})) + +vi.mock('@sentry/electron/main', () => ({ + // suppressTracing should call its callback through and return its result. + suppressTracing: vi.fn((cb: () => T): T => cb()), +})) + +vi.mock('../toolhive-manager', () => ({ + getToolhiveSocketPath: vi.fn(() => '/tmp/toolhive.sock'), +})) + +vi.mock('../headers', () => ({ + getHeaders: vi.fn(() => ({ + 'X-Client-Type': 'studio', + 'X-Client-Version': '1.0.0', + 'X-Client-Platform': 'darwin' as NodeJS.Platform, + 'X-Client-Release-Build': false, + })), +})) + +vi.mock('../logger', () => ({ + default: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, +})) + +vi.mock('@common/api/generated/client', () => ({ + createClient: vi.fn((cfg: unknown) => ({ __client: true, cfg })), +})) + +import * as Sentry from '@sentry/electron/main' +import { ipcMain } from 'electron' +import { getToolhiveSocketPath } from '../toolhive-manager' +import { getHeaders } from '../headers' +import log from '../logger' +import { createClient } from '@common/api/generated/client' +import { + hasToolhiveConnection, + createMainProcessFetch, + createMainProcessApiClient, + registerApiFetchHandlers, +} from '../unix-socket-fetch' + +const stubClientHeaders: ReturnType = { + 'X-Client-Type': 'studio', + 'X-Client-Version': '1.0.0', + 'X-Client-Platform': 'darwin', + 'X-Client-Release-Build': false, +} + +const mockHttpRequest = vi.mocked(http.request) as unknown as ReturnType< + typeof vi.fn +> +const mockGetSocketPath = vi.mocked(getToolhiveSocketPath) +const mockGetHeaders = vi.mocked(getHeaders) +const mockSuppressTracing = vi.mocked(Sentry.suppressTracing) +const mockIpcHandle = vi.mocked(ipcMain.handle) +const mockIpcRemoveHandler = vi.mocked(ipcMain.removeHandler) +const mockCreateClient = vi.mocked(createClient) +const mockLog = vi.mocked(log) + +interface FakeRequest extends EventEmitter { + write: ReturnType + end: ReturnType + destroy: ReturnType +} + +interface FakeResponse extends EventEmitter { + statusCode?: number + headers: http.IncomingHttpHeaders +} + +/** + * Installs an http.request stub for a single request that resolves with the + * given response shape. Returns the captured request options + the fake req + * object so tests can assert wire shape. + */ +function setupHttpRequest(opts: { + status?: number + headers?: http.IncomingHttpHeaders + body?: string + error?: Error + /** + * If true, the fake request never emits a response or error; the test is + * responsible for driving it (e.g. for abort tests). + */ + manualDrive?: boolean +}) { + const fakeReq = Object.assign(new EventEmitter(), { + write: vi.fn(), + end: vi.fn(), + destroy: vi.fn(), + }) as FakeRequest + + let capturedOpts: http.RequestOptions | undefined + let capturedCb: ((res: FakeResponse) => void) | undefined + + mockHttpRequest.mockImplementationOnce( + ( + requestOpts: http.RequestOptions, + cb: (res: FakeResponse) => void + ): FakeRequest => { + capturedOpts = requestOpts + capturedCb = cb + + if (opts.manualDrive) return fakeReq + + if (opts.error) { + // Defer error emit to next microtask. + queueMicrotask(() => fakeReq.emit('error', opts.error)) + return fakeReq + } + + const fakeRes = Object.assign(new EventEmitter(), { + statusCode: opts.status, + headers: opts.headers ?? {}, + }) as FakeResponse + + queueMicrotask(() => { + cb(fakeRes) + if (opts.body) fakeRes.emit('data', Buffer.from(opts.body)) + fakeRes.emit('end') + }) + return fakeReq + } + ) + + return { + fakeReq, + getCapturedOpts: () => capturedOpts, + getCapturedCb: () => capturedCb, + } +} + +beforeEach(() => { + vi.clearAllMocks() + mockGetSocketPath.mockReturnValue('/tmp/toolhive.sock') + mockGetHeaders.mockReturnValue(stubClientHeaders) +}) + +afterEach(() => { + vi.resetAllMocks() +}) + +describe('hasToolhiveConnection', () => { + it('returns true when getToolhiveSocketPath returns a string', () => { + mockGetSocketPath.mockReturnValue('/tmp/foo.sock') + expect(hasToolhiveConnection()).toBe(true) + }) + + it('returns false when getToolhiveSocketPath returns null/undefined/empty', () => { + mockGetSocketPath.mockReturnValue(null as unknown as string) + expect(hasToolhiveConnection()).toBe(false) + mockGetSocketPath.mockReturnValue(undefined as unknown as string) + expect(hasToolhiveConnection()).toBe(false) + mockGetSocketPath.mockReturnValue('') + expect(hasToolhiveConnection()).toBe(false) + }) +}) + +describe('createMainProcessFetch', () => { + it('forwards method, path+search, headers, and body to http.request via the socket', async () => { + const { getCapturedOpts, fakeReq } = setupHttpRequest({ + status: 200, + headers: { 'content-type': 'application/json' }, + body: '{"ok":true}', + }) + + const f = createMainProcessFetch() + const response = await f('http://localhost/api/v1/foo?x=1', { + method: 'POST', + headers: { 'X-Custom': 'value' }, + body: 'payload', + }) + + expect(mockHttpRequest).toHaveBeenCalledTimes(1) + const captured = getCapturedOpts() + expect(captured).toMatchObject({ + socketPath: '/tmp/toolhive.sock', + method: 'POST', + path: '/api/v1/foo?x=1', + }) + expect(captured?.headers).toMatchObject({ 'x-custom': 'value' }) + expect(fakeReq.write).toHaveBeenCalledWith('payload') + expect(fakeReq.end).toHaveBeenCalledTimes(1) + + expect(response.status).toBe(200) + await expect(response.json()).resolves.toEqual({ ok: true }) + }) + + it.each([204, 205, 304])( + 'returns a null-body Response for status %i even when the socket returns a body', + async (status) => { + setupHttpRequest({ + status, + headers: {}, + body: 'should-be-ignored', + }) + + const f = createMainProcessFetch() + const response = await f('http://localhost/no-content') + + expect(response.status).toBe(status) + expect(response.body).toBeNull() + } + ) + + it('rejects with "No ToolHive socket available" when no socket is configured', async () => { + mockGetSocketPath.mockReturnValue(undefined as unknown as string) + + const f = createMainProcessFetch() + + await expect(f('http://localhost/anything')).rejects.toThrow( + 'No ToolHive socket available' + ) + expect(mockHttpRequest).not.toHaveBeenCalled() + }) + + it('does not call req.write when there is no body', async () => { + const { fakeReq } = setupHttpRequest({ + status: 200, + headers: {}, + body: '', + }) + + const f = createMainProcessFetch() + await f('http://localhost/get', { method: 'GET' }) + + expect(fakeReq.write).not.toHaveBeenCalled() + expect(fakeReq.end).toHaveBeenCalledTimes(1) + }) + + it('rejects when http.request emits an error', async () => { + setupHttpRequest({ error: new Error('ECONNREFUSED') }) + + const f = createMainProcessFetch() + await expect(f('http://localhost/foo')).rejects.toThrow('ECONNREFUSED') + }) + + it('serializes array-valued response headers by joining with ", "', async () => { + setupHttpRequest({ + status: 200, + headers: { 'set-cookie': ['a=1', 'b=2'] }, + body: '', + }) + + const f = createMainProcessFetch() + const response = await f('http://localhost/foo') + + expect(response.headers.get('set-cookie')).toBe('a=1, b=2') + }) + + it('drops undefined response headers', async () => { + setupHttpRequest({ + status: 200, + headers: { + 'x-keep': 'yes', + 'x-drop': undefined, + } as unknown as http.IncomingHttpHeaders, + body: '', + }) + + const f = createMainProcessFetch() + const response = await f('http://localhost/foo') + + expect(response.headers.get('x-keep')).toBe('yes') + expect(response.headers.get('x-drop')).toBeNull() + }) + + it('defaults status to 500 when res.statusCode is undefined', async () => { + setupHttpRequest({ + status: undefined, + headers: {}, + body: '', + }) + + const f = createMainProcessFetch() + const response = await f('http://localhost/foo') + + expect(response.status).toBe(500) + }) +}) + +describe('createMainProcessApiClient', () => { + it('delegates to createClient with sentinel http://localhost baseUrl, headers, and a custom fetch', () => { + mockGetHeaders.mockReturnValue(stubClientHeaders) + + const client = createMainProcessApiClient() + + expect(mockCreateClient).toHaveBeenCalledTimes(1) + const cfg = mockCreateClient.mock.calls[0]?.[0] as { + baseUrl: string + headers: ReturnType + fetch: unknown + } + expect(cfg.baseUrl).toBe('http://localhost') + expect(cfg.headers).toEqual(stubClientHeaders) + expect(typeof cfg.fetch).toBe('function') + expect(client).toEqual({ __client: true, cfg }) + }) +}) + +describe('registerApiFetchHandlers', () => { + type Handler = (event: unknown, ...args: unknown[]) => unknown + function getHandler(channel: string): Handler { + const call = mockIpcHandle.mock.calls.find(([c]) => c === channel) + if (!call) throw new Error(`handler for ${channel} not registered`) + return call[1] as Handler + } + + it('removes prior api-fetch and api-fetch-abort handlers before registering', () => { + registerApiFetchHandlers() + + expect(mockIpcRemoveHandler).toHaveBeenCalledWith('api-fetch') + expect(mockIpcRemoveHandler).toHaveBeenCalledWith('api-fetch-abort') + expect(mockIpcHandle).toHaveBeenCalledWith( + 'api-fetch', + expect.any(Function) + ) + expect(mockIpcHandle).toHaveBeenCalledWith( + 'api-fetch-abort', + expect.any(Function) + ) + }) + + it('api-fetch handler merges getHeaders() into the request, with renderer-supplied headers winning', async () => { + mockGetHeaders.mockReturnValue(stubClientHeaders) + + const { getCapturedOpts } = setupHttpRequest({ + status: 200, + headers: {}, + body: '{}', + }) + + registerApiFetchHandlers() + const handler = getHandler('api-fetch') + + await handler( + {}, + { + requestId: 'req-1', + method: 'GET', + path: '/foo', + headers: { + 'X-Client-Version': 'override', + 'X-Custom': 'hello', + }, + } + ) + + const captured = getCapturedOpts() + expect(captured?.headers).toMatchObject({ + 'X-Client-Type': 'studio', + // Renderer-supplied headers must win (spread order). + 'X-Client-Version': 'override', + 'X-Custom': 'hello', + }) + }) + + it('api-fetch handler runs inside Sentry.suppressTracing', async () => { + setupHttpRequest({ status: 200, headers: {}, body: '{}' }) + + registerApiFetchHandlers() + const handler = getHandler('api-fetch') + await handler( + {}, + { + requestId: 'req-2', + method: 'GET', + path: '/foo', + headers: {}, + } + ) + + expect(mockSuppressTracing).toHaveBeenCalledTimes(1) + expect(mockSuppressTracing.mock.calls[0]?.[0]).toBeInstanceOf(Function) + }) + + it('api-fetch handler logs and rethrows when the request errors', async () => { + setupHttpRequest({ error: new Error('socket gone') }) + + registerApiFetchHandlers() + const handler = getHandler('api-fetch') + + await expect( + handler( + {}, + { + requestId: 'req-3', + method: 'GET', + path: '/foo', + headers: {}, + } + ) + ).rejects.toThrow('socket gone') + + expect(mockLog.error).toHaveBeenCalledWith( + '[api-fetch] Request failed: GET /foo', + expect.any(Error) + ) + }) + + it('api-fetch-abort handler destroys the inflight request and is a no-op for unknown ids', async () => { + // Drive a request that we keep open via manualDrive so we can abort it. + const { fakeReq } = setupHttpRequest({ manualDrive: true }) + + registerApiFetchHandlers() + const apiHandler = getHandler('api-fetch') + const abortHandler = getHandler('api-fetch-abort') + + // Kick off the request without awaiting (it will be aborted). + const inflight = apiHandler( + {}, + { + requestId: 'req-abort', + method: 'POST', + path: '/foo', + headers: {}, + body: 'payload', + } + ) + + // Yield so the registration in the inflight map has a chance to land. + await new Promise((r) => setImmediate(r)) + + abortHandler({}, 'req-abort') + expect(fakeReq.destroy).toHaveBeenCalledTimes(1) + expect(mockLog.info).toHaveBeenCalledWith( + '[api-fetch] Aborted request req-abort' + ) + + // Surfacing the destroyed request as an error keeps the promise from leaking. + fakeReq.emit('error', new Error('aborted')) + await expect(inflight).rejects.toThrow('aborted') + + // Aborting a request that doesn't exist is a no-op (no destroy, no log). + fakeReq.destroy.mockClear() + mockLog.info.mockClear() + abortHandler({}, 'unknown-id') + expect(fakeReq.destroy).not.toHaveBeenCalled() + expect(mockLog.info).not.toHaveBeenCalled() + }) +}) diff --git a/main/src/toolhive-manager.ts b/main/src/toolhive-manager.ts index 27806b163..aa43d08b4 100644 --- a/main/src/toolhive-manager.ts +++ b/main/src/toolhive-manager.ts @@ -1,5 +1,5 @@ import { spawn } from 'node:child_process' -import { existsSync } from 'node:fs' +import { existsSync, unlinkSync } from 'node:fs' import path from 'node:path' import net from 'node:net' import { app } from 'electron' @@ -34,21 +34,26 @@ const binPath = app.isPackaged ) let toolhiveProcess: ReturnType | undefined -let toolhivePort: number | undefined let toolhiveMcpPort: number | undefined +let toolhiveSocketPath: string | undefined let isRestarting = false let killTimer: NodeJS.Timeout | undefined let processError: ToolhiveProcessError | undefined -export function getToolhivePort(): number | undefined { - return toolhivePort -} - export function getToolhiveMcpPort(): number | undefined { return toolhiveMcpPort } +export function getToolhiveSocketPath(): string | undefined { + return toolhiveSocketPath +} + export function isToolhiveRunning(): boolean { + // When THV_SOCKET points at an externally managed thv we never spawn a + // child process, but the API is still reachable. Treat that as "running" + // so renderer guards (e.g. setupSecretProvider) and tray UI behave the + // same as in the bundled-binary case. + if (isUsingCustomSocket()) return true const isRunning = !!toolhiveProcess && !toolhiveProcess.killed return isRunning } @@ -61,10 +66,28 @@ export function getToolhiveStatus(): ToolhiveStatus { } /** - * Returns whether the app is using a custom ToolHive port (externally managed thv). + * Returns whether the app is using an externally managed thv reachable over + * a custom UNIX socket / Windows named pipe (THV_SOCKET env var). + */ +export function isUsingCustomSocket(): boolean { + return !app.isPackaged && !!process.env.THV_SOCKET +} + +/** + * Parses THV_MCP_PORT into a positive integer. Returns `undefined` for + * unset / blank / non-numeric / out-of-range values and logs a warning so + * a typo doesn't silently turn into NaN being treated as a real port. */ -export function isUsingCustomPort(): boolean { - return !app.isPackaged && !!process.env.THV_PORT +function parseMcpPortEnv(raw: string | undefined): number | undefined { + if (!raw) return undefined + const port = Number(raw) + if (!Number.isInteger(port) || port <= 0 || port > 65535) { + log.warn( + `Ignoring invalid THV_MCP_PORT=${raw}; expected an integer in 1..65535` + ) + return undefined + } + return port } async function findFreePort( @@ -124,21 +147,37 @@ async function findFreePort( return await getRandomPort() } +function generateSocketPath(): string { + // Windows AF_UNIX sockets created in %TEMP% hit EACCES on connect due to + // DACL handling. Named pipes are the canonical Windows IPC and are + // supported natively by Node's http.request({ socketPath }) and Go's + // Microsoft/go-winio. + if (process.platform === 'win32') { + return `\\\\.\\pipe\\toolhive-${process.pid}` + } + const socketName = `toolhive-${process.pid}.sock` + return path.join(app.getPath('temp'), socketName) +} + +function cleanupSocketFile(socketPath: string): void { + // Named pipes are released by the kernel when the listener exits; there's + // no filesystem entry to remove. + if (process.platform === 'win32') return + try { + if (existsSync(socketPath)) { + unlinkSync(socketPath) + } + } catch { + // Ignore cleanup errors + } +} + export async function startToolhive(): Promise { Sentry.withScope>(async (scope) => { - if (isUsingCustomPort()) { - const customPort = parseInt(process.env.THV_PORT!, 10) - if (isNaN(customPort)) { - log.error( - `Invalid THV_PORT environment variable: ${process.env.THV_PORT}` - ) - return - } - toolhivePort = customPort - toolhiveMcpPort = process.env.THV_MCP_PORT - ? parseInt(process.env.THV_MCP_PORT!, 10) - : undefined - log.info(`Using external ToolHive on port ${toolhivePort}`) + if (isUsingCustomSocket()) { + toolhiveSocketPath = process.env.THV_SOCKET! + toolhiveMcpPort = parseMcpPortEnv(process.env.THV_MCP_PORT) + log.info(`Using external ToolHive on socket ${toolhiveSocketPath}`) return } @@ -149,9 +188,11 @@ export async function startToolhive(): Promise { processError = undefined toolhiveMcpPort = await findFreePort() - toolhivePort = await findFreePort(50000, 50100) + toolhiveSocketPath = generateSocketPath() + cleanupSocketFile(toolhiveSocketPath) + log.info( - `Starting ToolHive from: ${binPath} on port ${toolhivePort}, MCP on port ${toolhiveMcpPort}` + `Starting ToolHive from: ${binPath} on socket ${toolhiveSocketPath}, MCP on port ${toolhiveMcpPort}` ) const serveArgs = [ @@ -160,8 +201,7 @@ export async function startToolhive(): Promise { '--experimental-mcp', '--experimental-mcp-host=127.0.0.1', `--experimental-mcp-port=${toolhiveMcpPort}`, - '--host=127.0.0.1', - `--port=${toolhivePort}`, + `--socket=${toolhiveSocketPath}`, ] const isE2E = process.env.TOOLHIVE_E2E === 'true' @@ -187,11 +227,12 @@ export async function startToolhive(): Promise { TOOLHIVE_SKIP_DESKTOP_CHECK: 'true', }, }) + log.info(`[startToolhive] Process spawned with PID: ${toolhiveProcess.pid}`) scope.addBreadcrumb({ category: 'debug', - message: `Starting ToolHive from: ${binPath} on port ${toolhivePort}, MCP on port ${toolhiveMcpPort}, PID: ${toolhiveProcess.pid}`, + message: `Starting ToolHive from: ${binPath} on socket ${toolhiveSocketPath}, MCP on port ${toolhiveMcpPort}, PID: ${toolhiveProcess.pid}`, }) updateTrayStatus(!!toolhiveProcess) @@ -349,6 +390,10 @@ export function stopToolhive(options?: { force?: boolean }): void { scheduleForceKill(processToKill, pidToKill) } + if (toolhiveSocketPath) { + cleanupSocketFile(toolhiveSocketPath) + } + log.info(`[stopToolhive] Process cleanup completed`) } diff --git a/main/src/unix-socket-fetch.ts b/main/src/unix-socket-fetch.ts new file mode 100644 index 000000000..632864144 --- /dev/null +++ b/main/src/unix-socket-fetch.ts @@ -0,0 +1,206 @@ +import http from 'node:http' +import { ipcMain } from 'electron' +import * as Sentry from '@sentry/electron/main' +import log from './logger' +import { getToolhiveSocketPath } from './toolhive-manager' +import { getHeaders } from './headers' +import { createClient, type Client } from '@common/api/generated/client' + +interface ApiFetchRequest { + requestId: string + method: string + path: string + headers: Record + body?: string +} + +interface ApiFetchResponse { + status: number + headers: Record + body: string +} + +// Status codes where the Fetch spec forbids a response body. Constructing +// `new Response(body, { status })` with a non-null body for any of these +// throws `TypeError: Response with null body status cannot have body`. +const NULL_BODY_STATUSES = new Set([101, 204, 205, 304]) + +const inflightRequests = new Map() + +function serializeResponseHeaders( + raw: http.IncomingHttpHeaders +): Record { + const headers: Record = {} + for (const [key, value] of Object.entries(raw)) { + if (value !== undefined) { + headers[key] = Array.isArray(value) ? value.join(', ') : value + } + } + return headers +} + +function performRequest( + socketPath: string, + opts: { + method: string + path: string + headers: Record + body?: string + }, + requestId?: string +): Promise { + return new Promise((resolve, reject) => { + const req = http.request( + { + socketPath, + method: opts.method, + path: opts.path, + headers: opts.headers, + }, + (res) => { + const chunks: Buffer[] = [] + res.on('data', (chunk: Buffer) => chunks.push(chunk)) + res.on('end', () => { + if (requestId) inflightRequests.delete(requestId) + resolve({ + status: res.statusCode ?? 500, + headers: serializeResponseHeaders(res.headers), + body: Buffer.concat(chunks).toString('utf-8'), + }) + }) + } + ) + + if (requestId) inflightRequests.set(requestId, req) + + req.on('error', (err) => { + if (requestId) inflightRequests.delete(requestId) + reject(err) + }) + + if (opts.body) req.write(opts.body) + req.end() + }) +} + +function requireSocketPath(): string { + const socketPath = getToolhiveSocketPath() + if (!socketPath) { + throw new Error('No ToolHive socket available') + } + return socketPath +} + +/** + * Returns whether a thv socket is currently available. Use this to + * short-circuit code paths that would otherwise build a client eagerly during + * bootstrap when thv has not started yet. + */ +export function hasToolhiveConnection(): boolean { + return !!getToolhiveSocketPath() +} + +/** + * Creates a hey-api Client configured to talk to the local thv API. Routing + * is handled by `createMainProcessFetch`, so callers do not need to know + * whether the transport is a UNIX socket or TCP. The `baseUrl` is a sentinel + * — the custom fetch ignores it and dials the live transport. + */ +export function createMainProcessApiClient(): Client { + return createClient({ + baseUrl: 'http://localhost', + headers: getHeaders(), + fetch: createMainProcessFetch(), + }) +} + +/** + * Creates a `fetch`-compatible function that routes requests through the + * thv UNIX socket / Windows named pipe. Intended for use in the main process + * (e.g. the graceful-exit client) where Node.js APIs are available. + */ +export function createMainProcessFetch(): typeof fetch { + return async ( + input: RequestInfo | URL, + init?: RequestInit + ): Promise => { + const request = new Request(input, init) + const url = new URL(request.url) + const body = request.body + ? await new Response(request.body).text() + : undefined + + const result = await performRequest(requireSocketPath(), { + method: request.method, + path: url.pathname + url.search, + headers: Object.fromEntries(request.headers), + body, + }) + + const responseBody = NULL_BODY_STATUSES.has(result.status) + ? null + : result.body + + return new Response(responseBody, { + status: result.status, + headers: result.headers, + }) + } +} + +/** + * Registers IPC handlers that let the renderer make API requests through the + * main process. The main process then forwards them over the thv UNIX socket + * / Windows named pipe. + */ +export function registerApiFetchHandlers(): void { + ipcMain.removeHandler('api-fetch') + ipcMain.removeHandler('api-fetch-abort') + + ipcMain.handle( + 'api-fetch', + async (_event, opts: ApiFetchRequest): Promise => { + const rawHeaders = getHeaders() + const telemetryHeaders: Record = {} + for (const [k, v] of Object.entries(rawHeaders)) { + telemetryHeaders[k] = String(v) + } + const mergedHeaders = { ...telemetryHeaders, ...opts.headers } + + // Suppress main's Sentry tracing for this socket call. The renderer is + // the only owner of the trace context: it injects `sentry-trace` and + // `baggage` into the IPC payload from its own `getTraceData()` and + // those headers must reach thv unmodified. Without this, the + // `httpIntegration` from `@sentry/electron/main` would auto-instrument + // `http.request` and overwrite the headers with main's active scope, + // breaking distributed tracing across the studio + thv projects. + return Sentry.suppressTracing(() => + performRequest( + requireSocketPath(), + { + method: opts.method, + path: opts.path, + headers: mergedHeaders, + body: opts.body, + }, + opts.requestId + ).catch((err) => { + log.error( + `[api-fetch] Request failed: ${opts.method} ${opts.path}`, + err + ) + throw err + }) + ) + } + ) + + ipcMain.handle('api-fetch-abort', (_event, requestId: string) => { + const req = inflightRequests.get(requestId) + if (req) { + req.destroy() + inflightRequests.delete(requestId) + log.info(`[api-fetch] Aborted request ${requestId}`) + } + }) +} diff --git a/package.json b/package.json index a72327cc9..5b91f34b1 100644 --- a/package.json +++ b/package.json @@ -10,6 +10,7 @@ }, "scripts": { "start": "electron-forge start", + "start:inspect": "electron-forge start -- --inspect --experimental-network-inspection", "rebuild": "electron-rebuild -f -w better-sqlite3", "postinstall": "pnpm run rebuild", "e2e": "pnpm package && pnpm e2e:prebuilt", diff --git a/preload/src/api/toolhive.ts b/preload/src/api/toolhive.ts index f99632fbf..b192196f1 100644 --- a/preload/src/api/toolhive.ts +++ b/preload/src/api/toolhive.ts @@ -4,15 +4,25 @@ import { TOOLHIVE_VERSION } from '../../../utils/constants' import type { ToolhiveStatus } from '../../../common/types/toolhive-status' export const toolhiveApi = { - getToolhivePort: () => ipcRenderer.invoke('get-toolhive-port'), getToolhiveMcpPort: () => ipcRenderer.invoke('get-toolhive-mcp-port'), + getToolhiveSocketPath: () => ipcRenderer.invoke('get-toolhive-socket-path'), getToolhiveVersion: () => TOOLHIVE_VERSION, isToolhiveRunning: () => ipcRenderer.invoke('is-toolhive-running'), getToolhiveStatus: () => ipcRenderer.invoke('get-toolhive-status'), - isUsingCustomPort: () => ipcRenderer.invoke('is-using-custom-port'), + isUsingCustomSocket: () => ipcRenderer.invoke('is-using-custom-socket'), checkContainerEngine: () => ipcRenderer.invoke('check-container-engine'), restartToolhive: () => ipcRenderer.invoke('restart-toolhive'), + apiFetch: (req: { + requestId: string + method: string + path: string + headers: Record + body?: string + }) => ipcRenderer.invoke('api-fetch', req), + apiFetchAbort: (requestId: string) => + ipcRenderer.invoke('api-fetch-abort', requestId), + shutdownStore: { getLastShutdownServers: () => ipcRenderer.invoke('shutdown-store:get-last-servers'), @@ -22,12 +32,12 @@ export const toolhiveApi = { } export interface ToolhiveAPI { - getToolhivePort: () => Promise getToolhiveMcpPort: () => Promise + getToolhiveSocketPath: () => Promise getToolhiveVersion: () => string isToolhiveRunning: () => Promise getToolhiveStatus: () => Promise - isUsingCustomPort: () => Promise + isUsingCustomSocket: () => Promise checkContainerEngine: () => Promise<{ docker: boolean podman: boolean @@ -38,6 +48,18 @@ export interface ToolhiveAPI { success: boolean error?: string }> + apiFetch: (req: { + requestId: string + method: string + path: string + headers: Record + body?: string + }) => Promise<{ + status: number + headers: Record + body: string + }> + apiFetchAbort: (requestId: string) => Promise shutdownStore: { getLastShutdownServers: () => Promise clearShutdownHistory: () => Promise<{ success: boolean }> diff --git a/renderer/src/common/components/custom-port-banner.tsx b/renderer/src/common/components/custom-socket-banner.tsx similarity index 54% rename from renderer/src/common/components/custom-port-banner.tsx rename to renderer/src/common/components/custom-socket-banner.tsx index aee6b4cea..2d535a5f9 100644 --- a/renderer/src/common/components/custom-port-banner.tsx +++ b/renderer/src/common/components/custom-socket-banner.tsx @@ -4,34 +4,32 @@ import { Alert, AlertDescription } from './ui/alert' import log from 'electron-log/renderer' /** - * Banner that displays a warning when using a custom ToolHive port in development mode. - * Only visible when THV_PORT environment variable is set. + * Banner that warns the developer when the studio is talking to an + * externally-managed `thv` over a custom UNIX socket / Windows named pipe + * (THV_SOCKET env var). Visible in development only. */ -export function CustomPortBanner() { - const [isCustomPort, setIsCustomPort] = useState(false) - const [port, setPort] = useState(undefined) +export function CustomSocketBanner() { + const [isCustomSocket, setIsCustomSocket] = useState(false) + const [socketPath, setSocketPath] = useState(undefined) useEffect(() => { Promise.all([ - window.electronAPI.isUsingCustomPort(), - window.electronAPI.getToolhivePort(), + window.electronAPI.isUsingCustomSocket(), + window.electronAPI.getToolhiveSocketPath(), ]) - .then(([usingCustom, toolhivePort]) => { - setIsCustomPort(usingCustom) - setPort(toolhivePort) + .then(([usingCustom, path]) => { + setIsCustomSocket(usingCustom) + setSocketPath(path) }) .catch((error: unknown) => { - log.error('Failed to get custom port info:', error) + log.error('Failed to get custom socket info:', error) }) }, []) - // Don't render if not using custom port or port is not available - if (!isCustomPort || !port) { + if (!isCustomSocket || !socketPath) { return null } - const httpAddress = `http://127.0.0.1:${port}` - return ( - {httpAddress} + {socketPath} diff --git a/renderer/src/common/lib/__tests__/ipc-fetch.test.ts b/renderer/src/common/lib/__tests__/ipc-fetch.test.ts new file mode 100644 index 000000000..d41a1955e --- /dev/null +++ b/renderer/src/common/lib/__tests__/ipc-fetch.test.ts @@ -0,0 +1,248 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest' +import * as Sentry from '@sentry/electron/renderer' +import { ipcFetch } from '../ipc-fetch' + +vi.mock('@sentry/electron/renderer', () => ({ + getTraceData: vi.fn(() => ({})), +})) + +const getTraceDataMock = vi.mocked(Sentry.getTraceData) + +function mockApiFetch(impl?: Parameters[0]) { + const fn = impl ? vi.fn(impl) : vi.fn() + window.electronAPI.apiFetch = + fn as unknown as typeof window.electronAPI.apiFetch + return fn +} + +function mockApiFetchAbort() { + const fn = vi.fn().mockResolvedValue(undefined) + window.electronAPI.apiFetchAbort = + fn as unknown as typeof window.electronAPI.apiFetchAbort + return fn +} + +const okResult = { + status: 200, + headers: { 'content-type': 'application/json' }, + body: '{"hello":"world"}', +} + +describe('ipcFetch', () => { + beforeEach(() => { + getTraceDataMock.mockReturnValue({}) + }) + + it('forwards method/path/headers/body to apiFetch and returns a Response', async () => { + const apiFetch = mockApiFetch(async () => okResult) + + const response = await ipcFetch( + 'http://localhost/api/v1/workloads?all=true', + { + method: 'POST', + headers: { 'X-Custom': 'value', 'content-type': 'application/json' }, + body: '{"foo":"bar"}', + } + ) + + expect(apiFetch).toHaveBeenCalledTimes(1) + const args = apiFetch.mock.calls[0]?.[0] + expect(args).toMatchObject({ + method: 'POST', + path: '/api/v1/workloads?all=true', + body: '{"foo":"bar"}', + headers: expect.objectContaining({ + 'x-custom': 'value', + 'content-type': 'application/json', + }), + }) + expect(typeof args.requestId).toBe('string') + expect(args.requestId).toMatch(/^req-/) + + expect(response).toBeInstanceOf(Response) + expect(response.status).toBe(200) + expect(response.headers.get('content-type')).toBe('application/json') + await expect(response.json()).resolves.toEqual({ hello: 'world' }) + }) + + it('generates a fresh requestId per call', async () => { + const apiFetch = mockApiFetch(async () => okResult) + + await ipcFetch('http://localhost/a') + await ipcFetch('http://localhost/b') + + const idA = apiFetch.mock.calls[0]?.[0].requestId + const idB = apiFetch.mock.calls[1]?.[0].requestId + expect(idA).toBeTruthy() + expect(idB).toBeTruthy() + expect(idA).not.toBe(idB) + }) + + it('merges Sentry trace headers (sentry-trace, baggage, traceparent)', async () => { + getTraceDataMock.mockReturnValue({ + 'sentry-trace': 'trace-123', + baggage: 'sentry-trace_id=abc', + traceparent: '00-trace-123-span-456-01', + }) + const apiFetch = mockApiFetch(async () => okResult) + + await ipcFetch('http://localhost/foo') + + expect(getTraceDataMock).toHaveBeenCalledWith({ + propagateTraceparent: true, + }) + const headers = apiFetch.mock.calls[0]?.[0].headers + expect(headers['sentry-trace']).toBe('trace-123') + expect(headers.baggage).toBe('sentry-trace_id=abc') + expect(headers.traceparent).toBe('00-trace-123-span-456-01') + }) + + it('omits Sentry trace headers that are not returned by getTraceData', async () => { + getTraceDataMock.mockReturnValue({ + 'sentry-trace': 'trace-only', + }) + const apiFetch = mockApiFetch(async () => okResult) + + await ipcFetch('http://localhost/foo') + + const headers = apiFetch.mock.calls[0]?.[0].headers + expect(headers['sentry-trace']).toBe('trace-only') + expect(headers.baggage).toBeUndefined() + expect(headers.traceparent).toBeUndefined() + }) + + it('lets Sentry trace headers win over caller-supplied trace headers', async () => { + getTraceDataMock.mockReturnValue({ + 'sentry-trace': 'sentry-wins', + traceparent: '00-fresh-trace-id-fresh-span-01', + }) + const apiFetch = mockApiFetch(async () => okResult) + + await ipcFetch('http://localhost/foo', { + headers: { + 'sentry-trace': 'caller-supplied', + traceparent: 'caller-supplied-traceparent', + }, + }) + + const headers = apiFetch.mock.calls[0]?.[0].headers + expect(headers['sentry-trace']).toBe('sentry-wins') + expect(headers.traceparent).toBe('00-fresh-trace-id-fresh-span-01') + }) + + // 101 is also in NULL_BODY_STATUSES but jsdom's Response constructor rejects + // statuses outside 200-599, so we exercise the remaining null-body codes here. + it.each([204, 205, 304])( + 'returns a null-body Response for status %i even if main returns a body string', + async (status) => { + mockApiFetch(async () => ({ + status, + headers: {}, + body: 'should-be-ignored', + })) + + const response = await ipcFetch('http://localhost/no-content') + + expect(response.status).toBe(status) + expect(response.body).toBeNull() + } + ) + + it('returns the body from main for non-null-body statuses (e.g. 200)', async () => { + mockApiFetch(async () => ({ + status: 200, + headers: { 'content-type': 'text/plain' }, + body: 'hello', + })) + + const response = await ipcFetch('http://localhost/text') + + await expect(response.text()).resolves.toBe('hello') + }) + + it('throws AbortError synchronously when the signal is already aborted, without calling apiFetch', async () => { + const apiFetch = mockApiFetch(async () => okResult) + const controller = new AbortController() + controller.abort() + + await expect( + ipcFetch('http://localhost/x', { signal: controller.signal }) + ).rejects.toMatchObject({ + name: 'AbortError', + }) + expect(apiFetch).not.toHaveBeenCalled() + }) + + it('aborting in-flight calls apiFetchAbort with the requestId and surfaces an AbortError', async () => { + // In real life, apiFetchAbort destroys the underlying http.ClientRequest + // which causes apiFetch to reject. We simulate that here so that the + // catch branch with `signal.aborted` runs. + let rejectApiFetch: (err: unknown) => void = () => {} + const apiFetch = mockApiFetch( + () => + new Promise((_, reject) => { + rejectApiFetch = reject + }) + ) + const apiFetchAbort = mockApiFetchAbort() + + const controller = new AbortController() + const promise = ipcFetch('http://localhost/x', { + signal: controller.signal, + }) + + // Wait a tick so the abort handler is wired up. + await new Promise((r) => setTimeout(r, 0)) + + controller.abort() + rejectApiFetch(new Error('Request aborted by client')) + + await expect(promise).rejects.toMatchObject({ name: 'AbortError' }) + expect(apiFetch).toHaveBeenCalledTimes(1) + expect(apiFetchAbort).toHaveBeenCalledTimes(1) + expect(apiFetchAbort).toHaveBeenCalledWith( + apiFetch.mock.calls[0]?.[0].requestId + ) + }) + + // Behavioral check: aborting the controller AFTER ipcFetch settles must not + // call apiFetchAbort, which proves the listener was removed in `finally`. + it('removes the abort listener on success', async () => { + mockApiFetch(async () => okResult) + const apiFetchAbort = mockApiFetchAbort() + const controller = new AbortController() + + await ipcFetch('http://localhost/x', { signal: controller.signal }) + + controller.abort() + await new Promise((r) => setTimeout(r, 0)) + + expect(apiFetchAbort).not.toHaveBeenCalled() + }) + + it('removes the abort listener when apiFetch rejects', async () => { + mockApiFetch(async () => { + throw new Error('boom') + }) + const apiFetchAbort = mockApiFetchAbort() + const controller = new AbortController() + + await expect( + ipcFetch('http://localhost/x', { signal: controller.signal }) + ).rejects.toThrow('boom') + + controller.abort() + await new Promise((r) => setTimeout(r, 0)) + + expect(apiFetchAbort).not.toHaveBeenCalled() + }) + + it('propagates non-abort errors from apiFetch unchanged', async () => { + const failure = new Error('IPC bridge unavailable') + mockApiFetch(async () => { + throw failure + }) + + await expect(ipcFetch('http://localhost/x')).rejects.toBe(failure) + }) +}) diff --git a/renderer/src/common/lib/ipc-fetch.ts b/renderer/src/common/lib/ipc-fetch.ts new file mode 100644 index 000000000..4206e6a65 --- /dev/null +++ b/renderer/src/common/lib/ipc-fetch.ts @@ -0,0 +1,101 @@ +import * as Sentry from '@sentry/electron/renderer' + +let requestCounter = 0 + +function nextRequestId(): string { + return `req-${Date.now()}-${++requestCounter}` +} + +// Status codes where the browser forbids a response body (fetch spec). +const NULL_BODY_STATUSES = new Set([101, 204, 205, 304]) + +/** + * Returns the trace-propagation headers for the renderer's active span: + * + * - `sentry-trace` / `baggage` — Sentry's native format (used by Sentry SDKs + * that read incoming requests directly). + * - `traceparent` — W3C trace context (used by OTEL-based middleware, e.g. + * the thv API server's `otelhttp` middleware). + * + * `getTraceData({ propagateTraceparent: true })` is Sentry's documented + * helper for emitting both formats; passing the option avoids hand-rolling + * the `sentry-trace -> traceparent` conversion. + */ +function getSentryTraceHeaders(): Record { + const traceData = Sentry.getTraceData({ propagateTraceparent: true }) + const headers: Record = {} + if (traceData['sentry-trace']) { + headers['sentry-trace'] = traceData['sentry-trace'] + } + if (traceData.baggage) { + headers.baggage = traceData.baggage + } + if (traceData.traceparent) { + headers.traceparent = traceData.traceparent + } + return headers +} + +/** + * A `fetch`-compatible function that routes HTTP requests through the Electron + * IPC bridge. The main process forwards them to the thv server over a UNIX + * socket / Windows named pipe. + * + * Plug this into the hey-api client via `client.setConfig({ fetch: ipcFetch })` + * so all generated SDK calls transparently use the IPC transport. + */ +export const ipcFetch: typeof fetch = async ( + input: RequestInfo | URL, + init?: RequestInit +): Promise => { + const request = new Request(input, init) + const url = new URL(request.url) + const requestId = nextRequestId() + + const body = request.body + ? await new Response(request.body).text() + : undefined + + if (request.signal?.aborted) { + throw new DOMException('The operation was aborted.', 'AbortError') + } + + const abortHandler = () => { + window.electronAPI.apiFetchAbort(requestId) + } + + request.signal?.addEventListener('abort', abortHandler, { once: true }) + + const requestHeaders = Object.fromEntries(request.headers) + + try { + const result = await window.electronAPI.apiFetch({ + requestId, + method: request.method, + path: url.pathname + url.search, + headers: { + ...requestHeaders, + ...getSentryTraceHeaders(), + }, + body, + }) + + // The browser's Response constructor throws if you provide a body for + // null-body status codes (204, 304, etc.). + const responseBody = NULL_BODY_STATUSES.has(result.status) + ? null + : result.body + + return new Response(responseBody, { + status: result.status, + headers: result.headers, + }) + } catch (err) { + if (request.signal?.aborted) { + throw new DOMException('The operation was aborted.', 'AbortError') + } + throw err + } finally { + request.signal?.removeEventListener('abort', abortHandler) + } +} diff --git a/renderer/src/common/mocks/electronAPI.ts b/renderer/src/common/mocks/electronAPI.ts index e736a8892..6a0f20fd7 100644 --- a/renderer/src/common/mocks/electronAPI.ts +++ b/renderer/src/common/mocks/electronAPI.ts @@ -46,6 +46,10 @@ function createElectronStub(): Partial { getPageSize: vi.fn().mockResolvedValue(undefined), setPageSize: vi.fn().mockResolvedValue(undefined), } as ElectronAPI['uiPreferences'], + apiFetch: vi + .fn() + .mockResolvedValue({ status: 200, headers: {}, body: '{}' }), + apiFetchAbort: vi.fn().mockResolvedValue(undefined), chat: { stream: vi.fn(), resumeStream: vi.fn().mockResolvedValue(null), diff --git a/renderer/src/lib/client-config.ts b/renderer/src/lib/client-config.ts index 25a4d7dd0..b4c93f50b 100644 --- a/renderer/src/lib/client-config.ts +++ b/renderer/src/lib/client-config.ts @@ -1,18 +1,23 @@ import { client } from '@common/api/generated/client.gen' +import { ipcFetch } from '../common/lib/ipc-fetch' import log from 'electron-log/renderer' export async function configureClient() { try { - const port = await window.electronAPI.getToolhivePort() + // All API requests are routed through the main process via IPC. The main + // process forwards them to the thv server over a UNIX socket (or TCP + // fallback). The baseUrl is a dummy used only for URL construction inside + // the hey-api client; the ipcFetch adapter strips it and sends only the + // path + query to the main process. const telemetryHeaders = await window.electronAPI.getTelemetryHeaders() - const baseUrl = `http://localhost:${port}` client.setConfig({ - baseUrl, + baseUrl: 'http://localhost', + fetch: ipcFetch, headers: telemetryHeaders, }) } catch (e) { - log.error('Failed to get ToolHive port from main process: ', e) + log.error('Failed to configure ToolHive API client: ', e) throw e } } diff --git a/renderer/src/renderer.tsx b/renderer/src/renderer.tsx index 14d3b4aca..3c95d29d4 100644 --- a/renderer/src/renderer.tsx +++ b/renderer/src/renderer.tsx @@ -20,8 +20,8 @@ import './common/lib/os-design' initSentry() -if (!window.electronAPI || !window.electronAPI.getToolhivePort) { - log.error('ToolHive port API not available in renderer') +if (!window.electronAPI || !window.electronAPI.apiFetch) { + log.error('ToolHive API bridge not available in renderer') } ;(async () => { diff --git a/renderer/src/routes/__root.tsx b/renderer/src/routes/__root.tsx index 8ade01eea..1f1efe448 100644 --- a/renderer/src/routes/__root.tsx +++ b/renderer/src/routes/__root.tsx @@ -20,7 +20,7 @@ import '@fontsource/atkinson-hyperlegible/700-italic.css' import '@fontsource-variable/inter/wght.css' import '@fontsource-variable/merriweather/wght.css' import log from 'electron-log/renderer' -import { CustomPortBanner } from '@/common/components/custom-port-banner' +import { CustomSocketBanner } from '@/common/components/custom-socket-banner' import { NewsletterModal } from '@/common/components/newsletter-modal' import { NewsletterModalProvider } from '@/common/contexts/newsletter-modal-provider' import { ExpertConsultationBanner } from '@/common/components/expert-consultation-banner' @@ -47,7 +47,7 @@ function RootComponent() { return ( {!hideNav && } - {!hideNav && import.meta.env.DEV && } + {!hideNav && import.meta.env.DEV && }
-interface ImportBaseApiEnv { +interface ImportMetaEnv { readonly VITE_BASE_API_URL: string } - -// Extend renderer env typings for custom development flag -interface ImportMetaEnv extends ImportBaseApiEnv { - readonly THV_PORT?: string -} diff --git a/tsconfig.app.json b/tsconfig.app.json index ee0fc3db3..ce15e0a9c 100644 --- a/tsconfig.app.json +++ b/tsconfig.app.json @@ -3,7 +3,7 @@ "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo", "target": "ES2023", "useDefineForClassFields": true, - "lib": ["ES2023", "DOM"], + "lib": ["ES2023", "DOM", "DOM.Iterable"], "module": "ESNext", "skipLibCheck": true, /* Bundler mode */ diff --git a/tsconfig.node.json b/tsconfig.node.json index 21f2e3490..508bb6635 100644 --- a/tsconfig.node.json +++ b/tsconfig.node.json @@ -2,7 +2,7 @@ "compilerOptions": { "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo", "target": "ES2023", - "lib": ["ES2023", "DOM"], + "lib": ["ES2023", "DOM", "DOM.Iterable"], "module": "ESNext", "skipLibCheck": true, /* Bundler mode */