diff --git a/src/lib/socket.ts b/src/lib/socket.ts index aedc8110..3f9431d7 100644 --- a/src/lib/socket.ts +++ b/src/lib/socket.ts @@ -14,6 +14,8 @@ export interface ISocket { readyState: number; } +export type SocketFactory = (url: string) => ISocket; + export class TCPSocket implements ISocket { private clientId: string; private isConnected = false; @@ -161,7 +163,7 @@ export class WebSocketWrapper implements ISocket { } } -export function createSocket(url: string): ISocket { +const defaultSocketFactory: SocketFactory = (url: string) => { if (url.startsWith("wss://")) { return new WebSocketWrapper(url); } @@ -169,4 +171,18 @@ export function createSocket(url: string): ISocket { return new TCPSocket(url); } throw new Error("Unsupported socket protocol"); +}; + +let socketFactory: SocketFactory = defaultSocketFactory; + +export function createSocket(url: string): ISocket { + return socketFactory(url); +} + +export function setSocketFactory(factory: SocketFactory): void { + socketFactory = factory; +} + +export function resetSocketFactory(): void { + socketFactory = defaultSocketFactory; } diff --git a/tests/lib/socketFactory.test.ts b/tests/lib/socketFactory.test.ts new file mode 100644 index 00000000..6a8e30d5 --- /dev/null +++ b/tests/lib/socketFactory.test.ts @@ -0,0 +1,91 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +vi.mock("@tauri-apps/api/core", () => ({ + invoke: vi.fn(() => Promise.resolve()), +})); + +vi.mock("@tauri-apps/api/event", () => ({ + listen: vi.fn(() => Promise.resolve(() => {})), +})); + +class MockWebSocket { + static readonly CONNECTING = 0; + static readonly OPEN = 1; + static readonly CLOSING = 2; + static readonly CLOSED = 3; + + readonly CONNECTING = 0; + readonly OPEN = 1; + readonly CLOSING = 2; + readonly CLOSED = 3; + + readyState = MockWebSocket.CONNECTING; + bufferedAmount = 0; + extensions = ""; + protocol = ""; + binaryType: BinaryType = "blob"; + url: string; + onopen: ((event: Event) => void) | null = null; + onmessage: ((event: MessageEvent) => void) | null = null; + onclose: ((event: CloseEvent) => void) | null = null; + onerror: ((event: Event) => void) | null = null; + + constructor(url: string) { + this.url = url; + } + + send(): void {} + + close(): void { + this.readyState = MockWebSocket.CLOSED; + } +} + +vi.stubGlobal("WebSocket", MockWebSocket); + +import { + createSocket, + resetSocketFactory, + setSocketFactory, + TCPSocket, + WebSocketWrapper, +} from "../../src/lib/socket"; + +describe("socket factory", () => { + afterEach(() => { + resetSocketFactory(); + vi.clearAllMocks(); + }); + + it("allows a custom socket factory to be injected", () => { + const fakeSocket = { + onopen: null, + onmessage: null, + onerror: null, + onclose: null, + send: vi.fn(), + close: vi.fn(), + readyState: 1, + }; + const factory = vi.fn(() => fakeSocket); + + setSocketFactory(factory); + + const socket = createSocket("wss://irc.example.com:443"); + + expect(factory).toHaveBeenCalledWith("wss://irc.example.com:443"); + expect(socket).toBe(fakeSocket); + }); + + it("keeps websocket routing as the default behavior", () => { + const socket = createSocket("wss://irc.example.com:443"); + + expect(socket).toBeInstanceOf(WebSocketWrapper); + }); + + it("keeps tauri tcp/tls routing as the default behavior", () => { + const socket = createSocket("ircs://irc.example.com:6697"); + + expect(socket).toBeInstanceOf(TCPSocket); + }); +});