Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions roborock/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
until the API is stable.
"""

import asyncio
import datetime
import logging
from abc import ABC
from collections.abc import Callable, Mapping
from typing import Any, TypeVar, cast

from roborock.data import HomeDataDevice, HomeDataProduct
from roborock.exceptions import RoborockException
from roborock.roborock_message import RoborockMessage

from .channel import Channel
Expand All @@ -22,6 +25,11 @@
"RoborockDevice",
]

# Exponential backoff parameters
MIN_BACKOFF_INTERVAL = datetime.timedelta(seconds=10)
MAX_BACKOFF_INTERVAL = datetime.timedelta(minutes=30)
BACKOFF_MULTIPLIER = 1.5


class RoborockDevice(ABC, TraitsMixin):
"""A generic channel for establishing a connection with a Roborock device.
Expand Down Expand Up @@ -54,6 +62,7 @@ def __init__(
self._device_info = device_info
self._product = product
self._channel = channel
self._connect_task: asyncio.Task[None] | None = None
self._unsub: Callable[[], None] | None = None

@property
Expand Down Expand Up @@ -98,6 +107,31 @@ def is_local_connected(self) -> bool:
"""
return self._channel.is_local_connected

def start_connect(self) -> None:
"""Start a background task to connect to the device.

This will attempt to connect to the device using the appropriate protocol
channel. If the connection fails, it will retry with exponential backoff.

Once connected, the device will remain connected until `close()` is
called. The device will automatically attempt to reconnect if the connection
is lost.
"""

async def connect_loop() -> None:
backoff = MIN_BACKOFF_INTERVAL
while True:
try:
await self.connect()
return
except RoborockException as e:
_LOGGER.info("Failed to connect to device %s: %s", self.name, e)
_LOGGER.info("Retrying connection to device %s in %s seconds", self.name, backoff.total_seconds())
await asyncio.sleep(backoff.total_seconds())
backoff = min(backoff * BACKOFF_MULTIPLIER, MAX_BACKOFF_INTERVAL)
Comment thread
allenporter marked this conversation as resolved.
Outdated

self._connect_task = asyncio.create_task(connect_loop())

async def connect(self) -> None:
"""Connect to the device using the appropriate protocol channel."""
if self._unsub:
Expand All @@ -107,6 +141,8 @@ async def connect(self) -> None:

async def close(self) -> None:
"""Close all connections to the device."""
if self._connect_task:
self._connect_task.cancel()
Comment thread
allenporter marked this conversation as resolved.
if self._unsub:
self._unsub()
self._unsub = None
Expand Down
2 changes: 1 addition & 1 deletion roborock/devices/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def discover_devices(self) -> list[RoborockDevice]:
if duid in self._devices:
continue
new_device = self._device_creator(home_data, device, product)
await new_device.connect()
new_device.start_connect()
new_devices[duid] = new_device

self._devices.update(new_devices)
Expand Down
61 changes: 61 additions & 0 deletions tests/devices/test_device_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for the DeviceManager class."""

import datetime
import asyncio
from collections.abc import Generator, Iterator
from unittest.mock import AsyncMock, Mock, patch

Expand Down Expand Up @@ -34,6 +36,25 @@ def channel_fixture() -> Generator[Mock, None, None]:
yield mock_channel


@pytest.fixture(autouse=True)
def mock_sleep() -> Generator[Mock, None, None]:
"""Mock asyncio.sleep in device module to speed up tests."""
sleep_time = datetime.timedelta(seconds=0.001)
with patch("roborock.devices.device.MIN_BACKOFF_INTERVAL", sleep_time), patch("roborock.devices.device.MAX_BACKOFF_INTERVAL", sleep_time):
Comment thread
allenporter marked this conversation as resolved.
Outdated
yield


@pytest.fixture(name="channel_failure")
def channel_failure_fixture() -> Generator[Mock, None, None]:
"""Fixture that makes channel subscribe fail."""
with patch("roborock.devices.device_manager.create_v1_channel") as mock_channel:
mock_channel.return_value.subscribe = AsyncMock(
side_effect=RoborockException("Connection failed")
)
mock_channel.return_value.is_connected = False
yield mock_channel


@pytest.fixture(name="home_data_no_devices")
def home_data_no_devices_fixture() -> Iterator[HomeData]:
"""Mock home data API that returns no devices."""
Expand Down Expand Up @@ -127,3 +148,43 @@ async def mock_home_data_with_counter(*args, **kwargs) -> HomeData:
assert len(devices2) == 1

await device_manager.close()


Comment thread
allenporter marked this conversation as resolved.
async def test_start_connect_failure(home_data: HomeData, channel_failure: Mock, mock_sleep: Mock) -> None:
"""Test that start_connect retries when connection fails."""
device_manager = await create_device_manager(USER_PARAMS)
devices = await device_manager.get_devices()

# Wait for the device to attempt to connect
attempts = 0
subscribe_mock = channel_failure.return_value.subscribe
while subscribe_mock.call_count < 1:
await asyncio.sleep(0.01)
attempts += 1
assert attempts < 10, "Device did not connect after multiple attempts"

# Device should exist but not be connected
assert len(devices) == 1
assert not devices[0].is_connected

# Verify retry attempts
assert channel_failure.return_value.subscribe.call_count >= 1

# Reset the mock channel so that it succeeds on the next attempt
mock_unsub = Mock()
subscribe_mock = AsyncMock()
subscribe_mock.return_value = mock_unsub
channel_failure.return_value.subscribe = subscribe_mock
channel_failure.return_value.is_connected = True

# Wait for the device to attempt to connect again
attempts = 0
while subscribe_mock.call_count < 1:
await asyncio.sleep(0.01)
attempts += 1
assert attempts < 10, "Device did not connect after multiple attempts"

assert devices[0].is_connected

await device_manager.close()
assert mock_unsub.call_count == 1
Loading