Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions roborock/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .api import KEEPALIVE, RoborockClient
from .containers import DeviceData, UserData
from .exceptions import RoborockException, VacuumError
from .protocol import MessageParser, md5hex
from .protocol import Decoder, Encoder, create_mqtt_decoder, create_mqtt_encoder, md5hex
from .roborock_future import RoborockFuture

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,6 +74,8 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
self._waiting_queue: dict[int, RoborockFuture] = {}
self._mutex = Lock()
self._decoder: Decoder = create_mqtt_decoder(device_info.device.local_key)
self._encoder: Encoder = create_mqtt_encoder(device_info.device.local_key)

def _mqtt_on_connect(self, *args, **kwargs):
_, __, ___, rc, ____ = args
Expand Down Expand Up @@ -102,7 +104,7 @@ def _mqtt_on_connect(self, *args, **kwargs):
def _mqtt_on_message(self, *args, **kwargs):
client, __, msg = args
try:
messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key)
messages = self._decoder(msg.payload)
super().on_message_received(messages)
except Exception as ex:
self._logger.exception(ex)
Expand Down
12 changes: 5 additions & 7 deletions roborock/local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import DeviceData
from .api import RoborockClient
from .exceptions import RoborockConnectionException, RoborockException
from .protocol import MessageParser
from .protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
from .roborock_message import RoborockMessage, RoborockMessageProtocol

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,20 +44,18 @@ def __init__(self, device_data: DeviceData):
self.host = device_data.host
self._batch_structs: list[RoborockMessage] = []
self._executing = False
self.remaining = b""
self.transport: Transport | None = None
self._mutex = Lock()
self.keep_alive_task: TimerHandle | None = None
RoborockClient.__init__(self, device_data)
self._local_protocol = _LocalProtocol(self._data_received, self._connection_lost)
self._encoder: Encoder = create_local_encoder(device_data.device.local_key)
self._decoder: Decoder = create_local_decoder(device_data.device.local_key)

def _data_received(self, message):
"""Called when data is received from the transport."""
if self.remaining:
message = self.remaining + message
self.remaining = b""
parser_msg, self.remaining = MessageParser.parse(message, local_key=self.device_info.device.local_key)
self.on_message_received(parser_msg)
parsed_msg = self._decoder(message)
self.on_message_received(parsed_msg)

def _connection_lost(self, exc: Exception | None):
"""Called when the transport connection is lost."""
Expand Down
53 changes: 53 additions & 0 deletions roborock/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,56 @@ def build(

MessageParser: _Parser = _Parser(_Messages, True)
BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)


Decoder = Callable[[bytes], list[RoborockMessage]]
Encoder = Callable[[RoborockMessage], bytes]


def create_mqtt_decoder(local_key: str) -> Decoder:
"""Create a decoder for MQTT messages."""

def decode(data: bytes) -> list[RoborockMessage]:
"""Parse the given data into Roborock messages."""
messages, _ = MessageParser.parse(data, local_key)
return messages

return decode


def create_mqtt_encoder(local_key: str) -> Encoder:
"""Create an encoder for MQTT messages."""

def encode(messages: RoborockMessage) -> bytes:
"""Build the given Roborock messages into a byte string."""
return MessageParser.build(messages, local_key, prefixed=False)

return encode


def create_local_decoder(local_key: str) -> Decoder:
"""Create a decoder for local API messages."""

# This buffer is used to accumulate bytes until a complete message can be parsed.
# It is defined outside the decode function to maintain state across calls.
buffer: bytes = b""

def decode(bytes: bytes) -> list[RoborockMessage]:
"""Parse the given data into Roborock messages."""
nonlocal buffer
buffer += bytes
parsed_messages, remaining = MessageParser.parse(buffer, local_key=local_key)
buffer = remaining
return parsed_messages
Comment thread
allenporter marked this conversation as resolved.

return decode


def create_local_encoder(local_key: str) -> Encoder:
"""Create an encoder for local API messages."""

def encode(message: RoborockMessage) -> bytes:
"""Called when data is sent to the transport."""
return MessageParser.build(message, local_key=local_key)

return encode
4 changes: 1 addition & 3 deletions roborock/version_1_apis/roborock_local_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
from ..exceptions import VacuumError
from ..protocol import MessageParser
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
from ..util import RoborockLoggerAdapter
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
Expand Down Expand Up @@ -57,8 +56,7 @@ async def send_message(self, roborock_message: RoborockMessage):
response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
if request_id is None:
raise RoborockException(f"Failed build message {roborock_message}")
local_key = self.device_info.device.local_key
msg = MessageParser.build(roborock_message, local_key=local_key)
msg = self._encoder(roborock_message)
if method:
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
# Send the command to the Roborock device
Expand Down
6 changes: 2 additions & 4 deletions roborock/version_1_apis/roborock_mqtt_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..containers import DeviceData, UserData
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
from ..protocol import MessageParser, Utils
from ..protocol import Utils
from ..roborock_message import (
RoborockMessage,
RoborockMessageProtocol,
Expand Down Expand Up @@ -47,9 +47,7 @@ async def send_message(self, roborock_message: RoborockMessage):
response_protocol = (
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
)

local_key = self.device_info.device.local_key
msg = MessageParser.build(roborock_message, local_key, False)
msg = self._encoder(roborock_message)
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
async_response = self._async_response(request_id, response_protocol)
self._send_msg_raw(msg)
Expand Down
4 changes: 1 addition & 3 deletions roborock/version_a01_apis/roborock_mqtt_client_a01.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from roborock.cloud_api import RoborockMqttClient
from roborock.containers import DeviceData, RoborockCategory, UserData
from roborock.exceptions import RoborockException
from roborock.protocol import MessageParser
from roborock.roborock_message import (
RoborockDyadDataProtocol,
RoborockMessage,
Expand Down Expand Up @@ -43,8 +42,7 @@ async def send_message(self, roborock_message: RoborockMessage):
await self.validate_connection()
response_protocol = RoborockMessageProtocol.RPC_RESPONSE

local_key = self.device_info.device.local_key
m = MessageParser.build(roborock_message, local_key, prefixed=False)
m = self._encoder(roborock_message)
# self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
payload = json.loads(unpad(roborock_message.payload, AES.block_size))
futures = []
Expand Down