Skip to content

Commit 46bdfca

Browse files
committed
fix: make threadsafe waiting queue
1 parent 67235cf commit 46bdfca

File tree

9 files changed

+217
-64
lines changed

9 files changed

+217
-64
lines changed

roborock/api.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
RoborockTimeout,
1818
UnknownMethodError,
1919
)
20-
from .roborock_future import RoborockFuture
20+
from .roborock_future import RequestKey, RoborockFuture, WaitingQueue
2121
from .roborock_message import (
2222
RoborockMessage,
2323
RoborockMessageProtocol,
@@ -38,7 +38,7 @@ def __init__(self, device_info: DeviceData) -> None:
3838
"""Initialize RoborockClient."""
3939
self.device_info = device_info
4040
self._nonce = secrets.token_bytes(16)
41-
self._waiting_queue: dict[int, RoborockFuture] = {}
41+
self._waiting_queue = WaitingQueue()
4242
self._last_device_msg_in = time.monotonic()
4343
self._last_disconnection = time.monotonic()
4444
self.keep_alive = KEEPALIVE
@@ -89,33 +89,22 @@ async def validate_connection(self) -> None:
8989
await self.async_disconnect()
9090
await self.async_connect()
9191

92-
async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any:
92+
async def _wait_response(self, request_key: RequestKey, future: RoborockFuture) -> Any:
9393
try:
94-
response = await queue.async_get(self.queue_timeout)
94+
response = await future.async_get(self.queue_timeout)
9595
if response == "unknown_method":
9696
raise UnknownMethodError("Unknown method")
9797
return response
9898
except (asyncio.TimeoutError, asyncio.CancelledError):
99-
raise RoborockTimeout(f"id={request_id} Timeout after {self.queue_timeout} seconds") from None
99+
raise RoborockTimeout(f"id={request_key} Timeout after {self.queue_timeout} seconds") from None
100100
finally:
101-
self._waiting_queue.pop(request_id, None)
102-
103-
def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
104-
queue = RoborockFuture(protocol_id)
105-
if request_id in self._waiting_queue and not (
106-
request_id == 2 and protocol_id == RoborockMessageProtocol.PING_REQUEST
107-
):
108-
new_id = get_next_int(10000, 32767)
109-
self._logger.warning(
110-
"Attempting to create a future with an existing id %s (%s)... New id is %s. "
111-
"Code may not function properly.",
112-
request_id,
113-
protocol_id,
114-
new_id,
115-
)
116-
request_id = new_id
117-
self._waiting_queue[request_id] = queue
118-
return asyncio.ensure_future(self._wait_response(request_id, queue))
101+
self._waiting_queue.safe_pop(request_key)
102+
103+
104+
def _async_response(self, request_key: RequestKey) -> Any:
105+
future = RoborockFuture()
106+
self._waiting_queue.put(request_key, future)
107+
return asyncio.ensure_future(self._wait_response(request_key, future))
119108

120109
@abstractmethod
121110
async def send_message(self, roborock_message: RoborockMessage):

roborock/cloud_api.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .containers import DeviceData, UserData
1515
from .exceptions import RoborockException, VacuumError
1616
from .protocol import MessageParser, md5hex
17-
from .roborock_future import RoborockFuture
17+
from .roborock_future import RequestKey
1818

1919
_LOGGER = logging.getLogger(__name__)
2020
CONNECT_REQUEST_ID = 0
@@ -72,12 +72,11 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
7272
self._mqtt_password = rriot.s
7373
self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:]
7474
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
75-
self._waiting_queue: dict[int, RoborockFuture] = {}
7675
self._mutex = Lock()
7776

7877
def _mqtt_on_connect(self, *args, **kwargs):
7978
_, __, ___, rc, ____ = args
80-
connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID)
79+
connection_queue = self._waiting_queue.safe_pop(RequestKey(CONNECT_REQUEST_ID))
8180
if rc != mqtt.MQTT_ERR_SUCCESS:
8281
message = f"Failed to connect ({mqtt.error_string(rc)})"
8382
self._logger.error(message)
@@ -98,6 +97,8 @@ def _mqtt_on_connect(self, *args, **kwargs):
9897
self._logger.info(f"Subscribed to topic {topic}")
9998
if connection_queue:
10099
connection_queue.set_result(True)
100+
else:
101+
self._logger.debug("Connected but no connect future")
101102

