Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion roborock/data/v1/v1_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ class AppInitStatus(RoborockBase):
new_feature_info_str: str
new_feature_info_2: int | None = None
carriage_type: int | None = None
dsp_version: int | None = None
Comment thread
Lash-L marked this conversation as resolved.
dsp_version: str | None = None


@dataclass
Expand Down
122 changes: 116 additions & 6 deletions roborock/devices/local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,25 @@

from roborock.callbacks import CallbackList, decoder_callback
from roborock.exceptions import RoborockConnectionException, RoborockException
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
from roborock.roborock_message import RoborockMessage
from roborock.protocol import create_local_decoder, create_local_encoder
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol

from ..protocols.v1_protocol import LocalProtocolVersion
from ..util import get_next_int
from .channel import Channel

_LOGGER = logging.getLogger(__name__)
_PORT = 58867
_TIMEOUT = 10.0
Comment thread
Lash-L marked this conversation as resolved.
Outdated


@dataclass
class LocalChannelParams:
"""Parameters for local channel encoder/decoder."""

local_key: str
connect_nonce: int
ack_nonce: int | None


@dataclass
Expand Down Expand Up @@ -45,12 +57,79 @@ def __init__(self, host: str, local_key: str):
self._protocol: _LocalProtocol | None = None
self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER)
self._is_connected = False

self._decoder: Decoder = create_local_decoder(local_key)
self._encoder: Encoder = create_local_encoder(local_key)
self._local_key = local_key
Comment thread
Lash-L marked this conversation as resolved.
Outdated
self._local_protocol_version: LocalProtocolVersion | None = None
self._connect_nonce = get_next_int(10000, 32767)
self._ack_nonce: int | None = None
self._update_encoder_decoder()

def _update_encoder_decoder(self, params: LocalChannelParams | None = None):
if params is None:
params = LocalChannelParams(
local_key=self._local_key, connect_nonce=self._connect_nonce, ack_nonce=self._ack_nonce
)
self._encoder = create_local_encoder(
Comment thread
Lash-L marked this conversation as resolved.
local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce
)
self._decoder = create_local_decoder(
local_key=params.local_key, connect_nonce=params.connect_nonce, ack_nonce=params.ack_nonce
)
# Callback to decode messages and dispatch to subscribers
self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER)

