Skip to content
113 changes: 113 additions & 0 deletions roborock/mqtt_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import asyncio
import dataclasses
import logging
from collections.abc import Coroutine
from typing import Callable, Self
from urllib.parse import urlparse

import aiomqtt
from aiomqtt import TLSParameters

from roborock import RoborockException, UserData
from roborock.protocol import MessageParser, md5hex

from .containers import DeviceData

LOGGER = logging.getLogger(__name__)


@dataclasses.dataclass
class ClientWrapper:
publish_function: Coroutine[None]
unsubscribe_function: Coroutine[None]
subscribe_function: Coroutine[None]


class RoborockMqttManager:
client_wrappers: dict[str, ClientWrapper] = {}
_instance: Self = None

def __new__(cls) -> RoborockMqttManager:
Comment thread
Lash-L marked this conversation as resolved.
Outdated
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

async def connect(self, user_data: UserData):
Comment thread
Lash-L marked this conversation as resolved.
Outdated
# Add some kind of lock so we don't try to connect if we are already trying to connect the same account.
if user_data.rriot.u not in self.client_wrappers:
loop = asyncio.get_event_loop()
loop.create_task(self._new_connect(user_data))

async def _new_connect(self, user_data: UserData):
rriot = user_data.rriot
mqtt_user = rriot.u
hashed_user = md5hex(mqtt_user + ":" + rriot.k)[2:10]
url = urlparse(rriot.r.m)
if not isinstance(url.hostname, str):
raise RoborockException("Url parsing returned an invalid hostname")
mqtt_host = str(url.hostname)
mqtt_port = url.port

mqtt_password = rriot.s
hashed_password = md5hex(mqtt_password + ":" + rriot.k)[16:]
LOGGER.debug("Connecting to %s for %s", mqtt_host, mqtt_user)

async with aiomqtt.Client(
hostname=mqtt_host,
port=mqtt_port,
username=hashed_user,
password=hashed_password,
keepalive=60,
tls_params=TLSParameters(),
) as client:
# TODO: Handle logic for when client loses connection
LOGGER.info("Connected to %s for %s", mqtt_host, mqtt_user)
callbacks: dict[str, Callable] = {}
device_map = {}

async def publish(device: DeviceData, payload: bytes):
await client.publish(f"rr/m/i/{mqtt_user}/{hashed_user}/{device.device.duid}", payload=payload)

async def subscribe(device: DeviceData, callback):
LOGGER.debug(f"Subscribing to rr/m/o/{mqtt_user}/{hashed_user}/{device.device.duid}")
await client.subscribe(f"rr/m/o/{mqtt_user}/{hashed_user}/{device.device.duid}")
LOGGER.debug(f"Subscribed to rr/m/o/{mqtt_user}/{hashed_user}/{device.device.duid}")
callbacks[device.device.duid] = callback
device_map[device.device.duid] = device
return

async def unsubscribe(device: DeviceData):
await client.unsubscribe(f"rr/m/o/{mqtt_user}/{hashed_user}/{device.device.duid}")

self.client_wrappers[user_data.rriot.u] = ClientWrapper(
publish_function=publish, unsubscribe_function=unsubscribe, subscribe_function=subscribe
)
async for message in client.messages:
try:
device_id = message.topic.value.split("/")[-1]
device = device_map[device_id]
message = MessageParser.parse(message.payload, device.device.local_key)
callbacks[device_id](message)
except Exception:
...

async def disconnect(self, user_data: UserData):
await self.client_wrappers[user_data.rriot.u].disconnect()

async def subscribe(self, user_data: UserData, device: DeviceData, callback):
if user_data.rriot.u not in self.client_wrappers:
await self.connect(user_data)
# add some kind of lock to make sure we don't subscribe until the connection is successful
await asyncio.sleep(2)
await self.client_wrappers[user_data.rriot.u].subscribe_function(device, callback)

async def unsubscribe(self):
pass

