Skip to content

Commit b9a241c

Browse files
authored
feat: Small tweaks to test fixtures (#704)
* chore: Small tweaks to test fixtures These are improvements factored out of a large change to add more e2e tests for device manager. * chore: Update device test snapshots * chore: update test fixtures * chore: fix lint errors * feat: revert whitespace change.
1 parent 6293a67 commit b9a241c

File tree

8 files changed

+31
-19
lines changed

8 files changed

+31
-19
lines changed

roborock/devices/local_channel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def connection_lost(self, exc: Exception | None) -> None:
4545
self.connection_lost_cb(exc)
4646

4747

48+
def get_running_loop() -> asyncio.AbstractEventLoop:
49+
"""Get the running event loop, extracted for mocking purposes."""
50+
return asyncio.get_running_loop()
51+
52+
4853
class LocalChannel(Channel):
4954
"""Simple RPC-style channel for communicating with a device over a local network.
5055
@@ -179,7 +184,7 @@ async def connect(self) -> None:
179184
if self._is_connected:
180185
self._logger.debug("Unexpected call to connect when already connected")
181186
return
182-
loop = asyncio.get_running_loop()
187+
loop = get_running_loop()
183188
protocol = _LocalProtocol(self._data_received, self._connection_lost)
184189
try:
185190
self._transport, self._protocol = await loop.create_connection(lambda: protocol, self._host, _PORT)

tests/devices/test_local_channel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def setup_mock_loop(mock_transport: Mock) -> Generator[Mock, None, None]:
5252
loop = Mock()
5353
loop.create_connection = AsyncMock(return_value=(mock_transport, Mock()))
5454

55-
with patch("asyncio.get_running_loop", return_value=loop):
55+
with patch("roborock.devices.local_channel.get_running_loop", return_value=loop):
5656
yield loop
5757

5858

tests/e2e/__snapshots__/test_mqtt_session.ambr

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# serializer version: 1
22
# name: test_session_e2e_publish_message
3-
[mqtt <]
4-
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
53
[mqtt >]
64
00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....|
75
00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw|
86
00000020 6f 72 64 |ord|
7+
[mqtt <]
8+
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
99
[mqtt >]
1010
00000000 30 41 00 07 74 6f 70 69 63 2d 31 00 31 2e 30 00 |0A..topic-1.1.0.|
1111
00000010 00 01 c8 00 00 23 82 68 a6 a2 23 00 65 00 20 91 |.....#.h..#.e. .|
@@ -14,13 +14,13 @@
1414
00000040 99 71 bf |.q.|
1515
# ---
1616
# name: test_session_e2e_receive_message
17-
[mqtt <]
18-
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
1917
[mqtt >]
2018
00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....|
2119
00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw|
2220
00000020 6f 72 64 |ord|
2321
[mqtt <]
22+
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
23+
[mqtt <]
2424
00000000 90 04 00 01 00 00 |......|
2525
[mqtt >]
2626
00000000 82 0d 00 01 00 00 07 74 6f 70 69 63 2d 31 00 |.......topic-1.|

tests/fixtures/aiomqtt_fixtures.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ async def mock_aiomqtt_client_fixture() -> AsyncGenerator[None, None]:
2828

2929
async def poll_sockets(client: mqtt.Client) -> None:
3030
"""Poll the mqtt client sockets in a loop to pick up new data."""
31-
while True:
32-
event_loop.call_soon_threadsafe(client.loop_read)
33-
event_loop.call_soon_threadsafe(client.loop_write)
34-
await asyncio.sleep(0.01)
31+
try:
32+
while True:
33+
event_loop.call_soon_threadsafe(client.loop_read)
34+
event_loop.call_soon_threadsafe(client.loop_write)
35+
await asyncio.sleep(0.01)
36+
except asyncio.CancelledError:
37+
pass
3538

3639
task: asyncio.Task[None] | None = None
3740

@@ -52,6 +55,7 @@ def new_client(*args: Any, **kwargs: Any) -> mqtt.Client:
5255
yield
5356
if task:
5457
task.cancel()
58+
await task
5559

5660

5761
@pytest.fixture

tests/fixtures/local_async_fixtures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def start_handle_write(data: bytes) -> None:
7979

8080
return (mock_transport, protocol)
8181

82-
with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop:
82+
with patch("roborock.devices.local_channel.get_running_loop") as mock_loop:
8383
mock_loop.return_value.create_connection.side_effect = create_connection
8484
yield
8585

tests/fixtures/mqtt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
self.handle_request = handle_request
3333
self.response_queue = response_queue
3434
self.log = log
35+
self.client_connected = False
3536

3637
def pending(self) -> int:
3738
"""Return the number of bytes in the response buffer."""
@@ -52,6 +53,7 @@ def handle_socket_recv(self, read_size: int) -> bytes:
5253

5354
def handle_socket_send(self, client_request: bytes) -> int:
5455
"""Receive an incoming request from the client."""
56+
self.client_connected = True
5557
_LOGGER.debug("Request: 0x%s", client_request.hex())
5658
self.log.add_log_entry("[mqtt >]", client_request)
5759
if (response := self.handle_request(client_request)) is not None:
@@ -64,7 +66,7 @@ def handle_socket_send(self, client_request: bytes) -> int:
6466

6567
def push_response(self) -> None:
6668
"""Push a response to the client."""
67-
if not self.response_queue.empty():
69+
if not self.response_queue.empty() and self.client_connected:
6870
response = self.response_queue.get()
6971
# Enqueue a response to be sent back to the client in the buffer.
7072
# The buffer will be emptied when the client calls recv() on the socket

tests/fixtures/pahomqtt_fixtures.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Common code for MQTT tests."""
22

33
import logging
4+
import warnings
45
from collections.abc import Callable, Generator
56
from queue import Queue
67
from typing import Any
@@ -50,9 +51,12 @@ def handle_select(rlist: list, wlist: list, *args: Any) -> list:
5051
@pytest.fixture(name="fake_mqtt_socket_handler")
5152
def fake_mqtt_socket_handler_fixture(
5253
mqtt_request_handler: MqttRequestHandler, mqtt_response_queue: Queue[bytes], log: CapturedRequestLog
53-
) -> FakeMqttSocketHandler:
54+
) -> Generator[FakeMqttSocketHandler, None, None]:
5455
"""Fixture that creates a fake MQTT broker."""
55-
return FakeMqttSocketHandler(mqtt_request_handler, mqtt_response_queue, log)
56+
socket_handler = FakeMqttSocketHandler(mqtt_request_handler, mqtt_response_queue, log)
57+
yield socket_handler
58+
if len(socket_handler.response_buf.getvalue()) > 0:
59+
warnings.warn("Some enqueued MQTT responses were not consumed during the test")
5660

5761

5862
@pytest.fixture(name="mock_sock")
@@ -76,7 +80,8 @@ def response_queue_fixture() -> Generator[Queue[bytes], None, None]:
7680
"""Fixture that provides a queue for enqueueing responses to be sent to the client under test."""
7781
response_queue: Queue[bytes] = Queue()
7882
yield response_queue
79-
assert response_queue.empty(), "Not all fake responses were consumed"
83+
if not response_queue.empty():
84+
warnings.warn("Some enqueued MQTT responses were not consumed during the test")
8085

8186

8287
@pytest.fixture(name="mqtt_request_handler")

tests/mqtt/test_roborock_session.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ async def test_session_no_subscribers(push_mqtt_response: Callable[[bytes], None
151151
"""Test the MQTT session."""
152152

153153
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
154-
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
155-
push_mqtt_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890"))
156154
session = await create_mqtt_session(FAKE_PARAMS)
157155
assert session.connected
158156

@@ -528,8 +526,6 @@ def succeed_then_fail_unauthorized() -> Any:
528526
# Don't produce messages, just exit and restart to reconnect
529527
message_iterator.loop = False
530528

531-
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
532-
533529
session = await create_mqtt_session(params)
534530
assert session.connected
535531

0 commit comments

Comments
 (0)