Skip to content

Commit ef73b80

Browse files
committed
fix: add a health manager for restarting unhealthy mqtt connections
1 parent ba0d287 commit ef73b80

File tree

4 files changed

+139
-1
lines changed

4 files changed

+139
-1
lines changed

roborock/devices/v1_rpc_channel.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from roborock.data import RoborockBase
1515
from roborock.exceptions import RoborockException
16+
from roborock.mqtt.health_manager import HealthManager
1617
from roborock.protocols.v1_protocol import (
1718
CommandType,
1819
MapResponse,
@@ -125,12 +126,14 @@ def __init__(
125126
channel: MqttChannel | LocalChannel,
126127
payload_encoder: Callable[[RequestMessage], RoborockMessage],
127128
decoder: Callable[[RoborockMessage], ResponseMessage] | Callable[[RoborockMessage], MapResponse | None],
129+
health_manager: HealthManager | None = None,
128130
) -> None:
129131
"""Initialize the channel with a raw channel and an encoder function."""
130132
self._name = name
131133
self._channel = channel
132134
self._payload_encoder = payload_encoder
133135
self._decoder = decoder
136+
self._health_manager = health_manager
134137

135138
async def _send_raw_command(
136139
self,
@@ -165,13 +168,19 @@ def find_response(response_message: RoborockMessage) -> None:
165168
unsub = await self._channel.subscribe(find_response)
166169
try:
167170
await self._channel.publish(message)
168-
return await asyncio.wait_for(future, timeout=_TIMEOUT)
171+
result = await asyncio.wait_for(future, timeout=_TIMEOUT)
169172
except TimeoutError as ex:
173+
if self._health_manager:
174+
await self._health_manager.on_timeout()
170175
future.cancel()
171176
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
172177
finally:
173178
unsub()
174179

180+
if self._health_manager:
181+
await self._health_manager.on_success()
182+
return result
183+
175184

176185
def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
177186
"""Create a V1 RPC channel using an MQTT channel."""
@@ -180,6 +189,7 @@ def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityDa
180189
mqtt_channel,
181190
lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data),
182191
decode_rpc_response,
192+
health_manager=HealthManager(mqtt_channel.restart),
183193
)
184194

185195

roborock/mqtt/health_manager.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""A health manager for monitoring MQTT connections to Roborock devices.
2+
3+
We observe a problem where sometimes the MQTT connection appears to be alive but
4+
no messages are being received. To mitigate this, we track consecutive timeouts
5+
and restart the connection if too many timeouts occur in succession.
6+
"""
7+
8+
import datetime
9+
from collections.abc import Awaitable, Callable
10+
11+
# Number of consecutive timeouts before considering the connection unhealthy.
12+
TIMEOUT_THRESHOLD = 3
13+
14+
# We won't restart the session more often than this interval.
15+
RESTART_COOLDOWN = datetime.timedelta(minutes=30)
16+
17+
18+
class HealthManager:
19+
"""Manager for monitoring the health of MQTT connections.
20+
21+
This tracks communication timeouts and can trigger restarts of the MQTT
22+
session if too many timeouts occur in succession.
23+
"""
24+
25+
def __init__(self, restart: Callable[[], Awaitable[None]]) -> None:
26+
"""Initialize the health manager.
27+
28+
Args:
29+
restart: A callable to restart the MQTT session.
30+
"""
31+
self._consecutive_timeouts = 0
32+
self._restart = restart
33+
self._last_restart: datetime.datetime | None = None
34+
35+
async def on_success(self) -> None:
36+
"""Record a successful communication event."""
37+
self._consecutive_timeouts = 0
38+
39+
async def on_timeout(self) -> None:
40+
"""Record a timeout event.
41+
42+
This may trigger a restart of the MQTT session if too many timeouts
43+
have occurred in succession.
44+
"""
45+
self._consecutive_timeouts += 1
46+
if self._consecutive_timeouts >= TIMEOUT_THRESHOLD:
47+
now = datetime.datetime.now()
48+
if self._last_restart is None or now - self._last_restart >= RESTART_COOLDOWN:
49+
await self._restart()
50+
self._last_restart = now
51+
self._consecutive_timeouts = 0

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,17 @@ def __init__(self):
363363
self.connect = AsyncMock(side_effect=self._connect)
364364
self.close = MagicMock(side_effect=self._close)
365365
self.protocol_version = LocalProtocolVersion.V1
366+
self.restart = MagicMock(side_effect=self._restart)
366367

367368
async def _connect(self) -> None:
368369
self._is_connected = True
369370

370371
def _close(self) -> None:
371372
self._is_connected = False
372373

374+
def _restart(self) -> None:
375+
self._is_connected = False
376+
373377
@property
374378
def is_connected(self) -> bool:
375379
"""Return true if connected."""

tests/mqtt/test_health_manager.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Tests for the health manager."""
2+
3+
import datetime
4+
from unittest.mock import AsyncMock, patch
5+
6+
from roborock.mqtt.health_manager import HealthManager
7+
8+
9+
async def test_health_manager_restart_called_after_timeouts() -> None:
10+
"""Test that the health manager calls restart after consecutive timeouts."""
11+
restart = AsyncMock()
12+
health_manager = HealthManager(restart=restart)
13+
14+
await health_manager.on_timeout()
15+
await health_manager.on_timeout()
16+
restart.assert_not_called()
17+
18+
await health_manager.on_timeout()
19+
restart.assert_called_once()
20+
21+
22+
async def test_health_manager_success_resets_counter() -> None:
23+
"""Test that a successful message resets the timeout counter."""
24+
restart = AsyncMock()
25+
health_manager = HealthManager(restart=restart)
26+
27+
await health_manager.on_timeout()
28+
await health_manager.on_timeout()
29+
restart.assert_not_called()
30+
31+
await health_manager.on_success()
32+
33+
await health_manager.on_timeout()
34+
await health_manager.on_timeout()
35+
restart.assert_not_called()
36+
37+
await health_manager.on_timeout()
38+
restart.assert_called_once()
39+
40+
41+
async def test_cooldown() -> None:
42+
"""Test that the health manager respects the restart cooldown."""
43+
restart = AsyncMock()
44+
health_manager = HealthManager(restart=restart)
45+
46+
with patch("roborock.mqtt.health_manager.datetime") as mock_datetime:
47+
now = datetime.datetime(2023, 1, 1, 12, 0, 0)
48+
mock_datetime.datetime.now.return_value = now
49+
50+
# Trigger first restart
51+
await health_manager.on_timeout()
52+
await health_manager.on_timeout()
53+
await health_manager.on_timeout()
54+
restart.assert_called_once()
55+
restart.reset_mock()
56+
57+
# Advance time but stay within cooldown (30 mins)
58+
mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=10)
59+
60+
# Trigger timeouts again
61+
await health_manager.on_timeout()
62+
await health_manager.on_timeout()
63+
await health_manager.on_timeout()
64+
restart.assert_not_called()
65+
66+
# Advance time past cooldown
67+
mock_datetime.datetime.now.return_value = now + datetime.timedelta(minutes=31)
68+
69+
# Trigger timeouts again
70+
await health_manager.on_timeout()
71+
await health_manager.on_timeout()
72+
await health_manager.on_timeout()
73+
restart.assert_called_once()

0 commit comments

Comments
 (0)