Skip to content

Commit 9ea6b03

Browse files
committed
chore: Apply co-pilot feedback
1 parent c19abb2 commit 9ea6b03

File tree

6 files changed

+67
-8
lines changed

6 files changed

+67
-8
lines changed

roborock/devices/rpc/v1_channel.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,14 +308,17 @@ async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callab
308308
loop = asyncio.get_running_loop()
309309
self._reconnect_task = loop.create_task(self._background_reconnect())
310310

311-
# Always subscribe to MQTT to receive protocol updates (data points)
311+
# Always attempt to subscribe to MQTT to receive protocol updates (data points)
312312
# even if we have a local connection. Protocol updates only come via cloud/MQTT.
313313
# Local connection is used for RPC commands, but push notifications come via MQTT.
314-
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
315-
if self.is_local_connected:
316-
self._logger.debug("V1Channel connected via local and MQTT (for protocol updates)")
317-
else:
318-
self._logger.debug("V1Channel connected via MQTT only")
314+
try:
315+
self._mqtt_unsub = await self._mqtt_channel.subscribe(self._on_mqtt_message)
316+
except RoborockException as err:
317+
if not self.is_local_connected:
318+
# Propagate error if both local and MQTT failed
319+
self._logger.debug("MQTT connection also failed: %s", err)
320+
raise
321+
self._logger.debug("MQTT subscription failed, continuing with local-only connection: %s", err)
319322

320323
def unsub() -> None:
321324
"""Unsubscribe from all messages."""

roborock/devices/traits/v1/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,15 @@ def _get_rpc_channel(self, trait: V1TraitMixin) -> V1RpcChannel:
235235

236236
async def start(self) -> None:
237237
"""Start the properties API and discover features."""
238+
if self._unsub:
239+
return
238240
await self.discover_features()
239241
self._unsub = self._add_dps_listener(self._on_dps_update)
240242

241243
def close(self) -> None:
242244
if self._unsub:
243245
self._unsub()
246+
self._unsub = None
244247

245248
def _on_dps_update(self, dps: dict[RoborockDataProtocol, Any]) -> None:
246249
"""Handle incoming messages from the device.

roborock/devices/traits/v1/consumeable.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
periodically, such as filters, brushes, etc.
55
"""
66

7+
import logging
78
from enum import StrEnum
89
from typing import Any, Self
910

@@ -18,6 +19,8 @@
1819
"ConsumableTrait",
1920
]
2021

22+
_LOGGER = logging.getLogger(__name__)
23+
2124
_DPS_CONVERTER = common.DpsDataConverter.from_dataclass(Consumable)
2225

2326

@@ -48,6 +51,11 @@ class ConsumableTrait(Consumable, common.V1TraitMixin, TraitUpdateListener):
4851
command = RoborockCommand.GET_CONSUMABLE
4952
converter = common.DefaultConverter(Consumable)
5053

54+
def __init__(self) -> None:
55+
"""Initialize the consumable trait."""
56+
super().__init__()
57+
TraitUpdateListener.__init__(self, logger=_LOGGER)
58+
5159
async def reset_consumable(self, consumable: ConsumableAttribute) -> None:
5260
"""Reset a specific consumable attribute on the device."""
5361
await self.rpc_channel.send_command(RoborockCommand.RESET_CONSUMABLE, params=[consumable.value])

roborock/protocols/v1_protocol.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,14 @@ def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
170170
response, as long as we can extract the request ID. This is so we can
171171
associate an API response with a request even if there was an error.
172172
"""
173-
if not (datapoints := _decode_dps_message(message)):
173+
if not message.payload:
174174
return ResponseMessage(request_id=message.seq, data={})
175175

176+
if (datapoints := _decode_dps_message(message)) is None:
177+
raise RoborockException(
178+
f"Invalid V1 message format: missing or invalid 'dps' in payload for {message.payload!r}"
179+
)
180+
176181
if not (data_point := datapoints.get(RoborockMessageProtocol.RPC_RESPONSE)):
177182
raise RoborockException(
178183
f"Invalid V1 message format: missing '{RoborockMessageProtocol.RPC_RESPONSE}' data point"

tests/devices/rpc/test_v1_channel.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
create_mqtt_encoder,
2424
)
2525
from roborock.protocols.v1_protocol import MapResponse, SecurityData, V1RpcChannel
26-
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
26+
from roborock.roborock_message import RoborockDataProtocol, RoborockMessage, RoborockMessageProtocol
2727
from roborock.roborock_typing import RoborockCommand
2828
from tests import mock_data
2929
from tests.fixtures.channel_fixtures import FakeChannel
@@ -580,3 +580,35 @@ async def test_v1_channel_send_map_command(
580580

581581
# Verify the result is the data from our mocked decoder
582582
assert result == decompressed_map_data
583+
584+
585+
async def test_v1_channel_add_dps_listener(
586+
v1_channel: V1Channel,
587+
mock_mqtt_channel: FakeChannel,
588+
) -> None:
589+
"""Test that DPS listeners receive decoded protocol updates from MQTT."""
590+
mock_mqtt_channel.response_queue.append(TEST_NETWORK_INFO_RESPONSE)
591+
await v1_channel.subscribe(Mock())
592+
593+
# Create a mock listener for DPS updates
594+
dps_listener = Mock()
595+
unsub_dps = v1_channel.add_dps_listener(dps_listener)
596+
597+
# Simulate an incoming MQTT message with data protocol payload.
598+
dps_payload = json.dumps({"dps": {"121": 5}}).encode()
599+
push_message = RoborockMessage(
600+
protocol=RoborockMessageProtocol.GENERAL_REQUEST,
601+
payload=dps_payload,
602+
)
603+
mock_mqtt_channel.notify_subscribers(push_message)
604+
605+
dps_listener.assert_called_once()
606+
called_args = dps_listener.call_args[0][0]
607+
assert called_args[RoborockDataProtocol.STATE] == 5
608+
609+
unsub_dps()
610+
611+
# Verify unsubscribe works
612+
dps_listener.reset_mock()
613+
v1_channel._on_mqtt_message(push_message)
614+
dps_listener.assert_not_called()

tests/fixtures/channel_fixtures.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,11 @@ async def _subscribe(self, callback: Callable[[RoborockMessage], None]) -> Calla
5151
"""Simulate subscribing to messages."""
5252
self.subscribers.append(callback)
5353
return lambda: self.subscribers.remove(callback)
54+
55+
def notify_subscribers(self, message: RoborockMessage) -> None:
56+
"""Notify subscribers of a message.
57+
58+
This can be used by tests to simulate the channel receiving a message.
59+
"""
60+
for subscriber in list(self.subscribers):
61+
subscriber(message)

0 commit comments

Comments
 (0)