diff --git a/.changeset/full-donuts-allow.md b/.changeset/full-donuts-allow.md new file mode 100644 index 00000000..a02552de --- /dev/null +++ b/.changeset/full-donuts-allow.md @@ -0,0 +1,5 @@ +--- +"nostream": minor +--- + +added NIP-45 COUNT support with end-to-end handling (validation, handler routing, DB counting, and tests). diff --git a/CONFIGURATION.md b/CONFIGURATION.md index 7e898cd6..9698e4b1 100644 --- a/CONFIGURATION.md +++ b/CONFIGURATION.md @@ -163,6 +163,7 @@ The settings below are listed in alphabetical order by name. Please keep this ta | nip05.mode | NIP-05 verification mode: `enabled` requires verification, `passive` verifies without blocking, `disabled` does nothing. Defaults to `disabled`. | | nip05.verifyExpiration | Time in milliseconds before a successful NIP-05 verification expires and needs re-checking. Defaults to 604800000 (1 week). | | nip05.verifyUpdateFrequency | Minimum interval in milliseconds between re-verification attempts for a given author. Defaults to 86400000 (24 hours). | +| nip45.enabled | Enable or disable NIP-45 COUNT handling. Defaults to true. | | paymentProcessors.lnbits.baseURL | Base URL of your Lnbits instance. | | paymentProcessors.lnbits.callbackBaseURL | Public-facing Nostream's Lnbits Callback URL. (e.g. https://relay.your-domain.com/callbacks/lnbits) | | paymentProcessors.lnurl.invoiceURL | [LUD-06 Pay Request](https://github.com/lnurl/luds/blob/luds/06.md) provider URL. (e.g. https://getalby.com/lnurlp/your-username) | diff --git a/README.md b/README.md index 61d3c873..6344b31a 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ NIPs with a relay-specific implementation are listed here. - [x] NIP-33: Parameterized Replaceable Events - [x] NIP-40: Expiration Timestamp - [x] NIP-44: Encrypted Payloads (Versioned) +- [x] NIP-45: Event Counts - [x] NIP-62: Request to Vanish ## Requirements diff --git a/package.json b/package.json index cb228d21..b25f5241 100644 --- a/package.json +++ b/package.json @@ -18,7 +18,8 @@ 28, 33, 40, - 44 + 44, + 45 ], "supportedNipExtensions": [ "11a" diff --git a/resources/default-settings.yaml b/resources/default-settings.yaml index d42326c5..5a1ed5d9 100755 --- a/resources/default-settings.yaml +++ b/resources/default-settings.yaml @@ -53,6 +53,8 @@ nip05: domainWhitelist: [] # Block authors with NIP-05 at these domains domainBlacklist: [] +nip45: + enabled: true network: maxPayloadSize: 524288 # Uncomment only when using a trusted reverse proxy and configuring trustedProxies. diff --git a/src/@types/messages.ts b/src/@types/messages.ts index c87d04ce..f95538f8 100644 --- a/src/@types/messages.ts +++ b/src/@types/messages.ts @@ -10,13 +10,15 @@ export enum MessageType { NOTICE = 'NOTICE', EOSE = 'EOSE', OK = 'OK', + COUNT = 'COUNT', + CLOSED = 'CLOSED', } -export type IncomingMessage = (SubscribeMessage | IncomingEventMessage | UnsubscribeMessage) & { +export type IncomingMessage = (SubscribeMessage | IncomingEventMessage | UnsubscribeMessage | CountMessage) & { [ContextMetadataKey]?: ContextMetadata } -export type OutgoingMessage = OutgoingEventMessage | EndOfStoredEventsNotice | NoticeMessage | CommandResult +export type OutgoingMessage = OutgoingEventMessage | EndOfStoredEventsNotice | NoticeMessage | CommandResult | CountResultMessage | ClosedMessage export type SubscribeMessage = { [index in Range<2, 100>]: SubscriptionFilter @@ -25,6 +27,13 @@ export type SubscribeMessage = { 1: SubscriptionId } & Array +export type CountMessage = { + [index in Range<2, 100>]: SubscriptionFilter +} & { + 0: MessageType.COUNT + 1: SubscriptionId +} & Array + export type IncomingEventMessage = EventMessage & [MessageType.EVENT, Event] export type IncomingRelayedEventMessage = [MessageType.EVENT, RelayedEvent, Secret] @@ -62,3 +71,21 @@ export interface EndOfStoredEventsNotice { 0: MessageType.EOSE 1: SubscriptionId } + +export interface CountResultPayload { + count: number + approximate?: boolean + hll?: string +} + +export interface CountResultMessage { + 0: MessageType.COUNT + 1: SubscriptionId + 2: CountResultPayload +} + +export interface ClosedMessage { + 0: MessageType.CLOSED + 1: SubscriptionId + 2: string +} diff --git a/src/@types/repositories.ts b/src/@types/repositories.ts index 23dff30d..743812a7 100644 --- a/src/@types/repositories.ts +++ b/src/@types/repositories.ts @@ -33,6 +33,7 @@ export interface IEventRepository { upsert(event: Event): Promise upsertMany(events: Event[]): Promise findByFilters(filters: SubscriptionFilter[]): IQueryResult + countByFilters(filters: SubscriptionFilter[]): Promise deleteByPubkeyAndIds(pubkey: Pubkey, ids: EventId[]): Promise deleteByPubkeyExceptKinds(pubkey: Pubkey, excludedKinds: number[]): Promise hasActiveRequestToVanish(pubkey: Pubkey): Promise diff --git a/src/@types/settings.ts b/src/@types/settings.ts index 67f28536..4e8a2075 100644 --- a/src/@types/settings.ts +++ b/src/@types/settings.ts @@ -236,6 +236,10 @@ export interface Mirroring { export type Nip05Mode = 'enabled' | 'passive' | 'disabled' +export interface Nip45Settings { + enabled?: boolean +} + export interface Nip05Settings { mode: Nip05Mode /** @@ -266,4 +270,5 @@ export interface Settings { limits?: Limits mirroring?: Mirroring nip05?: Nip05Settings + nip45?: Nip45Settings } diff --git a/src/factories/message-handler-factory.ts b/src/factories/message-handler-factory.ts index 9caf4001..273b5b37 100644 --- a/src/factories/message-handler-factory.ts +++ b/src/factories/message-handler-factory.ts @@ -2,6 +2,7 @@ import { ICacheAdapter, IWebSocketAdapter } from '../@types/adapters' import { IEventRepository, INip05VerificationRepository, IUserRepository } from '../@types/repositories' import { IncomingMessage, MessageType } from '../@types/messages' import { createSettings } from './settings-factory' +import { CountMessageHandler } from '../handlers/count-message-handler' import { EventMessageHandler } from '../handlers/event-message-handler' import { eventStrategyFactory } from './event-strategy-factory' import { getCacheClient } from '../cache/client' @@ -42,6 +43,8 @@ export const messageHandlerFactory = return new SubscribeMessageHandler(adapter, eventRepository, createSettings) case MessageType.CLOSE: return new UnsubscribeMessageHandler(adapter) + case MessageType.COUNT: + return new CountMessageHandler(adapter, eventRepository, createSettings) default: throw new Error(`Unknown message type: ${String(message[0]).substring(0, 64)}`) } diff --git a/src/handlers/count-message-handler.ts b/src/handlers/count-message-handler.ts new file mode 100644 index 00000000..74aa8ee8 --- /dev/null +++ b/src/handlers/count-message-handler.ts @@ -0,0 +1,67 @@ +import { equals, uniqWith } from 'ramda' + +import { IWebSocketAdapter } from '../@types/adapters' +import { IMessageHandler } from '../@types/message-handlers' +import { CountMessage } from '../@types/messages' +import { IEventRepository } from '../@types/repositories' +import { Settings } from '../@types/settings' +import { SubscriptionFilter, SubscriptionId } from '../@types/subscription' +import { WebSocketAdapterEvent } from '../constants/adapter' +import { createLogger } from '../factories/logger-factory' +import { createClosedMessage, createCountResultMessage } from '../utils/messages' + +const debug = createLogger('count-message-handler') + +export class CountMessageHandler implements IMessageHandler { + public constructor( + private readonly webSocket: IWebSocketAdapter, + private readonly eventRepository: IEventRepository, + private readonly settings: () => Settings, + ) {} + + public async handleMessage(message: CountMessage): Promise { + const queryId = message[1] + const countEnabled = this.settings().nip45?.enabled ?? true + if (!countEnabled) { + this.webSocket.emit(WebSocketAdapterEvent.Message, createClosedMessage(queryId, 'COUNT is disabled by relay configuration')) + return + } + + // Some clients send the same filter more than once. + // We remove duplicates so we do less DB work. + const filters = uniqWith(equals, message.slice(2)) as SubscriptionFilter[] + + const reason = this.canCount(queryId, filters) + if (reason) { + debug('count request %s with %o rejected: %s', queryId, filters, reason) + // NIP-45 says we should close rejected COUNT requests with a reason. + this.webSocket.emit(WebSocketAdapterEvent.Message, createClosedMessage(queryId, reason)) + return + } + + try { + const count = await this.eventRepository.countByFilters(filters) + this.webSocket.emit(WebSocketAdapterEvent.Message, createCountResultMessage(queryId, { count })) + } catch (error) { + debug('count request %s failed: %o', queryId, error) + // Keep this message generic so internal errors are not leaked to clients. + this.webSocket.emit(WebSocketAdapterEvent.Message, createClosedMessage(queryId, 'error: unable to count events')) + } + } + + private canCount(queryId: SubscriptionId, filters: SubscriptionFilter[]): string | undefined { + const subscriptionLimits = this.settings().limits?.client?.subscription + const maxFilters = subscriptionLimits?.maxFilters ?? 0 + + if (maxFilters > 0 && filters.length > maxFilters) { + return `Too many filters: Number of filters per count query must be less than or equal to ${maxFilters}` + } + + if ( + typeof subscriptionLimits?.maxSubscriptionIdLength === 'number' && + queryId.length > subscriptionLimits.maxSubscriptionIdLength + ) { + return `Query ID too long: Query ID must be less than or equal to ${subscriptionLimits.maxSubscriptionIdLength}` + } + } +} diff --git a/src/repositories/event-repository.ts b/src/repositories/event-repository.ts index dba91177..648e5184 100644 --- a/src/repositories/event-repository.ts +++ b/src/repositories/event-repository.ts @@ -71,92 +71,55 @@ export class EventRepository implements IEventRepository { const queries = filters.map((currentFilter) => { const builder = this.readReplicaDbClient('events') - forEachObjIndexed((tableFields: string[], filterName: string | number) => { - builder.andWhere((bd) => { - cond([ - [isEmpty, () => void bd.whereRaw('1 = 0')], - [ - complement(isNil), - pipe( - groupByLengthSpec, - evolve({ - exact: (pubkeys: string[]) => - tableFields.forEach((tableField) => bd.orWhereIn(tableField, pubkeys.map(toBuffer))), - even: forEach((prefix: string) => - tableFields.forEach((tableField) => - bd.orWhereRaw(`substring("${tableField}" from 1 for ?) = ?`, [ - prefix.length >> 1, - toBuffer(prefix), - ]), - ), - ), - odd: forEach((prefix: string) => - tableFields.forEach((tableField) => - bd.orWhereRaw(`substring("${tableField}" from 1 for ?) BETWEEN ? AND ?`, [ - (prefix.length >> 1) + 1, - `\\x${prefix}0`, - `\\x${prefix}f`, - ]), - ), - ), - } as any), - ), - ], - ])(currentFilter[filterName] as string[]) - }) - })({ - authors: ['event_pubkey'], - ids: ['event_id'], - }) + const isTagQuery = this.applyFilterConditions(builder, currentFilter) - if (Array.isArray(currentFilter.kinds)) { - builder.whereIn('event_kind', currentFilter.kinds) + if (typeof currentFilter.limit === 'number') { + builder.limit(currentFilter.limit).orderBy('event_created_at', 'DESC').orderBy('event_id', 'asc') + } else { + builder.limit(500).orderBy('event_created_at', 'asc').orderBy('event_id', 'asc') } - if (typeof currentFilter.since === 'number') { - builder.where('event_created_at', '>=', currentFilter.since) + if (isTagQuery) { + builder.select('events.*') } - if (typeof currentFilter.until === 'number') { - builder.where('event_created_at', '<=', currentFilter.until) - } + return builder + }) + + const [query, ...subqueries] = queries + if (subqueries.length) { + query.union(subqueries, true) + } + + return query + } + + public async countByFilters(filters: SubscriptionFilter[]): Promise { + logger('counting events for %o', filters) + + if (!Array.isArray(filters) || !filters.length) { + throw new Error('Filters cannot be empty') + } + + const now = Math.floor(Date.now() / 1000) + + const queries = filters.map((currentFilter) => { + const builder = this.readReplicaDbClient('events').select('events.event_id') + + const isTagQuery = this.applyFilterConditions(builder, currentFilter) if (typeof currentFilter.limit === 'number') { builder.limit(currentFilter.limit).orderBy('event_created_at', 'DESC').orderBy('event_id', 'asc') - } else { - builder.limit(500).orderBy('event_created_at', 'asc').orderBy('event_id', 'asc') } - const andWhereRaw = invoker(1, 'andWhereRaw') - const orWhereRaw = invoker(2, 'orWhereRaw') - - let isTagQuery = false - pipe( - toPairs, - filter(pipe(nth(0) as () => string, isGenericTagQuery)) as any, - forEach(([filterName, criteria]: [string, string[]]) => { - isTagQuery = true - builder.andWhere((bd) => { - ifElse( - isEmpty, - () => andWhereRaw('1 = 0', bd), - forEach( - (criterion: string) => - void orWhereRaw( - 'event_tags.tag_name = ? AND event_tags.tag_value = ?', - [filterName[1], criterion], - bd, - ), - ), - )(criteria) - }) - }), - )(currentFilter as any) - if (isTagQuery) { - builder.leftJoin('event_tags', 'events.event_id', 'event_tags.event_id').select('events.*') + builder.select('events.event_id') } + builder.whereNull('events.deleted_at').andWhere((bd) => { + bd.whereNull('events.expires_at').orWhere('events.expires_at', '>', now) + }) + return builder }) @@ -165,7 +128,83 @@ export class EventRepository implements IEventRepository { query.union(subqueries, true) } - return query + const result = await this.readReplicaDbClient.from(query.as('matching_events')).countDistinct({ count: 'event_id' }).first() + + return Number(result?.count ?? 0) + } + + private applyFilterConditions(builder: any, currentFilter: SubscriptionFilter): boolean { + forEachObjIndexed((tableFields: string[], filterName: string | number) => { + builder.andWhere((bd) => { + cond([ + [isEmpty, () => void bd.whereRaw('1 = 0')], + [ + complement(isNil), + pipe( + groupByLengthSpec, + evolve({ + exact: (pubkeys: string[]) => + tableFields.forEach((tableField) => bd.orWhereIn(tableField, pubkeys.map(toBuffer))), + even: forEach((prefix: string) => + tableFields.forEach((tableField) => + bd.orWhereRaw(`substring("${tableField}" from 1 for ?) = ?`, [prefix.length >> 1, toBuffer(prefix)]), + ), + ), + odd: forEach((prefix: string) => + tableFields.forEach((tableField) => + bd.orWhereRaw(`substring("${tableField}" from 1 for ?) BETWEEN ? AND ?`, [ + (prefix.length >> 1) + 1, + `\\x${prefix}0`, + `\\x${prefix}f`, + ]), + ), + ), + } as any), + ), + ], + ])(currentFilter[filterName] as string[]) + }) + })({ authors: ['event_pubkey'], ids: ['event_id'] }) + + if (Array.isArray(currentFilter.kinds)) { + builder.whereIn('event_kind', currentFilter.kinds) + } + + if (typeof currentFilter.since === 'number') { + builder.where('event_created_at', '>=', currentFilter.since) + } + + if (typeof currentFilter.until === 'number') { + builder.where('event_created_at', '<=', currentFilter.until) + } + + const andWhereRaw = invoker(1, 'andWhereRaw') + const orWhereRaw = invoker(2, 'orWhereRaw') + + let isTagQuery = false + pipe( + toPairs, + filter(pipe(nth(0) as () => string, isGenericTagQuery)) as any, + forEach(([filterName, criteria]: [string, string[]]) => { + isTagQuery = true + builder.andWhere((bd) => { + ifElse( + isEmpty, + () => andWhereRaw('1 = 0', bd), + forEach( + (criterion: string) => + void orWhereRaw('event_tags.tag_name = ? AND event_tags.tag_value = ?', [filterName[1], criterion], bd), + ), + )(criteria) + }) + }), + )(currentFilter as any) + + if (isTagQuery) { + builder.leftJoin('event_tags', 'events.event_id', 'event_tags.event_id') + } + + return isTagQuery } public async create(event: Event): Promise { diff --git a/src/schemas/message-schema.ts b/src/schemas/message-schema.ts index 28a1ae75..53b8f09f 100644 --- a/src/schemas/message-schema.ts +++ b/src/schemas/message-schema.ts @@ -30,6 +30,29 @@ export const reqMessageSchema = z } }) +export const countMessageSchema = z + .tuple([z.literal(MessageType.COUNT), z.string().max(256).min(1)]) + .rest(filterSchema) + .superRefine((val, ctx) => { + if (val.length < 3) { + ctx.addIssue({ + code: z.ZodIssueCode.too_small, + minimum: 3, + type: 'array', + inclusive: true, + message: 'COUNT message must contain at least one filter', + }) + } else if (val.length > 12) { + ctx.addIssue({ + code: z.ZodIssueCode.too_big, + maximum: 12, + type: 'array', + inclusive: true, + message: 'COUNT message must contain at most 12 elements', + }) + } + }) + export const closeMessageSchema = z.tuple([z.literal(MessageType.CLOSE), subscriptionSchema]) -export const messageSchema = z.union([eventMessageSchema, reqMessageSchema, closeMessageSchema]) +export const messageSchema = z.union([eventMessageSchema, reqMessageSchema, closeMessageSchema, countMessageSchema]) diff --git a/src/utils/messages.ts b/src/utils/messages.ts index ae6bb694..0b98d5f4 100644 --- a/src/utils/messages.ts +++ b/src/utils/messages.ts @@ -1,4 +1,7 @@ import { + ClosedMessage, + CountResultMessage, + CountResultPayload, EndOfStoredEventsNotice, IncomingEventMessage, IncomingRelayedEventMessage, @@ -29,6 +32,15 @@ export const createCommandResult = (eventId: EventId, successful: boolean, messa return [MessageType.OK, eventId, successful, message] } +// NIP-45 +export const createCountResultMessage = (queryId: SubscriptionId, payload: CountResultPayload): CountResultMessage => { + return [MessageType.COUNT, queryId, payload] +} + +export const createClosedMessage = (queryId: SubscriptionId, reason: string): ClosedMessage => { + return [MessageType.CLOSED, queryId, reason] +} + export const createSubscriptionMessage = ( subscriptionId: SubscriptionId, filters: SubscriptionFilter[], diff --git a/test/unit/factories/message-handler-factory.spec.ts b/test/unit/factories/message-handler-factory.spec.ts index be6713cc..41124f4b 100644 --- a/test/unit/factories/message-handler-factory.spec.ts +++ b/test/unit/factories/message-handler-factory.spec.ts @@ -6,6 +6,7 @@ import { Event } from '../../../src/@types/event' import { EventMessageHandler } from '../../../src/handlers/event-message-handler' import { IWebSocketAdapter } from '../../../src/@types/adapters' import { messageHandlerFactory } from '../../../src/factories/message-handler-factory' +import { CountMessageHandler } from '../../../src/handlers/count-message-handler' import { SubscribeMessageHandler } from '../../../src/handlers/subscribe-message-handler' import { UnsubscribeMessageHandler } from '../../../src/handlers/unsubscribe-message-handler' import * as cacheModule from '../../../src/cache/client' @@ -67,6 +68,12 @@ describe('messageHandlerFactory', () => { expect(factory([message, adapter])).to.be.an.instanceOf(UnsubscribeMessageHandler) }) + it('returns CountMessageHandler when given a COUNT message', () => { + message = [MessageType.COUNT, 'q1', {}] as any + + expect(factory([message, adapter])).to.be.an.instanceOf(CountMessageHandler) + }) + it('throws when given an invalid message', () => { message = [] as any diff --git a/test/unit/handlers/count-message-handler.spec.ts b/test/unit/handlers/count-message-handler.spec.ts new file mode 100644 index 00000000..05befd01 --- /dev/null +++ b/test/unit/handlers/count-message-handler.spec.ts @@ -0,0 +1,148 @@ +import chai from 'chai' +import EventEmitter from 'events' +import Sinon from 'sinon' +import sinonChai from 'sinon-chai' + +import { IWebSocketAdapter } from '../../../src/@types/adapters' +import { MessageType } from '../../../src/@types/messages' +import { IEventRepository } from '../../../src/@types/repositories' +import { Settings } from '../../../src/@types/settings' +import { WebSocketAdapterEvent } from '../../../src/constants/adapter' +import { CountMessageHandler } from '../../../src/handlers/count-message-handler' + +chai.use(sinonChai) +const { expect } = chai + +describe('CountMessageHandler', () => { + let webSocket: IWebSocketAdapter + let handler: CountMessageHandler + let eventRepository: IEventRepository + let sandbox: Sinon.SinonSandbox + + beforeEach(() => { + sandbox = Sinon.createSandbox() + + eventRepository = { + countByFilters: sandbox.stub().resolves(7), + } as any + + webSocket = new EventEmitter() as any + + handler = new CountMessageHandler(webSocket, eventRepository, () => ({ + limits: { + client: { + subscription: { + maxFilters: 10, + maxSubscriptionIdLength: 256, + }, + }, + }, + }) as Settings) + }) + + afterEach(() => { + webSocket.removeAllListeners() + sandbox.restore() + }) + + describe('handleMessage()', () => { + let webSocketOnMessageStub: Sinon.SinonStub + + beforeEach(() => { + webSocketOnMessageStub = sandbox.stub() + webSocket.on(WebSocketAdapterEvent.Message, webSocketOnMessageStub) + }) + + it('returns COUNT with the result when counting works', async () => { + const message = [MessageType.COUNT, 'q1', {}] as any + + await handler.handleMessage(message) + + expect(eventRepository.countByFilters).to.have.been.calledOnceWithExactly([{}]) + expect(webSocketOnMessageStub).to.have.been.calledOnceWithExactly([MessageType.COUNT, 'q1', { count: 7 }]) + }) + + it('drops duplicate filters before querying the repository', async () => { + const repeatedFilter = { kinds: [1] } + const message = [MessageType.COUNT, 'q1', repeatedFilter, repeatedFilter] as any + + await handler.handleMessage(message) + + expect(eventRepository.countByFilters).to.have.been.calledOnceWithExactly([repeatedFilter]) + expect(webSocketOnMessageStub).to.have.been.calledOnceWithExactly([MessageType.COUNT, 'q1', { count: 7 }]) + }) + + it('returns CLOSED when the request has too many filters', async () => { + handler = new CountMessageHandler(webSocket, eventRepository, () => ({ + limits: { + client: { + subscription: { + maxFilters: 1, + maxSubscriptionIdLength: 256, + }, + }, + }, + }) as Settings) + + const message = [MessageType.COUNT, 'q1', { kinds: [1] }, { kinds: [2] }] as any + + await handler.handleMessage(message) + + expect(eventRepository.countByFilters).to.not.have.been.called + expect(webSocketOnMessageStub).to.have.been.calledOnce + expect(webSocketOnMessageStub.firstCall.args[0][0]).to.equal(MessageType.CLOSED) + expect(webSocketOnMessageStub.firstCall.args[0][1]).to.equal('q1') + }) + + it('returns CLOSED when the query ID is too long', async () => { + handler = new CountMessageHandler(webSocket, eventRepository, () => ({ + limits: { + client: { + subscription: { + maxFilters: 10, + maxSubscriptionIdLength: 2, + }, + }, + }, + }) as Settings) + + const message = [MessageType.COUNT, 'q123', {}] as any + + await handler.handleMessage(message) + + expect(eventRepository.countByFilters).to.not.have.been.called + expect(webSocketOnMessageStub).to.have.been.calledOnce + expect(webSocketOnMessageStub.firstCall.args[0][0]).to.equal(MessageType.CLOSED) + expect(webSocketOnMessageStub.firstCall.args[0][1]).to.equal('q123') + }) + + it('returns CLOSED when counting fails in the repository', async () => { + const countByFiltersStub = eventRepository.countByFilters as Sinon.SinonStub + countByFiltersStub.rejects(new Error('boom')) + const message = [MessageType.COUNT, 'q1', {}] as any + + await handler.handleMessage(message) + + expect(webSocketOnMessageStub).to.have.been.calledOnceWithExactly([ + MessageType.CLOSED, + 'q1', + 'error: unable to count events', + ]) + }) + + it('returns CLOSED when COUNT is disabled in settings', async () => { + handler = new CountMessageHandler(webSocket, eventRepository, () => ({ nip45: { enabled: false } }) as Settings) + + const message = [MessageType.COUNT, 'q1', {}] as any + + await handler.handleMessage(message) + + expect(eventRepository.countByFilters).to.not.have.been.called + expect(webSocketOnMessageStub).to.have.been.calledOnceWithExactly([ + MessageType.CLOSED, + 'q1', + 'COUNT is disabled by relay configuration', + ]) + }) + }) +}) diff --git a/test/unit/repositories/event-repository.spec.ts b/test/unit/repositories/event-repository.spec.ts index f2c71cb2..73a2929e 100644 --- a/test/unit/repositories/event-repository.spec.ts +++ b/test/unit/repositories/event-repository.spec.ts @@ -417,6 +417,97 @@ describe('EventRepository', () => { }) }) + describe('.countByFilters', () => { + it('throws error if filters is empty', async () => { + try { + await repository.countByFilters([]) + expect.fail('Expected countByFilters to throw') + } catch (error) { + expect((error as Error).message).to.equal('Filters cannot be empty') + } + }) + + it('returns count value from query result', async () => { + sandbox.stub(rrDbClient, 'from').returns({ + countDistinct: () => ({ + first: async () => ({ count: '42' }), + }), + } as any) + + const result = await repository.countByFilters([{}]) + + expect(result).to.equal(42) + }) + + it('uses countDistinct on event_id to avoid duplicate counts', async () => { + const countDistinctStub = sandbox.stub().returns({ + first: async () => ({ count: '1' }), + }) + + sandbox.stub(rrDbClient, 'from').returns({ countDistinct: countDistinctStub } as any) + + await repository.countByFilters([{ '#e': ['aaaaaa'] } as any]) + + expect(countDistinctStub).to.have.been.calledOnceWithExactly({ count: 'event_id' }) + }) + + it('builds union query when there are multiple filters', async () => { + const fromStub = sandbox.stub(rrDbClient, 'from').returns({ + countDistinct: () => ({ + first: async () => ({ count: '1' }), + }), + } as any) + + await repository.countByFilters([{ kinds: [1] }, { authors: ['22e804d26ed16b68db5259e78449e96dab5d464c8f470bda3eb1a70467f2c793'] }]) + + const sql = fromStub.firstCall.args[0].toString() + expect(sql).to.include(' union ') + }) + + it('joins tags table for generic tag filters', async () => { + const fromStub = sandbox.stub(rrDbClient, 'from').returns({ + countDistinct: () => ({ + first: async () => ({ count: '1' }), + }), + } as any) + + await repository.countByFilters([{ '#e': ['aaaaaa'] } as any]) + + const sql = fromStub.firstCall.args[0].toString() + expect(sql).to.include('left join "event_tags"') + expect(sql).to.include('event_tags.tag_name') + expect(sql).to.include('event_tags.tag_value') + }) + + it('applies limit ordering when a filter includes limit', async () => { + const fromStub = sandbox.stub(rrDbClient, 'from').returns({ + countDistinct: () => ({ + first: async () => ({ count: '1' }), + }), + } as any) + + await repository.countByFilters([{ limit: 3 }]) + + const sql = fromStub.firstCall.args[0].toString() + expect(sql).to.include('order by "event_created_at" DESC, "event_id" asc limit 3') + }) + + it('filters out deleted and expired events', async () => { + const fromStub = sandbox.stub(rrDbClient, 'from').returns({ + countDistinct: () => ({ + first: async () => ({ count: '1' }), + }), + } as any) + + await repository.countByFilters([{ kinds: [1] }]) + + const sql = fromStub.firstCall.args[0].toString() + expect(sql).to.include('"events"."deleted_at" is null') + expect(sql).to.include('"events"."expires_at" is null') + expect(sql).to.include('"events"."expires_at" >') + }) + }) + describe('.create', () => { let insertStub: sinon.SinonStub beforeEach(() => { diff --git a/test/unit/schemas/message-schema.spec.ts b/test/unit/schemas/message-schema.spec.ts index e3913591..0b5e00a2 100644 --- a/test/unit/schemas/message-schema.spec.ts +++ b/test/unit/schemas/message-schema.spec.ts @@ -112,5 +112,60 @@ describe('NIP-01', () => { expect(result).to.have.property('error').that.is.not.undefined }) }) + + describe('COUNT', () => { + beforeEach(() => { + message = [ + 'COUNT', + 'id', + { + ids: ['aaaa', 'bbbb', 'cccc'], + authors: ['aaaa', 'bbbb', 'cccc'], + kinds: [0, 1, 2, 3], + since: 1000, + until: 1000, + limit: 100, + '#e': ['aa', 'bb', 'cc'], + '#p': ['dd', 'ee', 'ff'], + '#r': ['00', '11', '22'], + }, + ] as any + }) + + it('returns same message if valid', () => { + const result = validateSchema(messageSchema)(message) + expect(result.error).to.be.undefined + expect(result).to.have.deep.property('value', message) + }) + + it('returns error if query ID is missing', () => { + message[1] = null + + const result = validateSchema(messageSchema)(message) + expect(result).to.have.property('error').that.is.not.undefined + }) + + it('returns error if filter is missing', () => { + ;(message as any[]).splice(2, 1) + + const result = validateSchema(messageSchema)(message) + expect(result).to.have.property('error').that.is.not.undefined + }) + + it('returns error if filter is not an object', () => { + message[2] = null + + const result = validateSchema(messageSchema)(message) + expect(result).to.have.property('error').that.is.not.undefined + }) + + it('returns error if there are too many filters', () => { + ;(message as any[]).splice(2, 1) + ;(message as any[]).push(...range(0, 11).map(() => ({}))) + + const result = validateSchema(messageSchema)(message) + expect(result).to.have.property('error').that.is.not.undefined + }) + }) }) })