Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions backend/src/controllers/sse.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,33 @@ export const subscribe = async (req: Request, res: Response) => {
// Scope: only streams where the authenticated user is sender or recipient
const ownedStreams = await prisma.stream.findMany({
where: { OR: [{ sender: publicKey }, { recipient: publicKey }] },
select: { streamId: true },
select: { streamId: true, sender: true, recipient: true },
});
const ownedIds = new Set(ownedStreams.map((s: { streamId: number }) => String(s.streamId)));
const allowedUserKeys = new Set<string>([publicKey]);
for (const stream of ownedStreams) {
allowedUserKeys.add(stream.sender);
allowedUserKeys.add(stream.recipient);
}

let subscriptions: string[];
if (all) {
// "all" still scoped to the user's own streams
subscriptions = [...ownedIds] as string[];
subscriptions = [...ownedIds];
} else if (streams.length > 0) {
// Only allow subscribing to streams the user owns
subscriptions = streams.filter((id) => ownedIds.has(id));
} else {
subscriptions = [...ownedIds] as string[];
subscriptions = [...ownedIds];
}

subscriptions.push(...users.map((userKey) => `user:${userKey}`));

// Always add user-scoped subscription key
subscriptions.push(`user:${publicKey}`);
const userSubscriptions = new Set<string>([`user:${publicKey}`]);
for (const key of users.filter((k) => allowedUserKeys.has(k))) {
userSubscriptions.add(`user:${key}`);
}
subscriptions.push(...userSubscriptions);

const clientId = `${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
const clientId = `${Date.now()}-${Math.random().toString(36).slice(2, 11)}`;

res.writeHead(200, {
'Content-Type': 'text/event-stream',
Expand Down
38 changes: 31 additions & 7 deletions backend/tests/sse.controller.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,18 @@ describe('SSE Controller', () => {
(sseService.isShuttingDown as any).mockReturnValue(false);
(sseService.checkCapacity as any).mockReturnValue({ allowed: true });
(req as any).user = { publicKey: 'GUSER1' };
(prisma.stream.findMany as any).mockResolvedValue([{ streamId: 'stream-1' }]);
(prisma.stream.findMany as any).mockResolvedValue([
{ streamId: 'stream-1', sender: 'GUSER1', recipient: 'GUSER2' },
]);
req.query = { users: ['GUSER2', 'GUSER3'] };

await subscribe(req as Request, res as Response);

expect(sseService.addClient).toHaveBeenCalledWith(
expect.any(String),
expect.any(Object),
expect.arrayContaining(['stream-1', 'user:GUSER2', 'user:GUSER3', 'user:GUSER1']),
expect.any(String),
);
const subscriptions = (sseService.addClient as any).mock.calls[0][2] as string[];
expect(subscriptions).toContain('stream-1');
expect(subscriptions).toContain('user:GUSER1');
expect(subscriptions).toContain('user:GUSER2');
expect(subscriptions).not.toContain('user:GUSER3');
});

it('should handle zod validation error for query params', async () => {
Expand All @@ -97,5 +98,28 @@ describe('SSE Controller', () => {
await subscribe(req as Request, res as Response);

expect(res.status).toHaveBeenCalledWith(400);
expect(res.json).toHaveBeenCalledWith(
expect.objectContaining({
message: 'Invalid subscription parameters',
errors: expect.arrayContaining([expect.objectContaining({ code: expect.any(String) })]),
}),
);
});

it('should include allowed users query subscriptions', async () => {
(sseService.isShuttingDown as any).mockReturnValue(false);
(sseService.checkCapacity as any).mockReturnValue({ allowed: true });
(req as any).user = { publicKey: 'GUSER1' };
req.query = { users: ['GCOUNTER', 'GOTHER'] };
(prisma.stream.findMany as any).mockResolvedValue([
{ streamId: 1, sender: 'GUSER1', recipient: 'GCOUNTER' },
]);

await subscribe(req as Request, res as Response);

const subscriptions = (sseService.addClient as any).mock.calls[0][2] as string[];
expect(subscriptions).toContain('user:GUSER1');
expect(subscriptions).toContain('user:GCOUNTER');
expect(subscriptions).not.toContain('user:GOTHER');
});
});
Loading