Skip to content

Commit 8303a8c

Browse files
committed
🦎 q7: align B01 map helpers with maintainer feedback
1 parent 4887362 commit 8303a8c

File tree

4 files changed

+150
-156
lines changed

4 files changed

+150
-156
lines changed

roborock/devices/rpc/b01_q7_channel.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,12 @@ async def _send_command(
3434
request_message: Q7RequestMessage,
3535
*,
3636
response_matcher: Callable[[RoborockMessage], _T | None],
37-
timeout_error: str,
3837
) -> _T:
3938
"""Publish a B01 command and resolve on the first matching response."""
4039
roborock_message = encode_mqtt_payload(request_message)
4140
future: asyncio.Future[_T] = asyncio.get_running_loop().create_future()
4241

43-
def find_response(response_message: RoborockMessage) -> None:
42+
def on_message(response_message: RoborockMessage) -> None:
4443
if future.done():
4544
return
4645
try:
@@ -51,24 +50,22 @@ def find_response(response_message: RoborockMessage) -> None:
5150
if response is not None:
5251
future.set_result(response)
5352

54-
unsub = await mqtt_channel.subscribe(find_response)
53+
unsub = await mqtt_channel.subscribe(on_message)
5554
try:
5655
await mqtt_channel.publish(roborock_message)
5756
return await asyncio.wait_for(future, timeout=_TIMEOUT)
58-
except TimeoutError as ex:
59-
raise RoborockException(timeout_error) from ex
6057
finally:
6158
unsub()
6259

6360

6461
async def send_decoded_command(
6562
mqtt_channel: MqttChannel,
6663
request_message: Q7RequestMessage,
67-
) -> dict[str, Any] | None:
64+
) -> dict[str, Any]:
6865
"""Send a command on the MQTT channel and get a decoded response."""
6966
_LOGGER.debug("Sending B01 MQTT command: %s", request_message)
7067

