diff --git a/modules/ensemble/lib/framework/apiproviders/sse_api_provider.dart b/modules/ensemble/lib/framework/apiproviders/sse_api_provider.dart index 33df0d418..19a6af49b 100644 --- a/modules/ensemble/lib/framework/apiproviders/sse_api_provider.dart +++ b/modules/ensemble/lib/framework/apiproviders/sse_api_provider.dart @@ -27,6 +27,7 @@ class SSEAPIProvider extends APIProvider with LiveAPIProvider { final Map _subscriptions = {}; final Map _activeClients = {}; final Set _manuallyDisconnected = {}; + bool _disposed = false; @override Future init(String appId, Map config) async { @@ -206,7 +207,7 @@ class SSEAPIProvider extends APIProvider with LiveAPIProvider { ResponseListener listener, SSEOptions options, DataContext eContext) async { - int reconnectAttempts = 0; + final List reconnectAttempts = [0]; String? lastEventId; Future connect() async { @@ -236,7 +237,7 @@ class SSEAPIProvider extends APIProvider with LiveAPIProvider { } // Reset reconnect attempts on successful connection - reconnectAttempts = 0; + reconnectAttempts[0] = 0; // Parse SSE stream String? currentEventType; @@ -296,18 +297,20 @@ class SSEAPIProvider extends APIProvider with LiveAPIProvider { onError: (error) { log('SSE stream error: $error'); _handleSSEError(error, apiName, listener, options, - reconnectAttempts, () => connect(), url, headers, eContext); + reconnectAttempts, () => connect()); }, onDone: () { log('SSE stream closed'); // Attempt reconnection if enabled - if (options.autoReconnect && - reconnectAttempts < options.maxReconnectAttempts) { - reconnectAttempts++; + if (_shouldReconnect(apiName, options, reconnectAttempts[0])) { + reconnectAttempts[0]++; final delay = Duration( - milliseconds: options.reconnectDelay * reconnectAttempts); + milliseconds: options.reconnectDelay * reconnectAttempts[0]); Future.delayed(delay, () { - log('Reconnecting SSE (attempt $reconnectAttempts)...'); + if (!_shouldReconnect(apiName, options, reconnectAttempts[0])) { + return; + } + log('Reconnecting SSE (attempt ${reconnectAttempts[0]})...'); connect(); }); } else { @@ -328,7 +331,7 @@ class SSEAPIProvider extends APIProvider with LiveAPIProvider { _subscriptions[apiName] = subscription; } catch (error) { _handleSSEError(error, apiName, listener, options, reconnectAttempts, - () => connect(), url, headers, eContext); + () => connect()); } } @@ -374,16 +377,20 @@ class SSEAPIProvider extends APIProvider with LiveAPIProvider { } } + bool _shouldReconnect( + String apiName, SSEOptions options, int reconnectAttempts) => + !_disposed && + !_manuallyDisconnected.contains(apiName) && + options.autoReconnect && + reconnectAttempts < options.maxReconnectAttempts; + void _handleSSEError( dynamic error, String apiName, ResponseListener listener, SSEOptions options, - int reconnectAttempts, - VoidCallback reconnect, - String url, - Map headers, - DataContext eContext) { + List reconnectAttempts, + VoidCallback reconnect) { String errorMessage; if (error is HandshakeException || error is TlsException) { errorMessage = @@ -408,13 +415,15 @@ class SSEAPIProvider extends APIProvider with LiveAPIProvider { listener(errorResponse); // Attempt reconnection if enabled - if (options.autoReconnect && - reconnectAttempts < options.maxReconnectAttempts) { - reconnectAttempts++; - final delay = - Duration(milliseconds: options.reconnectDelay * reconnectAttempts); + if (_shouldReconnect(apiName, options, reconnectAttempts[0])) { + reconnectAttempts[0]++; + final delay = Duration( + milliseconds: options.reconnectDelay * reconnectAttempts[0]); Future.delayed(delay, () { - log('Reconnecting SSE after error (attempt $reconnectAttempts)...'); + if (!_shouldReconnect(apiName, options, reconnectAttempts[0])) { + return; + } + log('Reconnecting SSE after error (attempt ${reconnectAttempts[0]})...'); reconnect(); }); } @@ -530,6 +539,7 @@ class SSEAPIProvider extends APIProvider with LiveAPIProvider { @override dispose() { + _disposed = true; for (final apiName in _subscriptions.keys.toList()) { _manuallyDisconnected.add(apiName); } @@ -543,10 +553,13 @@ class SSEAPIProvider extends APIProvider with LiveAPIProvider { client.close(); } _activeClients.clear(); - - _manuallyDisconnected.clear(); } + @visibleForTesting + bool shouldReconnectForTesting( + String apiName, SSEOptions options, int reconnectAttempts) => + _shouldReconnect(apiName, options, reconnectAttempts); + @visibleForTesting int get subscriptionCountForTesting => _subscriptions.length; diff --git a/modules/ensemble/test/sse_provider_dispose_test.dart b/modules/ensemble/test/sse_provider_dispose_test.dart index e1201f786..a162179c2 100644 --- a/modules/ensemble/test/sse_provider_dispose_test.dart +++ b/modules/ensemble/test/sse_provider_dispose_test.dart @@ -49,4 +49,54 @@ void main() { expect(identical(providers.getProvider('sse'), sse), isTrue); }); }); + + group('SSEAPIProvider reconnect guards', () { + test('disconnect prevents auto-reconnect', () async { + final provider = SSEAPIProvider(); + await provider.disconnect('liveFeed'); + + expect( + provider.shouldReconnectForTesting('liveFeed', SSEOptions(), 0), + isFalse, + ); + }); + + test('dispose prevents auto-reconnect', () { + final provider = SSEAPIProvider(); + provider.dispose(); + + expect( + provider.shouldReconnectForTesting('liveFeed', SSEOptions(), 0), + isFalse, + ); + }); + + test('honors maxReconnectAttempts', () { + final provider = SSEAPIProvider(); + final options = SSEOptions(maxReconnectAttempts: 3); + + expect(provider.shouldReconnectForTesting('api', options, 0), isTrue); + expect(provider.shouldReconnectForTesting('api', options, 2), isTrue); + expect(provider.shouldReconnectForTesting('api', options, 3), isFalse); + }); + + test('shared reconnect counter stops after max error retries', () { + final provider = SSEAPIProvider(); + final options = SSEOptions(maxReconnectAttempts: 3); + final attempts = [0]; + + for (var i = 0; i < 3; i++) { + expect( + provider.shouldReconnectForTesting('api', options, attempts[0]), + isTrue, + ); + attempts[0]++; + } + + expect( + provider.shouldReconnectForTesting('api', options, attempts[0]), + isFalse, + ); + }); + }); }