async def publish(self, user_data: UserData, device, payload: bytes):
LOGGER.debug("Publishing topic for %s, Message: %s", device.device.duid, payload)
if user_data.rriot.u not in self.client_wrappers:
await self.connect(user_data)
await self.client_wrappers[user_data.rriot.u].publish_function(device, payload)
117 changes: 117 additions & 0 deletions roborock/roborock_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import base64
import json
import logging
import math
import secrets
import time

from . import RoborockCommand
from .containers import DeviceData, UserData
from .mqtt_manager import RoborockMqttManager
from .protocol import MessageParser, Utils
from .roborock_message import RoborockMessage, RoborockMessageProtocol
from .util import RoborockLoggerAdapter, get_next_int

_LOGGER = logging.getLogger(__name__)


class RoborockDevice:
def __init__(self, user_data: UserData, device_info: DeviceData):
self.user_data = user_data
self.device_info = device_info
self.data = None
self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER)
self._mqtt_endpoint = base64.b64encode(Utils.md5(user_data.rriot.k.encode())[8:14]).decode()
self._local_endpoint = "abc"
self._nonce = secrets.token_bytes(16)
self.manager = RoborockMqttManager()
self.update_commands = self.determine_supported_commands()

def determine_supported_commands(self):
# All devices support these
supported_commands = {
RoborockCommand.GET_CONSUMABLE,
RoborockCommand.GET_STATUS,
RoborockCommand.GET_CLEAN_SUMMARY,
}
# Get what features we can from the feature_set info.

# If a command is not described in feature_set, we should just add it anyways and then let it fail on the first call and remove it.
robot_new_features = int(self.device_info.device.feature_set)
new_feature_info_str = self.device_info.device.new_feature_set
if 33554432 & int(robot_new_features):
supported_commands.add(RoborockCommand.GET_DUST_COLLECTION_MODE)
if 2 & int(new_feature_info_str[-8:], 16):
# TODO: May not be needed as i think this can just be found in Status, but just POC
supported_commands.add(RoborockCommand.APP_GET_CLEAN_ESTIMATE_INFO)
return supported_commands

async def connect(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you have a TODO above, but i think it would be best to just now pass in a connected session and not add this method, or add a TODO remove here i guess.

"""Connect via MQTT and Local if possible."""
await self.manager.subscribe(self.user_data, self.device_info, self.on_message)
await self.update()

async def update(self):
for cmd in self.update_commands:
await self.send_message(method=cmd)

def _get_payload(
self,
method: RoborockCommand | str,
params: list | dict | int | None = None,
secured=False,
use_cloud: bool = False,
):
timestamp = math.floor(time.time())
request_id = get_next_int(10000, 32767)
inner = {
"id": request_id,
"method": method,
"params": params or [],
}
if secured:
inner["security"] = {
"endpoint": self._mqtt_endpoint if use_cloud else self._local_endpoint,
"nonce": self._nonce.hex().lower(),
}
payload = bytes(
json.dumps(
{
"dps": {"101": json.dumps(inner, separators=(",", ":"))},
"t": timestamp,
},
separators=(",", ":"),
).encode()
)
return request_id, timestamp, payload

async def send_message(
self, method: RoborockCommand | str, params: list | dict | int | None = None, use_cloud: bool = True
):
request_id, timestamp, payload = self._get_payload(method, params, True, use_cloud)
request_protocol = RoborockMessageProtocol.RPC_REQUEST
roborock_message = RoborockMessage(timestamp=timestamp, protocol=request_protocol, payload=payload)

local_key = self.device_info.device.local_key
msg = MessageParser.build(roborock_message, local_key, False)
if use_cloud:
await self.manager.publish(self.user_data, self.device_info, msg)
else:
# Handle doing local commands
pass

def on_message(self, message: RoborockMessage):
# If message is command not supported - remove from self.update_commands

# If message is an error - log it?

# If message is 'ok' - ignore it

# If message is anything else - store ids, and map back to id to determine message type.
# Then update self.data

# If we haven't received a message in X seconds, the device is likely offline. I think we can continue the connection,
# but we should have some way to mark ourselves as unavailable.

# This should also probably be split with on_cloud_message and on_local_message.
print(message)
Loading