11"""End-to-end tests for LocalChannel using fake sockets."""
22
33import asyncio
4- from collections .abc import AsyncGenerator
5- from unittest .mock import patch
4+ from collections .abc import AsyncGenerator , Generator , Callable
5+ from unittest .mock import patch , Mock
6+ from typing import Any
67
78import pytest
9+ import syrupy
810
911from roborock .devices .local_channel import LocalChannel
10- from roborock .protocol import create_local_decoder , create_local_encoder
12+ from roborock .protocol import MessageParser , create_local_decoder
13+ from roborock .protocols .v1_protocol import LocalProtocolVersion
1114from roborock .roborock_message import RoborockMessage , RoborockMessageProtocol
15+ from tests .fixtures .logging import CapturedRequestLog
16+ from tests .fixtures .mqtt import Subscriber
17+ from tests .fixtures .local_async_fixtures import AsyncLocalRequestHandler
1218from tests .mock_data import LOCAL_KEY
1319
1420TEST_HOST = "192.168.1.100"
1521TEST_DEVICE_UID = "test_device_uid"
16- TEST_CONNECT_NONCE = 12345
17- TEST_ACK_NONCE = 67890
18- TEST_RANDOM = 13579
22+ TEST_RANDOM = 23
23+
24+
25+ @pytest .fixture
26+ def auto_deterministic_message_fixtures (deterministic_message_fixtures : None ) -> None :
27+ """Auto-use deterministic message fixtures for all tests in this module."""
28+ pass
29+
30+
31+ @pytest .fixture (name = "mock_create_local_connection" )
32+ def create_local_connection_fixture (
33+ local_async_request_handler : AsyncLocalRequestHandler , log : CapturedRequestLog
34+ ) -> Generator [None , None , None ]:
35+ """Fixture that overrides the transport creation to wire it up to the mock socket."""
36+
37+ async def create_connection (protocol_factory : Callable [[], asyncio .Protocol ], * args , ** kwargs ) -> tuple [Any , Any ]:
38+ protocol = protocol_factory ()
39+
40+ async def handle_write (data : bytes ) -> None :
41+ log .add_log_entry ("[local >]" , data )
42+ response = await local_async_request_handler (data )
43+ if response is not None :
44+ log .add_log_entry ("[local <]" , response )
45+ # Call data_received directly to avoid loop scheduling issues in test
46+ protocol .data_received (response )
47+
48+ closed = asyncio .Event ()
49+
50+ mock_transport = Mock ()
51+ mock_transport .write = handle_write
52+ mock_transport .close = closed .set
53+ mock_transport .is_reading = lambda : not closed .is_set ()
54+
55+ return (mock_transport , protocol )
56+
57+ with patch ("roborock.devices.local_channel.asyncio.get_running_loop" ) as mock_loop :
58+ mock_loop .return_value .create_connection .side_effect = create_connection
59+ yield
1960
2061
2162@pytest .fixture (name = "local_channel" )
22- async def local_channel_fixture (mock_async_create_local_connection : None ) -> AsyncGenerator [LocalChannel , None ]:
23- with patch (
24- "roborock.devices.local_channel.get_next_int" , return_value = TEST_CONNECT_NONCE , device_uid = TEST_DEVICE_UID
25- ):
26- channel = LocalChannel (host = TEST_HOST , local_key = LOCAL_KEY , device_uid = TEST_DEVICE_UID )
27- yield channel
28- channel .close ()
63+ async def local_channel_fixture (mock_create_local_connection : None ) -> AsyncGenerator [LocalChannel , None ]:
64+ channel = LocalChannel (host = TEST_HOST , local_key = LOCAL_KEY , device_uid = TEST_DEVICE_UID )
65+ yield channel
66+ channel .close ()
2967
3068
31- def build_response (
69+ def build_raw_response (
3270 protocol : RoborockMessageProtocol ,
3371 seq : int ,
3472 payload : bytes ,
35- random : int ,
73+ version : LocalProtocolVersion = LocalProtocolVersion .V1 ,
74+ connect_nonce : int | None = None ,
75+ ack_nonce : int | None = None ,
3676) -> bytes :
37- """Build an encoded response message."""
38- if protocol == RoborockMessageProtocol .HELLO_RESPONSE :
39- encoder = create_local_encoder (local_key = LOCAL_KEY , connect_nonce = TEST_CONNECT_NONCE , ack_nonce = None )
40- else :
41- encoder = create_local_encoder (local_key = LOCAL_KEY , connect_nonce = TEST_CONNECT_NONCE , ack_nonce = TEST_ACK_NONCE )
42-
43- msg = RoborockMessage (
77+ """Build an encoded RPC response message."""
78+ message = RoborockMessage (
4479 protocol = protocol ,
45- random = random ,
80+ random = 23 ,
4681 seq = seq ,
4782 payload = payload ,
83+ version = version .value .encode (),
4884 )
49- return encoder ( msg )
85+ return MessageParser . build ( message , local_key = LOCAL_KEY , connect_nonce = connect_nonce , ack_nonce = ack_nonce )
5086
5187
5288async def test_connect (
5389 local_channel : LocalChannel ,
5490 local_response_queue : asyncio .Queue [bytes ],
5591 local_received_requests : asyncio .Queue [bytes ],
92+ log : CapturedRequestLog ,
93+ snapshot : syrupy .SnapshotAssertion ,
5694) -> None :
5795 """Test connecting to the device."""
5896 # Queue HELLO response with payload to ensure it can be parsed
5997 local_response_queue .put_nowait (
60- build_response (RoborockMessageProtocol .HELLO_RESPONSE , 1 , payload = b"ok" , random = TEST_RANDOM )
98+ build_raw_response (RoborockMessageProtocol .HELLO_RESPONSE , 1 , payload = b"ok" )
6199 )
62100
63101 await local_channel .connect ()
@@ -76,16 +114,20 @@ async def test_connect(
76114 protocol_bytes = request_bytes [19 :21 ]
77115 assert int .from_bytes (protocol_bytes , "big" ) == RoborockMessageProtocol .HELLO_REQUEST
78116
117+ assert snapshot == log
118+
79119
80120async def test_send_command (
81121 local_channel : LocalChannel ,
82122 local_response_queue : asyncio .Queue [bytes ],
83123 local_received_requests : asyncio .Queue [bytes ],
124+ log : CapturedRequestLog ,
125+ snapshot : syrupy .SnapshotAssertion ,
84126) -> None :
85127 """Test sending a command."""
86128 # Queue HELLO response
87129 local_response_queue .put_nowait (
88- build_response (RoborockMessageProtocol .HELLO_RESPONSE , 1 , payload = b"ok" , random = TEST_RANDOM )
130+ build_raw_response (RoborockMessageProtocol .HELLO_RESPONSE , 1 , payload = b"ok" )
89131 )
90132
91133 await local_channel .connect ()
@@ -101,16 +143,136 @@ async def test_send_command(
101143 seq = cmd_seq ,
102144 payload = b'{"method":"get_status"}' ,
103145 )
146+ # Prepare a fake response to the command.
147+ response_queue .put (build_raw_response (RoborockMessageProtocol .RPC_RESPONSE , cmd_seq , payload = b'{"status": "ok"}' ))
148+
149+ subscriber = Subscriber ()
150+ unsub = await local_channel .subscribe (subscriber .append )
104151
105152 await local_channel .publish (msg )
106153
107- # Verify request
154+ # Verify request received by the server
108155 request_bytes = await local_received_requests .get ()
109156 assert local_received_requests .empty ()
110157
111158 # Decode request
112- decoder = create_local_decoder (local_key = LOCAL_KEY , connect_nonce = TEST_CONNECT_NONCE , ack_nonce = TEST_ACK_NONCE )
159+ decoder = create_local_decoder (local_key = LOCAL_KEY )
160+ msgs = list (decoder (request_bytes ))
161+ assert len (msgs ) == 1
162+ assert msgs [0 ].protocol == RoborockMessageProtocol .RPC_REQUEST
163+ assert msgs [0 ].payload == b'{"method":"get_status"}'
164+ assert msgs [0 ].version == LocalProtocolVersion .V1 .value .encode ()
165+
166+ # Verify response received by subscriber
167+ await subscriber .wait ()
168+ assert len (subscriber .messages ) == 1
169+ response_message = subscriber .messages [0 ]
170+ assert isinstance (response_message , RoborockMessage )
171+ assert response_message .protocol == RoborockMessageProtocol .RPC_RESPONSE
172+ assert response_message .payload == b'{"status": "ok"}'
173+
174+ unsub ()
175+
176+ assert snapshot == log
177+
178+
179+ async def test_l01_session (
180+ local_channel : LocalChannel ,
181+ local_response_queue : asyncio .Queue [bytes ],
182+ local_received_requests : asyncio .Queue [bytes ],
183+ log : CapturedRequestLog ,
184+ snapshot : syrupy .SnapshotAssertion ,
185+ ) -> None :
186+ """Test connecting to a device that speaks the L01 protocol.."""
187+ # Client first attempts 1.0 and we reply with a fake invalid response. The
188+ # response is arbitrary, and this could be improved by capturing a real L01
189+ # device response to a 1.0 message.
190+ local_response_queue .put_nowait (b"\x00 " )
191+ # The client attempts L01 protocol as a followup. The connect nonce uses
192+ # a deterministic number from deterministic_message_fixtures.
193+ connect_nonce = 9090
194+ local_response_queue .put_nowait (
195+ build_raw_response (
196+ RoborockMessageProtocol .HELLO_RESPONSE ,
197+ 1 ,
198+ payload = b"ok" ,
199+ version = LocalProtocolVersion .L01 ,
200+ connect_nonce = connect_nonce ,
201+ ack_nonce = None ,
202+ )
203+ )
204+
205+ await local_channel .connect ()
206+
207+ assert local_channel .is_connected
208+
209+ # Verify 1.0 HELLO request
210+ request_bytes = local_received_requests .get ()
211+ # Protocol is at offset 19 (2 bytes)
212+ # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
213+ assert len (request_bytes ) >= 21
214+ protocol_bytes = request_bytes [19 :21 ]
215+ assert int .from_bytes (protocol_bytes , "big" ) == RoborockMessageProtocol .HELLO_REQUEST
216+
217+ # Verify L01 HELLO request
218+ request_bytes = local_received_requests .get ()
219+ # Protocol is at offset 19 (2 bytes)
220+ # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
221+ assert len (request_bytes ) >= 21
222+ protocol_bytes = request_bytes [19 :21 ]
223+ assert int .from_bytes (protocol_bytes , "big" ) == RoborockMessageProtocol .HELLO_REQUEST
224+
225+ assert local_received_requests .empty ()
226+
227+ # Verify the channel switched to L01 protocol
228+ assert local_channel .protocol_version == LocalProtocolVersion .L01 .value
229+
230+ # We have established a connection. Now send some messages.
231+ # Publish an L01 command. Currently the caller of the local channel needs to
232+ # determine the protocol version to use, but this could be pushed inside of
233+ # the channel in the future.
234+ cmd_seq = 123
235+ msg = RoborockMessage (
236+ protocol = RoborockMessageProtocol .RPC_REQUEST ,
237+ seq = cmd_seq ,
238+ payload = b'{"method":"get_status"}' ,
239+ version = b"L01" ,
240+ )
241+ # Prepare a fake response to the command.
242+ local_response_queue .put_nowait (
243+ build_raw_response (
244+ RoborockMessageProtocol .RPC_RESPONSE ,
245+ cmd_seq ,
246+ payload = b'{"status": "ok"}' ,
247+ version = LocalProtocolVersion .L01 ,
248+ connect_nonce = connect_nonce ,
249+ ack_nonce = TEST_RANDOM ,
250+ )
251+ )
252+
253+ # Set up a subscriber to listen for the response then publish the message.
254+ subscriber = Subscriber ()
255+ unsub = await local_channel .subscribe (subscriber .append )
256+ await local_channel .publish (msg )
257+
258+ # Verify request received by the server
259+ request_bytes = await local_received_requests .get ()
260+ decoder = create_local_decoder (local_key = LOCAL_KEY , connect_nonce = connect_nonce , ack_nonce = TEST_RANDOM )
113261 msgs = list (decoder (request_bytes ))
114262 assert len (msgs ) == 1
115263 assert msgs [0 ].protocol == RoborockMessageProtocol .RPC_REQUEST
116264 assert msgs [0 ].payload == b'{"method":"get_status"}'
265+ assert msgs [0 ].version == LocalProtocolVersion .L01 .value .encode ()
266+
267+ # Verify fake response published by the server, received by subscriber
268+ await subscriber .wait ()
269+ assert len (subscriber .messages ) == 1
270+ response_message = subscriber .messages [0 ]
271+ assert isinstance (response_message , RoborockMessage )
272+ assert response_message .protocol == RoborockMessageProtocol .RPC_RESPONSE
273+ assert response_message .payload == b'{"status": "ok"}'
274+ assert response_message .version == LocalProtocolVersion .L01 .value .encode ()
275+
276+ unsub ()
277+
278+ assert snapshot == log
0 commit comments