Skip to content

Commit 90b64ab

Browse files
committed
feat: add COUNT message validation, response helpers, repository counting, and handler routing
1 parent 7b92f78 commit 90b64ab

6 files changed

Lines changed: 204 additions & 1 deletion

File tree

src/@types/repositories.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ export interface IEventRepository {
3333
upsert(event: Event): Promise<number>
3434
upsertMany(events: Event[]): Promise<number>
3535
findByFilters(filters: SubscriptionFilter[]): IQueryResult<DBEvent[]>
36+
countByFilters(filters: SubscriptionFilter[]): Promise<number>
3637
deleteByPubkeyAndIds(pubkey: Pubkey, ids: EventId[]): Promise<number>
3738
deleteByPubkeyExceptKinds(pubkey: Pubkey, excludedKinds: number[]): Promise<number>
3839
hasActiveRequestToVanish(pubkey: Pubkey): Promise<boolean>

src/factories/message-handler-factory.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { ICacheAdapter, IWebSocketAdapter } from '../@types/adapters'
22
import { IEventRepository, INip05VerificationRepository, IUserRepository } from '../@types/repositories'
33
import { IncomingMessage, MessageType } from '../@types/messages'
44
import { createSettings } from './settings-factory'
5+
import { CountMessageHandler } from '../handlers/count-message-handler'
56
import { EventMessageHandler } from '../handlers/event-message-handler'
67
import { eventStrategyFactory } from './event-strategy-factory'
78
import { getCacheClient } from '../cache/client'
@@ -42,6 +43,8 @@ export const messageHandlerFactory =
4243
return new SubscribeMessageHandler(adapter, eventRepository, createSettings)
4344
case MessageType.CLOSE:
4445
return new UnsubscribeMessageHandler(adapter)
46+
case MessageType.COUNT:
47+
return new CountMessageHandler(adapter, eventRepository, createSettings)
4548
default:
4649
throw new Error(`Unknown message type: ${String(message[0]).substring(0, 64)}`)
4750
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import { equals, uniqWith } from 'ramda'
2+
3+
import { IWebSocketAdapter } from '../@types/adapters'
4+
import { IMessageHandler } from '../@types/message-handlers'
5+
import { CountMessage } from '../@types/messages'
6+
import { IEventRepository } from '../@types/repositories'
7+
import { Settings } from '../@types/settings'
8+
import { SubscriptionFilter } from '../@types/subscription'
9+
import { WebSocketAdapterEvent } from '../constants/adapter'
10+
import { createLogger } from '../factories/logger-factory'
11+
import { createClosedMessage, createCountResultMessage } from '../utils/messages'
12+
13+
const debug = createLogger('count-message-handler')
14+
15+
export class CountMessageHandler implements IMessageHandler {
16+
public constructor(
17+
private readonly webSocket: IWebSocketAdapter,
18+
private readonly eventRepository: IEventRepository,
19+
private readonly settings: () => Settings,
20+
) {}
21+
22+
public async handleMessage(message: CountMessage): Promise<void> {
23+
const queryId = message[1]
24+
// Some clients send the same filter more than once.
25+
// We remove duplicates so we do less DB work.
26+
const filters = uniqWith(equals, message.slice(2)) as SubscriptionFilter[]
27+
28+
const reason = this.canCount(queryId, filters)
29+
if (reason) {
30+
debug('count request %s with %o rejected: %s', queryId, filters, reason)
31+
// NIP-45 says we should close rejected COUNT requests with a reason.
32+
this.webSocket.emit(WebSocketAdapterEvent.Message, createClosedMessage(queryId, reason))
33+
return
34+
}
35+
36+
try {
37+
const count = await this.eventRepository.countByFilters(filters)
38+
this.webSocket.emit(WebSocketAdapterEvent.Message, createCountResultMessage(queryId, { count }))
39+
} catch (error) {
40+
debug('count request %s failed: %o', queryId, error)
41+
// Keep this message generic so internal errors are not leaked to clients.
42+
this.webSocket.emit(WebSocketAdapterEvent.Message, createClosedMessage(queryId, 'error: unable to count events'))
43+
}
44+
}
45+
46+
private canCount(queryId: string, filters: SubscriptionFilter[]): string | undefined {
47+
const subscriptionLimits = this.settings().limits?.client?.subscription
48+
const maxFilters = subscriptionLimits?.maxFilters ?? 0
49+
50+
if (maxFilters > 0 && filters.length > maxFilters) {
51+
return `Too many filters: Number of filters per count query must be less then or equal to ${maxFilters}`
52+
}
53+
54+
if (
55+
typeof subscriptionLimits?.maxSubscriptionIdLength === 'number' &&
56+
queryId.length > subscriptionLimits.maxSubscriptionIdLength
57+
) {
58+
return `Query ID too long: Query ID must be less or equal to ${subscriptionLimits.maxSubscriptionIdLength}`
59+
}
60+
}
61+
}

src/repositories/event-repository.ts

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,109 @@ export class EventRepository implements IEventRepository {
168168
return query
169169
}
170170

171+
public async countByFilters(filters: SubscriptionFilter[]): Promise<number> {
172+
debug('counting events for %o', filters)
173+
174+
if (!Array.isArray(filters) || !filters.length) {
175+
throw new Error('Filters cannot be empty')
176+
}
177+
178+
const now = Math.floor(Date.now() / 1000)
179+
180+
const queries = filters.map((currentFilter) => {
181+
const builder = this.readReplicaDbClient<DBEvent>('events').select('events.event_id')
182+
183+
forEachObjIndexed((tableFields: string[], filterName: string | number) => {
184+
builder.andWhere((bd) => {
185+
cond([
186+
[isEmpty, () => void bd.whereRaw('1 = 0')],
187+
[
188+
complement(isNil),
189+
pipe(
190+
groupByLengthSpec,
191+
evolve({
192+
exact: (pubkeys: string[]) =>
193+
tableFields.forEach((tableField) => bd.orWhereIn(tableField, pubkeys.map(toBuffer))),
194+
even: forEach((prefix: string) =>
195+
tableFields.forEach((tableField) =>
196+
bd.orWhereRaw(`substring("${tableField}" from 1 for ?) = ?`, [prefix.length >> 1, toBuffer(prefix)]),
197+
),
198+
),
199+
odd: forEach((prefix: string) =>
200+
tableFields.forEach((tableField) =>
201+
bd.orWhereRaw(`substring("${tableField}" from 1 for ?) BETWEEN ? AND ?`, [
202+
(prefix.length >> 1) + 1,
203+
`\\x${prefix}0`,
204+
`\\x${prefix}f`,
205+
]),
206+
),
207+
),
208+
} as any),
209+
),
210+
],
211+
])(currentFilter[filterName] as string[])
212+
})
213+
})({ authors: ['event_pubkey'], ids: ['event_id'] })
214+
215+
if (Array.isArray(currentFilter.kinds)) {
216+
builder.whereIn('event_kind', currentFilter.kinds)
217+
}
218+
219+
if (typeof currentFilter.since === 'number') {
220+
builder.where('event_created_at', '>=', currentFilter.since)
221+
}
222+
223+
if (typeof currentFilter.until === 'number') {
224+
builder.where('event_created_at', '<=', currentFilter.until)
225+
}
226+
227+
if (typeof currentFilter.limit === 'number') {
228+
builder.limit(currentFilter.limit).orderBy('event_created_at', 'DESC').orderBy('event_id', 'asc')
229+
}
230+
231+
const andWhereRaw = invoker(1, 'andWhereRaw')
232+
const orWhereRaw = invoker(2, 'orWhereRaw')
233+
234+
let isTagQuery = false
235+
pipe(
236+
toPairs,
237+
filter(pipe(nth(0) as () => string, isGenericTagQuery)) as any,
238+
forEach(([filterName, criteria]: [string, string[]]) => {
239+
isTagQuery = true
240+
builder.andWhere((bd) => {
241+
ifElse(
242+
isEmpty,
243+
() => andWhereRaw('1 = 0', bd),
244+
forEach(
245+
(criterion: string) =>
246+
void orWhereRaw('event_tags.tag_name = ? AND event_tags.tag_value = ?', [filterName[1], criterion], bd),
247+
),
248+
)(criteria)
249+
})
250+
}),
251+
)(currentFilter as any)
252+
253+
if (isTagQuery) {
254+
builder.leftJoin('event_tags', 'events.event_id', 'event_tags.event_id').select('events.event_id')
255+
}
256+
257+
builder.whereNull('events.deleted_at').andWhere((bd) => {
258+
bd.whereNull('events.expires_at').orWhere('events.expires_at', '>', now)
259+
})
260+
261+
return builder
262+
})
263+
264+
const [query, ...subqueries] = queries
265+
if (subqueries.length) {
266+
query.union(subqueries, true)
267+
}
268+
269+
const result = await this.readReplicaDbClient.from(query.as('matching_events')).countDistinct({ count: 'event_id' }).first()
270+
271+
return Number(result?.count ?? 0)
272+
}
273+
171274
public async create(event: Event): Promise<number> {
172275
return this.insert(event).then(prop('rowCount') as () => number, () => 0)
173276
}

src/schemas/message-schema.ts

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,29 @@ export const reqMessageSchema = z
3030
}
3131
})
3232