102103
def _mqtt_on_message(self, *args, **kwargs):
103104
client, __, msg = args
@@ -112,9 +113,11 @@ def _mqtt_on_disconnect(self, *args, **kwargs):
112113
try:
113114
exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None
114115
super().on_connection_lost(exc)
115-
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
116+
connection_queue = self._waiting_queue.safe_pop(RequestKey(DISCONNECT_REQUEST_ID))
116117
if connection_queue:
117118
connection_queue.set_result(True)
119+
else:
120+
self._logger.debug("Disconnected but no disconnect future")
118121
except Exception as ex:
119122
self._logger.exception(ex)
120123

@@ -124,10 +127,11 @@ def is_connected(self) -> bool:
124127

125128
def _sync_disconnect(self) -> Any:
126129
if not self.is_connected():
130+
self._logger.debug("Already disconnected from mqtt")
127131
return None
128132

129133
self._logger.info("Disconnecting from mqtt")
130-
disconnected_future = self._async_response(DISCONNECT_REQUEST_ID)
134+
disconnected_future = self._async_response(RequestKey(DISCONNECT_REQUEST_ID))
131135
rc = self._mqtt_client.disconnect()
132136

133137
if rc == mqtt.MQTT_ERR_NO_CONN:
@@ -149,7 +153,7 @@ def _sync_connect(self) -> Any:
149153
raise RoborockException("Mqtt information was not entered. Cannot connect.")
150154

151155
self._logger.debug("Connecting to mqtt")
152-
connected_future = self._async_response(CONNECT_REQUEST_ID)
156+
connected_future = self._async_response(RequestKey(CONNECT_REQUEST_ID))
153157
self._mqtt_client.connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
154158
self._mqtt_client.maybe_restart_loop()
155159
return connected_future

roborock/roborock_future.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,64 @@
11
from __future__ import annotations
22

3+
import logging
34
from asyncio import Future
5+
from dataclasses import dataclass
6+
from threading import Lock
47
from typing import Any
58

69
import async_timeout
710

8-
from .exceptions import VacuumError
11+
from .exceptions import UnknownMethodError, VacuumError
12+
from .roborock_message import RoborockMessageProtocol
13+
14+
_LOGGER = logging.getLogger(__name__)
15+
_TRIES = 3
16+
17+
18+
@dataclass(frozen=True)
19+
class RequestKey:
20+
"""A key for a Roborock message request."""
21+
22+
request_id: int
23+
protocol: RoborockMessageProtocol | int = 0
24+
25+
def __str__(self) -> str:
26+
"""Get the key for the request."""
27+
return f"{self.request_id}-{self.protocol}"
28+
29+
30+
class WaitingQueue:
31+
"""A threadsafe waiting queue for Roborock messages."""
32+
33+
def __init__(self) -> None:
34+
"""Initialize the waiting queue."""
35+
self._lock = Lock()
36+
self._queue: dict[RequestKey, RoborockFuture] = {}
37+
38+
def put(self, request_key: RequestKey, future: RoborockFuture) -> None:
39+
"""Create a future for the given protocol."""
40+
_LOGGER.debug("Putting request key %s in the queue", request_key)
41+
with self._lock:
42+
if request_key in self._queue:
43+
raise ValueError(f"Request key {request_key} already exists in the queue")
44+
self._queue[request_key] = future
45+
46+
def safe_pop(self, request_key: RequestKey) -> RoborockFuture | None:
47+
"""Get the future from the queue if it has not yet been popped, otherwise ignore."""
48+
_LOGGER.debug("Popping request key %s from the queue", request_key)
49+
with self._lock:
50+
return self._queue.pop(request_key, None)
951

1052

1153
class RoborockFuture:
12-
def __init__(self, protocol: int):
13-
self.protocol = protocol
54+
"""A threadsafe asyncio Future for Roborock messages.
55+
56+
The results may be set from a background thread. The future
57+
must be awaited in an asyncio event loop.
58+
"""
59+
60+
def __init__(self):
61+
"""Initialize the Roborock future."""
1462
self.fut: Future = Future()
1563
self.loop = self.fut.get_loop()
1664

