Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions packages/core/realtime-js/src/RealtimeChannel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ export default class RealtimeChannel {

this.joinPush
.receive('ok', async ({ postgres_changes }: PostgresChangesFilters) => {
this.socket.setAuth()
// Only refresh auth if using callback-based tokens
if (!this.socket._isManualToken()) {
this.socket.setAuth()
}
if (postgres_changes === undefined) {
callback?.(REALTIME_SUBSCRIBE_STATES.SUBSCRIBED)
return
Expand Down Expand Up @@ -531,7 +534,7 @@ export default class RealtimeChannel {
'channel',
`resubscribe to ${this.topic} due to change in presence callbacks on joined channel`
)
this.unsubscribe().then(() => this.subscribe())
this.unsubscribe().then(async () => await this.subscribe())
}
return this._on(type, filter, callback)
}
Expand Down
52 changes: 47 additions & 5 deletions packages/core/realtime-js/src/RealtimeClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ const WORKER_SCRIPT = `
export default class RealtimeClient {
accessTokenValue: string | null = null
apiKey: string | null = null
private _manuallySetToken: boolean = false
channels: RealtimeChannel[] = new Array()
endPoint: string = ''
httpEndpoint: string = ''
Expand Down Expand Up @@ -416,7 +417,18 @@ export default class RealtimeClient {
*
* On callback used, it will set the value of the token internal to the client.
*
* When a token is explicitly provided, it will be preserved across channel operations
* (including removeChannel and resubscribe). The `accessToken` callback will not be
* invoked until `setAuth()` is called without arguments.
*
* @param token A JWT string to override the token set on the client.
*
* @example
* // Use a manual token (preserved across resubscribes, ignores accessToken callback)
* client.realtime.setAuth('my-custom-jwt')
*
* // Switch back to using the accessToken callback
* client.realtime.setAuth()
*/
async setAuth(token: string | null = null): Promise<void> {
this._authPromise = this._performAuth(token)
Expand All @@ -426,6 +438,16 @@ export default class RealtimeClient {
this._authPromise = null
}
}

/**
* Returns true if the current access token was explicitly set via setAuth(token),
* false if it was obtained via the accessToken callback.
* @internal
*/
_isManualToken(): boolean {
return this._manuallySetToken
}

/**
* Sends a heartbeat message if the socket is connected.
*/
Expand Down Expand Up @@ -779,16 +801,33 @@ export default class RealtimeClient {
*/
private async _performAuth(token: string | null = null): Promise<void> {
let tokenToSend: string | null
let isManualToken = false

if (token) {
tokenToSend = token
// Track if this is a manually-provided token
isManualToken = true
} else if (this.accessToken) {
// Always call the accessToken callback to get fresh token
tokenToSend = await this.accessToken()
// Call the accessToken callback to get fresh token
try {
tokenToSend = await this.accessToken()
} catch (e) {
this.log('error', 'Error fetching access token from callback', e)
// Fall back to cached value if callback fails
tokenToSend = this.accessTokenValue
}
} else {
tokenToSend = this.accessTokenValue
}

// Track whether this token was manually set or fetched via callback
if (isManualToken) {
this._manuallySetToken = true
} else if (this.accessToken) {
// If we used the callback, clear the manual flag
this._manuallySetToken = false
}

if (this.accessTokenValue != tokenToSend) {
this.accessTokenValue = tokenToSend
this.channels.forEach((channel) => {
Expand Down Expand Up @@ -823,9 +862,12 @@ export default class RealtimeClient {
* @internal
*/
private _setAuthSafely(context = 'general'): void {
this.setAuth().catch((e) => {
this.log('error', `error setting auth in ${context}`, e)
})
// Only refresh auth if using callback-based tokens
if (!this._isManualToken()) {
this.setAuth().catch((e) => {
this.log('error', `Error setting auth in ${context}`, e)
})
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,10 @@ describe('Channel Lifecycle Management', () => {
assert.equal(channel.state, CHANNEL_STATES.joining)
})

test('updates join push payload access token', () => {
test('updates join push payload access token', async () => {
testSetup.socket.accessTokenValue = 'token123'

channel.subscribe()
await channel.subscribe()

assert.deepEqual(channel.joinPush.payload, {
access_token: 'token123',
Expand All @@ -257,15 +257,15 @@ describe('Channel Lifecycle Management', () => {
})
const channel = testSocket.channel('topic')

channel.subscribe()
await channel.subscribe()
await new Promise((resolve) => setTimeout(resolve, 50))
assert.equal(channel.socket.accessTokenValue, tokens[0])

testSocket.disconnect()
// Wait for disconnect to complete (including fallback timer)
await new Promise((resolve) => setTimeout(resolve, 150))

channel.subscribe()
await channel.subscribe()
await new Promise((resolve) => setTimeout(resolve, 50))
assert.equal(channel.socket.accessTokenValue, tokens[1])
})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import assert from 'assert'
import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'
import { testBuilders, EnhancedTestSetup } from './helpers/setup'
import { utils } from './helpers/auth'
import { CHANNEL_STATES } from '../src/lib/constants'

let testSetup: EnhancedTestSetup

beforeEach(() => {
testSetup = testBuilders.standardClient()
})

afterEach(() => {
testSetup.cleanup()
testSetup.socket.removeAllChannels()
})

describe('Custom JWT token preservation', () => {
test('preserves access token when resubscribing after removeChannel', async () => {
// Test scenario:
// 1. Set custom JWT via setAuth (not using accessToken callback)
// 2. Subscribe to private channel
// 3. removeChannel
// 4. Create new channel with same topic and subscribe

const customToken = utils.generateJWT('1h')

// Step 1: Set auth with custom token (mimics user's setup)
await testSetup.socket.setAuth(customToken)

// Verify token was set
assert.strictEqual(testSetup.socket.accessTokenValue, customToken)

// Step 2: Create and subscribe to private channel (first time)
const channel1 = testSetup.socket.channel('conversation:dc3fb8c1-ceef-4c00-9f92-e496acd03593', {
config: { private: true },
})

// Spy on the push to verify join payload
const pushSpy = vi.spyOn(testSetup.socket, 'push')

// Simulate successful subscription
channel1.state = CHANNEL_STATES.closed // Start from closed
await channel1.subscribe()

// Verify first join includes access_token
const firstJoinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
expect(firstJoinCall).toBeDefined()
expect(firstJoinCall![0].payload).toHaveProperty('access_token', customToken)

// Step 3: Remove channel (mimics user cleanup)
await testSetup.socket.removeChannel(channel1)

// Verify channel was removed
expect(testSetup.socket.getChannels()).not.toContain(channel1)

// Step 4: Create NEW channel with SAME topic and subscribe
pushSpy.mockClear()
const channel2 = testSetup.socket.channel('conversation:dc3fb8c1-ceef-4c00-9f92-e496acd03593', {
config: { private: true },
})

// This should be a different channel instance
expect(channel2).not.toBe(channel1)

// Subscribe to the new channel
channel2.state = CHANNEL_STATES.closed
await channel2.subscribe()

// Verify second join also includes access token
const secondJoinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')

expect(secondJoinCall).toBeDefined()
expect(secondJoinCall![0].payload).toHaveProperty('access_token', customToken)
})

test('supports accessToken callback for token rotation', async () => {
// Verify that callback-based token fetching works correctly
const customToken = utils.generateJWT('1h')
let callCount = 0

const clientWithCallback = testBuilders.standardClient({
accessToken: async () => {
callCount++
return customToken
},
})

// Set initial auth
await clientWithCallback.socket.setAuth()

// Create and subscribe to first channel
const channel1 = clientWithCallback.socket.channel('conversation:test', {
config: { private: true },
})

const pushSpy = vi.spyOn(clientWithCallback.socket, 'push')
channel1.state = CHANNEL_STATES.closed
await channel1.subscribe()

const firstJoin = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
expect(firstJoin![0].payload).toHaveProperty('access_token', customToken)

// Remove and recreate
await clientWithCallback.socket.removeChannel(channel1)
pushSpy.mockClear()

const channel2 = clientWithCallback.socket.channel('conversation:test', {
config: { private: true },
})

channel2.state = CHANNEL_STATES.closed
await channel2.subscribe()

const secondJoin = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')

// Callback should provide token for both subscriptions
expect(secondJoin![0].payload).toHaveProperty('access_token', customToken)

clientWithCallback.cleanup()
})

test('preserves token when subscribing to different topics', async () => {
const customToken = utils.generateJWT('1h')
await testSetup.socket.setAuth(customToken)

// Subscribe to first topic
const channel1 = testSetup.socket.channel('topic1', { config: { private: true } })
channel1.state = CHANNEL_STATES.closed
await channel1.subscribe()

await testSetup.socket.removeChannel(channel1)

// Subscribe to DIFFERENT topic
const pushSpy = vi.spyOn(testSetup.socket, 'push')
const channel2 = testSetup.socket.channel('topic2', { config: { private: true } })
channel2.state = CHANNEL_STATES.closed
await channel2.subscribe()

const joinCall = pushSpy.mock.calls.find((call) => call[0]?.event === 'phx_join')
expect(joinCall![0].payload).toHaveProperty('access_token', customToken)
})

test('handles accessToken callback errors gracefully during subscribe', async () => {
const errorMessage = 'Token fetch failed during subscribe'
let callCount = 0
const tokens = ['initial-token', null] // Second call will throw

const accessToken = vi.fn(() => {
if (callCount++ === 0) {
return Promise.resolve(tokens[0])
}
return Promise.reject(new Error(errorMessage))
})

const logSpy = vi.fn()

const client = testBuilders.standardClient({
accessToken,
logger: logSpy,
})

// First subscribe should work
await client.socket.setAuth()
const channel1 = client.socket.channel('test', { config: { private: true } })
channel1.state = CHANNEL_STATES.closed
await channel1.subscribe()

expect(client.socket.accessTokenValue).toBe(tokens[0])

// Remove and resubscribe - callback will fail but should fall back
await client.socket.removeChannel(channel1)

const channel2 = client.socket.channel('test', { config: { private: true } })
channel2.state = CHANNEL_STATES.closed
await channel2.subscribe()

// Verify error was logged
expect(logSpy).toHaveBeenCalledWith(
'error',
'Error fetching access token from callback',
expect.any(Error)
)

// Verify subscription still succeeded with cached token
expect(client.socket.accessTokenValue).toBe(tokens[0])

client.cleanup()
})
})
10 changes: 7 additions & 3 deletions packages/core/realtime-js/test/RealtimeClient.auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,12 @@ describe('auth during connection states', () => {

await new Promise((resolve) => setTimeout(() => resolve(undefined), 100))

// Verify that the error was logged
expect(logSpy).toHaveBeenCalledWith('error', 'error setting auth in connect', expect.any(Error))
// Verify that the error was logged with more specific message
expect(logSpy).toHaveBeenCalledWith(
'error',
'Error fetching access token from callback',
expect.any(Error)
)

// Verify that the connection was still established despite the error
assert.ok(socketWithError.conn, 'connection should still exist')
Expand Down Expand Up @@ -199,7 +203,7 @@ describe('auth during connection states', () => {
expect(socket.accessTokenValue).toBe(tokens[0])

// Call the callback and wait for async operations to complete
await socket.reconnectTimer.callback()
await socket.reconnectTimer?.callback()
await new Promise((resolve) => setTimeout(resolve, 100))
expect(socket.accessTokenValue).toBe(tokens[1])
expect(accessToken).toHaveBeenCalledTimes(2)
Expand Down
11 changes: 6 additions & 5 deletions packages/core/realtime-js/test/RealtimeClient.channels.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ describe('channel', () => {
const connectStub = vi.spyOn(testSetup.socket, 'connect')
const disconnectStub = vi.spyOn(testSetup.socket, 'disconnect')

channel = testSetup.socket.channel('topic').subscribe()
channel = testSetup.socket.channel('topic')
await channel.subscribe()

assert.equal(testSetup.socket.getChannels().length, 1)
expect(connectStub).toHaveBeenCalled()
Expand All @@ -118,11 +119,11 @@ describe('channel', () => {
test('does not remove other channels when removing one', async () => {
const connectStub = vi.spyOn(testSetup.socket, 'connect')
const disconnectStub = vi.spyOn(testSetup.socket, 'disconnect')
const channel1 = testSetup.socket.channel('chan1').subscribe()
const channel2 = testSetup.socket.channel('chan2').subscribe()
const channel1 = testSetup.socket.channel('chan1')
const channel2 = testSetup.socket.channel('chan2')

channel1.subscribe()
channel2.subscribe()
await channel1.subscribe()
await channel2.subscribe()
assert.equal(testSetup.socket.getChannels().length, 2)
expect(connectStub).toHaveBeenCalled()

Expand Down