Skip to content

Commit f3487e8

Browse files
authored
chore: refactor v1 rpc channels (#609)
* chore: refactor v1 rpc channels This merges the two packages v1_channel and v1_rpc_channel since they are very closely related. The differences between the RPC types are now handled with `RpcStrategy` (encoding, decoding, which channel, health checking, etc). The "PayloadEncodedV1RpcChannel" is now `_send_rpc` which accepts the rpc strategy. The `V1RpcChannel` interface is moved to the `v1_protocol` module and now has a single implementation that handles everyting (the `PickFirstAvailable` logic as well as the response parsing). Overall the code is now less generalized, but is probably easier to understand since its all in a single place. Notably, all the rpc code was already tested via the v1_channel interface. * chore: fix v1 channel typing and improve readability * chore: remove unnecessary pydoc on private members * chore: remove unnecessary pydoc to make the code more compact * chore: Improve doc string readability and grammar * chore: remove unnecessary docstrings * fix: remove python 3.11 incompatibility
1 parent 212dc24 commit f3487e8

File tree

8 files changed

+228
-268
lines changed

8 files changed

+228
-268
lines changed

roborock/devices/mqtt_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,5 +90,5 @@ async def restart(self) -> None:
9090
def create_mqtt_channel(
9191
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
9292
) -> MqttChannel:
93-
"""Create a V1Channel for the given device."""
93+
"""Create a MQTT channel for the given device."""
9494
return MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)

roborock/devices/traits/v1/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode
3939
from roborock.devices.cache import Cache
4040
from roborock.devices.traits import Trait
41-
from roborock.devices.v1_rpc_channel import V1RpcChannel
4241
from roborock.map.map_parser import MapParserConfig
42+
from roborock.protocols.v1_protocol import V1RpcChannel
4343
from roborock.web_api import UserWebApiClient
4444

4545
from .child_lock import ChildLockTrait

roborock/devices/traits/v1/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import ClassVar, Self
1010

1111
from roborock.data import RoborockBase
12-
from roborock.devices.v1_rpc_channel import V1RpcChannel
12+
from roborock.protocols.v1_protocol import V1RpcChannel
1313
from roborock.roborock_typing import RoborockCommand
1414

1515
_LOGGER = logging.getLogger(__name__)

roborock/devices/v1_channel.py

Lines changed: 167 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,43 @@
88
import datetime
99
import logging
1010
from collections.abc import Callable
11-
from typing import TypeVar
11+
from dataclasses import dataclass
12+
from typing import Any, TypeVar
1213

1314
from roborock.data import HomeDataDevice, NetworkInfo, RoborockBase, UserData
1415
from roborock.exceptions import RoborockException
16+
from roborock.mqtt.health_manager import HealthManager
1517
from roborock.mqtt.session import MqttParams, MqttSession
1618
from roborock.protocols.v1_protocol import (
19+
CommandType,
20+
MapResponse,
21+
ParamsType,
22+
RequestMessage,
23+
ResponseData,
24+
ResponseMessage,
1725
SecurityData,
26+
V1RpcChannel,
27+
create_map_response_decoder,
1828
create_security_data,
29+
decode_rpc_response,
1930
)
20-
from roborock.roborock_message import RoborockMessage
31+
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
2132
from roborock.roborock_typing import RoborockCommand
2233

2334
from .cache import Cache
2435
from .channel import Channel
2536
from .local_channel import LocalChannel, LocalSession, create_local_session
2637
from .mqtt_channel import MqttChannel
27-
from .v1_rpc_channel import (
28-
PickFirstAvailable,
29-
V1RpcChannel,
30-
create_local_rpc_channel,
31-
create_map_rpc_channel,
32-
create_mqtt_rpc_channel,
33-
)
3438

3539
_LOGGER = logging.getLogger(__name__)
3640

3741
__all__ = [
38-
"V1Channel",
42+
"create_v1_channel",
3943
]
4044

4145
_T = TypeVar("_T", bound=RoborockBase)
46+
_TIMEOUT = 10.0
47+
4248

4349
# Exponential backoff parameters for reconnecting to local
4450
MIN_RECONNECT_INTERVAL = datetime.timedelta(minutes=1)
@@ -50,6 +56,106 @@
5056
LOCAL_CONNECTION_CHECK_INTERVAL = datetime.timedelta(seconds=15)
5157

5258