71-
def find_response(response_message: RoborockMessage) -> Any | None:
68+
def find_response(response_message: RoborockMessage) -> dict[str, Any] | None:
7269
"""Handle incoming messages and resolve the future."""
7370
try:
7471
decoded_dps = decode_rpc_response(response_message)
@@ -94,15 +91,13 @@ def find_response(response_message: RoborockMessage) -> Any | None:
9491
continue
9592
if isinstance(inner, dict) and inner.get("msgId") == str(request_message.msg_id):
9693
_LOGGER.debug("Received query response: %s", inner)
97-
# Check for error code (0 = success, non-zero = error)
9894
code = inner.get("code", 0)
9995
if code != 0:
10096
error_msg = f"B01 command failed with code {code} ({request_message})"
10197
_LOGGER.debug("B01 error response: %s", error_msg)
10298
raise RoborockException(error_msg)
10399
data = inner.get("data")
104-
# All get commands should be dicts
105-
if request_message.command.endswith(".get") and not isinstance(data, dict):
100+
if not isinstance(data, dict):
106101
raise RoborockException(f"Unexpected data type for response {data} ({request_message})")
107102
return data
108103
return None
@@ -112,16 +107,16 @@ def find_response(response_message: RoborockMessage) -> Any | None:
112107
mqtt_channel,
113108
request_message,
114109
response_matcher=find_response,
115-
timeout_error=f"B01 command timed out after {_TIMEOUT}s ({request_message})",
116110
)
111+
except TimeoutError as ex:
112+
raise RoborockException(f"B01 command timed out after {_TIMEOUT}s ({request_message})") from ex
117113
except RoborockException as ex:
118114
_LOGGER.warning(
119115
"Error sending B01 decoded command (%ss): %s",
120116
request_message,
121117
ex,
122118
)
123119
raise
124-
125120
except Exception as ex:
126121
_LOGGER.exception(
127122
"Error sending B01 decoded command (%ss): %s",
@@ -138,9 +133,11 @@ async def send_map_command(mqtt_channel: MqttChannel, request_message: Q7Request
138133
raw ``MAP_RESPONSE`` payload bytes instead of a decoded RPC ``data`` payload.
139134
"""
140135

141-
return await _send_command(
142-
mqtt_channel,
143-
request_message,
144-
response_matcher=lambda response_message: _matches_map_response(response_message, version=B01_VERSION),
145-
timeout_error=f"B01 map command timed out after {_TIMEOUT}s ({request_message})",
146-
)
136+
try:
137+
return await _send_command(
138+
mqtt_channel,
139+
request_message,
140+
response_matcher=lambda response_message: _matches_map_response(response_message, version=B01_VERSION),
141+
)
142+
except TimeoutError as ex:
143+
raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex

roborock/devices/traits/b01/q7/map.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,28 @@ class Q7MapList(RoborockBase):
2626

2727
map_list: list[Q7MapListEntry] = field(default_factory=list)
2828

29+
@property
30+
def current_map_id(self) -> int | None:
31+
"""Current map id, preferring the entry marked current."""
32+
if not self.map_list:
33+
return None
34+
35+
ordered = sorted(self.map_list, key=lambda entry: entry.cur or False, reverse=True)
36+
first = next(iter(ordered), None)
37+
if first is None or not isinstance(first.id, int):
38+
return None
39+
return first.id
40+
2941

30-
class MapTrait(Trait):
42+
class MapTrait(Q7MapList, Trait):
3143
"""Map retrieval + map metadata helpers for Q7 devices."""
3244

3345
def __init__(self, channel: MqttChannel) -> None:
46+
super().__init__()
3447
self._channel = channel
3548
# Map uploads are serialized per-device to avoid response cross-wiring.
3649
self._map_command_lock = asyncio.Lock()
37-
self._map_list: Q7MapList | None = None
38-
39-
@property
40-
def map_list(self) -> Q7MapList | None:
41-
"""Latest cached map list metadata, populated by ``refresh()``."""
42-
return self._map_list
43-
44-
@property
45-
def current_map_id(self) -> int | None:
46-
"""Current map id derived from cached map list metadata."""
47-
if self._map_list is None:
48-
return None
49-
return self._extract_current_map_id(self._map_list)
50+
self._loaded = False
5051

5152
async def refresh(self) -> None:
5253
"""Refresh cached map list metadata from the device."""
@@ -55,13 +56,15 @@ async def refresh(self) -> None:
5556
Q7RequestMessage(dps=B01_Q7_DPS, command=RoborockB01Q7Methods.GET_MAP_LIST, params={}),
5657
)
5758
if not isinstance(response, dict):
58-
raise TypeError(f"Unexpected response type for GET_MAP_LIST: {type(response).__name__}: {response!r}")
59+
raise RoborockException(
60+
f"Unexpected response type for GET_MAP_LIST: {type(response).__name__}: {response!r}"
61+
)
5962

60-
parsed = Q7MapList.from_dict(response)
61-
if parsed is None:
62-
raise TypeError(f"Failed to decode map list response: {response!r}")
63+
if (parsed := Q7MapList.from_dict(response)) is None:
64+
raise RoborockException(f"Failed to decode map list response: {response!r}")
6365

64-
self._map_list = parsed
66+
self.map_list = parsed.map_list
67+
self._loaded = True
6568

6669
async def _get_map_payload(self, *, map_id: int) -> bytes:
6770
"""Fetch raw map payload bytes for the given map id."""
@@ -75,25 +78,10 @@ async def _get_map_payload(self, *, map_id: int) -> bytes:
7578

7679
async def get_current_map_payload(self) -> bytes:
7780
"""Fetch raw map payload bytes for the currently selected map."""
78-
if self._map_list is None:
81+
if not self._loaded:
7982
await self.refresh()
8083

8184
map_id = self.current_map_id
8285
if map_id is None:
83-
raise RoborockException(f"Unable to determine map_id from map list response: {self._map_list!r}")
86+
raise RoborockException(f"Unable to determine map_id from map list response: {self!r}")
8487
return await self._get_map_payload(map_id=map_id)
85-
86-
@staticmethod
87-
def _extract_current_map_id(map_list_response: Q7MapList) -> int | None:
88-
map_list = map_list_response.map_list
89-
if not map_list:
90-
return None
91-
92-
for entry in map_list:
93-
if entry.cur and isinstance(entry.id, int):
94-
return entry.id
95-
96-
first = map_list[0]
97-
if isinstance(first.id, int):
98-
return first.id
99-
return None

tests/devices/traits/b01/q7/test_init.py

Lines changed: 1 addition & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from roborock.devices.traits.b01.q7 import Q7PropertiesApi
1818
from roborock.exceptions import RoborockException
1919
from roborock.protocols.b01_q7_protocol import B01_VERSION, Q7RequestMessage
20-
from roborock.roborock_message import RoborockB01Props, RoborockMessage, RoborockMessageProtocol
20+
from roborock.roborock_message import RoborockB01Props, RoborockMessageProtocol
2121
from tests.fixtures.channel_fixtures import FakeChannel
2222

2323
from . import B01MessageBuilder
@@ -27,16 +27,12 @@ async def test_q7_api_query_values(
2727
q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder
2828
):
2929
"""Test that Q7PropertiesApi correctly converts raw values."""
30-
# We need to construct the expected result based on the mappings
31-
# status: 1 -> WAITING_FOR_ORDERS
32-
# wind: 2 -> STANDARD
3330
response_data = {
3431
"status": 1,
3532
"wind": 2,
3633
"battery": 100,
3734
}
3835

39-
# Queue the response
4036
fake_channel.response_queue.append(message_builder.build(response_data))
4137

4238
result = await q7_api.query_values(
@@ -48,23 +44,15 @@ async def test_q7_api_query_values(
4844

4945
assert result is not None
5046
assert result.status == WorkStatusMapping.WAITING_FOR_ORDERS
51-
# wind might be mapped to SCWindMapping.STANDARD (2)
52-
# let's verify checking the prop definition in B01Props
53-
# wind: SCWindMapping | None = None
54-
# SCWindMapping.STANDARD is 2 ('balanced')
55-
from roborock.data.b01_q7 import SCWindMapping
56-
5747
assert result.wind == SCWindMapping.STANDARD
5848

5949
assert len(fake_channel.published_messages) == 1
6050
message = fake_channel.published_messages[0]
6151
assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
6252
assert message.version == B01_VERSION
6353

64-
# Verify request payload
6554
assert message.payload is not None
6655
payload_data = json.loads(unpad(message.payload, AES.block_size))
67-
# {"dps": {"10000": {"method": "prop.get", "msgId": "123456789", "params": {"property": ["status", "wind"]}}}}
6856
assert "dps" in payload_data
6957
assert "10000" in payload_data["dps"]
7058
inner = payload_data["dps"]["10000"]
@@ -110,8 +98,6 @@ async def test_send_decoded_command_non_dict_response(fake_channel: FakeChannel,
11098
message = message_builder.build("some_string_error")
11199
fake_channel.response_queue.append(message)
112100

113-
# Use a random string for command type to avoid needing import
114-
115101
with pytest.raises(RoborockException, match="Unexpected data type for response"):
116102
await send_decoded_command(fake_channel, Q7RequestMessage(dps=10000, command="prop.get", params=[])) # type: ignore[arg-type]
117103

@@ -275,90 +261,3 @@ async def test_q7_api_clean_segments(
275261
"ctrl_value": SCDeviceCleanParam.START.code,
276262
"room_ids": [10, 11],
277263
}
278-
279-
280-
async def test_q7_api_get_current_map_payload(
281-
q7_api: Q7PropertiesApi,
282-
fake_channel: FakeChannel,
283-
message_builder: B01MessageBuilder,
284-
):
285-
"""Fetch current map by map-list lookup, then upload_by_mapid."""
286-
fake_channel.response_queue.append(message_builder.build({"map_list": [{"id": 1772093512, "cur": True}]}))
287-
fake_channel.response_queue.append(
288-
RoborockMessage(
289-
protocol=RoborockMessageProtocol.MAP_RESPONSE,
290-
payload=b"raw-map-payload",
291-
version=b"B01",
292-
seq=message_builder.seq + 1,
293-
)
294-
)
295-
296-
raw_payload = await q7_api.map.get_current_map_payload()
297-
assert raw_payload == b"raw-map-payload"
298-
299-
assert len(fake_channel.published_messages) == 2
300-
301-
first = fake_channel.published_messages[0]
302-
first_payload = json.loads(unpad(first.payload, AES.block_size))
303-
assert first_payload["dps"]["10000"]["method"] == "service.get_map_list"
304-
assert first_payload["dps"]["10000"]["params"] == {}
305-
306-
second = fake_channel.published_messages[1]
307-
second_payload = json.loads(unpad(second.payload, AES.block_size))
308-
assert second_payload["dps"]["10000"]["method"] == "service.upload_by_mapid"
309-
assert second_payload["dps"]["10000"]["params"] == {"map_id": 1772093512}
310-
311-
312-
async def test_q7_api_map_trait_refresh_populates_cached_values(
313-
q7_api: Q7PropertiesApi,
314-
fake_channel: FakeChannel,
315-
message_builder: B01MessageBuilder,
316-
):
317-
"""Map trait follows refresh + cached-value access pattern."""
318-
fake_channel.response_queue.append(message_builder.build({"map_list": [{"id": 101, "cur": True}]}))
319-
320-
assert q7_api.map.map_list is None
321-
assert q7_api.map.current_map_id is None
322-
323-
await q7_api.map.refresh()
324-
325-
assert len(fake_channel.published_messages) == 1
326-
assert q7_api.map.map_list is not None
327-
assert q7_api.map.map_list.map_list[0].id == 101
328-
assert q7_api.map.map_list.map_list[0].cur is True
329-
assert q7_api.map.current_map_id == 101
330-
331-
332-
async def test_q7_api_get_current_map_payload_falls_back_to_first_map(
333-
q7_api: Q7PropertiesApi,
334-
fake_channel: FakeChannel,
335-
message_builder: B01MessageBuilder,
336-
):
337-
"""If no current map marker exists, first map in list is used."""
338-
fake_channel.response_queue.append(message_builder.build({"map_list": [{"id": 111}, {"id": 222, "cur": False}]}))
339-
fake_channel.response_queue.append(
340-
RoborockMessage(
341-
protocol=RoborockMessageProtocol.MAP_RESPONSE,
342-
payload=b"raw-map-payload",
343-
version=b"B01",
344-
seq=message_builder.seq + 1,
345-
)
346-
)
347-
348-
await q7_api.map.get_current_map_payload()
349-
350-
second = fake_channel.published_messages[1]
351-
second_payload = json.loads(unpad(second.payload, AES.block_size))
352-
assert second_payload["dps"]["10000"]["params"] == {"map_id": 111}
353-
354-
355-
async def test_q7_api_get_current_map_payload_errors_without_map_list(
356-
q7_api: Q7PropertiesApi,
357-
fake_channel: FakeChannel,
358-
message_builder: B01MessageBuilder,
359-
):
360-
"""Current-map payload fetch should fail clearly when map list is unusable."""
361-
fake_channel.response_queue.append(message_builder.build({"map_list": []}))
362-
363-
with pytest.raises(RoborockException, match="Unable to determine map_id"):
364-
await q7_api.map.get_current_map_payload()

0 commit comments

Comments
 (0)