Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion roborock/devices/traits/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
self.device_features = DeviceFeaturesTrait(product, self._device_cache)
self.status = StatusTrait(self.device_features, region=self._region)
self.consumables = ConsumableTrait()
self.rooms = RoomsTrait(home_data)
self.rooms = RoomsTrait(home_data, web_api)
self.maps = MapsTrait(self.status)
self.map_content = MapContentTrait(map_parser_config)
self.home = HomeTrait(self.status, self.maps, self.map_content, self.rooms, self._device_cache)
Expand Down
18 changes: 14 additions & 4 deletions roborock/devices/traits/v1/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
MAP_SLEEP = 3


def _is_default_room_name(name: str, segment_id: int) -> bool:
return name in ("Unknown", f"Room {segment_id}")


class HomeTrait(RoborockBase, common.V1TraitMixin):
"""Trait that represents a full view of the home layout."""

Expand Down Expand Up @@ -129,13 +133,19 @@ async def _refresh_map_info(self, map_info) -> CombinedMapInfo:
name=room.iot_name or "Unknown",
)

# Add rooms from rooms_trait. If room already exists and rooms_trait has "Unknown", don't override.
# Add rooms from rooms_trait.
# Prefer existing non-default map_info names over fallback names from RoomsTrait.
if self._rooms_trait.rooms:
for room in self._rooms_trait.rooms:
if room.segment_id is not None and room.name:
if room.segment_id not in rooms or room.name != "Unknown":
# Add the room to rooms if the room segment is not already in it
# or if the room name isn't unknown.
existing_room = rooms.get(room.segment_id)
if existing_room is None:
rooms[room.segment_id] = room
continue

if _is_default_room_name(existing_room.name, existing_room.segment_id) or not _is_default_room_name(
room.name, room.segment_id
):
rooms[room.segment_id] = room