59+
@dataclass(frozen=True)
60+
class RpcStrategy:
61+
"""Strategy for encoding/sending/decoding RPC commands."""
62+
63+
name: str # For debug logging
64+
channel: LocalChannel | MqttChannel
65+
encoder: Callable[[RequestMessage], RoborockMessage]
66+
decoder: Callable[[RoborockMessage], ResponseMessage | MapResponse | None]
67+
health_manager: HealthManager | None = None
68+
69+
70+
class RpcChannel(V1RpcChannel):
71+
"""Provides an RPC interface around a pub/sub transport channel."""
72+
73+
def __init__(self, rpc_strategies: list[RpcStrategy]) -> None:
74+
"""Initialize the RpcChannel with on ordered list of strategies."""
75+
self._rpc_strategies = rpc_strategies
76+
77+
async def send_command(
78+
self,
79+
method: CommandType,
80+
*,
81+
response_type: type[_T] | None = None,
82+
params: ParamsType = None,
83+
) -> _T | Any:
84+
"""Send a command and return either a decoded or parsed response."""
85+
request = RequestMessage(method, params=params)
86+
87+
# Try each channel in order until one succeeds
88+
last_exception = None
89+
for strategy in self._rpc_strategies:
90+
try:
91+
decoded_response = await self._send_rpc(strategy, request)
92+
except RoborockException as e:
93+
_LOGGER.warning("Command %s failed on %s channel: %s", method, strategy.name, e)
94+
last_exception = e
95+
except Exception as e:
96+
_LOGGER.exception("Unexpected error sending command %s on %s channel", method, strategy.name)
97+
last_exception = RoborockException(f"Unexpected error: {e}")
98+
else:
99+
if response_type is not None:
100+
if not isinstance(decoded_response, dict):
101+
raise RoborockException(
102+
f"Expected dict response to parse {response_type.__name__}, got {type(decoded_response)}"
103+
)
104+
return response_type.from_dict(decoded_response)
105+
return decoded_response
106+
107+
raise last_exception or RoborockException("No available connection to send command")
108+
109+
@staticmethod
110+
async def _send_rpc(strategy: RpcStrategy, request: RequestMessage) -> ResponseData | bytes:
111+
"""Send a command and return a decoded response type.
112+
113+
This provides an RPC interface over a given channel strategy. The device
114+
channel only supports publish and subscribe, so this function handles
115+
associating requests with their corresponding responses.
116+
"""
117+
future: asyncio.Future[ResponseData | bytes] = asyncio.Future()
118+
_LOGGER.debug(
119+
"Sending command (%s, request_id=%s): %s, params=%s",
120+
strategy.name,
121+
request.request_id,
122+
request.method,
123+
request.params,
124+
)
125+
126+
message = strategy.encoder(request)
127+
128+
def find_response(response_message: RoborockMessage) -> None:
129+
try:
130+
decoded = strategy.decoder(response_message)
131+
except RoborockException as ex:
132+
_LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex)
133+
return
134+
if decoded is None:
135+
return
136+
_LOGGER.debug("Received response (%s, request_id=%s)", strategy.name, decoded.request_id)
137+
if decoded.request_id == request.request_id:
138+
if isinstance(decoded, ResponseMessage) and decoded.api_error:
139+
future.set_exception(decoded.api_error)
140+
else:
141+
future.set_result(decoded.data)
142+
143+
unsub = await strategy.channel.subscribe(find_response)
144+
try:
145+
await strategy.channel.publish(message)
146+
result = await asyncio.wait_for(future, timeout=_TIMEOUT)
147+
except TimeoutError as ex:
148+
if strategy.health_manager:
149+
await strategy.health_manager.on_timeout()
150+
future.cancel()
151+
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
152+
finally:
153+
unsub()
154+
if strategy.health_manager:
155+
await strategy.health_manager.on_success()
156+
return result
157+
158+
53159
class V1Channel(Channel):
54160
"""Unified V1 protocol channel with automatic MQTT/local connection handling.
55161
@@ -66,23 +172,13 @@ def __init__(
66172
local_session: LocalSession,
67173
cache: Cache,
68174
) -> None:
69-
"""Initialize the V1Channel.
70-
71-
Args:
72-
mqtt_channel: MQTT channel for cloud communication
73-
local_session: Factory that creates LocalChannels for a hostname.
74-
"""
175+
"""Initialize the V1Channel."""
75176
self._device_uid = device_uid
177+
self._security_data = security_data
76178
self._mqtt_channel = mqtt_channel
77-
self._mqtt_rpc_channel = create_mqtt_rpc_channel(mqtt_channel, security_data)
179+
self._mqtt_health_manager = HealthManager(self._mqtt_channel.restart)
78180
self._local_session = local_session
79181
self._local_channel: LocalChannel | None = None
80-
self._local_rpc_channel: V1RpcChannel | None = None
81-
# Prefer local, fallback to MQTT
82-
self._combined_rpc_channel = PickFirstAvailable(
83-
[lambda: self._local_rpc_channel, lambda: self._mqtt_rpc_channel]
84-
)
85-
self._map_rpc_channel = create_map_rpc_channel(mqtt_channel, security_data)
86182
self._mqtt_unsub: Callable[[], None] | None = None
87183
self._local_unsub: Callable[[], None] | None = None
88184
self._callback: Callable[[RoborockMessage], None] | None = None
@@ -107,18 +203,60 @@ def is_mqtt_connected(self) -> bool:
107203

108204
@property
109205
def rpc_channel(self) -> V1RpcChannel:
110-
"""Return the combined RPC channel prefers local with a fallback to MQTT."""
111-
return self._combined_rpc_channel
206+
"""Return the combined RPC channel that prefers local with a fallback to MQTT."""
207+
strategies = []
208+
if local_rpc_strategy := self._create_local_rpc_strategy():
209+
strategies.append(local_rpc_strategy)
210+
strategies.append(self._create_mqtt_rpc_strategy())
211+
return RpcChannel(strategies)
112212

113213
@property
114214
def mqtt_rpc_channel(self) -> V1RpcChannel:
115-
"""Return the MQTT RPC channel."""
116-
return self._mqtt_rpc_channel
215+
"""Return the MQTT-only RPC channel."""
216+
return RpcChannel([self._create_mqtt_rpc_strategy()])
117217

118218
@property
119219
def map_rpc_channel(self) -> V1RpcChannel:
120220
"""Return the map RPC channel used for fetching map content."""
121-
return self._map_rpc_channel
221+
decoder = create_map_response_decoder(security_data=self._security_data)
222+
return RpcChannel([self._create_mqtt_rpc_strategy(decoder)])
223+
224+
def _create_local_rpc_strategy(self) -> RpcStrategy | None:
225+
"""Create the RPC strategy for local transport."""
226+
if self._local_channel is None or not self.is_local_connected:
227+
return None
228+
return RpcStrategy(
229+
name="local",
230+
channel=self._local_channel,
231+
encoder=self._local_encoder,
232+
decoder=decode_rpc_response,
233+
)
234+
235+
def _local_encoder(self, x: RequestMessage) -> RoborockMessage:
236+
"""Encode a request message for local transport.
237+
238+
This will read the current local channel's protocol version which
239+
changes as the protocol version is discovered.
240+
"""
241+
if self._local_channel is None:
242+
raise ValueError("Local channel unavailable for encoding")
243+
return x.encode_message(
244+
RoborockMessageProtocol.GENERAL_REQUEST,
245+
version=self._local_channel.protocol_version,
246+
)
247+
248+
def _create_mqtt_rpc_strategy(self, decoder: Callable[[RoborockMessage], Any] = decode_rpc_response) -> RpcStrategy:
249+
"""Create the RPC strategy for MQTT transport with optional custom decoder."""
250+
return RpcStrategy(
251+
name="mqtt",
252+
channel=self._mqtt_channel,
253+
encoder=lambda x: x.encode_message(
254+
RoborockMessageProtocol.RPC_REQUEST,
255+
security_data=self._security_data,
256+
),
257+
decoder=decoder,
258+
health_manager=self._mqtt_health_manager,
259+
)
122260

123261
async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
124262
"""Subscribe to all messages from the device.
@@ -185,7 +323,7 @@ async def _get_networking_info(self, *, prefer_cache: bool = True) -> NetworkInf
185323
_LOGGER.debug("Using cached network info for device %s", self._device_uid)
186324
return network_info
187325
try:
188-
network_info = await self._mqtt_rpc_channel.send_command(
326+
network_info = await self.mqtt_rpc_channel.send_command(
189327
RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo
190328
)
191329
except RoborockException as e:
@@ -216,7 +354,6 @@ async def _local_connect(self, *, prefer_cache: bool = True) -> None:
216354
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
217355
# Wire up the new channel
218356
self._local_channel = local_channel
219-
self._local_rpc_channel = create_local_rpc_channel(self._local_channel)
220357
self._local_unsub = await self._local_channel.subscribe(self._on_local_message)
221358
_LOGGER.info("Successfully connected to local device %s", self._device_uid)
222359

0 commit comments

Comments
 (0)