diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index 8a16f09..1ae9b6d 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -282,6 +282,8 @@ def __init__( ws_url: str, auth_details: DictObject, log_context: LogEventContext | None = None, + max_reconnect_retries: int = 3, + initial_retry_delay: float = 1.0, ) -> None: self._auth_details = auth_details self._connection_attempted = asyncio.Event() @@ -295,6 +297,10 @@ def __init__( self._logger = logger = new_logger(type(self).__name__) logger.update_context(log_context, ws_url=ws_url) self._mux = MultiplexingManager(logger) + # Reconnection configuration + self._max_reconnect_retries = max_reconnect_retries + self._initial_retry_delay = initial_retry_delay + self._consecutive_failures = 0 async def connect(self) -> bool: """Connect websocket from the task manager's event loop.""" @@ -515,15 +521,48 @@ async def _process_next_message(self) -> bool: return await self._enqueue_message(message) async def _receive_messages(self) -> None: - """Process received messages until task is cancelled.""" + """Process received messages with automatic reconnection on failure.""" while True: try: await self._process_next_message() - except (LMStudioWebsocketError, HTTPXWSException): - if self._ws is not None and not self._ws_disconnected.is_set(): - # Websocket failed unexpectedly (rather than due to client shutdown) - self._logger.error("Websocket failed, terminating session.") - break + # this Reset failure counter on successful yeah + self._consecutive_failures = 0 + except (LMStudioWebsocketError, HTTPXWSException) as exc: + # and it will check if this was an intentional disconnect + if self._ws_disconnected.is_set(): + self._logger.debug("Websocket disconnected intentionally") + break + + # and this is for Increment failure counter + self._consecutive_failures += 1 + + # this wiill Check if we should attempt reconnection + if self._consecutive_failures > self._max_reconnect_retries: + self._logger.error( + f"Websocket failed after {self._max_reconnect_retries} reconnection attempts, " + "terminating session.", + consecutive_failures=self._consecutive_failures, + ) + break + + # Calculate exponential backoff delay + retry_delay = self._initial_retry_delay * (2 ** (self._consecutive_failures - 1)) + retry_delay = min(retry_delay, 30.0) # Cap at 30 seconds + + self._logger.warning( + f"Websocket error (attempt {self._consecutive_failures}/{self._max_reconnect_retries}), " + f"retrying in {retry_delay:.1f}s: {exc}", + consecutive_failures=self._consecutive_failures, + retry_delay=retry_delay, + error=str(exc), + ) + + # Wait before attempting to reconnect + await asyncio.sleep(retry_delay) + + # there is a note like The actual reconnection happens at a higher level + # This code allows the message loop to continue, giving the + # connection a chance to reestablish itself async def _enqueue_message(self, message: Any) -> bool: if message is None: diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 7c33104..1ed6b1c 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -605,12 +605,22 @@ def acquire_channel_id(self, rx_queue: RxQueue) -> int: def release_channel_id(self, channel_id: int, rx_queue: RxQueue) -> None: """Release a previously acquired streaming channel ID.""" open_channels = self._open_channels - assigned_queue = open_channels.get(channel_id) - if rx_queue is not assigned_queue: - raise LMStudioRuntimeError( - f"Unexpected change to reply queue for channel ({channel_id} in {self!r})" + # this Use pop to safely remove the channel, even if already gone + assigned_queue = open_channels.pop(channel_id, None) + + # Make cleanup more forgiving log warnings instead of raising + if assigned_queue is None: + self._logger.warning( + f"Channel {channel_id} already released or never acquired", + channel_id=channel_id, + ) + elif rx_queue is not assigned_queue: + # Queue mismatch is suspicious but shouldn't prevent cleanup + self._logger.warning( + f"Channel {channel_id} queue mismatch during release " + f"(expected {rx_queue!r}, found {assigned_queue!r})", + channel_id=channel_id, ) - del open_channels[channel_id] @contextmanager def assign_channel_id(self, rx_queue: RxQueue) -> Generator[int, None, None]: @@ -636,12 +646,22 @@ def acquire_call_id(self, rx_queue: RxQueue) -> int: def release_call_id(self, call_id: int, rx_queue: RxQueue) -> None: """Release a previously acquired remote call ID.""" pending_calls = self._pending_calls - assigned_queue = pending_calls.get(call_id) - if rx_queue is not assigned_queue: - raise LMStudioRuntimeError( - f"Unexpected change to reply queue for remote call ({call_id} in {self!r})" + # Use pop to safely remove the call, even if already gone + assigned_queue = pending_calls.pop(call_id, None) + + # Make cleanup more forgiving log warnings instead of raising + if assigned_queue is None: + self._logger.warning( + f"Remote call {call_id} already released or never acquired", + call_id=call_id, + ) + elif rx_queue is not assigned_queue: + # Queue mismatch is suspicious but shouldn't prevent cleanup + self._logger.warning( + f"Remote call {call_id} queue mismatch during release " + f"(expected {rx_queue!r}, found {assigned_queue!r})", + call_id=call_id, ) - del pending_calls[call_id] @contextmanager def assign_call_id(self, rx_queue: RxQueue) -> Generator[int, None, None]: