Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion roborock/devices/mqtt_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,5 @@ async def restart(self) -> None:
def create_mqtt_channel(
user_data: UserData, mqtt_params: MqttParams, mqtt_session: MqttSession, device: HomeDataDevice
) -> MqttChannel:
"""Create a V1Channel for the given device."""
"""Create a MQTT channel for the given device."""
return MqttChannel(mqtt_session, device.duid, device.local_key, user_data.rriot, mqtt_params)
2 changes: 1 addition & 1 deletion roborock/devices/traits/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from roborock.data.v1.v1_code_mappings import RoborockDockTypeCode
from roborock.devices.cache import Cache
from roborock.devices.traits import Trait
from roborock.devices.v1_rpc_channel import V1RpcChannel
from roborock.map.map_parser import MapParserConfig
from roborock.protocols.v1_protocol import V1RpcChannel
from roborock.web_api import UserWebApiClient

from .child_lock import ChildLockTrait
Expand Down
2 changes: 1 addition & 1 deletion roborock/devices/traits/v1/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import ClassVar, Self

from roborock.data import RoborockBase
from roborock.devices.v1_rpc_channel import V1RpcChannel
from roborock.protocols.v1_protocol import V1RpcChannel
from roborock.roborock_typing import RoborockCommand

_LOGGER = logging.getLogger(__name__)
Expand Down
209 changes: 186 additions & 23 deletions roborock/devices/v1_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,43 @@
import datetime
import logging
from collections.abc import Callable
from typing import TypeVar
from dataclasses import dataclass
from typing import Any, TypeVar, override

from roborock.data import HomeDataDevice, NetworkInfo, RoborockBase, UserData
from roborock.exceptions import RoborockException
from roborock.mqtt.health_manager import HealthManager
from roborock.mqtt.session import MqttParams, MqttSession
from roborock.protocols.v1_protocol import (
CommandType,
MapResponse,
ParamsType,
RequestMessage,
ResponseData,
ResponseMessage,
SecurityData,
V1RpcChannel,
create_map_response_decoder,
create_security_data,
decode_rpc_response,
)
from roborock.roborock_message import RoborockMessage
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from roborock.roborock_typing import RoborockCommand

from .cache import Cache
from .channel import Channel
from .local_channel import LocalChannel, LocalSession, create_local_session
from .mqtt_channel import MqttChannel
from .v1_rpc_channel import (
PickFirstAvailable,
V1RpcChannel,
create_local_rpc_channel,
create_map_rpc_channel,
create_mqtt_rpc_channel,
)

_LOGGER = logging.getLogger(__name__)

__all__ = [
"V1Channel",
"create_v1_channel",
]

_T = TypeVar("_T", bound=RoborockBase)
_TIMEOUT = 10.0


# Exponential backoff parameters for reconnecting to local
MIN_RECONNECT_INTERVAL = datetime.timedelta(minutes=1)
Expand All @@ -50,6 +56,117 @@
LOCAL_CONNECTION_CHECK_INTERVAL = datetime.timedelta(seconds=15)


@dataclass(frozen=True)
class RpcStrategy:
"""Strategy for sending RPC commands over a specific channel.

This holds the configuration for a specific transport method that differ
Comment thread
allenporter marked this conversation as resolved.
Outdated
in how messages are encoded/decoded and which channel is used.
"""

name: str # For debug logging
channel: LocalChannel | MqttChannel
encoder: Callable[[RequestMessage], RoborockMessage]
decoder: Callable[[RoborockMessage], ResponseMessage | MapResponse | None]
health_manager: HealthManager | None = None


class RpcChannel(V1RpcChannel):
"""Wrapper to expose V1RpcChannel interface with a specific set of RpcStrategies.

This is used to provide a simpler interface to v1 traits for sending commands
over multiple possible transports (local, MQTT) with automatic fallback.
"""

def __init__(self, rpc_strategies: list[RpcStrategy]) -> None:
self._rpc_strategies = rpc_strategies

@override
async def send_command(
self,
method: CommandType,
*,
response_type: type[_T] | None = None,
params: ParamsType = None,
) -> _T | Any:
"""Send a command and return either a decoded or parsed response."""
request = RequestMessage(method, params=params)

# Try each channel in order until one succeeds
last_exception = None
for strategy in self._rpc_strategies:
try:
decoded_response = await self._send_rpc(strategy, request)
except RoborockException as e:
_LOGGER.warning("Command %s failed on %s channel: %s", method, strategy.name, e)
last_exception = e
except Exception as e:
_LOGGER.exception("Unexpected error sending command %s on %s channel", method, strategy.name)
last_exception = RoborockException(f"Unexpected error: {e}")
else:
if response_type is not None:
if not isinstance(decoded_response, dict):
raise RoborockException(
f"Expected dict response to parse {response_type.__name__}, got {type(decoded_response)}"
)
return response_type.from_dict(decoded_response)
return decoded_response

raise last_exception or RoborockException("No available connection to send command")

@staticmethod
async def _send_rpc(strategy: RpcStrategy, request: RequestMessage) -> ResponseData | bytes:
"""Send a command and return a parsed response RoborockBase type.

This provides an RPC interface over a given channel strategy. The device
channel only supports publish and subscribe, so this function handles
associating requests with their corresponding responses.

The provided RpcStrategy defines how to encode/decode messages and which
channel to use for communication.
"""
future: asyncio.Future[ResponseData | bytes] = asyncio.Future()
Comment thread
allenporter marked this conversation as resolved.
_LOGGER.debug(
"Sending command (%s, request_id=%s): %s, params=%s",
strategy.name,
request.request_id,
request.method,
request.params,
)

message = strategy.encoder(request)

def find_response(response_message: RoborockMessage) -> None:
try:
decoded = strategy.decoder(response_message)
except RoborockException as ex:
_LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex)
return
if decoded is None:
return
_LOGGER.debug("Received response (%s, request_id=%s)", strategy.name, decoded.request_id)
if decoded.request_id == request.request_id:
if isinstance(decoded, ResponseMessage) and decoded.api_error:
future.set_exception(decoded.api_error)
else:
future.set_result(decoded.data)

unsub = await strategy.channel.subscribe(find_response)
try:
await strategy.channel.publish(message)
result = await asyncio.wait_for(future, timeout=_TIMEOUT)
except TimeoutError as ex:
if strategy.health_manager:
await strategy.health_manager.on_timeout()
future.cancel()
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
finally:
unsub()
if strategy.health_manager:
await strategy.health_manager.on_success()
return result


class V1Channel(Channel):
"""Unified V1 protocol channel with automatic MQTT/local connection handling.

Expand All @@ -69,20 +186,17 @@ def __init__(
"""Initialize the V1Channel.

Args:
device_uid: Unique device identifier (DUID).
mqtt_channel: MQTT channel for cloud communication
local_session: Factory that creates LocalChannels for a hostname.
cache: Cache for storing network information.
"""
self._device_uid = device_uid
self._security_data = security_data
self._mqtt_channel = mqtt_channel
self._mqtt_rpc_channel = create_mqtt_rpc_channel(mqtt_channel, security_data)
self._mqtt_health_manager = HealthManager(self._mqtt_channel.restart)
self._local_session = local_session
self._local_channel: LocalChannel | None = None
self._local_rpc_channel: V1RpcChannel | None = None
# Prefer local, fallback to MQTT
self._combined_rpc_channel = PickFirstAvailable(
[lambda: self._local_rpc_channel, lambda: self._mqtt_rpc_channel]
)
self._map_rpc_channel = create_map_rpc_channel(mqtt_channel, security_data)
self._mqtt_unsub: Callable[[], None] | None = None
self._local_unsub: Callable[[], None] | None = None
self._callback: Callable[[RoborockMessage], None] | None = None
Comment thread
allenporter marked this conversation as resolved.
Expand All @@ -108,17 +222,67 @@ def is_mqtt_connected(self) -> bool:
@property
def rpc_channel(self) -> V1RpcChannel:
"""Return the combined RPC channel prefers local with a fallback to MQTT."""
return self._combined_rpc_channel
strategies = []
if local_rpc_strategy := self._create_local_rpc_strategy():
strategies.append(local_rpc_strategy)
strategies.append(self._create_mqtt_rpc_strategy())
return RpcChannel(strategies)

@property
def mqtt_rpc_channel(self) -> V1RpcChannel:
"""Return the MQTT RPC channel."""
return self._mqtt_rpc_channel
"""Return the MQTT-only RPC channel."""
Comment thread
allenporter marked this conversation as resolved.
return RpcChannel([self._create_mqtt_rpc_strategy()])

@property
def map_rpc_channel(self) -> V1RpcChannel:
"""Return the map RPC channel used for fetching map content."""
return self._map_rpc_channel
decoder = create_map_response_decoder(security_data=self._security_data)
return RpcChannel([self._create_mqtt_rpc_strategy(decoder)])

def _create_local_rpc_strategy(self) -> RpcStrategy | None:
"""Create the RPC strategy for local transport."""
if self._local_channel is None or not self.is_local_connected:
return None
return RpcStrategy(
name="local",
channel=self._local_channel,
encoder=self._local_encoder,
decoder=decode_rpc_response,
)

def _local_encoder(self, x: RequestMessage) -> RoborockMessage:
"""Encode a request message for local transport.

This is passed to the RpcStrategy as a function so that it will
read the current local channel's protocol version which changes as
the protocol version is discovered.
"""
if self._local_channel is None:
# This is for typing and should not happen since we only create the
# strategy if local is connected and it will never get set back to
# None once connected.
raise ValueError("Local channel is not available for encoding")
return x.encode_message(
RoborockMessageProtocol.GENERAL_REQUEST,
version=self._local_channel.protocol_version,
)

def _create_mqtt_rpc_strategy(self, decoder: Callable[[RoborockMessage], Any] = decode_rpc_response) -> RpcStrategy:
"""Create the RPC strategy for MQTT transport.

This can optionally take a custom decoder for different response types
such as map data.
"""
return RpcStrategy(
name="mqtt",
channel=self._mqtt_channel,
encoder=lambda x: x.encode_message(
RoborockMessageProtocol.RPC_REQUEST,
security_data=self._security_data,
),
decoder=decoder,
health_manager=self._mqtt_health_manager,
)

async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
"""Subscribe to all messages from the device.
Expand Down Expand Up @@ -185,7 +349,7 @@ async def _get_networking_info(self, *, prefer_cache: bool = True) -> NetworkInf
_LOGGER.debug("Using cached network info for device %s", self._device_uid)
return network_info
try:
network_info = await self._mqtt_rpc_channel.send_command(
network_info = await self.mqtt_rpc_channel.send_command(
RoborockCommand.GET_NETWORK_INFO, response_type=NetworkInfo
)
except RoborockException as e:
Expand Down Expand Up @@ -216,7 +380,6 @@ async def _local_connect(self, *, prefer_cache: bool = True) -> None:
raise RoborockException(f"Error connecting to local device {self._device_uid}: {e}") from e
# Wire up the new channel
self._local_channel = local_channel
self._local_rpc_channel = create_local_rpc_channel(self._local_channel)
self._local_unsub = await self._local_channel.subscribe(self._on_local_message)
_LOGGER.info("Successfully connected to local device %s", self._device_uid)

Expand Down
Loading
Loading