88import datetime
99import logging
1010from collections .abc import Callable
11- from typing import TypeVar
11+ from dataclasses import dataclass
12+ from typing import Any , TypeVar
1213
1314from roborock .data import HomeDataDevice , NetworkInfo , RoborockBase , UserData
1415from roborock .exceptions import RoborockException
16+ from roborock .mqtt .health_manager import HealthManager
1517from roborock .mqtt .session import MqttParams , MqttSession
1618from roborock .protocols .v1_protocol import (
19+ CommandType ,
20+ MapResponse ,
21+ ParamsType ,
22+ RequestMessage ,
23+ ResponseData ,
24+ ResponseMessage ,
1725 SecurityData ,
26+ V1RpcChannel ,
27+ create_map_response_decoder ,
1828 create_security_data ,
29+ decode_rpc_response ,
1930)
20- from roborock .roborock_message import RoborockMessage
31+ from roborock .roborock_message import RoborockMessage , RoborockMessageProtocol
2132from roborock .roborock_typing import RoborockCommand
2233
2334from .cache import Cache
2435from .channel import Channel
2536from .local_channel import LocalChannel , LocalSession , create_local_session
2637from .mqtt_channel import MqttChannel
27- from .v1_rpc_channel import (
28- PickFirstAvailable ,
29- V1RpcChannel ,
30- create_local_rpc_channel ,
31- create_map_rpc_channel ,
32- create_mqtt_rpc_channel ,
33- )
3438
3539_LOGGER = logging .getLogger (__name__ )
3640
3741__all__ = [
38- "V1Channel " ,
42+ "create_v1_channel " ,
3943]
4044
4145_T = TypeVar ("_T" , bound = RoborockBase )
46+ _TIMEOUT = 10.0
47+
4248
4349# Exponential backoff parameters for reconnecting to local
4450MIN_RECONNECT_INTERVAL = datetime .timedelta (minutes = 1 )
5056LOCAL_CONNECTION_CHECK_INTERVAL = datetime .timedelta (seconds = 15 )
5157
5258
59+ @dataclass (frozen = True )
60+ class RpcStrategy :
61+ """Strategy for encoding/sending/decoding RPC commands."""
62+
63+ name : str # For debug logging
64+ channel : LocalChannel | MqttChannel
65+ encoder : Callable [[RequestMessage ], RoborockMessage ]
66+ decoder : Callable [[RoborockMessage ], ResponseMessage | MapResponse | None ]
67+ health_manager : HealthManager | None = None
68+
69+
70+ class RpcChannel (V1RpcChannel ):
71+ """Provides an RPC interface around a pub/sub transport channel."""
72+
73+ def __init__ (self , rpc_strategies : list [RpcStrategy ]) -> None :
74+ """Initialize the RpcChannel with on ordered list of strategies."""
75+ self ._rpc_strategies = rpc_strategies
76+
77+ async def send_command (
78+ self ,
79+ method : CommandType ,
80+ * ,
81+ response_type : type [_T ] | None = None ,
82+ params : ParamsType = None ,
83+ ) -> _T | Any :
84+ """Send a command and return either a decoded or parsed response."""
85+ request = RequestMessage (method , params = params )
86+
87+ # Try each channel in order until one succeeds
88+ last_exception = None
89+ for strategy in self ._rpc_strategies :
90+ try :
91+ decoded_response = await self ._send_rpc (strategy , request )
92+ except RoborockException as e :
93+ _LOGGER .warning ("Command %s failed on %s channel: %s" , method , strategy .name , e )
94+ last_exception = e
95+ except Exception as e :
96+ _LOGGER .exception ("Unexpected error sending command %s on %s channel" , method , strategy .name )
97+ last_exception = RoborockException (f"Unexpected error: { e } " )
98+ else :
99+ if response_type is not None :
100+ if not isinstance (decoded_response , dict ):
101+ raise RoborockException (
102+ f"Expected dict response to parse { response_type .__name__ } , got { type (decoded_response )} "
103+ )
104+ return response_type .from_dict (decoded_response )
105+ return decoded_response
106+
107+ raise last_exception or RoborockException ("No available connection to send command" )
108+
109+ @staticmethod
110+ async def _send_rpc (strategy : RpcStrategy , request : RequestMessage ) -> ResponseData | bytes :
111+ """Send a command and return a decoded response type.
112+
113+ This provides an RPC interface over a given channel strategy. The device
114+ channel only supports publish and subscribe, so this function handles
115+ associating requests with their corresponding responses.
116+ """
117+ future : asyncio .Future [ResponseData | bytes ] = asyncio .Future ()
118+ _LOGGER .debug (
119+ "Sending command (%s, request_id=%s): %s, params=%s" ,
120+ strategy .name ,
121+ request .request_id ,
122+ request .method ,
123+ request .params ,
124+ )
125+
126+ message = strategy .encoder (request )
127+
128+ def find_response (response_message : RoborockMessage ) -> None :
129+ try :
130+ decoded = strategy .decoder (response_message )
131+ except RoborockException as ex :
132+ _LOGGER .debug ("Exception while decoding message (%s): %s" , response_message , ex )
133+ return
134+ if decoded is None :
135+ return
136+ _LOGGER .debug ("Received response (%s, request_id=%s)" , strategy .name , decoded .request_id )
137+ if decoded .request_id == request .request_id :
138+ if isinstance (decoded , ResponseMessage ) and decoded .api_error :
139+ future .set_exception (decoded .api_error )
140+ else :
141+ future .set_result (decoded .data )
142+
143+ unsub = await strategy .channel .subscribe (find_response )
144+ try :
145+ await strategy .channel .publish (message )
146+ result = await asyncio .wait_for (future , timeout = _TIMEOUT )
147+ except TimeoutError as ex :
148+ if strategy .health_manager :
149+ await strategy .health_manager .on_timeout ()
150+ future .cancel ()
151+ raise RoborockException (f"Command timed out after { _TIMEOUT } s" ) from ex
152+ finally :
153+ unsub ()
154+ if strategy .health_manager :
155+ await strategy .health_manager .on_success ()
156+ return result
157+
158+
53159class V1Channel (Channel ):
54160 """Unified V1 protocol channel with automatic MQTT/local connection handling.
55161
@@ -66,23 +172,13 @@ def __init__(
66172 local_session : LocalSession ,
67173 cache : Cache ,
68174 ) -> None :
69- """Initialize the V1Channel.
70-
71- Args:
72- mqtt_channel: MQTT channel for cloud communication
73- local_session: Factory that creates LocalChannels for a hostname.
74- """
175+ """Initialize the V1Channel."""
75176 self ._device_uid = device_uid
177+ self ._security_data = security_data
76178 self ._mqtt_channel = mqtt_channel
77- self ._mqtt_rpc_channel = create_mqtt_rpc_channel ( mqtt_channel , security_data )
179+ self ._mqtt_health_manager = HealthManager ( self . _mqtt_channel . restart )
78180 self ._local_session = local_session
79181 self ._local_channel : LocalChannel | None = None
80- self ._local_rpc_channel : V1RpcChannel | None = None
81- # Prefer local, fallback to MQTT
82- self ._combined_rpc_channel = PickFirstAvailable (
83- [lambda : self ._local_rpc_channel , lambda : self ._mqtt_rpc_channel ]
84- )
85- self ._map_rpc_channel = create_map_rpc_channel (mqtt_channel , security_data )
86182 self ._mqtt_unsub : Callable [[], None ] | None = None
87183 self ._local_unsub : Callable [[], None ] | None = None
88184 self ._callback : Callable [[RoborockMessage ], None ] | None = None
@@ -107,18 +203,60 @@ def is_mqtt_connected(self) -> bool:
107203
108204 @property
109205 def rpc_channel (self ) -> V1RpcChannel :
110- """Return the combined RPC channel prefers local with a fallback to MQTT."""
111- return self ._combined_rpc_channel
206+ """Return the combined RPC channel that prefers local with a fallback to MQTT."""
207+ strategies = []
208+ if local_rpc_strategy := self ._create_local_rpc_strategy ():
209+ strategies .append (local_rpc_strategy )
210+ strategies .append (self ._create_mqtt_rpc_strategy ())
211+ return RpcChannel (strategies )
112212
113213 @property
114214 def mqtt_rpc_channel (self ) -> V1RpcChannel :
115- """Return the MQTT RPC channel."""
116- return self ._mqtt_rpc_channel
215+ """Return the MQTT-only RPC channel."""
216+ return RpcChannel ([ self ._create_mqtt_rpc_strategy ()])
117217
118218 @property
119219 def map_rpc_channel (self ) -> V1RpcChannel :
120220 """Return the map RPC channel used for fetching map content."""
121- return self ._map_rpc_channel
221+ decoder = create_map_response_decoder (security_data = self ._security_data )
222+ return RpcChannel ([self ._create_mqtt_rpc_strategy (decoder )])
223+
224+ def _create_local_rpc_strategy (self ) -> RpcStrategy | None :
225+ """Create the RPC strategy for local transport."""
226+ if self ._local_channel is None or not self .is_local_connected :
227+ return None
228+ return RpcStrategy (
229+ name = "local" ,
230+ channel = self ._local_channel ,
231+ encoder = self ._local_encoder ,
232+ decoder = decode_rpc_response ,
233+ )
234+
235+ def _local_encoder (self , x : RequestMessage ) -> RoborockMessage :
236+ """Encode a request message for local transport.
237+
238+ This will read the current local channel's protocol version which
239+ changes as the protocol version is discovered.
240+ """
241+ if self ._local_channel is None :
242+ raise ValueError ("Local channel unavailable for encoding" )
243+ return x .encode_message (
244+ RoborockMessageProtocol .GENERAL_REQUEST ,
245+ version = self ._local_channel .protocol_version ,
246+ )
247+
248+ def _create_mqtt_rpc_strategy (self , decoder : Callable [[RoborockMessage ], Any ] = decode_rpc_response ) -> RpcStrategy :
249+ """Create the RPC strategy for MQTT transport with optional custom decoder."""
250+ return RpcStrategy (
251+ name = "mqtt" ,
252+ channel = self ._mqtt_channel ,
253+ encoder = lambda x : x .encode_message (
254+ RoborockMessageProtocol .RPC_REQUEST ,
255+ security_data = self ._security_data ,
256+ ),
257+ decoder = decoder ,
258+ health_manager = self ._mqtt_health_manager ,
259+ )
122260
123261 async def subscribe (self , callback : Callable [[RoborockMessage ], None ]) -> Callable [[], None ]:
124262 """Subscribe to all messages from the device.
@@ -185,7 +323,7 @@ async def _get_networking_info(self, *, prefer_cache: bool = True) -> NetworkInf
185323 _LOGGER .debug ("Using cached network info for device %s" , self ._device_uid )
186324 return network_info
187325 try :
188- network_info = await self ._mqtt_rpc_channel .send_command (
326+ network_info = await self .mqtt_rpc_channel .send_command (
189327 RoborockCommand .GET_NETWORK_INFO , response_type = NetworkInfo
190328 )
191329 except RoborockException as e :
@@ -216,7 +354,6 @@ async def _local_connect(self, *, prefer_cache: bool = True) -> None:
216354 raise RoborockException (f"Error connecting to local device { self ._device_uid } : { e } " ) from e
217355 # Wire up the new channel
218356 self ._local_channel = local_channel
219- self ._local_rpc_channel = create_local_rpc_channel (self ._local_channel )
220357 self ._local_unsub = await self ._local_channel .subscribe (self ._on_local_message )
221358 _LOGGER .info ("Successfully connected to local device %s" , self ._device_uid )
222359
0 commit comments