From e50a1ba8df71388387d94dbec1f872ae94bd9106 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joris=20Pelgr=C3=B6m?= Date: Wed, 6 Aug 2025 21:55:18 +0200 Subject: [PATCH] Fix requires_feature decorator breaking named arguments --- letpot/deviceclient.py | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/letpot/deviceclient.py b/letpot/deviceclient.py index 2bcf3f3..c113c42 100644 --- a/letpot/deviceclient.py +++ b/letpot/deviceclient.py @@ -10,7 +10,7 @@ from datetime import time from functools import wraps from hashlib import md5, sha256 -from typing import Any, Callable, Concatenate +from typing import Any, Callable, ParamSpec, TypeVar, cast import aiomqtt @@ -32,32 +32,26 @@ _LOGGER = logging.getLogger(__name__) +T = TypeVar("T", bound="LetPotDeviceClient") +_R = TypeVar("_R") +P = ParamSpec("P") -def _create_ssl_context() -> ssl.SSLContext: - """Create a SSL context for the MQTT connection, avoids a blocking call later.""" - context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - context.load_default_certs() - return context - -_SSL_CONTEXT = _create_ssl_context() - - -def requires_feature[T: "LetPotDeviceClient", _R, **P]( +def requires_feature( *required_feature: DeviceFeature, ) -> Callable[ - [Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]]], - Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]], + [Callable[P, Coroutine[Any, Any, _R]]], + Callable[P, Coroutine[Any, Any, _R]], ]: """Decorate the function to require device type support for a specific feature (inferred from serial).""" def decorator( - func: Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]], - ) -> Callable[Concatenate[T, str, P], Coroutine[Any, Any, _R]]: + func: Callable[P, Coroutine[Any, Any, _R]], + ) -> Callable[P, Coroutine[Any, Any, _R]]: @wraps(func) - async def wrapper( - self: T, serial: str, *args: P.args, **kwargs: P.kwargs - ) -> _R: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> _R: + self = cast(LetPotDeviceClient, args[0]) + serial = cast(str, args[1]) exception_message = f"Device missing required feature: {required_feature}" try: supported_features = self._converter(serial).supported_features() @@ -67,13 +61,23 @@ async def wrapper( raise LetPotFeatureException(exception_message) except LetPotException: raise LetPotFeatureException(exception_message) - return await func(self, serial, *args, **kwargs) + return await func(*args, **kwargs) return wrapper return decorator +def _create_ssl_context() -> ssl.SSLContext: + """Create a SSL context for the MQTT connection, avoids a blocking call later.""" + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.load_default_certs() + return context + + +_SSL_CONTEXT = _create_ssl_context() + + class LetPotDeviceClient: """Client for connecting to LetPot device."""