diff --git a/packages/core/realtime-js/src/RealtimeChannel.ts b/packages/core/realtime-js/src/RealtimeChannel.ts index 4bf49e686..1e48eacd4 100644 --- a/packages/core/realtime-js/src/RealtimeChannel.ts +++ b/packages/core/realtime-js/src/RealtimeChannel.ts @@ -308,7 +308,10 @@ export default class RealtimeChannel { this.joinPush .receive('ok', async ({ postgres_changes }: PostgresChangesFilters) => { - this.socket.setAuth() + // Only refresh auth if using callback-based tokens + if (!this.socket._isManualToken()) { + this.socket.setAuth() + } if (postgres_changes === undefined) { callback?.(REALTIME_SUBSCRIBE_STATES.SUBSCRIBED) return @@ -531,7 +534,7 @@ export default class RealtimeChannel { 'channel', `resubscribe to ${this.topic} due to change in presence callbacks on joined channel` ) - this.unsubscribe().then(() => this.subscribe()) + this.unsubscribe().then(async () => await this.subscribe()) } return this._on(type, filter, callback) } diff --git a/packages/core/realtime-js/src/RealtimeClient.ts b/packages/core/realtime-js/src/RealtimeClient.ts index 19db54fa6..91880472b 100755 --- a/packages/core/realtime-js/src/RealtimeClient.ts +++ b/packages/core/realtime-js/src/RealtimeClient.ts @@ -102,6 +102,7 @@ const WORKER_SCRIPT = ` export default class RealtimeClient { accessTokenValue: string | null = null apiKey: string | null = null + private _manuallySetToken: boolean = false channels: RealtimeChannel[] = new Array() endPoint: string = '' httpEndpoint: string = '' @@ -416,7 +417,18 @@ export default class RealtimeClient { * * On callback used, it will set the value of the token internal to the client. * + * When a token is explicitly provided, it will be preserved across channel operations + * (including removeChannel and resubscribe). The `accessToken` callback will not be + * invoked until `setAuth()` is called without arguments. + * * @param token A JWT string to override the token set on the client. + * + * @example + * // Use a manual token (preserved across resubscribes, ignores accessToken callback) + * client.realtime.setAuth('my-custom-jwt') + * + * // Switch back to using the accessToken callback + * client.realtime.setAuth() */ async setAuth(token: string | null = null): Promise { this._authPromise = this._performAuth(token) @@ -426,6 +438,16 @@ export default class RealtimeClient { this._authPromise = null } } + + /** + * Returns true if the current access token was explicitly set via setAuth(token), + * false if it was obtained via the accessToken callback. + * @internal + */ + _isManualToken(): boolean { + return this._manuallySetToken + } + /** * Sends a heartbeat message if the socket is connected. */ @@ -779,16 +801,33 @@ export default class RealtimeClient { */ private async _performAuth(token: string | null = null): Promise { let tokenToSend: string | null + let isManualToken = false if (token) { tokenToSend = token + // Track if this is a manually-provided token + isManualToken = true } else if (this.accessToken) { - // Always call the accessToken callback to get fresh token - tokenToSend = await this.accessToken() + // Call the accessToken callback to get fresh token + try { + tokenToSend = await this.accessToken() + } catch (e) { + this.log('error', 'Error fetching access token from callback', e) + // Fall back to cached value if callback fails + tokenToSend = this.accessTokenValue + } } else { tokenToSend = this.accessTokenValue } + // Track whether this token was manually set or fetched via callback + if (isManualToken) { + this._manuallySetToken = true + } else if (this.accessToken) { + // If we used the callback, clear the manual flag + this._manuallySetToken = false + } + if (this.accessTokenValue != tokenToSend) { this.accessTokenValue = tokenToSend this.channels.forEach((channel) => { @@ -823,9 +862,12 @@ export default class RealtimeClient { * @internal */ private _setAuthSafely(context = 'general'): void { - this.setAuth().catch((e) => { - this.log('error', `error setting auth in ${context}`, e) - }) + // Only refresh auth if using callback-based tokens + if (!this._isManualToken()) { + this.setAuth().catch((e) => { + this.log('error', `Error setting auth in ${context}`, e) + }) + } } /** diff --git a/packages/core/realtime-js/test/RealtimeChannel.lifecycle.test.ts b/packages/core/realtime-js/test/RealtimeChannel.lifecycle.test.ts index 376f931bf..404011151 100644 --- a/packages/core/realtime-js/test/RealtimeChannel.lifecycle.test.ts +++ b/packages/core/realtime-js/test/RealtimeChannel.lifecycle.test.ts @@ -229,10 +229,10 @@ describe('Channel Lifecycle Management', () => { assert.equal(channel.state, CHANNEL_STATES.joining) }) - test('updates join push payload access token', () => { + test('updates join push payload access token', async () => { testSetup.socket.accessTokenValue = 'token123' - channel.subscribe() + await channel.subscribe() assert.deepEqual(channel.joinPush.payload, { access_token: 'token123', @@ -257,7 +257,7 @@ describe('Channel Lifecycle Management', () => { }) const channel = testSocket.channel('topic') - channel.subscribe() + await channel.subscribe() await new Promise((resolve) => setTimeout(resolve, 50)) assert.equal(channel.socket.accessTokenValue, tokens[0]) @@ -265,7 +265,7 @@ describe('Channel Lifecycle Management', () => { // Wait for disconnect to complete (including fallback timer) await new Promise((resolve) => setTimeout(resolve, 150)) - channel.subscribe() + await channel.subscribe() await new Promise((resolve) => setTimeout(resolve, 50)) assert.equal(channel.socket.accessTokenValue, tokens[1]) }) diff --git a/packages/core/realtime-js/test/RealtimeClient.auth.resubscribe.test.ts b/packages/core/realtime-js/test/RealtimeClient.auth.resubscribe.test.ts new file mode 100644 index 000000000..9aed0ecac --- /dev/null +++ b/packages/core/realtime-js/test/RealtimeClient.auth.resubscribe.test.ts @@ -0,0 +1,190 @@ +import assert from 'assert' +import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest' +import { testBuilders, EnhancedTestSetup } from './helpers/setup' +import { utils } from './helpers/auth' +import { CHANNEL_STATES } from '../src/lib/constants' + +let testSetup: EnhancedTestSetup + +beforeEach(() => { + testSetup = testBuilders.standardClient() +}) + +afterEach(() => { + testSetup.cleanup() + testSetup.socket.removeAllChannels() +}) + +describe('Custom JWT token preservation', () => { + test('preserves access token when resubscribing after removeChannel', async () => { + // Test scenario: + // 1. Set custom JWT via setAuth (not using accessToken callback) + // 2. Subscribe to private channel + // 3. removeChannel + // 4. Create new channel with same topic and subscribe + + const customToken = utils.generateJWT('1h') + + // Step 1: Set auth with custom token (mimics user's setup) + await testSetup.socket.setAuth(customToken) + + // Verify token was set + assert.strictEqual(testSetup.socket.accessTokenValue, customToken) + + // Step 2: Create and subscribe to private channel (first time) + const channel1 = testSetup.socket.channel('conversation:dc3fb8c1-ceef-4c00-9f92-e496acd03593', { + config: { private: true }, + }) + + // Spy on the push to verify join payload + const pushSpy = vi.spyOn(testSetup.socket, 'push') + + // Simulate successful subscription + channel1.state = CHANNEL_STATES.closed // Start from closed + await channel1.subscribe() + + // Verify first join includes access_token + const firstJoinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join') + expect(firstJoinCall).toBeDefined() + expect(firstJoinCall![0].payload).toHaveProperty('access_token', customToken) + + // Step 3: Remove channel (mimics user cleanup) + await testSetup.socket.removeChannel(channel1) + + // Verify channel was removed + expect(testSetup.socket.getChannels()).not.toContain(channel1) + + // Step 4: Create NEW channel with SAME topic and subscribe + pushSpy.mockClear() + const channel2 = testSetup.socket.channel('conversation:dc3fb8c1-ceef-4c00-9f92-e496acd03593', { + config: { private: true }, + }) + + // This should be a different channel instance + expect(channel2).not.toBe(channel1) + + // Subscribe to the new channel + channel2.state = CHANNEL_STATES.closed + await channel2.subscribe() + + // Verify second join also includes access token + const secondJoinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join') + + expect(secondJoinCall).toBeDefined() + expect(secondJoinCall![0].payload).toHaveProperty('access_token', customToken) + }) + + test('supports accessToken callback for token rotation', async () => { + // Verify that callback-based token fetching works correctly + const customToken = utils.generateJWT('1h') + let callCount = 0 + + const clientWithCallback = testBuilders.standardClient({ + accessToken: async () => { + callCount++ + return customToken + }, + }) + + // Set initial auth + await clientWithCallback.socket.setAuth() + + // Create and subscribe to first channel + const channel1 = clientWithCallback.socket.channel('conversation:test', { + config: { private: true }, + }) + + const pushSpy = vi.spyOn(clientWithCallback.socket, 'push') + channel1.state = CHANNEL_STATES.closed + await channel1.subscribe() + + const firstJoin = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join') + expect(firstJoin![0].payload).toHaveProperty('access_token', customToken) + + // Remove and recreate + await clientWithCallback.socket.removeChannel(channel1) + pushSpy.mockClear() + + const channel2 = clientWithCallback.socket.channel('conversation:test', { + config: { private: true }, + }) + + channel2.state = CHANNEL_STATES.closed + await channel2.subscribe() + + const secondJoin = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join') + + // Callback should provide token for both subscriptions + expect(secondJoin![0].payload).toHaveProperty('access_token', customToken) + + clientWithCallback.cleanup() + }) + + test('preserves token when subscribing to different topics', async () => { + const customToken = utils.generateJWT('1h') + await testSetup.socket.setAuth(customToken) + + // Subscribe to first topic + const channel1 = testSetup.socket.channel('topic1', { config: { private: true } }) + channel1.state = CHANNEL_STATES.closed + await channel1.subscribe() + + await testSetup.socket.removeChannel(channel1) + + // Subscribe to DIFFERENT topic + const pushSpy = vi.spyOn(testSetup.socket, 'push') + const channel2 = testSetup.socket.channel('topic2', { config: { private: true } }) + channel2.state = CHANNEL_STATES.closed + await channel2.subscribe() + + const joinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join') + expect(joinCall![0].payload).toHaveProperty('access_token', customToken) + }) + + test('handles accessToken callback errors gracefully during subscribe', async () => { + const errorMessage = 'Token fetch failed during subscribe' + let callCount = 0 + const tokens = ['initial-token', null] // Second call will throw + + const accessToken = vi.fn(() => { + if (callCount++ === 0) { + return Promise.resolve(tokens[0]) + } + return Promise.reject(new Error(errorMessage)) + }) + + const logSpy = vi.fn() + + const client = testBuilders.standardClient({ + accessToken, + logger: logSpy, + }) + + // First subscribe should work + await client.socket.setAuth() + const channel1 = client.socket.channel('test', { config: { private: true } }) + channel1.state = CHANNEL_STATES.closed + await channel1.subscribe() + + expect(client.socket.accessTokenValue).toBe(tokens[0]) + + // Remove and resubscribe - callback will fail but should fall back + await client.socket.removeChannel(channel1) + + const channel2 = client.socket.channel('test', { config: { private: true } }) + channel2.state = CHANNEL_STATES.closed + await channel2.subscribe() + + // Verify error was logged + expect(logSpy).toHaveBeenCalledWith( + 'error', + 'Error fetching access token from callback', + expect.any(Error) + ) + + // Verify subscription still succeeded with cached token + expect(client.socket.accessTokenValue).toBe(tokens[0]) + + client.cleanup() + }) +}) diff --git a/packages/core/realtime-js/test/RealtimeClient.auth.test.ts b/packages/core/realtime-js/test/RealtimeClient.auth.test.ts index 60f140fdb..8aa6f22d5 100644 --- a/packages/core/realtime-js/test/RealtimeClient.auth.test.ts +++ b/packages/core/realtime-js/test/RealtimeClient.auth.test.ts @@ -140,8 +140,12 @@ describe('auth during connection states', () => { await new Promise((resolve) => setTimeout(() => resolve(undefined), 100)) - // Verify that the error was logged - expect(logSpy).toHaveBeenCalledWith('error', 'error setting auth in connect', expect.any(Error)) + // Verify that the error was logged with more specific message + expect(logSpy).toHaveBeenCalledWith( + 'error', + 'Error fetching access token from callback', + expect.any(Error) + ) // Verify that the connection was still established despite the error assert.ok(socketWithError.conn, 'connection should still exist') @@ -199,7 +203,7 @@ describe('auth during connection states', () => { expect(socket.accessTokenValue).toBe(tokens[0]) // Call the callback and wait for async operations to complete - await socket.reconnectTimer.callback() + await socket.reconnectTimer?.callback() await new Promise((resolve) => setTimeout(resolve, 100)) expect(socket.accessTokenValue).toBe(tokens[1]) expect(accessToken).toHaveBeenCalledTimes(2) diff --git a/packages/core/realtime-js/test/RealtimeClient.channels.test.ts b/packages/core/realtime-js/test/RealtimeClient.channels.test.ts index 6bac3dc00..8cfb4f437 100644 --- a/packages/core/realtime-js/test/RealtimeClient.channels.test.ts +++ b/packages/core/realtime-js/test/RealtimeClient.channels.test.ts @@ -104,7 +104,8 @@ describe('channel', () => { const connectStub = vi.spyOn(testSetup.socket, 'connect') const disconnectStub = vi.spyOn(testSetup.socket, 'disconnect') - channel = testSetup.socket.channel('topic').subscribe() + channel = testSetup.socket.channel('topic') + await channel.subscribe() assert.equal(testSetup.socket.getChannels().length, 1) expect(connectStub).toHaveBeenCalled() @@ -118,11 +119,11 @@ describe('channel', () => { test('does not remove other channels when removing one', async () => { const connectStub = vi.spyOn(testSetup.socket, 'connect') const disconnectStub = vi.spyOn(testSetup.socket, 'disconnect') - const channel1 = testSetup.socket.channel('chan1').subscribe() - const channel2 = testSetup.socket.channel('chan2').subscribe() + const channel1 = testSetup.socket.channel('chan1') + const channel2 = testSetup.socket.channel('chan2') - channel1.subscribe() - channel2.subscribe() + await channel1.subscribe() + await channel2.subscribe() assert.equal(testSetup.socket.getChannels().length, 2) expect(connectStub).toHaveBeenCalled()