@@ -28,9 +76,15 @@ def _set_exception(self, exc: VacuumError) -> None:
2876
def set_exception(self, exc: VacuumError) -> None:
2977
self.loop.call_soon_threadsafe(self._set_exception, exc)
3078

31-
async def async_get(self, timeout: float | int) -> tuple[Any, VacuumError | None]:
79+
async def async_get(self, timeout: float | int) -> Any:
80+
"""Get the result from the future or raises an error."""
3281
try:
3382
async with async_timeout.timeout(timeout):
34-
return await self.fut
83+
response = await self.fut
84+
# This should be moved to the specific client that handles this
85+
# and set an exception directly rather than patching an exception here
86+
if response == "unknown_method":
87+
raise UnknownMethodError("Unknown method")
88+
return response
3589
finally:
3690
self.fut.cancel()

roborock/version_1_apis/roborock_client_v1.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
WashTowelMode,
4848
)
4949
from roborock.protocol import Utils
50+
from roborock.roborock_future import RequestKey
5051
from roborock.roborock_message import (
5152
ROBOROCK_DATA_CONSUMABLE_PROTOCOL,
5253
ROBOROCK_DATA_STATUS_PROTOCOL,
@@ -391,8 +392,9 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
391392
if data_point_number == "102":
392393
data_point_response = json.loads(data_point)
393394
request_id = data_point_response.get("id")
394-
queue = self._waiting_queue.get(request_id)
395-
if queue and queue.protocol == protocol:
395+
request_key = RequestKey(request_id, protocol)
396+
queue = self._waiting_queue.safe_pop(request_key)
397+
if queue:
396398
error = data_point_response.get("error")
397399
if error:
398400
queue.set_exception(
@@ -407,7 +409,7 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
407409
result = result[0]
408410
queue.set_result(result)
409411
else:
410-
self._logger.debug("Received response for unknown request id %s", request_id)
412+
self._logger.debug("Received response for unknown request id %s", request_key)
411413
else:
412414
try:
413415
data_protocol = RoborockDataProtocol(int(data_point_number))
@@ -467,19 +469,21 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
467469
except ValueError as err:
468470
raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err
469471
decompressed = Utils.decompress(decrypted)
470-
queue = self._waiting_queue.get(request_id)
472+
request_key = RequestKey(request_id, protocol)
473+
queue = self._waiting_queue.safe_pop(request_key)
471474
if queue:
472475
if isinstance(decompressed, list):
473476
decompressed = decompressed[0]
474477
queue.set_result(decompressed)
475478
else:
476-
self._logger.debug("Received response for unknown request id %s", request_id)
479+
self._logger.debug("Received response for unknown request id %s", request_key)
477480
else:
478-
queue = self._waiting_queue.get(data.seq)
481+
request_key = RequestKey(data.seq, protocol)
482+
queue = self._waiting_queue.safe_pop(request_key)
479483
if queue:
480484
queue.set_result(data.payload)
481485
else:
482-
self._logger.debug("Received response for unknown request id %s", data.seq)
486+
self._logger.debug("Received response for unknown request id %s", request_key)
483487
except Exception as ex:
484488
self._logger.exception(ex)
485489

roborock/version_1_apis/roborock_local_client_v1.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
66
from ..exceptions import VacuumError
77
from ..protocol import MessageParser
8+
from ..roborock_future import RequestKey
89
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
910
from ..util import RoborockLoggerAdapter
1011
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
@@ -54,15 +55,19 @@ async def send_message(self, roborock_message: RoborockMessage):
5455
response_protocol = request_id + 1
5556
else:
5657
request_id = roborock_message.get_request_id()
58+
_LOGGER.debug("Getting next request id: %s", request_id)
5759
response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
5860
if request_id is None:
5961
raise RoborockException(f"Failed build message {roborock_message}")
6062
local_key = self.device_info.device.local_key
6163
msg = MessageParser.build(roborock_message, local_key=local_key)
64+
request_key = RequestKey(request_id, response_protocol)
6265
if method:
63-
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
66+
self._logger.debug(f"id={request_key} Requesting method {method} with {params}")
67+
else:
68+
self._logger.debug(f"id={request_key} Requesting with {params}")
6469
# Send the command to the Roborock device
65-
async_response = self._async_response(request_id, response_protocol)
70+
async_response = self._async_response(request_key)
6671
self._send_msg_raw(msg)
6772
diagnostic_key = method if method is not None else "unknown"
6873
try:

roborock/version_1_apis/roborock_mqtt_client_v1.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..containers import DeviceData, UserData
1212
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
1313
from ..protocol import MessageParser, Utils
14+
from ..roborock_future import RequestKey
1415
from ..roborock_message import (
1516
RoborockMessage,
1617
RoborockMessageProtocol,
@@ -47,11 +48,11 @@ async def send_message(self, roborock_message: RoborockMessage):
4748
response_protocol = (
4849
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
4950
)
50-
51+
request_key = RequestKey(request_id, response_protocol)
5152
local_key = self.device_info.device.local_key
5253
msg = MessageParser.build(roborock_message, local_key, False)
53-
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
54-
async_response = self._async_response(request_id, response_protocol)
54+
self._logger.debug(f"id={request_key} Requesting method {method} with {params}")
55+
async_response = self._async_response(request_key)
5556
self._send_msg_raw(msg)
5657
diagnostic_key = method if method is not None else "unknown"
5758
try:
@@ -67,9 +68,9 @@ async def send_message(self, roborock_message: RoborockMessage):
6768
"response": response,
6869
}
6970
if response_protocol == RoborockMessageProtocol.MAP_RESPONSE:
70-
self._logger.debug(f"id={request_id} Response from {method}: {len(response)} bytes")
71+
self._logger.debug(f"id={request_key} Response from {method}: {len(response)} bytes")
7172
else:
72-
self._logger.debug(f"id={request_id} Response from {method}: {response}")
73+
self._logger.debug(f"id={request_key} Response from {method}: {response}")
7374
return response
7475

7576
async def _send_command(

roborock/version_a01_apis/roborock_client_a01.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ZeoTemperature,
3434
)
3535
from roborock.containers import DyadProductInfo, DyadSndState, RoborockCategory
36+
from roborock.roborock_future import RequestKey
3637
from roborock.roborock_message import (
3738
RoborockDyadDataProtocol,
3839
RoborockMessage,
@@ -142,9 +143,12 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
142143
if data_point_protocol in entries:
143144
# Auto convert into data struct we want.
144145
converted_response = entries[data_point_protocol].post_process_fn(data_point)
145-
queue = self._waiting_queue.get(int(data_point_number))
146-
if queue and queue.protocol == protocol:
147-
queue.set_result(converted_response)
146+
request_key = RequestKey(int(data_point_number), protocol)
147+
future = self._waiting_queue.safe_pop(request_key)
148+
if future is not None:
149+
future.set_result(converted_response)
150+
else:
151+
self._logger.debug(f"Got response for {request_key} but no future found")
148152

149153
@abstractmethod
150154
async def update_values(

roborock/version_a01_apis/roborock_mqtt_client_a01.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from roborock.containers import DeviceData, RoborockCategory, UserData
1111
from roborock.exceptions import RoborockException
1212
from roborock.protocol import MessageParser
13+
from roborock.roborock_future import RequestKey
1314
from roborock.roborock_message import (
1415
RoborockDyadDataProtocol,
1516
RoborockMessage,
@@ -50,7 +51,7 @@ async def send_message(self, roborock_message: RoborockMessage):
5051
futures = []
5152
if "10000" in payload["dps"]:
5253
for dps in json.loads(payload["dps"]["10000"]):
53-
futures.append(self._async_response(dps, response_protocol))
54+
futures.append(self._async_response(RequestKey(dps, response_protocol)))
5455
self._send_msg_raw(m)
5556
responses = await asyncio.gather(*futures, return_exceptions=True)
5657
dps_responses: dict[int, typing.Any] = {}

0 commit comments

Comments
 (0)