diff --git a/PROTOCOL.md b/PROTOCOL.md index e35a4cdd..54d76da5 100644 --- a/PROTOCOL.md +++ b/PROTOCOL.md @@ -196,7 +196,7 @@ type ProtocolError = UncaughtError | InvalidRequestError | CancelError; `ProtocolError`s, just like service-level errors, are wrapped with a `Result`, which is further wrapped with `TransportMessage` and MUST have a `StreamCancelBit` flag. Please note that these are separate from user-defined errors, which should be treated just like any response message. -There are 4 `Control` payloads: +There are 6 `Control` payloads: ```ts // Used in cases where we want to send a close without @@ -241,11 +241,28 @@ interface ControlHandshakeResponse { }; } +// Sent by the server to ask the client to re-handshake — re-construct its +// handshake metadata (e.g. fetch a fresh token) over the live connection. Sent +// on the reserved `rehandshake` streamId with no control flags. +interface ControlRehandshakeRequest { + type: 'REHANDSHAKE_REQ'; +} + +// Sent by the client in response to a ControlRehandshakeRequest, carrying +// freshly constructed handshake metadata for the server to re-validate. Sent on +// the reserved `rehandshake` streamId with no control flags. +interface ControlRehandshakeResponse { + type: 'REHANDSHAKE_RESP'; + metadata?: unknown; +} + type Control = | ControlClose | ControlAck | ControlHandshakeRequest - | ControlHandshakeResponse; + | ControlHandshakeResponse + | ControlRehandshakeRequest + | ControlRehandshakeResponse; ``` `Control` is a payload that is wrapped with `TransportMessage`. @@ -305,6 +322,7 @@ When a message is received, it MUST be validated before being processed. - Match the JSON schema for the `TransportMessage` type. - Have an existing session for the transport `clientId` in the `from` field (see the 'Transports, Sessions, and Connections' heading for more information on sessions and transports). +- The `from` field MUST match the authenticated peer of the session/connection that delivered the message (the identity established at handshake). A message whose `from` names a different client is a protocol violation — it MUST NOT be processed, and the delivering connection MUST be torn down. Without this check a connected client could spoof `from` to act as another client (using that client's metadata/identity). - The `to` field of the message MUST match the transport's `clientId`. - Have the expected `seq` number (see the 'Handling Transparent Reconnections' heading for more information on seq/ack). - Is not an explicit heartbeat (i.e. the `AckBit` is not set). @@ -583,6 +601,18 @@ The server will send an error response if either: When the client receives a status with `ok: false`, it should consider the handshake failed and close the connection. +### Re-handshaking (live credential refresh) + +Handshake metadata (e.g. an auth token) can be refreshed over an already-established connection without dropping the session — a "follow-up handshake". This lets a server keep long-lived sessions alive across short-lived credentials. + +The exchange reuses the heartbeat mechanism: control messages on the reserved `rehandshake` streamId that update transport-level bookkeeping like any other message but are consumed by the transport and never surfaced to procedure handlers. + +1. The server sends a `ControlRehandshakeRequest` to ask the client to re-handshake. When to do this is up to the server (e.g. shortly before a token's expiry). +2. The client re-runs the same metadata construction it used during the original handshake and replies with a `ControlRehandshakeResponse` carrying the new metadata. +3. The server re-validates the metadata exactly as it would during a handshake. On success it replaces the metadata stored for the session (which the application surfaces to its handlers); on failure (malformed, rejected, or no response within the handshake timeout) it tears the session down. + +Because the metadata is re-validated on every (re)handshake as well, the re-handshake schedule naturally re-establishes itself after a transparent reconnect. + ### Transparent reconnections River handles disconnections and reconnections in a transparent manner wherever possible when diff --git a/README.md b/README.md index c07fdc7c..b2f6e467 100644 --- a/README.md +++ b/README.md @@ -713,6 +713,38 @@ async handler(ctx, ...args) { } ``` +#### Re-handshaking (refreshing handshake metadata) + +For long-lived sessions with short-lived credentials (e.g. a JWT), the server can ask a connected client to re-handshake — a follow-up handshake that refreshes its metadata without dropping the session. The client re-runs the same `construct` function it used during the original handshake, and the server re-runs `validate` on the result and replaces the stored metadata. + +`ctx.metadata` is live: a handler that re-reads it (including a long-running stream or subscription that was already in flight when the re-handshake happened) observes the new value. This is the point of re-handshaking — the operations that outlive a token are exactly the ones that need the new token. If you want a value fixed for the lifetime of a call, destructure it once (e.g. `const { token } = ctx.metadata`). Because `validate` receives the previous parsed metadata, it can enforce that a re-handshake stays the same principal. + +A re-handshake can be triggered manually from the server transport: + +```ts +serverTransport.requestRehandshake('client-id'); +``` + +More commonly, let the server schedule re-handshakes automatically by passing a third argument to `createServerHandshakeOptions` — a `Date` for when the credential expires. The server re-handshakes shortly before it: + +```ts +createServerHandshakeOptions( + handshakeSchema, + (metadata) => ({ + parsedToken: metadata.token, + expiresAt: getExpiry(metadata.token), + }), + // when this credential expires + (parsed) => parsed.expiresAt, +); +``` + +The server fires the re-handshake one `handshakeTimeoutMs` before the expiry you return, so the exchange resolves by then: either the refresh lands, or — if the client never answers — its deadline elapses and the session is torn down. Net effect: the session never serves past expiry. (This assumes `handshakeTimeoutMs` is comfortably shorter than the credential's lifetime, which holds for any realistic token.) + +If the client fails to return valid metadata (rejected by `validate`, malformed, or no response before the handshake timeout), the server tears the session down. The client then reconnects with a fresh handshake, which re-establishes the schedule. + +Re-handshaking is scheduling, not request gating — it doesn't pause in-flight requests — so still reject already-expired credentials in `validate` and/or by checking the live `ctx.metadata` in your handlers. + ## Protobuf Services (Experimental) River also supports defining services using Protocol Buffers. Instead of TypeBox schemas, you define your service in a `.proto` file and use the generated descriptors directly. diff --git a/__tests__/e2e.test.ts b/__tests__/e2e.test.ts index 63454e32..be360d73 100644 --- a/__tests__/e2e.test.ts +++ b/__tests__/e2e.test.ts @@ -5,6 +5,7 @@ import { isReadableDone, numberOfConnections, readNextResult, + testingSessionOptions, } from '../testUtil'; import { createServer } from '../router/server'; import { createClient } from '../router/client'; @@ -29,7 +30,7 @@ import { waitFor, } from '../testUtil/fixtures/cleanup'; import { testMatrix } from '../testUtil/fixtures/matrix'; -import { Type } from 'typebox'; +import { Static, Type } from 'typebox'; import { Procedure, createServiceSchema, @@ -42,6 +43,7 @@ import { createClientHandshakeOptions, createServerHandshakeOptions, } from '../router/handshake'; +import { RehandshakeStreamId } from '../transport/message'; import { TestSetupHelpers } from '../testUtil/fixtures/transports'; describe.each(testMatrix())( @@ -1315,6 +1317,438 @@ describe.each(testMatrix())( }); }); + test('server can refresh handshake metadata over a live connection', async () => { + const requestSchema = Type.Object({ token: Type.String() }); + + type ParsedMetadata = Static; + + let token = 'token-v1'; + const construct = vi.fn(() => ({ token })); + const clientTransport = getClientTransport( + 'client', + createClientHandshakeOptions(requestSchema, construct), + ); + const validate = vi.fn( + ( + metadata: ParsedMetadata, + _prev?: ParsedMetadata, + _from?: string, + ): ParsedMetadata => ({ token: metadata.token }), + ); + const serverTransport = getServerTransport( + 'SERVER', + createServerHandshakeOptions( + requestSchema, + validate, + ), + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const ServiceSchema = createServiceSchema< + MaybeDisposable, + ParsedMetadata + >(); + const services = { + test: ServiceSchema.define({ + getToken: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({ token: Type.String() }), + handler: async ({ ctx }) => Ok({ token: ctx.metadata.token }), + }), + }), + }; + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + // establish the session with the initial token + const before = await client.test.getToken.rpc({}); + expect(before).toStrictEqual({ + ok: true, + payload: { token: 'token-v1' }, + }); + + // ask the client to refresh; construct now hands back the new token + token = 'token-v2'; + expect(serverTransport.requestRehandshake('client')).toBe(true); + + await waitFor(() => + expect( + serverTransport.sessionHandshakeMetadata.get('client'), + ).toStrictEqual({ token: 'token-v2' }), + ); + + // the initial handshake and the re-handshake both bind to the client id + expect(validate.mock.calls.map((call) => call[2])).toEqual([ + 'client', + 'client', + ]); + + // subsequent calls observe the refreshed metadata + const after = await client.test.getToken.rpc({}); + expect(after).toStrictEqual({ ok: true, payload: { token: 'token-v2' } }); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('server proactively re-handshakes via expiry', async () => { + const requestSchema = Type.Object({ token: Type.String() }); + + type ParsedMetadata = Static; + + let token = 'token-v1'; + const construct = vi.fn(() => ({ token })); + const clientTransport = getClientTransport( + 'client', + createClientHandshakeOptions(requestSchema, construct), + ); + const validate = vi.fn((metadata: ParsedMetadata) => ({ + token: metadata.token, + })); + const serverTransport = getServerTransport( + 'SERVER', + createServerHandshakeOptions( + requestSchema, + validate, + // expire the first token soon, so the server re-handshakes shortly + // after connecting (one handshake window before this), then stop + (parsed) => + parsed.token === 'token-v1' + ? new Date( + Date.now() + testingSessionOptions.handshakeTimeoutMs + 100, + ) + : undefined, + ), + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const ServiceSchema = createServiceSchema< + MaybeDisposable, + ParsedMetadata + >(); + const services = { + test: ServiceSchema.define({ + getToken: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({ token: Type.String() }), + handler: async ({ ctx }) => Ok({ token: ctx.metadata.token }), + }), + }), + }; + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const before = await client.test.getToken.rpc({}); + expect(before).toStrictEqual({ + ok: true, + payload: { token: 'token-v1' }, + }); + + // the scheduled refresh fires on its own; construct now returns v2 + token = 'token-v2'; + await waitFor(() => + expect( + serverTransport.sessionHandshakeMetadata.get('client'), + ).toStrictEqual({ token: 'token-v2' }), + ); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('a rejected metadata refresh tears down the session', async () => { + const requestSchema = Type.Object({ token: Type.String() }); + + type ParsedMetadata = Static; + + let token = 'token-v1'; + const construct = vi.fn(() => ({ token })); + const clientTransport = getClientTransport( + 'client', + createClientHandshakeOptions(requestSchema, construct), + ); + const validate = vi.fn( + ( + metadata: ParsedMetadata, + ): ParsedMetadata | 'REJECTED_BY_CUSTOM_HANDLER' => + metadata.token === 'token-v1' + ? { token: metadata.token } + : 'REJECTED_BY_CUSTOM_HANDLER', + ); + const serverTransport = getServerTransport< + typeof requestSchema, + ParsedMetadata + >( + 'SERVER', + createServerHandshakeOptions( + requestSchema, + validate, + ), + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const ServiceSchema = createServiceSchema< + MaybeDisposable, + ParsedMetadata + >(); + const services = { + test: ServiceSchema.define({ + getToken: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({ token: Type.String() }), + handler: async ({ ctx }) => Ok({ token: ctx.metadata.token }), + }), + }), + }; + createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + await client.test.getToken.rpc({}); + expect(numberOfConnections(serverTransport)).toEqual(1); + + // the client would otherwise reconnect with the same bad token; keep it + // offline so we can assert the teardown deterministically + clientTransport.reconnectOnConnectionDrop = false; + + // the refreshed token is rejected, so the server tears the session down + token = 'token-v2'; + serverTransport.requestRehandshake('client'); + + await waitFor(() => + expect(serverTransport.sessions.has('client')).toBe(false), + ); + await waitFor(() => expect(numberOfConnections(clientTransport)).toBe(0)); + + // let the client's now-disconnected session lapse before cleanup + await advanceFakeTimersBySessionGrace(); + }); + + test('an in-flight handler observes refreshed metadata mid-stream', async () => { + const requestSchema = Type.Object({ token: Type.String() }); + + type ParsedMetadata = Static; + + let token = 'token-v1'; + const construct = vi.fn(() => ({ token })); + const clientTransport = getClientTransport( + 'client', + createClientHandshakeOptions(requestSchema, construct), + ); + const validate = vi.fn((metadata: ParsedMetadata) => ({ + token: metadata.token, + })); + const serverTransport = getServerTransport( + 'SERVER', + createServerHandshakeOptions(requestSchema, validate), + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const ServiceSchema = createServiceSchema< + MaybeDisposable, + ParsedMetadata + >(); + const services = { + test: ServiceSchema.define({ + // echoes the current metadata token for every request it receives, + // so a single long-lived handler can be observed across a refresh + echoToken: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({ token: Type.String() }), + handler: async ({ ctx, reqReadable, resWritable }) => { + for await (const msg of reqReadable) { + if (!msg.ok) break; + resWritable.write(Ok({ token: ctx.metadata.token })); + } + resWritable.close(); + }, + }), + }), + }; + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { reqWritable, resReadable } = client.test.echoToken.stream({}); + + reqWritable.write({}); + expect(await readNextResult(resReadable)).toStrictEqual({ + ok: true, + payload: { token: 'token-v1' }, + }); + + // refresh while the stream handler is still running + token = 'token-v2'; + expect(serverTransport.requestRehandshake('client')).toBe(true); + await waitFor(() => + expect( + serverTransport.sessionHandshakeMetadata.get('client'), + ).toStrictEqual({ token: 'token-v2' }), + ); + + // the same handler now sees the refreshed token + reqWritable.write({}); + expect(await readNextResult(resReadable)).toStrictEqual({ + ok: true, + payload: { token: 'token-v2' }, + }); + + reqWritable.close(); + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('a refresh the client never answers tears the session down', async () => { + const requestSchema = Type.Object({ token: Type.String() }); + + type ParsedMetadata = Static; + + let failRefresh = false; + const construct = vi.fn(() => { + if (failRefresh) { + // a client that refuses to hand back a fresh token + throw new Error('client refuses to refresh'); + } + + return { token: 'token-v1' }; + }); + const clientTransport = getClientTransport( + 'client', + createClientHandshakeOptions(requestSchema, construct), + ); + const validate = vi.fn((metadata: ParsedMetadata) => ({ + token: metadata.token, + })); + const serverTransport = getServerTransport( + 'SERVER', + createServerHandshakeOptions(requestSchema, validate), + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const ServiceSchema = createServiceSchema< + MaybeDisposable, + ParsedMetadata + >(); + const services = { + test: ServiceSchema.define({ + getToken: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({ token: Type.String() }), + handler: async ({ ctx }) => Ok({ token: ctx.metadata.token }), + }), + }), + }; + createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + await client.test.getToken.rpc({}); + expect(numberOfConnections(serverTransport)).toEqual(1); + + // the client will ignore the refresh; keep it offline so the teardown is + // observable rather than racing a reconnect + failRefresh = true; + clientTransport.reconnectOnConnectionDrop = false; + + serverTransport.requestRehandshake('client'); + + // with no valid response, the deadline elapses and the server tears the + // session down rather than trusting the stale metadata indefinitely + await vi.advanceTimersByTimeAsync( + testingSessionOptions.handshakeTimeoutMs + 1, + ); + await waitFor(() => + expect(serverTransport.sessions.has('client')).toBe(false), + ); + + await advanceFakeTimersBySessionGrace(); + }); + + test('a malformed re-handshake frame tears the session down', async () => { + const requestSchema = Type.Object({ token: Type.String() }); + + type ParsedMetadata = Static; + + const construct = vi.fn(() => ({ token: 'token-v1' })); + const clientTransport = getClientTransport( + 'client', + createClientHandshakeOptions(requestSchema, construct), + ); + const validate = vi.fn((metadata: ParsedMetadata) => ({ + token: metadata.token, + })); + const serverTransport = getServerTransport( + 'SERVER', + createServerHandshakeOptions(requestSchema, validate), + ); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + clientTransport.connect(serverTransport.clientId); + await waitFor(() => { + expect(serverTransport.sessions.has('client')).toBe(true); + expect(numberOfConnections(clientTransport)).toBe(1); + }); + + // keep the client offline so the teardown is observable, not racing a reconnect + clientTransport.reconnectOnConnectionDrop = false; + + // a connected peer sends a garbage payload on the reserved re-handshake + // stream; the server treats the protocol violation as a failed re-handshake + const clientSession = clientTransport.sessions.get( + serverTransport.clientId, + ); + assert(clientSession); + const send = clientTransport.getSessionBoundSendFn( + serverTransport.clientId, + clientSession.id, + ); + send({ + streamId: RehandshakeStreamId, + controlFlags: 0, + payload: { type: 'NOT_A_REHANDSHAKE_RESPONSE' }, + }); + + await waitFor(() => + expect(serverTransport.sessions.has('client')).toBe(false), + ); + + await advanceFakeTimersBySessionGrace(); + }); + test('validate receives the connecting client id', async () => { const requestSchema = Type.Object({}); @@ -1322,18 +1756,21 @@ describe.each(testMatrix())( seenFrom: string; } + const construct = vi.fn(() => ({})); const clientTransport = getClientTransport( 'client', - createClientHandshakeOptions(requestSchema, () => ({})), + createClientHandshakeOptions(requestSchema, construct), + ); + const validate = vi.fn( + ( + _metadata: Static, + _prev?: ParsedMetadata, + from?: string, + ): ParsedMetadata => ({ seenFrom: from ?? '' }), ); const serverTransport = getServerTransport( 'SERVER', - createServerHandshakeOptions( - requestSchema, - (_metadata, _prev, from) => ({ - seenFrom: from ?? '', - }), - ), + createServerHandshakeOptions(requestSchema, validate), ); addPostTestCleanup(async () => { await cleanupTransports([clientTransport, serverTransport]); diff --git a/protobuf/handshake.ts b/protobuf/handshake.ts index 686d5eb8..06693985 100644 --- a/protobuf/handshake.ts +++ b/protobuf/handshake.ts @@ -62,6 +62,7 @@ export function createServerHandshakeOptions< >( schema: Schema, validate: ValidateHandshake, + expiry?: (parsedMetadata: ParsedMetadata) => Date | undefined, ): ServerHandshakeOptions { return createTransportServerHandshakeOptions( HandshakeBytesSchema, @@ -75,5 +76,6 @@ export function createServerHandshakeOptions< return await validate(decoded, previousParsedMetadata, from); }, + expiry, ); } diff --git a/protobuf/server.ts b/protobuf/server.ts index 5fad2f0e..45d443a9 100644 --- a/protobuf/server.ts +++ b/protobuf/server.ts @@ -612,13 +612,29 @@ class ProtobufServer< }, }; + // metadata is live: handlers re-reading ctx.metadata observe values + // refreshed mid-stream, scoped to this stream's session so a hard reconnect + // can't surface another session's metadata. See the TypeBox router for the + // fuller explanation. + const transport = this.transport; + const currentMetadata = (): object => { + const session = transport.sessions.get(from); + if (session?.id === sessionId) { + return transport.sessionHandshakeMetadata.get(from) ?? sessionMetadata; + } + + return sessionMetadata; + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any, @typescript-eslint/no-unsafe-assignment const handlerContext: ProtobufHandlerContext = { ...serviceContext, state: serviceState, from, sessionId, - metadata: sessionMetadata, + get metadata() { + return currentMetadata(); + }, span, service, method, @@ -635,6 +651,9 @@ class ProtobufServer< signal: finishedController.signal, }; + // middleware runs once at stream start, before any re-handshake; the spread + // copies metadata by value (the live getter above is evaluated once), so it + // sees the metadata as of invocation. The handler ctx is the live view. // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment const middlewareContext: MiddlewareContext = { ...handlerContext, diff --git a/router/handshake.ts b/router/handshake.ts index 821f2af4..801690c0 100644 --- a/router/handshake.ts +++ b/router/handshake.ts @@ -58,6 +58,18 @@ export interface ServerHandshakeOptions< * returning parsed metadata. */ validate: ValidateHandshake; + + /** + * When the credential expires (or undefined if it never does). The server + * re-handshakes one `handshakeTimeoutMs` beforehand — re-validating fresh + * metadata and live-replacing the stored value — so the session never serves + * past expiry: a refresh lands first, or an unanswered re-handshake tears the + * session down by then. Re-evaluated on every (re)validation. + * + * Scheduling only — it does not gate requests, so reject already-expired + * credentials in {@link validate} or against the live `ctx.metadata`. + */ + expiry?: (parsedMetadata: ParsedMetadata) => Date | undefined; } export function createClientHandshakeOptions< @@ -75,6 +87,7 @@ export function createServerHandshakeOptions< >( schema: MetadataSchema, validate: ValidateHandshake, + expiry?: (parsedMetadata: ParsedMetadata) => Date | undefined, ): ServerHandshakeOptions { - return { schema, validate }; + return { schema, validate, expiry }; } diff --git a/router/server.ts b/router/server.ts index cda6e270..b7b642f7 100644 --- a/router/server.ts +++ b/router/server.ts @@ -637,6 +637,22 @@ class RiverServer< closeReadable(); } + // metadata is live: handlers re-reading ctx.metadata observe values + // refreshed mid-stream (see requestRehandshake). We scope the lookup to + // this stream's session so a handler never observes a different session's + // metadata after a hard reconnect; once the session changes or ends we fall + // back to the snapshot captured when the stream opened. Handlers that want a + // frozen value can destructure (e.g. `const { token } = ctx.metadata`). + const transport = this.transport; + const currentMetadata = (): ParsedMetadata => { + const session = transport.sessions.get(from); + if (session?.id === sessionId) { + return transport.sessionHandshakeMetadata.get(from) ?? sessionMetadata; + } + + return sessionMetadata; + }; + const handlerContextWithSpan: ProcedureHandlerContext< object, object, @@ -645,7 +661,9 @@ class RiverServer< ...serviceContext, from: from, sessionId, - metadata: sessionMetadata, + get metadata() { + return currentMetadata(); + }, span, cancel: (message?: string) => { const errRes = { @@ -665,6 +683,9 @@ class RiverServer< ...serviceContext, sessionId, from, + // middleware runs once at stream start, before any re-handshake, so it + // sees the metadata as of invocation; the handler ctx above is the live + // view metadata: sessionMetadata, span, deferCleanup, diff --git a/transport/client.ts b/transport/client.ts index e56ae027..4ab0d531 100644 --- a/transport/client.ts +++ b/transport/client.ts @@ -3,11 +3,13 @@ import { ClientHandshakeOptions } from '../router/handshake'; import { validationErrorToRiverErrors } from '../router/errors'; import { ControlMessageHandshakeResponseSchema, + ControlMessageRehandshakeRequestSchema, HandshakeErrorRetriableResponseCodes, OpaqueTransportMessage, TransportClientId, currentProtocolVersion, handshakeRequestMessage, + rehandshakeResponseMessage, } from './message'; import { ClientTransportOptions, @@ -75,6 +77,75 @@ export abstract class ClientTransport< this.handshakeExtensions = options; } + protected handleRehandshakeMessage(message: OpaqueTransportMessage): void { + if (!Value.Check(ControlMessageRehandshakeRequestSchema, message.payload)) { + this.log?.warn( + `ignoring malformed re-handshake request from ${message.from}`, + { clientId: this.clientId, connectedTo: message.from }, + ); + + return; + } + + void this.sendRehandshake(message.from); + } + + /** + * Re-constructs handshake metadata via the configured handshake extension and + * sends it back to the server so it can replace the metadata for this session. + * Triggered by a server {@link ControlMessageRehandshakeRequestSchema}. + */ + private async sendRehandshake(to: TransportClientId) { + if (!this.handshakeExtensions) { + this.log?.warn( + `got re-handshake request from ${to} but no handshake extensions are configured, ignoring`, + { clientId: this.clientId, connectedTo: to }, + ); + + return; + } + + const session = this.sessions.get(to); + if (!session || session.state !== SessionState.Connected) { + return; + } + + const loggingMetadata = session.loggingMetadata; + const sessionId = session.id; + + let metadata: unknown; + try { + metadata = await this.handshakeExtensions.construct(); + } catch (err) { + this.log?.error( + `failed to construct re-handshake metadata for ${to}: ${coerceErrorString( + err, + )}`, + loggingMetadata, + ); + + return; + } + + // bind the send to the session that asked to re-handshake: a hard reconnect + // during construct() makes the send throw rather than delivering stale + // metadata to the freshly handshaked session that replaced it + try { + const send = this.getSessionBoundSendFn(to, sessionId); + send(rehandshakeResponseMessage(metadata)); + } catch (err) { + const reason = coerceErrorString(err); + this.log?.error( + `failed to send re-handshake metadata to ${to}: ${reason}`, + loggingMetadata, + ); + this.protocolError({ + type: ProtocolError.MessageSendFailure, + message: reason, + }); + } + } + /** * Abstract method that creates a new {@link Connection} object. * @@ -330,6 +401,9 @@ export abstract class ClientTransport< onMessage: (msg) => { this.handleMsg(msg); }, + onRehandshake: (msg) => { + this.handleRehandshakeMessage(msg); + }, onInvalidMessage: (reason) => { this.log?.error(`invalid message: ${reason}`, { ...connectedSession.loggingMetadata, diff --git a/transport/message.ts b/transport/message.ts index 1033a52a..a55a3145 100644 --- a/transport/message.ts +++ b/transport/message.ts @@ -137,11 +137,38 @@ export const ControlMessageHandshakeResponseSchema = Type.Object({ ]), }); +/** + * Reserved stream id for the follow-up handshake (re-handshake) control + * messages, analogous to the reserved `heartbeat` stream id used for acks. + * Messages on this stream are consumed by the transport itself and never + * surface to the router. + */ +export const RehandshakeStreamId = 'rehandshake'; + +/** + * Sent by the server over a live connection to ask the client to re-handshake, + * i.e. re-construct its handshake metadata (e.g. fetch a fresh token). + */ +export const ControlMessageRehandshakeRequestSchema = Type.Object({ + type: Type.Literal('REHANDSHAKE_REQ'), +}); + +/** + * Sent by the client in response to a {@link ControlMessageRehandshakeRequestSchema}, + * carrying freshly constructed handshake metadata for the server to re-validate. + */ +export const ControlMessageRehandshakeResponseSchema = Type.Object({ + type: Type.Literal('REHANDSHAKE_RESP'), + metadata: Type.Optional(Type.Unknown()), +}); + export const ControlMessagePayloadSchema = Type.Union([ ControlMessageCloseSchema, ControlMessageAckSchema, ControlMessageHandshakeRequestSchema, ControlMessageHandshakeResponseSchema, + ControlMessageRehandshakeRequestSchema, + ControlMessageRehandshakeResponseSchema, ]); /** @@ -263,6 +290,29 @@ export function closeStreamMessage(streamId: string): PartialTransportMessage { }; } +export function rehandshakeRequestMessage(): PartialTransportMessage { + return { + streamId: RehandshakeStreamId, + controlFlags: 0, + payload: { + type: 'REHANDSHAKE_REQ' as const, + } satisfies Static, + }; +} + +export function rehandshakeResponseMessage( + metadata: unknown, +): PartialTransportMessage { + return { + streamId: RehandshakeStreamId, + controlFlags: 0, + payload: { + type: 'REHANDSHAKE_RESP' as const, + metadata, + } satisfies Static, + }; +} + export function cancelMessage( streamId: string, payload: ErrResult, diff --git a/transport/server.ts b/transport/server.ts index 8bb015b5..ef42d22d 100644 --- a/transport/server.ts +++ b/transport/server.ts @@ -3,6 +3,7 @@ import { ServerHandshakeOptions } from '../router/handshake'; import { validationErrorToRiverErrors } from '../router/errors'; import { ControlMessageHandshakeRequestSchema, + ControlMessageRehandshakeResponseSchema, HandshakeErrorCustomHandlerFatalResponseCodes, HandshakeErrorResponseCodes, OpaqueTransportMessage, @@ -94,6 +95,176 @@ export abstract class ServerTransport< super.deleteSession(session, options); } + /** + * Asks the connected client to re-handshake — re-construct and resend its + * handshake metadata (e.g. a refreshed token). Returns false if there is no + * live connection to send the request over. + * + * On a successful re-handshake the stored metadata is replaced and observed by + * subsequent procedure calls. If the client does not respond with metadata that + * re-validates before the response deadline — the shorter of + * {@link SessionOptions.handshakeTimeoutMs} and the credential's remaining + * lifetime — the session is torn down. + */ + requestRehandshake(to: TransportClientId): boolean { + const session = this.sessions.get(to); + if (!session || session.state !== SessionState.Connected) { + return false; + } + + return session.requestRehandshakeNow(); + } + + /** + * Stores freshly validated handshake metadata for a client and hands the + * credential's expiry to the session, which schedules and runs the next + * re-handshake itself. Called on every successful (re)handshake, so the schedule + * perpetuates itself and survives transparent reconnects. + */ + private storeSessionMetadata( + session: ServerSession, + parsed: ParsedMetadata, + ) { + this.sessionHandshakeMetadata.set(session.to, parsed); + + if (session.state === SessionState.Connected) { + session.scheduleRehandshake( + this.handshakeExtensions?.expiry?.(parsed)?.getTime(), + ); + } + } + + protected handleRehandshakeMessage(message: OpaqueTransportMessage): void { + if ( + !Value.Check(ControlMessageRehandshakeResponseSchema, message.payload) + ) { + // a frame on the reserved re-handshake stream that isn't a valid response + // is a protocol violation by the authenticated peer; fail the refresh + const session = this.sessions.get(message.from); + if (session) { + this.teardownForFailedRehandshake( + session, + 'received malformed re-handshake control message', + ); + } + + return; + } + + void this.onRehandshakeResponse(message.from, message.payload.metadata); + } + + /** + * Re-validates handshake metadata sent by the client during a re-handshake and + * replaces the stored metadata on success. Any failure (malformed metadata, + * rejection, or a thrown validator) tears the session down. + */ + private async onRehandshakeResponse( + from: TransportClientId, + metadata: unknown, + ) { + const handshakeExtensions = this.handshakeExtensions; + if (!handshakeExtensions) { + return; + } + + const session = this.sessions.get(from); + if (!session) { + return; + } + + if (!Value.Check(handshakeExtensions.schema, metadata)) { + this.teardownForFailedRehandshake( + session, + 'received malformed handshake metadata during re-handshake', + ); + + return; + } + + const previousParsedMetadata = this.sessionHandshakeMetadata.get(from); + + let parsedMetadataOrFailureCode; + try { + parsedMetadataOrFailureCode = await handshakeExtensions.validate( + metadata, + previousParsedMetadata, + from, + ); + } catch (err) { + // teardownForFailedRehandshake no-ops if this session was already replaced + // (e.g. a transparent reconnect) while we awaited validation + this.teardownForFailedRehandshake( + session, + `handshake validation threw during re-handshake: ${coerceErrorString( + err, + )}`, + ); + + return; + } + + if ( + Value.Check( + HandshakeErrorCustomHandlerFatalResponseCodes, + parsedMetadataOrFailureCode, + ) + ) { + this.teardownForFailedRehandshake( + session, + 're-handshake metadata rejected by handshake handler', + ); + + return; + } + + // a reconnect/teardown during validation may have replaced this exact session + // (a transparent reconnect keeps the same id); only store and reschedule if + // it's still the one we validated against, so we don't clobber fresher metadata + if (this.sessions.get(from) !== session) { + return; + } + + this.storeSessionMetadata( + session, + parsedMetadataOrFailureCode as ParsedMetadata, + ); + + this.log?.info(`re-handshake from ${from} ok`, { + ...session.loggingMetadata, + connectedTo: from, + }); + } + + /** + * Tears down a session whose re-handshake failed (rejected, malformed, timed + * out, or a thrown validator). No-ops if {@link session} is no longer the live + * session for its peer — a transparent reconnect keeps the same id, so callers + * reaching here after an async gap can't accidentally close the session that + * replaced it. + */ + private teardownForFailedRehandshake( + session: ServerSession, + reason: string, + ) { + if (this.sessions.get(session.to) !== session) { + return; + } + + const to = session.to; + this.log?.warn(`tearing down session to ${to}: ${reason}`, { + ...session.loggingMetadata, + connectedTo: to, + }); + + this.protocolError({ + type: ProtocolError.HandshakeFailed, + code: 'REJECTED_BY_CUSTOM_HANDLER', + message: reason, + }); + this.deleteSession(session, { unhealthy: true }); + } + protected handleConnection(conn: ConnType) { if (this.getStatus() !== 'open') return; @@ -547,6 +718,15 @@ export abstract class ServerTransport< onMessage: (msg) => { this.handleMsg(msg); }, + onRehandshake: (msg) => { + this.handleRehandshakeMessage(msg); + }, + onRehandshakeTimeout: () => { + this.teardownForFailedRehandshake( + connectedSession, + 're-handshake timed out', + ); + }, onInvalidMessage: (reason) => { this.log?.error(`invalid message: ${reason}`, { ...connectedSession.loggingMetadata, @@ -580,7 +760,7 @@ export abstract class ServerTransport< return; } - this.sessionHandshakeMetadata.set(connectedSession.to, parsedMetadata); + this.storeSessionMetadata(connectedSession, parsedMetadata); if (oldSession) { this.updateSession(connectedSession); } else { diff --git a/transport/sessionStateMachine/SessionConnected.ts b/transport/sessionStateMachine/SessionConnected.ts index e4875532..c84e01ff 100644 --- a/transport/sessionStateMachine/SessionConnected.ts +++ b/transport/sessionStateMachine/SessionConnected.ts @@ -5,6 +5,8 @@ import { EncodedTransportMessage, OpaqueTransportMessage, PartialTransportMessage, + RehandshakeStreamId, + rehandshakeRequestMessage, isAck, } from '../message'; import { @@ -21,6 +23,17 @@ export interface SessionConnectedListeners extends IdentifiedSessionListeners { onConnectionErrored: (err: unknown) => void; onConnectionClosed: () => void; onMessage: (msg: OpaqueTransportMessage) => void; + /** + * A frame arrived on the reserved re-handshake stream. The transport consumes + * it to drive the follow-up handshake rather than surfacing it to the router. + */ + onRehandshake: (msg: OpaqueTransportMessage) => void; + /** + * A scheduled re-handshake went unanswered within its deadline. Only the server + * arms this (via {@link SessionConnected.scheduleRehandshake}); it tears the + * session down rather than keep serving a credential past its expiry. + */ + onRehandshakeTimeout?: () => void; onInvalidMessage: (reason: string) => void; } @@ -44,6 +57,8 @@ export class SessionConnected< private heartbeatHandle?: ReturnType | undefined; private heartbeatMissTimeout?: ReturnType | undefined; private isActivelyHeartbeating = false; + private rehandshakeTimer?: ReturnType | undefined; + private credentialExpiry?: number | undefined; updateBookkeeping(ack: number, seq: number) { this.sendBuffer = this.sendBuffer.filter((unacked) => unacked.seq >= ack); @@ -180,6 +195,77 @@ export class SessionConnected< this.send(heartbeat); } + /** + * Schedules the next proactive re-handshake from the credential's expiry. The + * server calls this after each (re)validation, mirroring {@link startActiveHeartbeat}: + * once armed the session drives the exchange itself — one handshake window before + * expiry it sends a re-handshake request and waits for the response, firing + * {@link SessionConnectedListeners.onRehandshakeTimeout} if none arrives in time. + * Passing `undefined` (a credential that never expires) cancels any schedule. + */ + scheduleRehandshake(expiry: number | undefined) { + this.clearRehandshakeTimer(); + this.credentialExpiry = expiry; + if (expiry === undefined) { + return; + } + + // re-handshake one window before expiry so the exchange resolves (a refresh + // lands, or the deadline below tears the session down) by the time it expires + const delayMs = expiry - this.options.handshakeTimeoutMs - Date.now(); + this.rehandshakeTimer = setTimeout( + () => { + this.rehandshakeTimer = undefined; + this.sendRehandshakeRequest(); + }, + Math.max(0, delayMs), + ); + } + + /** + * Sends a re-handshake request immediately and arms the response deadline, + * bypassing the expiry schedule. Returns false if the request couldn't be sent. + */ + requestRehandshakeNow(): boolean { + this.clearRehandshakeTimer(); + + return this.sendRehandshakeRequest(); + } + + private sendRehandshakeRequest(): boolean { + const res = this.send(rehandshakeRequestMessage()); + if (!res.ok) { + // the send failure already tore the session down via onMessageSendFailure + return false; + } + + // clamp the deadline to the credential's remaining life, so one validated with + // little time left is still torn down by expiry rather than a full window later + const deadlineMs = + this.credentialExpiry !== undefined + ? Math.min( + this.options.handshakeTimeoutMs, + this.credentialExpiry - Date.now(), + ) + : this.options.handshakeTimeoutMs; + this.rehandshakeTimer = setTimeout( + () => { + this.rehandshakeTimer = undefined; + this.listeners.onRehandshakeTimeout?.(); + }, + Math.max(0, deadlineMs), + ); + + return true; + } + + clearRehandshakeTimer() { + if (this.rehandshakeTimer) { + clearTimeout(this.rehandshakeTimer); + this.rehandshakeTimer = undefined; + } + } + onMessageData = (msg: Uint8Array) => { const parsedMsgRes = this.codec.fromBuffer(msg); if (!parsedMsgRes.ok) { @@ -242,6 +328,14 @@ export class SessionConnected< // dispatch directly if its not an explicit ack if (!isAck(parsedMsg.controlFlags)) { + // re-handshake frames ride a reserved stream and are consumed by the + // transport, never surfaced to the router (same as the acks handled below) + if (parsedMsg.streamId === RehandshakeStreamId) { + this.listeners.onRehandshake(parsedMsg); + + return; + } + this.listeners.onMessage(parsedMsg); return; @@ -275,6 +369,8 @@ export class SessionConnected< clearTimeout(this.heartbeatMissTimeout); this.heartbeatMissTimeout = undefined; } + + this.clearRehandshakeTimer(); } _handleClose(): void { diff --git a/transport/sessionStateMachine/stateMachine.test.ts b/transport/sessionStateMachine/stateMachine.test.ts index e2e55206..b71f754b 100644 --- a/transport/sessionStateMachine/stateMachine.test.ts +++ b/transport/sessionStateMachine/stateMachine.test.ts @@ -136,6 +136,7 @@ function createSessionHandshakingListeners(): SessionHandshakingListeners { function createSessionConnectedListeners(): SessionConnectedListeners { return { onMessage: vi.fn(), + onRehandshake: vi.fn(), onConnectionClosed: vi.fn(), onConnectionErrored: vi.fn(), onInvalidMessage: vi.fn(), diff --git a/transport/transport.ts b/transport/transport.ts index 85e80083..7da4a119 100644 --- a/transport/transport.ts +++ b/transport/transport.ts @@ -131,6 +131,7 @@ export abstract class Transport { */ protected handleMsg(message: OpaqueTransportMessage) { if (this.getStatus() !== 'open') return; + this.eventDispatcher.dispatchEvent('message', message); }