Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions letpot/deviceclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import time
from functools import wraps
from hashlib import md5, sha256
from typing import Any, Callable, Concatenate
from typing import Any, Callable, ParamSpec, TypeVar, cast

import aiomqtt

Expand All @@ -32,32 +32,26 @@

_LOGGER = logging.getLogger(__name__)

T = TypeVar("T", bound="LetPotDeviceClient")
_R = TypeVar("_R")
P = ParamSpec("P")

def _create_ssl_context() -> ssl.SSLContext:
"""Create a SSL context for the MQTT connection, avoids a blocking call later."""
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.load_default_certs()
return context


_SSL_CONTEXT = _create_ssl_context()


def requires_feature[T: "LetPotDeviceClient", _R, **P](
def requires_feature(
*required_feature: DeviceFeature,
) -> Callable[
[Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]]],
Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]],
[Callable[P, Coroutine[Any, Any, _R]]],
Callable[P, Coroutine[Any, Any, _R]],
]:
"""Decorate the function to require device type support for a specific feature (inferred from serial)."""

def decorator(
func: Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]],
) -> Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]]:
func: Callable[P, Coroutine[Any, Any, _R]],
) -> Callable[P, Coroutine[Any, Any, _R]]:
@wraps(func)
async def wrapper(
self: T, serial: str, *args: P.args, **kwargs: P.kwargs
) -> _R:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> _R:
self = cast(LetPotDeviceClient, args[0])
serial = cast(str, args[1])
exception_message = f"Device missing required feature: {required_feature}"
try:
supported_features = self._converter(serial).supported_features()
Expand All @@ -67,13 +61,23 @@ async def wrapper(
raise LetPotFeatureException(exception_message)
except LetPotException:
raise LetPotFeatureException(exception_message)
return await func(self, serial, *args, **kwargs)
return await func(*args, **kwargs)

return wrapper

return decorator


def _create_ssl_context() -> ssl.SSLContext:
"""Create a SSL context for the MQTT connection, avoids a blocking call later."""
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.load_default_certs()
return context


_SSL_CONTEXT = _create_ssl_context()


class LetPotDeviceClient:
"""Client for connecting to LetPot device."""

Expand Down