diff --git a/backend/src/controllers/sse.controller.ts b/backend/src/controllers/sse.controller.ts index 7678e65..c5db102 100644 --- a/backend/src/controllers/sse.controller.ts +++ b/backend/src/controllers/sse.controller.ts @@ -47,27 +47,33 @@ export const subscribe = async (req: Request, res: Response) => { // Scope: only streams where the authenticated user is sender or recipient const ownedStreams = await prisma.stream.findMany({ where: { OR: [{ sender: publicKey }, { recipient: publicKey }] }, - select: { streamId: true }, + select: { streamId: true, sender: true, recipient: true }, }); const ownedIds = new Set(ownedStreams.map((s: { streamId: number }) => String(s.streamId))); + const allowedUserKeys = new Set([publicKey]); + for (const stream of ownedStreams) { + allowedUserKeys.add(stream.sender); + allowedUserKeys.add(stream.recipient); + } let subscriptions: string[]; if (all) { // "all" still scoped to the user's own streams - subscriptions = [...ownedIds] as string[]; + subscriptions = [...ownedIds]; } else if (streams.length > 0) { // Only allow subscribing to streams the user owns subscriptions = streams.filter((id) => ownedIds.has(id)); } else { - subscriptions = [...ownedIds] as string[]; + subscriptions = [...ownedIds]; } - subscriptions.push(...users.map((userKey) => `user:${userKey}`)); - - // Always add user-scoped subscription key - subscriptions.push(`user:${publicKey}`); + const userSubscriptions = new Set([`user:${publicKey}`]); + for (const key of users.filter((k) => allowedUserKeys.has(k))) { + userSubscriptions.add(`user:${key}`); + } + subscriptions.push(...userSubscriptions); - const clientId = `${Date.now()}-${Math.random().toString(36).substr(2, 9)}`; + const clientId = `${Date.now()}-${Math.random().toString(36).slice(2, 11)}`; res.writeHead(200, { 'Content-Type': 'text/event-stream', diff --git a/backend/tests/sse.controller.test.ts b/backend/tests/sse.controller.test.ts index 4533785..3d31bab 100644 --- a/backend/tests/sse.controller.test.ts +++ b/backend/tests/sse.controller.test.ts @@ -75,17 +75,18 @@ describe('SSE Controller', () => { (sseService.isShuttingDown as any).mockReturnValue(false); (sseService.checkCapacity as any).mockReturnValue({ allowed: true }); (req as any).user = { publicKey: 'GUSER1' }; - (prisma.stream.findMany as any).mockResolvedValue([{ streamId: 'stream-1' }]); + (prisma.stream.findMany as any).mockResolvedValue([ + { streamId: 'stream-1', sender: 'GUSER1', recipient: 'GUSER2' }, + ]); req.query = { users: ['GUSER2', 'GUSER3'] }; await subscribe(req as Request, res as Response); - expect(sseService.addClient).toHaveBeenCalledWith( - expect.any(String), - expect.any(Object), - expect.arrayContaining(['stream-1', 'user:GUSER2', 'user:GUSER3', 'user:GUSER1']), - expect.any(String), - ); + const subscriptions = (sseService.addClient as any).mock.calls[0][2] as string[]; + expect(subscriptions).toContain('stream-1'); + expect(subscriptions).toContain('user:GUSER1'); + expect(subscriptions).toContain('user:GUSER2'); + expect(subscriptions).not.toContain('user:GUSER3'); }); it('should handle zod validation error for query params', async () => { @@ -97,5 +98,28 @@ describe('SSE Controller', () => { await subscribe(req as Request, res as Response); expect(res.status).toHaveBeenCalledWith(400); + expect(res.json).toHaveBeenCalledWith( + expect.objectContaining({ + message: 'Invalid subscription parameters', + errors: expect.arrayContaining([expect.objectContaining({ code: expect.any(String) })]), + }), + ); + }); + + it('should include allowed users query subscriptions', async () => { + (sseService.isShuttingDown as any).mockReturnValue(false); + (sseService.checkCapacity as any).mockReturnValue({ allowed: true }); + (req as any).user = { publicKey: 'GUSER1' }; + req.query = { users: ['GCOUNTER', 'GOTHER'] }; + (prisma.stream.findMany as any).mockResolvedValue([ + { streamId: 1, sender: 'GUSER1', recipient: 'GCOUNTER' }, + ]); + + await subscribe(req as Request, res as Response); + + const subscriptions = (sseService.addClient as any).mock.calls[0][2] as string[]; + expect(subscriptions).toContain('user:GUSER1'); + expect(subscriptions).toContain('user:GCOUNTER'); + expect(subscriptions).not.toContain('user:GOTHER'); }); });