|
5 | 5 |
|
6 | 6 | import logging |
7 | 7 | from abc import ABC, abstractmethod |
| 8 | +from collections.abc import Callable |
8 | 9 | from dataclasses import fields |
9 | | -from typing import ClassVar |
| 10 | +from typing import Any, ClassVar |
10 | 11 |
|
| 12 | +from roborock.callbacks import CallbackList |
11 | 13 | from roborock.data import RoborockBase |
12 | 14 | from roborock.protocols.v1_protocol import V1RpcChannel |
| 15 | +from roborock.roborock_message import RoborockDataProtocol |
13 | 16 | from roborock.roborock_typing import RoborockCommand |
14 | 17 |
|
15 | 18 | _LOGGER = logging.getLogger(__name__) |
@@ -173,3 +176,74 @@ def wrapper(*args, **kwargs): |
173 | 176 |
|
174 | 177 | cls.map_rpc_channel = True # type: ignore[attr-defined] |
175 | 178 | return wrapper |
| 179 | + |
| 180 | + |
| 181 | +# TODO(allenporter): Merge with roborock.devices.traits.b01.q10.common.TraitUpdateListener |
| 182 | +class TraitUpdateListener(ABC): |
| 183 | + """Trait update listener. |
| 184 | +
|
| 185 | + This is a base class for traits to support notifying listeners when they |
| 186 | + have been updated. Clients may register callbacks to be notified when the |
| 187 | + trait has been updated. When the listener callback is invoked, the client |
| 188 | + should read the trait's properties to get the updated values. |
| 189 | + """ |
| 190 | + |
| 191 | + def __init__(self, logger: logging.Logger) -> None: |
| 192 | + """Initialize the trait update listener.""" |
| 193 | + self._update_callbacks: CallbackList[None] = CallbackList(logger=logger) |
| 194 | + |
| 195 | + def add_update_listener(self, callback: Callable[[], None]) -> Callable[[], None]: |
| 196 | + """Register a callback when the trait has been updated. |
| 197 | +
|
| 198 | + Returns a callable to remove the listener. |
| 199 | + """ |
| 200 | + # We wrap the callback to ignore the value passed to it. |
| 201 | + return self._update_callbacks.add_callback(lambda _: callback()) |
| 202 | + |
| 203 | + def _notify_update(self) -> None: |
| 204 | + """Notify all update listeners.""" |
| 205 | + self._update_callbacks(None) |
| 206 | + |
| 207 | + |
| 208 | +class DpsDataConverter: |
| 209 | + """Utility to handle the transformation and merging of DPS data into models. |
| 210 | +
|
| 211 | + This class pre-calculates the mapping between Data Point IDs and dataclass fields |
| 212 | + to optimize repeated updates from device streams. |
| 213 | + """ |
| 214 | + |
| 215 | + def __init__(self, dps_type_map: dict[RoborockDataProtocol, type], dps_field_map: dict[RoborockDataProtocol, str]): |
| 216 | + """Initialize the converter for a specific RoborockBase-derived class.""" |
| 217 | + self._dps_type_map = dps_type_map |
| 218 | + self._dps_field_map = dps_field_map |
| 219 | + |
| 220 | + @classmethod |
| 221 | + def from_dataclass(cls, dataclass_type: type[RoborockBase]): |
| 222 | + """Initialize the converter for a specific RoborockBase-derived class.""" |
| 223 | + dps_type_map: dict[RoborockDataProtocol, type] = {} |
| 224 | + dps_field_map: dict[RoborockDataProtocol, str] = {} |
| 225 | + for field_obj in fields(dataclass_type): |
| 226 | + if field_obj.metadata and "dps" in field_obj.metadata: |
| 227 | + dps_id = field_obj.metadata["dps"] |
| 228 | + dps_type_map[dps_id] = field_obj.type |
| 229 | + dps_field_map[dps_id] = field_obj.name |
| 230 | + return cls(dps_type_map, dps_field_map) |
| 231 | + |
| 232 | + def update_from_dps(self, target: RoborockBase, decoded_dps: dict[RoborockDataProtocol, Any]) -> bool: |
| 233 | + """Convert and merge raw DPS data into the target object. |
| 234 | +
|
| 235 | + Uses the pre-calculated type mapping to ensure values are converted to the |
| 236 | + correct Python types before being updated on the target. |
| 237 | +
|
| 238 | + Args: |
| 239 | + target: The target object to update. |
| 240 | + decoded_dps: The decoded DPS data to convert. |
| 241 | +
|
| 242 | + Returns: |
| 243 | + True if any values were updated, False otherwise. |
| 244 | + """ |
| 245 | + conversions = RoborockBase.convert_dict(self._dps_type_map, decoded_dps) |
| 246 | + for dps_id, value in conversions.items(): |
| 247 | + field_name = self._dps_field_map[dps_id] |
| 248 | + setattr(target, field_name, value) |
| 249 | + return bool(conversions) |
0 commit comments