11from __future__ import annotations
22
33import asyncio
4+ import hashlib
45import json
56import logging
67from asyncio import BaseTransport , Lock
78
89from construct import ( # type: ignore
910 Bytes ,
1011 Checksum ,
12+ GreedyBytes ,
1113 Int16ub ,
1214 Int32ub ,
15+ Prefixed ,
1316 RawCopy ,
1417 Struct ,
1518)
19+ from Crypto .Cipher import AES
1620
21+ from roborock import RoborockException
1722from roborock .containers import BroadcastMessage
1823from roborock .protocol import EncryptionAdapter , Utils , _Parser
1924
@@ -29,14 +34,41 @@ def __init__(self, timeout: int = 5):
2934 self .devices_found : list [BroadcastMessage ] = []
3035 self ._mutex = Lock ()
3136
32- def datagram_received (self , data , _ ):
33- [broadcast_message ], _ = BroadcastParser .parse (data )
34- if broadcast_message .payload :
35- parsed_message = BroadcastMessage .from_dict (json .loads (broadcast_message .payload ))
36- _LOGGER .debug (f"Received broadcast: { parsed_message } " )
37- self .devices_found .append (parsed_message )
37+ def datagram_received (self , data : bytes , _ ):
38+ """Handle incoming broadcast datagrams."""
39+ try :
40+ version = data [:3 ]
41+ if version == b"L01" :
42+ [parsed_msg ], _ = L01Parser .parse (data )
43+ encrypted_payload = parsed_msg .payload
44+ if encrypted_payload is None :
45+ raise RoborockException ("No encrypted payload found in broadcast message" )
46+ ciphertext = encrypted_payload [:- 16 ]
47+ tag = encrypted_payload [- 16 :]
3848
39- async def discover (self ):
49+ key = hashlib .sha256 (BROADCAST_TOKEN ).digest ()
50+ iv_digest_input = data [:9 ]
51+ digest = hashlib .sha256 (iv_digest_input ).digest ()
52+ iv = digest [:12 ]
53+
54+ cipher = AES .new (key , AES .MODE_GCM , nonce = iv )
55+ decrypted_payload_bytes = cipher .decrypt_and_verify (ciphertext , tag )
56+ json_payload = json .loads (decrypted_payload_bytes )
57+ parsed_message = BroadcastMessage (duid = json_payload ["duid" ], ip = json_payload ["ip" ], version = version )
58+ _LOGGER .debug (f"Received L01 broadcast: { parsed_message } " )
59+ self .devices_found .append (parsed_message )
60+ else :
61+ # Fallback to the original protocol parser for other versions
62+ [broadcast_message ], _ = BroadcastParser .parse (data )
63+ if broadcast_message .payload :
64+ json_payload = json .loads (broadcast_message .payload )
65+ parsed_message = BroadcastMessage (duid = json_payload ["duid" ], ip = json_payload ["ip" ], version = version )
66+ _LOGGER .debug (f"Received broadcast: { parsed_message } " )
67+ self .devices_found .append (parsed_message )
68+ except Exception as e :
69+ _LOGGER .warning (f"Failed to decode message: { data !r} . Error: { e } " )
70+
71+ async def discover (self ) -> list [BroadcastMessage ]:
4072 async with self ._mutex :
4173 try :
4274 loop = asyncio .get_event_loop ()
@@ -64,5 +96,19 @@ def close(self):
6496 "checksum" / Checksum (Int32ub , Utils .crc , lambda ctx : ctx .message .data ),
6597)
6698
99+ _L01BroadcastMessage = Struct (
100+ "message"
101+ / RawCopy (
102+ Struct (
103+ "version" / Bytes (3 ),
104+ "field1" / Bytes (4 ), # Unknown field
105+ "field2" / Bytes (2 ), # Unknown field
106+ "payload" / Prefixed (Int16ub , GreedyBytes ), # Encrypted payload with length prefix
107+ )
108+ ),
109+ "checksum" / Checksum (Int32ub , Utils .crc , lambda ctx : ctx .message .data ),
110+ )
111+
67112
68113BroadcastParser : _Parser = _Parser (_BroadcastMessage , False )
114+ L01Parser : _Parser = _Parser (_L01BroadcastMessage , False )
0 commit comments