diff --git a/test/unit/adapters/redis-adapter.spec.ts b/test/unit/adapters/redis-adapter.spec.ts new file mode 100644 index 00000000..d37c4b96 --- /dev/null +++ b/test/unit/adapters/redis-adapter.spec.ts @@ -0,0 +1,196 @@ +import chai from 'chai' +import chaiAsPromised from 'chai-as-promised' +import Sinon from 'sinon' +import sinonChai from 'sinon-chai' + +chai.use(sinonChai) +chai.use(chaiAsPromised) + +const { expect } = chai + +import { RedisAdapter } from '../../../src/adapters/redis-adapter' + +describe('RedisAdapter', () => { + let sandbox: Sinon.SinonSandbox + let client: any + let adapter: RedisAdapter + + let originalConsoleError: typeof console.error + + beforeEach(() => { + sandbox = Sinon.createSandbox() + originalConsoleError = console.error + console.error = () => undefined + + client = { + connect: sandbox.stub().resolves(), + on: sandbox.stub().returnsThis(), + exists: sandbox.stub(), + get: sandbox.stub(), + set: sandbox.stub(), + zRemRangeByScore: sandbox.stub(), + zRange: sandbox.stub(), + expire: sandbox.stub(), + zAdd: sandbox.stub(), + removeListener: sandbox.stub(), + once: sandbox.stub(), + } + + adapter = new RedisAdapter(client) + }) + + afterEach(() => { + console.error = originalConsoleError + sandbox.restore() + }) + + describe('constructor', () => { + it('calls client.connect()', () => { + expect(client.connect).to.have.been.calledOnce + }) + + it('registers event listeners for connect, ready, error, and reconnecting', () => { + expect(client.on).to.have.been.calledWith('connect') + expect(client.on).to.have.been.calledWith('ready') + expect(client.on).to.have.been.calledWith('error') + expect(client.on).to.have.been.calledWith('reconnecting') + }) + }) + + describe('constructor error handling', () => { + it('handles connection rejection without throwing', () => { + const failingClient = { + connect: sandbox.stub().rejects(new Error('connection refused')), + on: sandbox.stub().returnsThis(), + } + + expect(() => new RedisAdapter(failingClient as any)).not.to.throw() + }) + }) + + describe('hasKey', () => { + it('awaits connection and calls client.exists with the key', async () => { + client.exists.returns(1) + + const result = await adapter.hasKey('test-key') + + expect(client.exists).to.have.been.calledOnceWithExactly('test-key') + expect(result).to.be.true + }) + + it('returns false when key does not exist', async () => { + client.exists.returns(0) + + const result = await adapter.hasKey('missing-key') + + expect(result).to.be.false + }) + }) + + describe('getKey', () => { + it('awaits connection and calls client.get with the key', async () => { + client.get.resolves('test-value') + + const result = await adapter.getKey('test-key') + + expect(client.get).to.have.been.calledOnceWithExactly('test-key') + expect(result).to.equal('test-value') + }) + + it('returns null when key does not exist', async () => { + client.get.resolves(null) + + const result = await adapter.getKey('missing-key') + + expect(result).to.be.null + }) + }) + + describe('setKey', () => { + it('returns true when client.set returns OK', async () => { + client.set.resolves('OK') + + const result = await adapter.setKey('key', 'value') + + expect(client.set).to.have.been.calledOnceWithExactly('key', 'value') + expect(result).to.be.true + }) + + it('returns false when client.set does not return OK', async () => { + client.set.resolves(null) + + const result = await adapter.setKey('key', 'value') + + expect(result).to.be.false + }) + }) + + describe('removeRangeByScoreFromSortedSet', () => { + it('calls client.zRemRangeByScore with correct arguments', async () => { + client.zRemRangeByScore.resolves(3) + + const result = await adapter.removeRangeByScoreFromSortedSet('sorted-key', 10, 20) + + expect(client.zRemRangeByScore).to.have.been.calledOnceWithExactly('sorted-key', 10, 20) + expect(result).to.equal(3) + }) + }) + + describe('getRangeFromSortedSet', () => { + it('calls client.zRange with correct arguments', async () => { + client.zRange.resolves(['a', 'b', 'c']) + + const result = await adapter.getRangeFromSortedSet('sorted-key', 0, 10) + + expect(client.zRange).to.have.been.calledOnceWithExactly('sorted-key', 0, 10) + expect(result).to.deep.equal(['a', 'b', 'c']) + }) + + it('returns empty array when set is empty', async () => { + client.zRange.resolves([]) + + const result = await adapter.getRangeFromSortedSet('empty-key', 0, 10) + + expect(result).to.deep.equal([]) + }) + }) + + describe('setKeyExpiry', () => { + it('calls client.expire with correct arguments', async () => { + client.expire.resolves(true) + + await adapter.setKeyExpiry('key', 3600) + + expect(client.expire).to.have.been.calledOnceWithExactly('key', 3600) + }) + }) + + describe('addToSortedSet', () => { + it('transforms record entries to score/value members and calls client.zAdd', async () => { + client.zAdd.resolves(2) + + const set = { 'member1': '100', 'member2': '200' } + const result = await adapter.addToSortedSet('sorted-key', set) + + expect(client.zAdd).to.have.been.calledOnce + const callArgs = client.zAdd.firstCall.args + expect(callArgs[0]).to.equal('sorted-key') + expect(callArgs[1]).to.deep.include.members([ + { score: 100, value: 'member1' }, + { score: 200, value: 'member2' }, + ]) + expect(result).to.equal(2) + }) + + it('handles a single entry', async () => { + client.zAdd.resolves(1) + + const set = { 'only-member': '50' } + const result = await adapter.addToSortedSet('sorted-key', set) + + const callArgs = client.zAdd.firstCall.args + expect(callArgs[1]).to.deep.equal([{ score: 50, value: 'only-member' }]) + expect(result).to.equal(1) + }) + }) +}) diff --git a/test/unit/adapters/web-server-adapter.spec.ts b/test/unit/adapters/web-server-adapter.spec.ts new file mode 100644 index 00000000..b9fdb02d --- /dev/null +++ b/test/unit/adapters/web-server-adapter.spec.ts @@ -0,0 +1,155 @@ +import chai from 'chai' +import Sinon from 'sinon' +import sinonChai from 'sinon-chai' + +chai.use(sinonChai) + +const { expect } = chai + +import { WebServerAdapter } from '../../../src/adapters/web-server-adapter' + +describe('WebServerAdapter', () => { + let sandbox: Sinon.SinonSandbox + let webServer: any + let adapter: WebServerAdapter + + let originalConsoleError: typeof console.error + + beforeEach(() => { + sandbox = Sinon.createSandbox() + originalConsoleError = console.error + console.error = () => undefined + + webServer = { + on: sandbox.stub().returnsThis(), + once: sandbox.stub().returnsThis(), + listen: sandbox.stub(), + close: sandbox.stub(), + removeAllListeners: sandbox.stub(), + } + + adapter = new WebServerAdapter(webServer) + }) + + afterEach(() => { + console.error = originalConsoleError + sandbox.restore() + adapter.removeAllListeners() + }) + + describe('constructor', () => { + it('registers error event listener on webServer', () => { + expect(webServer.on).to.have.been.calledWith('error') + }) + + it('registers clientError event listener on webServer', () => { + expect(webServer.on).to.have.been.calledWith('clientError') + }) + + it('registers close event listener on webServer', () => { + expect(webServer.once).to.have.been.calledWith('close') + }) + + it('registers listening event listener on webServer', () => { + expect(webServer.once).to.have.been.calledWith('listening') + }) + }) + + describe('listen', () => { + it('calls webServer.listen with the given port', () => { + adapter.listen(8080) + + expect(webServer.listen).to.have.been.calledOnceWithExactly(8080) + }) + }) + + describe('close', () => { + it('calls webServer.close', () => { + adapter.close() + + expect(webServer.close).to.have.been.calledOnce + }) + + it('invokes callback after close completes', () => { + const callback = sandbox.stub() + webServer.close.callsFake((cb: () => void) => cb()) + + adapter.close(callback) + + expect(callback).to.have.been.calledOnce + }) + + it('removes all listeners from webServer after close', () => { + webServer.close.callsFake((cb: () => void) => cb()) + + adapter.close() + + expect(webServer.removeAllListeners).to.have.been.calledOnce + }) + + it('does not throw if callback is undefined', () => { + webServer.close.callsFake((cb: () => void) => cb()) + + expect(() => adapter.close()).not.to.throw() + }) + }) + + describe('onClientError', () => { + it('ignores ECONNRESET errors', () => { + const error: any = new Error('connection reset') + error.code = 'ECONNRESET' + const socket: any = { writable: true, end: sandbox.stub() } + + // Access private method through event handler + // Find the clientError handler registered in constructor + const clientErrorCall = webServer.on.getCalls().find( + (call: any) => call.args[0] === 'clientError' + ) + const handler = clientErrorCall.args[1] + + handler(error, socket) + + expect(socket.end).not.to.have.been.called + }) + + it('ignores errors when socket is not writable', () => { + const error = new Error('some error') + const socket: any = { writable: false, end: sandbox.stub() } + + const clientErrorCall = webServer.on.getCalls().find( + (call: any) => call.args[0] === 'clientError' + ) + const handler = clientErrorCall.args[1] + + handler(error, socket) + + expect(socket.end).not.to.have.been.called + }) + + it('sends 400 response for other client errors', () => { + const error = new Error('bad request') + const socket: any = { writable: true, end: sandbox.stub() } + + const clientErrorCall = webServer.on.getCalls().find( + (call: any) => call.args[0] === 'clientError' + ) + const handler = clientErrorCall.args[1] + + handler(error, socket) + + expect(socket.end).to.have.been.calledOnce + expect(socket.end.firstCall.args[0]).to.include('400 Bad Request') + }) + }) + + describe('onError', () => { + it('handles server errors without throwing', () => { + const errorCall = webServer.on.getCalls().find( + (call: any) => call.args[0] === 'error' + ) + const handler = errorCall.args[1] + + expect(() => handler(new Error('server error'))).not.to.throw() + }) + }) +}) diff --git a/test/unit/adapters/web-socket-adapter.spec.ts b/test/unit/adapters/web-socket-adapter.spec.ts new file mode 100644 index 00000000..d293f3b8 --- /dev/null +++ b/test/unit/adapters/web-socket-adapter.spec.ts @@ -0,0 +1,638 @@ +import EventEmitter from 'events' +import { WebSocket } from 'ws' + +import chai from 'chai' +import chaiAsPromised from 'chai-as-promised' +import Sinon from 'sinon' +import sinonChai from 'sinon-chai' + +chai.use(sinonChai) +chai.use(chaiAsPromised) + +const { expect } = chai + +import { WebSocketAdapterEvent, WebSocketServerAdapterEvent } from '../../../src/constants/adapter' +import { IWebSocketServerAdapter } from '../../../src/@types/adapters' +import { WebSocketAdapter } from '../../../src/adapters/web-socket-adapter' + +describe('WebSocketAdapter', () => { + let sandbox: Sinon.SinonSandbox + let client: any + let request: any + let webSocketServer: any + let createMessageHandler: Sinon.SinonStub + let slidingWindowRateLimiter: Sinon.SinonStub + let settingsFactory: Sinon.SinonStub + let adapter: WebSocketAdapter + + let originalConsoleError: typeof console.error + + beforeEach(() => { + sandbox = Sinon.createSandbox() + originalConsoleError = console.error + console.error = () => undefined + + client = { + on: sandbox.stub().returnsThis(), + send: sandbox.stub(), + close: sandbox.stub(), + ping: sandbox.stub(), + pong: sandbox.stub(), + readyState: WebSocket.OPEN, + removeAllListeners: sandbox.stub(), + } + + request = { + headers: { + 'sec-websocket-key': Buffer.from('test-key-123', 'utf8').toString('base64'), + }, + socket: { + remoteAddress: '127.0.0.1', + }, + } + + webSocketServer = new EventEmitter() as IWebSocketServerAdapter + + createMessageHandler = sandbox.stub() + slidingWindowRateLimiter = sandbox.stub().returns({ + hit: sandbox.stub().resolves(false), + }) + settingsFactory = sandbox.stub().returns({ + network: { + remoteIpHeader: '', + }, + limits: { + message: { + rateLimits: [], + ipWhitelist: [], + }, + }, + }) + + adapter = new WebSocketAdapter( + client, + request, + webSocketServer as any, + createMessageHandler, + slidingWindowRateLimiter, + settingsFactory, + ) + }) + + afterEach(() => { + console.error = originalConsoleError + adapter.removeAllListeners() + webSocketServer.removeAllListeners() + sandbox.restore() + }) + + describe('constructor', () => { + it('extracts clientId from sec-websocket-key header', () => { + const expectedId = Buffer.from( + Buffer.from('test-key-123', 'utf8').toString('base64'), + 'base64', + ).toString('hex') + + expect(adapter.getClientId()).to.equal(expectedId) + }) + + it('resolves client address from request', () => { + expect(adapter.getClientAddress()).to.equal('127.0.0.1') + }) + + it('registers WebSocket event listeners', () => { + expect(client.on).to.have.been.calledWith('error') + expect(client.on).to.have.been.calledWith('message') + expect(client.on).to.have.been.calledWith('close') + expect(client.on).to.have.been.calledWith('pong') + expect(client.on).to.have.been.calledWith('ping') + }) + + it('registers internal event listeners', () => { + expect(adapter.listenerCount(WebSocketAdapterEvent.Heartbeat)).to.be.greaterThan(0) + expect(adapter.listenerCount(WebSocketAdapterEvent.Subscribe)).to.be.greaterThan(0) + expect(adapter.listenerCount(WebSocketAdapterEvent.Unsubscribe)).to.be.greaterThan(0) + expect(adapter.listenerCount(WebSocketAdapterEvent.Event)).to.be.greaterThan(0) + expect(adapter.listenerCount(WebSocketAdapterEvent.Broadcast)).to.be.greaterThan(0) + expect(adapter.listenerCount(WebSocketAdapterEvent.Message)).to.be.greaterThan(0) + }) + }) + + describe('getClientId', () => { + it('returns the client ID', () => { + expect(adapter.getClientId()).to.be.a('string') + expect(adapter.getClientId().length).to.be.greaterThan(0) + }) + }) + + describe('getClientAddress', () => { + it('returns the client IP address', () => { + expect(adapter.getClientAddress()).to.equal('127.0.0.1') + }) + }) + + describe('getSubscriptions', () => { + it('returns an empty map when no subscriptions', () => { + const subs = adapter.getSubscriptions() + + expect(subs).to.be.instanceOf(Map) + expect(subs.size).to.equal(0) + }) + + it('returns a copy of subscriptions map', () => { + adapter.onSubscribed('sub-1', [{ kinds: [1] }]) + + const subs = adapter.getSubscriptions() + + expect(subs.size).to.equal(1) + expect(subs.get('sub-1')).to.deep.equal([{ kinds: [1] }]) + }) + }) + + describe('onSubscribed', () => { + it('adds subscription to the map', () => { + const filters = [{ kinds: [1] }, { authors: ['abc'] }] + + adapter.onSubscribed('sub-1', filters) + + const subs = adapter.getSubscriptions() + expect(subs.get('sub-1')).to.deep.equal(filters) + }) + + it('overwrites existing subscription with same id', () => { + adapter.onSubscribed('sub-1', [{ kinds: [1] }]) + adapter.onSubscribed('sub-1', [{ kinds: [2] }]) + + const subs = adapter.getSubscriptions() + expect(subs.size).to.equal(1) + expect(subs.get('sub-1')).to.deep.equal([{ kinds: [2] }]) + }) + }) + + describe('onUnsubscribed', () => { + it('removes subscription from the map', () => { + adapter.onSubscribed('sub-1', [{ kinds: [1] }]) + adapter.onUnsubscribed('sub-1') + + const subs = adapter.getSubscriptions() + expect(subs.size).to.equal(0) + }) + + it('does not throw when removing non-existent subscription', () => { + expect(() => adapter.onUnsubscribed('non-existent')).not.to.throw() + }) + }) + + describe('onBroadcast', () => { + it('emits broadcast event on the WebSocket server adapter', () => { + const emitSpy = sandbox.spy(webSocketServer, 'emit') + const event = { + id: 'a'.repeat(64), + pubkey: 'b'.repeat(64), + kind: 1, + content: 'test', + created_at: 1000000, + sig: 'c'.repeat(128), + tags: [], + } + + adapter.onBroadcast(event) + + expect(emitSpy).to.have.been.calledWith(WebSocketServerAdapterEvent.Broadcast, event) + }) + }) + + describe('onHeartbeat', () => { + it('pings the client and sets alive to false', () => { + // Adapter starts with alive = true + adapter.emit(WebSocketAdapterEvent.Heartbeat) + + expect(client.ping).to.have.been.calledOnce + }) + + it('closes connection when client is not alive and has no subscriptions', () => { + // First heartbeat: sets alive to false, pings + adapter.emit(WebSocketAdapterEvent.Heartbeat) + // Second heartbeat: alive is false, no subs -> close + adapter.emit(WebSocketAdapterEvent.Heartbeat) + + expect(client.close).to.have.been.calledOnce + }) + + it('does not close when client is not alive but has subscriptions', () => { + adapter.onSubscribed('sub-1', [{ kinds: [1] }]) + + // First heartbeat: sets alive to false, pings + adapter.emit(WebSocketAdapterEvent.Heartbeat) + // Second heartbeat: alive is false, but has subs -> keep alive + adapter.emit(WebSocketAdapterEvent.Heartbeat) + + expect(client.close).not.to.have.been.called + }) + }) + + describe('sendMessage (via Message event)', () => { + it('sends JSON-serialized message when WebSocket is OPEN', () => { + client.readyState = WebSocket.OPEN + const message = ['NOTICE', 'hello'] + + adapter.emit(WebSocketAdapterEvent.Message, message) + + expect(client.send).to.have.been.calledOnceWithExactly(JSON.stringify(message)) + }) + + it('does nothing when WebSocket is not OPEN', () => { + client.readyState = WebSocket.CLOSED + const message = ['NOTICE', 'hello'] + + adapter.emit(WebSocketAdapterEvent.Message, message) + + expect(client.send).not.to.have.been.called + }) + }) + + describe('onSendEvent (via Event event)', () => { + it('sends event matching subscription filters', () => { + client.readyState = WebSocket.OPEN + adapter.onSubscribed('sub-1', [{ kinds: [1] }]) + + const event = { + id: 'a'.repeat(64), + pubkey: 'b'.repeat(64), + kind: 1, + content: 'hello', + created_at: 1000000, + sig: 'c'.repeat(128), + tags: [], + } + + adapter.emit(WebSocketAdapterEvent.Event, event) + + expect(client.send).to.have.been.calledOnce + const sent = JSON.parse(client.send.firstCall.args[0]) + expect(sent[0]).to.equal('EVENT') + expect(sent[1]).to.equal('sub-1') + expect(sent[2]).to.deep.equal(event) + }) + + it('does not send event not matching any filter', () => { + client.readyState = WebSocket.OPEN + adapter.onSubscribed('sub-1', [{ kinds: [999] }]) + + const event = { + id: 'a'.repeat(64), + pubkey: 'b'.repeat(64), + kind: 1, + content: 'hello', + created_at: 1000000, + sig: 'c'.repeat(128), + tags: [], + } + + adapter.emit(WebSocketAdapterEvent.Event, event) + + expect(client.send).not.to.have.been.called + }) + }) + + describe('onClientClose', () => { + it('clears all subscriptions when client disconnects', () => { + adapter.onSubscribed('sub-1', [{ kinds: [1] }]) + adapter.onSubscribed('sub-2', [{ kinds: [2] }]) + + // Trigger the close handler + const closeCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'close' + ) + const onClose = closeCall.args[1] + onClose() + + expect(adapter.getSubscriptions().size).to.equal(0) + }) + + it('removes all listeners from client', () => { + const closeCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'close' + ) + const onClose = closeCall.args[1] + onClose() + + expect(client.removeAllListeners).to.have.been.calledOnce + }) + }) + + describe('onClientPong', () => { + it('marks client as alive', () => { + // First heartbeat sets alive = false + adapter.emit(WebSocketAdapterEvent.Heartbeat) + + // Trigger pong handler - should set alive = true + const pongCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'pong' + ) + const onPong = pongCall.args[1] + onPong() + + // Next heartbeat should not close (alive was reset to true) + adapter.emit(WebSocketAdapterEvent.Heartbeat) + + expect(client.close).not.to.have.been.called + }) + }) + + describe('onClientPing', () => { + it('responds with pong', () => { + const pingCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'ping' + ) + const onPing = pingCall.args[1] + const data = Buffer.from('ping-data') + + onPing(data) + + expect(client.pong).to.have.been.calledOnceWithExactly(data) + }) + }) + + describe('error handling', () => { + it('closes client on RangeError with max payload exceeded', () => { + const errorCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'error' + ) + const onError = errorCall.args[1] + + const error = new RangeError('Max payload size exceeded') + + onError(error) + + expect(client.close).to.have.been.calledOnce + }) + + it('closes client on RSV1 compression error', () => { + const errorCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'error' + ) + const onError = errorCall.args[1] + + const error = new RangeError('Invalid WebSocket frame: RSV1 must be clear') + + onError(error) + + expect(client.close).to.have.been.calledOnce + }) + + it('closes client on generic errors', () => { + const errorCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'error' + ) + const onError = errorCall.args[1] + + const error = new Error('something went wrong') + + onError(error) + + expect(client.close).to.have.been.calledOnce + }) + }) + + describe('onClientMessage', () => { + it('handles invalid JSON gracefully', async () => { + client.readyState = WebSocket.OPEN + + const messageCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'message' + ) + const onMessage = messageCall.args[1] + + await onMessage(Buffer.from('not-json')) + + // Should send a NOTICE about invalid message + expect(client.send).to.have.been.calledOnce + const sent = JSON.parse(client.send.firstCall.args[0]) + expect(sent[0]).to.equal('NOTICE') + }) + + it('sends rate-limited notice when rate limited', async () => { + client.readyState = WebSocket.OPEN + + // Configure rate limiting to be active + settingsFactory.returns({ + network: { remoteIpHeader: '' }, + limits: { + message: { + rateLimits: [{ period: 60000, rate: 1 }], + ipWhitelist: [], + }, + }, + }) + + slidingWindowRateLimiter.returns({ + hit: sandbox.stub().resolves(true), + }) + + const messageCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'message' + ) + const onMessage = messageCall.args[1] + + // Valid JSON message that would pass parsing + await onMessage(Buffer.from(JSON.stringify(['EVENT', {}]))) + + expect(client.send).to.have.been.called + }) + + it('does not rate limit when no rateLimits are configured', async () => { + client.readyState = WebSocket.OPEN + + settingsFactory.returns({ + network: { remoteIpHeader: '' }, + limits: { + message: {}, + }, + }) + + const messageCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'message' + ) + const onMessage = messageCall.args[1] + + // Invalid JSON will cause a parsing NOTICE, not a rate-limit NOTICE + await onMessage(Buffer.from('invalid')) + + expect(client.send).to.have.been.calledOnce + const sent = JSON.parse(client.send.firstCall.args[0]) + expect(sent[0]).to.equal('NOTICE') + expect(sent[1]).not.to.include('rate limited') + }) + + it('does not rate limit when client IP is whitelisted', async () => { + client.readyState = WebSocket.OPEN + + settingsFactory.returns({ + network: { remoteIpHeader: '' }, + limits: { + message: { + rateLimits: [{ period: 60000, rate: 0 }], + ipWhitelist: ['127.0.0.1'], + }, + }, + }) + + const messageCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'message' + ) + const onMessage = messageCall.args[1] + + await onMessage(Buffer.from('invalid')) + + // Should get a parsing error NOTICE, not a rate-limit NOTICE + expect(client.send).to.have.been.calledOnce + const sent = JSON.parse(client.send.firstCall.args[0]) + expect(sent[1]).not.to.include('rate limited') + }) + + it('sets alive to true when message is received', async () => { + // First heartbeat sets alive = false + adapter.emit(WebSocketAdapterEvent.Heartbeat) + + const messageCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'message' + ) + const onMessage = messageCall.args[1] + + // Receiving any message sets alive = true + await onMessage(Buffer.from('invalid')) + + // Next heartbeat should NOT close (alive was reset by message) + adapter.emit(WebSocketAdapterEvent.Heartbeat) + + expect(client.close).not.to.have.been.called + }) + + it('handles AbortError without sending notice', async () => { + client.readyState = WebSocket.OPEN + + settingsFactory.returns({ + network: { remoteIpHeader: '' }, + limits: { message: {} }, + }) + + const abortError = new Error('aborted') + abortError.name = 'AbortError' + + createMessageHandler.returns({ + handleMessage: sandbox.stub().rejects(abortError), + }) + + const messageCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'message' + ) + const onMessage = messageCall.args[1] + + await onMessage(Buffer.from(JSON.stringify(['REQ', 'sub-1', {}]))) + + // AbortError should NOT send a NOTICE to client + expect(client.send).not.to.have.been.called + }) + + it('returns early when no handler is found for message', async () => { + client.readyState = WebSocket.OPEN + + settingsFactory.returns({ + network: { remoteIpHeader: '' }, + limits: { message: {} }, + }) + + createMessageHandler.returns(null) + + const messageCall = client.on.getCalls().find( + (call: any) => call.args[0] === 'message' + ) + const onMessage = messageCall.args[1] + + // Should not throw and should not send any message + await onMessage(Buffer.from(JSON.stringify(['REQ', 'sub-1', {}]))) + + expect(client.send).not.to.have.been.called + }) + }) + + describe('onSendEvent edge cases', () => { + it('sends event to multiple matching subscriptions', () => { + client.readyState = WebSocket.OPEN + adapter.onSubscribed('sub-1', [{ kinds: [1] }]) + adapter.onSubscribed('sub-2', [{ kinds: [1] }]) + + const event = { + id: 'a'.repeat(64), + pubkey: 'b'.repeat(64), + kind: 1, + content: 'hello', + created_at: 1000000, + sig: 'c'.repeat(128), + tags: [], + } + + adapter.emit(WebSocketAdapterEvent.Event, event) + + expect(client.send).to.have.been.calledTwice + }) + + it('does not send when socket is not OPEN', () => { + client.readyState = WebSocket.CLOSED + adapter.onSubscribed('sub-1', [{ kinds: [1] }]) + + const event = { + id: 'a'.repeat(64), + pubkey: 'b'.repeat(64), + kind: 1, + content: 'hello', + created_at: 1000000, + sig: 'c'.repeat(128), + tags: [], + } + + adapter.emit(WebSocketAdapterEvent.Event, event) + + expect(client.send).not.to.have.been.called + }) + }) + + describe('getSubscriptions edge cases', () => { + it('returns a copy that does not affect internal state', () => { + adapter.onSubscribed('sub-1', [{ kinds: [1] }]) + + const subs = adapter.getSubscriptions() + subs.delete('sub-1') + + // Internal state should not be affected + expect(adapter.getSubscriptions().size).to.equal(1) + }) + }) + + describe('IPv6 support', () => { + it('handles IPv6 client address', () => { + const ipv6Request = { + headers: { + 'sec-websocket-key': Buffer.from('ipv6-key', 'utf8').toString('base64'), + }, + socket: { + remoteAddress: '::1', + }, + } + + const ipv6Adapter = new WebSocketAdapter( + client, + ipv6Request as any, + webSocketServer as any, + createMessageHandler, + slidingWindowRateLimiter, + settingsFactory, + ) + + expect(ipv6Adapter.getClientAddress()).to.equal('::1') + ipv6Adapter.removeAllListeners() + }) + }) +}) + diff --git a/test/unit/adapters/web-socket-server-adapter.spec.ts b/test/unit/adapters/web-socket-server-adapter.spec.ts new file mode 100644 index 00000000..393c501d --- /dev/null +++ b/test/unit/adapters/web-socket-server-adapter.spec.ts @@ -0,0 +1,292 @@ +import * as rateLimiterMiddleware from '../../../src/handlers/request-handlers/rate-limiter-middleware' + +import chai from 'chai' +import chaiAsPromised from 'chai-as-promised' +import Sinon from 'sinon' +import sinonChai from 'sinon-chai' + +chai.use(sinonChai) +chai.use(chaiAsPromised) + +const { expect } = chai + +import { WebSocketAdapterEvent, WebSocketServerAdapterEvent } from '../../../src/constants/adapter' +import { WebSocketServerAdapter } from '../../../src/adapters/web-socket-server-adapter' + +describe('WebSocketServerAdapter', () => { + let sandbox: Sinon.SinonSandbox + let webServer: any + let webSocketServer: any + let createWebSocketAdapter: Sinon.SinonStub + let settings: any + let adapter: WebSocketServerAdapter + let isRateLimitedStub: Sinon.SinonStub + + let originalConsoleError: typeof console.error + + beforeEach(() => { + sandbox = Sinon.createSandbox() + sandbox.useFakeTimers() + originalConsoleError = console.error + console.error = () => undefined + + isRateLimitedStub = sandbox.stub(rateLimiterMiddleware, 'isRateLimited').resolves(false) + + webServer = { + on: sandbox.stub().returnsThis(), + once: sandbox.stub().returnsThis(), + close: sandbox.stub(), + removeAllListeners: sandbox.stub(), + listen: sandbox.stub(), + } + + webSocketServer = { + on: sandbox.stub().returnsThis(), + clients: new Set(), + close: sandbox.stub(), + removeAllListeners: sandbox.stub(), + } + + createWebSocketAdapter = sandbox.stub() + + settings = () => ({ + network: { + remoteIpHeader: '', + }, + limits: { + connection: { + rateLimits: [], + }, + }, + }) as any + + adapter = new WebSocketServerAdapter( + webServer, + webSocketServer, + createWebSocketAdapter, + settings, + ) + }) + + afterEach(() => { + console.error = originalConsoleError + webServer.close.callsFake((cb: () => void) => cb()) + webSocketServer.clients = new Set() + webSocketServer.close.callsFake((cb: () => void) => cb()) + adapter.close() + sandbox.restore() + }) + + describe('constructor', () => { + it('registers broadcast event listener on itself', () => { + expect(adapter.listenerCount(WebSocketServerAdapterEvent.Broadcast)).to.be.greaterThan(0) + }) + + it('registers connection event listener on webSocketServer', () => { + expect(webSocketServer.on).to.have.been.calledWith(WebSocketServerAdapterEvent.Connection) + }) + + it('registers error event listener on webSocketServer', () => { + expect(webSocketServer.on).to.have.been.calledWith('error') + }) + }) + + describe('getConnectedClients', () => { + it('returns 0 when no clients are connected', () => { + webSocketServer.clients = new Set() + + expect(adapter.getConnectedClients()).to.equal(0) + }) + + it('counts only clients with OPEN readyState', () => { + const OPEN = 1 + const CLOSING = 2 + + webSocketServer.clients = new Set([ + { readyState: OPEN }, + { readyState: OPEN }, + { readyState: CLOSING }, + ] as any) + + expect(adapter.getConnectedClients()).to.equal(2) + }) + }) + + describe('close', () => { + it('calls parent close which closes webServer', () => { + adapter.close() + + expect(webServer.close).to.have.been.calledOnce + }) + + it('terminates all connected WebSocket clients', () => { + const terminateStub1 = sandbox.stub() + const terminateStub2 = sandbox.stub() + + webSocketServer.clients = new Set([ + { terminate: terminateStub1 }, + { terminate: terminateStub2 }, + ] as any) + + webServer.close.callsFake((cb: () => void) => cb()) + webSocketServer.close.callsFake((cb: () => void) => cb()) + + adapter.close() + + expect(terminateStub1).to.have.been.calledOnce + expect(terminateStub2).to.have.been.calledOnce + }) + + it('closes the webSocketServer after terminating clients', () => { + webSocketServer.clients = new Set() + webServer.close.callsFake((cb: () => void) => cb()) + webSocketServer.close.callsFake((cb: () => void) => cb()) + + adapter.close() + + expect(webSocketServer.close).to.have.been.calledOnce + }) + + it('invokes callback after full close', () => { + const callback = sandbox.stub() + webSocketServer.clients = new Set() + webServer.close.callsFake((cb: () => void) => cb()) + webSocketServer.close.callsFake((cb: () => void) => cb()) + + adapter.close(callback) + + expect(callback).to.have.been.calledOnce + }) + + it('removes all listeners from webSocketServer after close', () => { + webSocketServer.clients = new Set() + webServer.close.callsFake((cb: () => void) => cb()) + webSocketServer.close.callsFake((cb: () => void) => cb()) + + adapter.close() + + expect(webSocketServer.removeAllListeners).to.have.been.calledOnce + }) + }) + + describe('onBroadcast', () => { + it('emits event to adapters of all OPEN clients', async () => { + const OPEN = 1 + const emitStub = sandbox.stub() + + const mockClient = { readyState: OPEN } + webSocketServer.clients = new Set([mockClient] as any) + + const mockAdapter = { + emit: emitStub, + getClientId: () => 'test-id', + getClientAddress: () => '127.0.0.1', + } + createWebSocketAdapter.returns(mockAdapter) + + // Populate the WeakMap by invoking onConnection + const connectionCall = webSocketServer.on.getCalls().find( + (call: any) => call.args[0] === WebSocketServerAdapterEvent.Connection + ) + const onConnection = connectionCall.args[1] + const mockReq = { + headers: {}, + socket: { remoteAddress: '127.0.0.1' }, + } + await onConnection(mockClient, mockReq) + + const event = { id: 'test', pubkey: 'test', kind: 1, content: '', created_at: 0, sig: 'test', tags: [] } + adapter.emit(WebSocketServerAdapterEvent.Broadcast, event) + + expect(emitStub).to.have.been.calledWith(WebSocketAdapterEvent.Event, event) + }) + + it('skips clients that are not in OPEN state', () => { + const CLOSING = 2 + + const mockClient = { readyState: CLOSING } + webSocketServer.clients = new Set([mockClient] as any) + + const event = { id: 'test', pubkey: 'test', kind: 1, content: '', created_at: 0, sig: 'test', tags: [] } + + // Should not throw when skipping non-OPEN clients + expect(() => adapter.emit(WebSocketServerAdapterEvent.Broadcast, event)).not.to.throw() + }) + }) + + describe('onHeartbeat', () => { + it('emits heartbeat to connected adapters on the heartbeat interval', async () => { + const emitStub = sandbox.stub() + const OPEN = 1 + + const mockClient = { readyState: OPEN } + webSocketServer.clients = new Set([mockClient] as any) + + const mockWsAdapter = { + emit: emitStub, + getClientId: () => 'test-id', + getClientAddress: () => '127.0.0.1', + } + createWebSocketAdapter.returns(mockWsAdapter) + + // Populate the WeakMap via onConnection + const connectionCall = webSocketServer.on.getCalls().find( + (call: any) => call.args[0] === WebSocketServerAdapterEvent.Connection + ) + const onConnection = connectionCall.args[1] + const mockReq = { + headers: {}, + socket: { remoteAddress: '127.0.0.1' }, + } + await onConnection(mockClient, mockReq) + + // Advance past the heartbeat interval (WSS_CLIENT_HEALTH_PROBE_INTERVAL = 120000ms) + sandbox.clock.tick(120000) + + expect(emitStub).to.have.been.calledWith(WebSocketAdapterEvent.Heartbeat) + }) + }) + + describe('onConnection', () => { + it('creates a WebSocketAdapter for new connection', async () => { + const mockWsAdapter = { getClientId: () => 'test-id', getClientAddress: () => '127.0.0.1' } + createWebSocketAdapter.returns(mockWsAdapter) + + const connectionCall = webSocketServer.on.getCalls().find( + (call: any) => call.args[0] === WebSocketServerAdapterEvent.Connection + ) + const onConnection = connectionCall.args[1] + + const mockClient = {} + const mockReq = { + headers: {}, + socket: { remoteAddress: '127.0.0.1' }, + } + + await onConnection(mockClient, mockReq) + + expect(createWebSocketAdapter).to.have.been.calledOnce + }) + + it('terminates rate-limited connections', async () => { + const terminateStub = sandbox.stub() + isRateLimitedStub.resolves(true) + + const connectionCall = webSocketServer.on.getCalls().find( + (call: any) => call.args[0] === WebSocketServerAdapterEvent.Connection + ) + const onConnection = connectionCall.args[1] + + const mockClient = { terminate: terminateStub } + const mockReq = { + headers: {}, + socket: { remoteAddress: '127.0.0.1' }, + } + + await onConnection(mockClient, mockReq) + + expect(terminateStub).to.have.been.calledOnce + expect(createWebSocketAdapter).not.to.have.been.called + }) + }) +})