async def _do_hello(self, local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None:
"""Perform the initial handshaking and return encoder params if successful."""
_LOGGER.debug(
"Attempting to use the %s protocol for client %s...",
local_protocol_version,
self._host,
)
request = RoborockMessage(
protocol=RoborockMessageProtocol.HELLO_REQUEST,
version=local_protocol_version.encode(),
random=self._connect_nonce,
seq=1,
)
try:
response = await self.send_message(
roborock_message=request,
request_id=request.seq,
response_protocol=RoborockMessageProtocol.HELLO_RESPONSE,
)
_LOGGER.debug(
"Client %s speaks the %s protocol.",
self._host,
local_protocol_version,
)
return LocalChannelParams(
local_key=self._local_key, connect_nonce=self._connect_nonce, ack_nonce=response.random
)
except RoborockException as e:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a follow up: Is there a very specific specific error code or raised when the protocol is not supported? Would be nice if we could narrow down to a specific exception as right now the transport code is simply catching except Exception as err which is probably overly broad.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No response from the vac when protocol is not supported. So it just gives a timeout

_LOGGER.debug(
"Client %s did not respond or does not speak the %s protocol. %s",
self._host,
local_protocol_version,
e,
)
return None

async def _hello(self):
"""Send hello to the device to negotiate protocol."""
attempt_versions = [LocalProtocolVersion.V1, LocalProtocolVersion.L01]
if self._local_protocol_version:
# Sort to try the preferred version first
attempt_versions.sort(key=lambda v: v != self._local_protocol_version)

for version in attempt_versions:
params = await self._do_hello(version)
if params is not None:
self._ack_nonce = params.ack_nonce
self._local_protocol_version = version
self._update_encoder_decoder(params)
return

raise RoborockException("Failed to connect to device with any known protocol")

@property
def is_connected(self) -> bool:
"""Check if the channel is currently connected."""
Expand All @@ -62,7 +141,7 @@ def is_local_connected(self) -> bool:
return self._is_connected

async def connect(self) -> None:
"""Connect to the device."""
"""Connect to the device and negotiate protocol."""
if self._is_connected:
_LOGGER.warning("Already connected")
return
Expand All @@ -75,6 +154,14 @@ async def connect(self) -> None:
except OSError as e:
raise RoborockConnectionException(f"Failed to connect to {self._host}:{_PORT}") from e

# Perform protocol negotiation
try:
await self._hello()
except RoborockException:
# If protocol negotiation fails, clean up the connection state
self.close()
raise

def close(self) -> None:
"""Disconnect from the device."""
if self._transport:
Expand Down Expand Up @@ -113,6 +200,29 @@ async def publish(self, message: RoborockMessage) -> None:
logging.exception("Uncaught error sending command")
raise RoborockException(f"Failed to send message: {message}") from err

async def send_message(
Comment thread
Lash-L marked this conversation as resolved.
Outdated
self,
roborock_message: RoborockMessage,
request_id: int,
response_protocol: int,
) -> RoborockMessage:
"""Send a raw message and wait for a raw response."""
future: asyncio.Future[RoborockMessage] = asyncio.Future()

def find_response(response_message: RoborockMessage) -> None:
if response_message.protocol == response_protocol and response_message.seq == request_id:
future.set_result(response_message)

unsub = await self.subscribe(find_response)
try:
await self.publish(roborock_message)
return await asyncio.wait_for(future, timeout=_TIMEOUT)
except TimeoutError as ex:
future.cancel()
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
finally:
unsub()


# This module provides a factory function to create LocalChannel instances.
#
Expand Down
1 change: 1 addition & 0 deletions roborock/mqtt/roborock_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ async def _run_task(self, start_future: asyncio.Future[None] | None) -> None:
# Reset backoff once we've successfully connected
self._backoff = MIN_BACKOFF_INTERVAL
self._healthy = True
_LOGGER.info("MQTT Session connected.")
Comment thread
Lash-L marked this conversation as resolved.
if start_future:
start_future.set_result(None)
start_future = None
Expand Down
8 changes: 8 additions & 0 deletions roborock/protocols/v1_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any

from roborock.data import RRiot
Expand All @@ -32,6 +33,13 @@
ParamsType = list | dict | int | None


class LocalProtocolVersion(StrEnum):
"""Supported local protocol versions. Different from vacuum protocol versions."""

L01 = "L01"
V1 = "1.0"


@dataclass(frozen=True, kw_only=True)
class SecurityData:
"""Security data included in the request for some V1 commands."""
Expand Down
10 changes: 1 addition & 9 deletions roborock/version_1_apis/roborock_local_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,19 @@
from asyncio import Lock, TimerHandle, Transport, get_running_loop
from collections.abc import Callable
from dataclasses import dataclass
from enum import StrEnum

from .. import CommandVacuumError, DeviceData, RoborockCommand
from ..api import RoborockClient
from ..exceptions import RoborockConnectionException, RoborockException, VacuumError
from ..protocol import create_local_decoder, create_local_encoder
from ..protocols.v1_protocol import RequestMessage
from ..protocols.v1_protocol import LocalProtocolVersion, RequestMessage
from ..roborock_message import RoborockMessage, RoborockMessageProtocol
from ..util import RoborockLoggerAdapter, get_next_int
from .roborock_client_v1 import CLOUD_REQUIRED, RoborockClientV1

_LOGGER = logging.getLogger(__name__)


class LocalProtocolVersion(StrEnum):
"""Supported local protocol versions. Different from vacuum protocol versions."""

L01 = "L01"
V1 = "1.0"


@dataclass
class _LocalProtocol(asyncio.Protocol):
"""Callbacks for the Roborock local client transport."""
Expand Down
48 changes: 44 additions & 4 deletions tests/devices/test_local_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import pytest

from roborock.devices.local_channel import LocalChannel
from roborock.devices.local_channel import LocalChannel, LocalChannelParams
from roborock.exceptions import RoborockConnectionException
from roborock.protocol import create_local_decoder, create_local_encoder
from roborock.protocols.v1_protocol import LocalProtocolVersion
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol

TEST_HOST = "192.168.1.100"
Expand Down Expand Up @@ -56,9 +57,18 @@ def setup_mock_loop(mock_transport: Mock) -> Generator[Mock, None, None]:


@pytest.fixture(name="local_channel")
def setup_local_channel() -> LocalChannel:
"""Fixture to set up the local channel for tests."""
return LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY)
async def setup_local_channel_with_hello_mock() -> LocalChannel:
"""Fixture to set up the local channel with automatic hello mocking."""
channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY)

async def mock_do_hello(local_protocol_version):
Comment thread
Lash-L marked this conversation as resolved.
Outdated
"""Mock _do_hello to return successful params without sending actual request."""
return LocalChannelParams(local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=54321)

# Replace the _do_hello method
setattr(channel, "_do_hello", mock_do_hello)

return channel


@pytest.fixture(name="received_messages")
Expand Down Expand Up @@ -231,3 +241,33 @@ async def test_connection_lost_without_exception(
assert local_channel._is_connected is False
assert local_channel._transport is None
assert "Connection lost to 192.168.1.100" in caplog.text


async def test_hello_fallback_to_l01_protocol(mock_loop: Mock, mock_transport: Mock) -> None:
"""Test that when first hello() message fails (V1) but second succeeds (L01), we use L01."""

# Create a channel without the automatic hello mocking
channel = LocalChannel(host=TEST_HOST, local_key=TEST_LOCAL_KEY)

# Mock _do_hello to fail for V1 but succeed for L01
async def mock_do_hello(local_protocol_version: LocalProtocolVersion) -> LocalChannelParams | None:
if local_protocol_version == LocalProtocolVersion.V1:
# First attempt (V1) fails - return None to simulate failure
return None
elif local_protocol_version == LocalProtocolVersion.L01:
# Second attempt (L01) succeeds
return LocalChannelParams(
local_key=channel._local_key, connect_nonce=channel._connect_nonce, ack_nonce=54321
)
return None

# Replace the _do_hello method
setattr(channel, "_do_hello", mock_do_hello)

# Connect and verify L01 protocol is used
await channel.connect()

# Verify that the channel is using L01 protocol
assert channel._local_protocol_version == LocalProtocolVersion.L01
assert channel._ack_nonce == 54321
assert channel._is_connected is True
Loading