Skip to content

Commit bf4bbd6

Browse files
committed
chore: Add protocol snapshot tests for the mqtt and local e2e tests
Update the local tests to send the exact same binary format for the initial hello requests. This also adds L01 coverage.
1 parent c317f8e commit bf4bbd6

File tree

7 files changed

+310
-35
lines changed

7 files changed

+310
-35
lines changed

roborock/protocol.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,12 @@ def _l01_iv(timestamp: int, nonce: int, sequence: int) -> bytes:
163163
return digest[:12]
164164

165165
@staticmethod
166-
def _l01_aad(timestamp: int, nonce: int, sequence: int, connect_nonce: int, ack_nonce: int) -> bytes:
166+
def _l01_aad(timestamp: int, nonce: int, sequence: int, connect_nonce: int, ack_nonce: int | None = None) -> bytes:
167167
"""Derive AAD for L01 protocol."""
168168
return (
169169
sequence.to_bytes(4, "big")
170170
+ connect_nonce.to_bytes(4, "big")
171-
+ ack_nonce.to_bytes(4, "big")
171+
+ (ack_nonce.to_bytes(4, "big") if ack_nonce is not None else b"")
172172
+ nonce.to_bytes(4, "big")
173173
+ timestamp.to_bytes(4, "big")
174174
)
@@ -181,7 +181,7 @@ def encrypt_gcm_l01(
181181
sequence: int,
182182
nonce: int,
183183
connect_nonce: int,
184-
ack_nonce: int,
184+
ack_nonce: int | None = None,
185185
) -> bytes:
186186
"""Encrypt plaintext for L01 protocol using AES-256-GCM."""
187187
if not isinstance(plaintext, bytes):
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# serializer version: 1
2+
# name: test_connect
3+
[local >]
4+
00000000 00 00 00 15 31 2e 30 00 00 00 01 00 00 23 82 68 |....1.0......#.h|
5+
00000010 a6 a2 24 00 00 e6 b9 24 63 |..$....$c|
6+
[local <]
7+
00000000 00 00 00 27 31 2e 30 00 00 00 01 00 00 00 17 68 |...'1.0........h|
8+
00000010 a6 a2 23 00 01 00 10 cb 93 c7 39 b9 21 53 43 48 |..#.......9.!SCH|
9+
00000020 83 b3 c2 af 0f 51 2c da 9e ea 3b |.....Q,...;|
10+
# ---
11+
# name: test_l01_session
12+
[local >]
13+
00000000 00 00 00 15 31 2e 30 00 00 00 01 00 00 23 82 68 |....1.0......#.h|
14+
00000010 a6 a2 24 00 00 e6 b9 24 63 |..$....$c|
15+
[local <]
16+
00000000 00 |.|
17+
[local >]
18+
00000000 00 00 00 15 4c 30 31 00 00 00 01 00 00 23 82 68 |....L01......#.h|
19+
00000010 a6 a2 25 00 00 ee 2f 30 e8 |..%.../0.|
20+
[local <]
21+
00000000 00 00 00 29 4c 30 31 00 00 00 01 00 00 00 17 68 |...)L01........h|
22+
00000010 a6 a2 23 00 01 00 12 a0 4a ec 75 88 03 75 0f d2 |..#.....J.u..u..|
23+
00000020 40 33 69 02 f4 71 50 72 f3 81 56 80 f4 |@3i..qPr..V..|
24+
[local >]
25+
00000000 00 00 00 3e 4c 30 31 00 00 00 7b 00 00 23 83 68 |...>L01...{..#.h|
26+
00000010 a6 a2 26 00 65 00 27 9e fd c2 42 b7 01 b4 eb 9c |..&.e.'...B.....|
27+
00000020 00 84 4f fd 51 1f bc a5 65 12 c2 dc 45 0e 21 cb |..O.Q...e...E.!.|
28+
00000030 45 dc bb 0a ba 16 84 28 a7 33 e5 e2 fa a8 f1 f2 |E......(.3......|
29+
00000040 ec f4 |..|
30+
[local <]
31+
00000000 00 00 00 37 4c 30 31 00 00 00 7b 00 00 00 17 68 |...7L01...{....h|
32+
00000010 a6 a2 27 00 66 00 20 b7 72 49 8a 64 eb 16 a5 71 |..'.f. .rI.d...q|
33+
00000020 73 eb 9e 7e 37 64 3e 75 c0 70 ea 39 4e de 82 1f |s..~7d>u.p.9N...|
34+
00000030 e2 29 86 de 4a 7b 38 20 55 12 8a |.)..J{8 U..|
35+
# ---
36+
# name: test_send_command
37+
[local >]
38+
00000000 00 00 00 15 31 2e 30 00 00 00 01 00 00 23 82 68 |....1.0......#.h|
39+
00000010 a6 a2 24 00 00 e6 b9 24 63 |..$....$c|
40+
[local <]
41+
00000000 00 00 00 27 31 2e 30 00 00 00 01 00 00 00 17 68 |...'1.0........h|
42+
00000010 a6 a2 23 00 01 00 10 cb 93 c7 39 b9 21 53 43 48 |..#.......9.!SCH|
43+
00000020 83 b3 c2 af 0f 51 2c da 9e ea 3b |.....Q,...;|
44+
[local >]
45+
00000000 00 00 00 37 31 2e 30 00 00 00 7b 00 00 23 83 68 |...71.0...{..#.h|
46+
00000010 a6 a2 25 00 65 00 20 91 5b 1f 43 34 d5 22 47 9f |..%.e. .[.C4."G.|
47+
00000020 59 4e 45 53 85 f9 c6 6e f2 eb 27 eb 6d 03 d8 92 |YNES...n..'.m...|
48+
00000030 5b 30 83 b4 a4 ea f5 85 be 38 57 |[0.......8W|
49+
[local <]
50+
00000000 00 00 00 37 31 2e 30 00 00 00 7b 00 00 00 17 68 |...71.0...{....h|
51+
00000010 a6 a2 26 00 66 00 20 07 8b 28 60 a8 08 18 12 47 |..&.f. ..(`....G|
52+
00000020 05 20 3e f5 53 e3 fd 4a cc 03 72 7b b4 2c d9 84 |. >.S..J..r{.,..|
53+
00000030 7f 4b 18 d8 76 7d 5c 65 87 7c 2d |.K..v}\e.|-|
54+
# ---
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# serializer version: 1
2+
# name: test_session_e2e_publish_message
3+
[mqtt <]
4+
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
5+
[mqtt >]
6+
00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....|
7+
00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw|
8+
00000020 6f 72 64 |ord|
9+
[mqtt >]
10+
00000000 30 41 00 07 74 6f 70 69 63 2d 31 00 31 2e 30 00 |0A..topic-1.1.0.|
11+
00000010 00 01 c8 00 00 23 82 68 a6 a2 23 00 65 00 20 91 |.....#.h..#.e. .|
12+
00000020 22 f1 91 1a 6e 89 71 ca ec 2d 44 2a 16 57 e7 5b |"...n.q..-D*.W.[|
13+
00000030 4a 9a c8 97 4b 13 37 3b f5 81 13 45 7c e7 48 03 |J...K.7;...E|.H.|
14+
00000040 99 71 bf |.q.|
15+
# ---
16+
# name: test_session_e2e_receive_message
17+
[mqtt <]
18+
00000000 20 09 02 00 06 22 00 0a 21 00 14 | ...."..!..|
19+
[mqtt >]
20+
00000000 10 21 00 04 4d 51 54 54 05 c2 00 3c 00 00 00 00 |.!..MQTT...<....|
21+
00000010 08 75 73 65 72 6e 61 6d 65 00 08 70 61 73 73 77 |.username..passw|
22+
00000020 6f 72 64 |ord|
23+
[mqtt <]
24+
00000000 90 04 00 01 00 00 |......|
25+
[mqtt >]
26+
00000000 82 0d 00 01 00 00 07 74 6f 70 69 63 2d 31 00 |.......topic-1.|
27+
[mqtt <]
28+
00000000 30 31 00 07 74 6f 70 69 63 2d 31 00 31 2e 30 00 |01..topic-1.1.0.|
29+
00000010 00 00 7b 00 00 23 82 68 a6 a2 23 00 66 00 10 45 |..{..#.h..#.f..E|
30+
00000020 3b c3 2b 12 a6 77 d9 55 f6 e0 89 f5 93 a5 30 5d |;.+..w.U......0]|
31+
00000030 a0 72 fa |.r.|
32+
# ---

tests/e2e/test_local_session.py

Lines changed: 190 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,101 @@
11
"""End-to-end tests for LocalChannel using fake sockets."""
22

33
import asyncio
4-
from collections.abc import AsyncGenerator
5-
from unittest.mock import patch
4+
from collections.abc import AsyncGenerator, Generator, Callable
5+
from unittest.mock import patch, Mock
6+
from typing import Any
67

78
import pytest
9+
import syrupy
810

911
from roborock.devices.local_channel import LocalChannel
10-
from roborock.protocol import create_local_decoder, create_local_encoder
12+
from roborock.protocol import MessageParser, create_local_decoder
13+
from roborock.protocols.v1_protocol import LocalProtocolVersion
1114
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
15+
from tests.fixtures.logging import CapturedRequestLog
16+
from tests.fixtures.mqtt import Subscriber
17+
from tests.fixtures.local_async_fixtures import AsyncLocalRequestHandler
1218
from tests.mock_data import LOCAL_KEY
1319

1420
TEST_HOST = "192.168.1.100"
1521
TEST_DEVICE_UID = "test_device_uid"
16-
TEST_CONNECT_NONCE = 12345
17-
TEST_ACK_NONCE = 67890
18-
TEST_RANDOM = 13579
22+
TEST_RANDOM = 23
23+
24+
25+
@pytest.fixture
26+
def auto_deterministic_message_fixtures(deterministic_message_fixtures: None) -> None:
27+
"""Auto-use deterministic message fixtures for all tests in this module."""
28+
pass
29+
30+
31+
@pytest.fixture(name="mock_create_local_connection")
32+
def create_local_connection_fixture(
33+
local_async_request_handler: AsyncLocalRequestHandler, log: CapturedRequestLog
34+
) -> Generator[None, None, None]:
35+
"""Fixture that overrides the transport creation to wire it up to the mock socket."""
36+
37+
async def create_connection(protocol_factory: Callable[[], asyncio.Protocol], *args, **kwargs) -> tuple[Any, Any]:
38+
protocol = protocol_factory()
39+
40+
async def handle_write(data: bytes) -> None:
41+
log.add_log_entry("[local >]", data)
42+
response = await local_async_request_handler(data)
43+
if response is not None:
44+
log.add_log_entry("[local <]", response)
45+
# Call data_received directly to avoid loop scheduling issues in test
46+
protocol.data_received(response)
47+
48+
closed = asyncio.Event()
49+
50+
mock_transport = Mock()
51+
mock_transport.write = handle_write
52+
mock_transport.close = closed.set
53+
mock_transport.is_reading = lambda: not closed.is_set()
54+
55+
return (mock_transport, protocol)
56+
57+
with patch("roborock.devices.local_channel.asyncio.get_running_loop") as mock_loop:
58+
mock_loop.return_value.create_connection.side_effect = create_connection
59+
yield
1960

2061

2162
@pytest.fixture(name="local_channel")
22-
async def local_channel_fixture(mock_async_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
23-
with patch(
24-
"roborock.devices.local_channel.get_next_int", return_value=TEST_CONNECT_NONCE, device_uid=TEST_DEVICE_UID
25-
):
26-
channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID)
27-
yield channel
28-
channel.close()
63+
async def local_channel_fixture(mock_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
64+
channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID)
65+
yield channel
66+
channel.close()
2967

3068

31-
def build_response(
69+
def build_raw_response(
3270
protocol: RoborockMessageProtocol,
3371
seq: int,
3472
payload: bytes,
35-
random: int,
73+
version: LocalProtocolVersion = LocalProtocolVersion.V1,
74+
connect_nonce: int | None = None,
75+
ack_nonce: int | None = None,
3676
) -> bytes:
37-
"""Build an encoded response message."""
38-
if protocol == RoborockMessageProtocol.HELLO_RESPONSE:
39-
encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=None)
40-
else:
41-
encoder = create_local_encoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)
42-
43-
msg = RoborockMessage(
77+
"""Build an encoded RPC response message."""
78+
message = RoborockMessage(
4479
protocol=protocol,
45-
random=random,
80+
random=23,
4681
seq=seq,
4782
payload=payload,
83+
version=version.value.encode(),
4884
)
49-
return encoder(msg)
85+
return MessageParser.build(message, local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=ack_nonce)
5086

5187

5288
async def test_connect(
5389
local_channel: LocalChannel,
5490
local_response_queue: asyncio.Queue[bytes],
5591
local_received_requests: asyncio.Queue[bytes],
92+
log: CapturedRequestLog,
93+
snapshot: syrupy.SnapshotAssertion,
5694
) -> None:
5795
"""Test connecting to the device."""
5896
# Queue HELLO response with payload to ensure it can be parsed
5997
local_response_queue.put_nowait(
60-
build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)
98+
build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok")
6199
)
62100

63101
await local_channel.connect()
@@ -76,16 +114,20 @@ async def test_connect(
76114
protocol_bytes = request_bytes[19:21]
77115
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST
78116

117+
assert snapshot == log
118+
79119

80120
async def test_send_command(
81121
local_channel: LocalChannel,
82122
local_response_queue: asyncio.Queue[bytes],
83123
local_received_requests: asyncio.Queue[bytes],
124+
log: CapturedRequestLog,
125+
snapshot: syrupy.SnapshotAssertion,
84126
) -> None:
85127
"""Test sending a command."""
86128
# Queue HELLO response
87129
local_response_queue.put_nowait(
88-
build_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok", random=TEST_RANDOM)
130+
build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok")
89131
)
90132

91133
await local_channel.connect()
@@ -101,16 +143,136 @@ async def test_send_command(
101143
seq=cmd_seq,
102144
payload=b'{"method":"get_status"}',
103145
)
146+
# Prepare a fake response to the command.
147+
response_queue.put(build_raw_response(RoborockMessageProtocol.RPC_RESPONSE, cmd_seq, payload=b'{"status": "ok"}'))
148+
149+
subscriber = Subscriber()
150+
unsub = await local_channel.subscribe(subscriber.append)
104151

105152
await local_channel.publish(msg)
106153

107-
# Verify request
154+
# Verify request received by the server
108155
request_bytes = await local_received_requests.get()
109156
assert local_received_requests.empty()
110157

111158
# Decode request
112-
decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=TEST_CONNECT_NONCE, ack_nonce=TEST_ACK_NONCE)
159+
decoder = create_local_decoder(local_key=LOCAL_KEY)
160+
msgs = list(decoder(request_bytes))
161+
assert len(msgs) == 1
162+
assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
163+
assert msgs[0].payload == b'{"method":"get_status"}'
164+
assert msgs[0].version == LocalProtocolVersion.V1.value.encode()
165+
166+
# Verify response received by subscriber
167+
await subscriber.wait()
168+
assert len(subscriber.messages) == 1
169+
response_message = subscriber.messages[0]
170+
assert isinstance(response_message, RoborockMessage)
171+
assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE
172+
assert response_message.payload == b'{"status": "ok"}'
173+
174+
unsub()
175+
176+
assert snapshot == log
177+
178+
179+
async def test_l01_session(
180+
local_channel: LocalChannel,
181+
local_response_queue: asyncio.Queue[bytes],
182+
local_received_requests: asyncio.Queue[bytes],
183+
log: CapturedRequestLog,
184+
snapshot: syrupy.SnapshotAssertion,
185+
) -> None:
186+
"""Test connecting to a device that speaks the L01 protocol.."""
187+
# Client first attempts 1.0 and we reply with a fake invalid response. The
188+
# response is arbitrary, and this could be improved by capturing a real L01
189+
# device response to a 1.0 message.
190+
local_response_queue.put_nowait(b"\x00")
191+
# The client attempts L01 protocol as a followup. The connect nonce uses
192+
# a deterministic number from deterministic_message_fixtures.
193+
connect_nonce = 9090
194+
local_response_queue.put_nowait(
195+
build_raw_response(
196+
RoborockMessageProtocol.HELLO_RESPONSE,
197+
1,
198+
payload=b"ok",
199+
version=LocalProtocolVersion.L01,
200+
connect_nonce=connect_nonce,
201+
ack_nonce=None,
202+
)
203+
)
204+
205+
await local_channel.connect()
206+
207+
assert local_channel.is_connected
208+
209+
# Verify 1.0 HELLO request
210+
request_bytes = local_received_requests.get()
211+
# Protocol is at offset 19 (2 bytes)
212+
# Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
213+
assert len(request_bytes) >= 21
214+
protocol_bytes = request_bytes[19:21]
215+
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST
216+
217+
# Verify L01 HELLO request
218+
request_bytes = local_received_requests.get()
219+
# Protocol is at offset 19 (2 bytes)
220+
# Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
221+
assert len(request_bytes) >= 21
222+
protocol_bytes = request_bytes[19:21]
223+
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST
224+
225+
assert local_received_requests.empty()
226+
227+
# Verify the channel switched to L01 protocol
228+
assert local_channel.protocol_version == LocalProtocolVersion.L01.value
229+
230+
# We have established a connection. Now send some messages.
231+
# Publish an L01 command. Currently the caller of the local channel needs to
232+
# determine the protocol version to use, but this could be pushed inside of
233+
# the channel in the future.
234+
cmd_seq = 123
235+
msg = RoborockMessage(
236+
protocol=RoborockMessageProtocol.RPC_REQUEST,
237+
seq=cmd_seq,
238+
payload=b'{"method":"get_status"}',
239+
version=b"L01",
240+
)
241+
# Prepare a fake response to the command.
242+
local_response_queue.put_nowait(
243+
build_raw_response(
244+
RoborockMessageProtocol.RPC_RESPONSE,
245+
cmd_seq,
246+
payload=b'{"status": "ok"}',
247+
version=LocalProtocolVersion.L01,
248+
connect_nonce=connect_nonce,
249+
ack_nonce=TEST_RANDOM,
250+
)
251+
)
252+
253+
# Set up a subscriber to listen for the response then publish the message.
254+
subscriber = Subscriber()
255+
unsub = await local_channel.subscribe(subscriber.append)
256+
await local_channel.publish(msg)
257+
258+
# Verify request received by the server
259+
request_bytes = await local_received_requests.get()
260+
decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=TEST_RANDOM)
113261
msgs = list(decoder(request_bytes))
114262
assert len(msgs) == 1
115263
assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
116264
assert msgs[0].payload == b'{"method":"get_status"}'
265+
assert msgs[0].version == LocalProtocolVersion.L01.value.encode()
266+
267+
# Verify fake response published by the server, received by subscriber
268+
await subscriber.wait()
269+
assert len(subscriber.messages) == 1
270+
response_message = subscriber.messages[0]
271+
assert isinstance(response_message, RoborockMessage)
272+
assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE
273+
assert response_message.payload == b'{"status": "ok"}'
274+
assert response_message.version == LocalProtocolVersion.L01.value.encode()
275+
276+
unsub()
277+
278+
assert snapshot == log

0 commit comments

Comments
 (0)