33+
export const countMessageSchema = z
34+
.tuple([z.literal(MessageType.COUNT), z.string().max(256).min(1)])
35+
.rest(filterSchema)
36+
.superRefine((val, ctx) => {
37+
if (val.length < 3) {
38+
ctx.addIssue({
39+
code: z.ZodIssueCode.too_small,
40+
minimum: 3,
41+
type: 'array',
42+
inclusive: true,
43+
message: 'COUNT message must contain at least one filter',
44+
})
45+
} else if (val.length > 12) {
46+
ctx.addIssue({
47+
code: z.ZodIssueCode.too_big,
48+
maximum: 12,
49+
type: 'array',
50+
inclusive: true,
51+
message: 'COUNT message must contain at most 12 elements',
52+
})
53+
}
54+
})
55+
3356
export const closeMessageSchema = z.tuple([z.literal(MessageType.CLOSE), subscriptionSchema])
3457

35-
export const messageSchema = z.union([eventMessageSchema, reqMessageSchema, closeMessageSchema])
58+
export const messageSchema = z.union([eventMessageSchema, reqMessageSchema, closeMessageSchema, countMessageSchema])

src/utils/messages.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import {
2+
ClosedMessage,
3+
CountResultMessage,
4+
CountResultPayload,
25
EndOfStoredEventsNotice,
36
IncomingEventMessage,
47
IncomingRelayedEventMessage,
@@ -29,6 +32,15 @@ export const createCommandResult = (eventId: EventId, successful: boolean, messa
2932
return [MessageType.OK, eventId, successful, message]
3033
}
3134

35+
// NIP-45
36+
export const createCountResultMessage = (queryId: SubscriptionId, payload: CountResultPayload): CountResultMessage => {
37+
return [MessageType.COUNT, queryId, payload]
38+
}
39+
40+
export const createClosedMessage = (queryId: SubscriptionId, reason: string): ClosedMessage => {
41+
return [MessageType.CLOSED, queryId, reason]
42+
}
43+
3244
export const createSubscriptionMessage = (
3345
subscriptionId: SubscriptionId,
3446
filters: SubscriptionFilter[],

0 commit comments

Comments
 (0)