return CombinedMapInfo(
Expand Down
60 changes: 46 additions & 14 deletions roborock/devices/traits/v1/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from roborock.data import HomeData, NamedRoomMapping, RoborockBase
from roborock.devices.traits.v1 import common
from roborock.roborock_typing import RoborockCommand
from roborock.web_api import UserWebApiClient

_LOGGER = logging.getLogger(__name__)

_DEFAULT_NAME = "Unknown"


@dataclass
class Rooms(RoborockBase):
Expand All @@ -32,50 +31,83 @@ class RoomsTrait(Rooms, common.V1TraitMixin):

command = RoborockCommand.GET_ROOM_MAPPING

def __init__(self, home_data: HomeData) -> None:
def __init__(self, home_data: HomeData, web_api: UserWebApiClient) -> None:
"""Initialize the RoomsTrait."""
super().__init__()
self._home_data = home_data
self._web_api = web_api
self._seen_unknown_room_iot_ids: set[str] = set()

async def refresh(self) -> None:
"""Refresh room mappings and backfill unknown room names from the web API."""
response = await self.rpc_channel.send_command(self.command)
if not isinstance(response, list):
raise ValueError(f"Unexpected RoomsTrait response format: {response!r}")

segment_map = _extract_segment_map(response)
await self._populate_missing_home_data_rooms(segment_map)

new_data = self._parse_response(response, segment_map)
self._update_trait_values(new_data)
_LOGGER.debug("Refreshed %s: %s", self.__class__.__name__, new_data)

@property
def _iot_id_room_name_map(self) -> dict[str, str]:
"""Returns a dictionary of Room IOT IDs to room names."""
return {str(room.id): room.name for room in self._home_data.rooms or ()}

def _parse_response(self, response: common.V1ResponseData) -> Rooms:
def _parse_response(self, response: common.V1ResponseData, segment_map: dict[int, str] | None = None) -> Rooms:
"""Parse the response from the device into a list of NamedRoomMapping."""
if not isinstance(response, list):
raise ValueError(f"Unexpected RoomsTrait response format: {response!r}")
if segment_map is None:
segment_map = _extract_segment_map(response)
name_map = self._iot_id_room_name_map
segment_pairs = _extract_segment_pairs(response)
return Rooms(
rooms=[
NamedRoomMapping(segment_id=segment_id, iot_id=iot_id, name=name_map.get(iot_id, _DEFAULT_NAME))
for segment_id, iot_id in segment_pairs
NamedRoomMapping(segment_id=segment_id, iot_id=iot_id, name=name_map.get(iot_id, f"Room {segment_id}"))
for segment_id, iot_id in segment_map.items()
]
)

async def _populate_missing_home_data_rooms(self, segment_map: dict[int, str]) -> None:
"""Load missing room names into home data for newly-seen unknown room ids."""
missing_room_iot_ids = set(segment_map.values()) - set(self._iot_id_room_name_map.keys())
new_missing_room_iot_ids = missing_room_iot_ids - self._seen_unknown_room_iot_ids
if not new_missing_room_iot_ids:
return

try:
web_rooms = await self._web_api.get_rooms()
except Exception:
_LOGGER.debug("Failed to fetch rooms from web API", exc_info=True)
else:
if isinstance(web_rooms, list) and web_rooms:
self._home_data.rooms = web_rooms

self._seen_unknown_room_iot_ids.update(missing_room_iot_ids)


def _extract_segment_pairs(response: list) -> list[tuple[int, str]]:
"""Extract segment_id and iot_id pairs from the response.
def _extract_segment_map(response: list) -> dict[int, str]:
"""Extract a segment_id -> iot_id mapping from the response.

The response format can be either a flat list of [segment_id, iot_id] or a
list of lists, where each inner list is a pair of [segment_id, iot_id]. This
function normalizes the response into a list of (segment_id, iot_id) tuples
function normalizes the response into a dict of segment_id to iot_id.

NOTE: We currently only partial samples of the room mapping formats, so
improving test coverage with samples from a real device with this format
would be helpful.
"""
if len(response) == 2 and not isinstance(response[0], list):
segment_id, iot_id = response[0], response[1]
return [(segment_id, iot_id)]
return {segment_id: str(iot_id)}

segment_pairs: list[tuple[int, str]] = []
segment_map: dict[int, str] = {}
for part in response:
if not isinstance(part, list) or len(part) < 2:
_LOGGER.warning("Unexpected room mapping entry format: %r", part)
continue
segment_id, iot_id = part[0], part[1]
segment_pairs.append((segment_id, iot_id))
return segment_pairs
segment_map[segment_id] = str(iot_id)
return segment_map
8 changes: 6 additions & 2 deletions roborock/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,10 @@ async def get_rooms(self, user_data: UserData, home_id: int | None = None) -> li
rriot.r.a,
self.session,
{
"Authorization": _get_hawk_authentication(rriot, "/v2/user/homes/" + str(home_id)),
"Authorization": _get_hawk_authentication(rriot, f"/user/homes/{home_id}/rooms"),
},
)
room_response = await room_request.request("get", f"/user/homes/{str(home_id)}/rooms" + str(home_id))
room_response = await room_request.request("get", f"/user/homes/{home_id}/rooms")
if not room_response.get("success"):
raise RoborockException(room_response)
rooms = room_response.get("result")
Expand Down Expand Up @@ -752,6 +752,10 @@ async def get_routines(self, device_id: str) -> list[HomeDataScene]:
"""Fetch routines (scenes) for a specific device."""
return await self._web_api.get_scenes(self._user_data, device_id)

async def get_rooms(self) -> list[HomeDataRoom]:
"""Fetch rooms using the API client."""
return await self._web_api.get_rooms(self._user_data)

async def execute_routine(self, scene_id: int) -> None:
"""Execute a specific routine (scene) by its ID."""
await self._web_api.execute_scene(self._user_data, scene_id)
3 changes: 2 additions & 1 deletion tests/devices/traits/v1/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Fixtures for V1 trait tests."""

from copy import deepcopy
from unittest.mock import AsyncMock

import pytest
Expand Down Expand Up @@ -89,7 +90,7 @@ def device_fixture(
trait=v1.create(
device_info.duid,
product,
HOME_DATA,
deepcopy(HOME_DATA),
mock_rpc_channel,
mock_mqtt_rpc_channel,
mock_map_rpc_channel,
Expand Down
Loading