From c478c8bc618c0058d78be8ad44547db43d0a78a5 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 21 Jun 2026 13:10:31 -0400 Subject: [PATCH 1/3] Properly deal with backend kwargs --- serialx/__init__.py | 4 ++ serialx/async_serial.py | 11 ++-- serialx/common.py | 122 ++++++++++++++++++++++++++++++++++++---- tests/test_common.py | 85 ++++++++++++++++++++++++++++ 4 files changed, 208 insertions(+), 14 deletions(-) diff --git a/serialx/__init__.py b/serialx/__init__.py index 5d37c1c..d6d8bcf 100644 --- a/serialx/__init__.py +++ b/serialx/__init__.py @@ -8,8 +8,10 @@ open_serial_connection, ) from .common import ( + AllConnectKwargs, BaseSerial, BaseSerialTransport, + ConnectKwargs, ModemPins, Parity, PinState, @@ -57,8 +59,10 @@ "Parity", "PinState", "Platform", + "AllConnectKwargs", "BaseSerial", "BaseSerialTransport", + "ConnectKwargs", "Serial", "SerialException", "UnsupportedSetting", diff --git a/serialx/async_serial.py b/serialx/async_serial.py index f284aa8..d8d84c3 100644 --- a/serialx/async_serial.py +++ b/serialx/async_serial.py @@ -7,9 +7,10 @@ import logging from typing import Any, Generic, TypeVar, cast -from typing_extensions import Self +from typing_extensions import Self, Unpack from .common import ( + AllConnectKwargs, BaseSerialTransport, ModemPins, Parity, @@ -17,6 +18,7 @@ SerialException, StopBits, get_uri_handler, + route_backend_kwargs, ) LOGGER = logging.getLogger(__name__) @@ -41,7 +43,7 @@ def __init__( url: str | None, *, transport_cls: type[BaseSerialTransport] | None = None, - **kwargs: Any, + **kwargs: Unpack[AllConnectKwargs], ) -> None: """Initialize an unopened serial port. @@ -50,7 +52,7 @@ def __init__( """ self._url = url - self._connect_kwargs: dict[str, Any] = kwargs + self._connect_kwargs: dict[str, Any] = dict(kwargs) self._transport_cls = transport_cls self._reader: asyncio.StreamReader | None = None @@ -282,6 +284,7 @@ async def create_serial_connection( None, get_uri_handler, url ) resolved_cls = handler.async_transport_cls + kwargs = route_backend_kwargs(handler, kwargs) # pylint: disable=serialx-reassigned-parameter protocol = protocol_factory() transport = resolved_cls(loop=loop, protocol=protocol) @@ -322,7 +325,7 @@ def async_serial_for_url( url: str | None, *, transport_cls: type[BaseSerialTransport] | None = None, - **kwargs: Any, + **kwargs: Unpack[AllConnectKwargs], ) -> AsyncSerial: """Build an unopened AsyncSerial. Use `async with` or `await serial.open()`.""" return AsyncSerial(url, transport_cls=transport_cls, **kwargs) diff --git a/serialx/common.py b/serialx/common.py index d866ed3..d280e91 100644 --- a/serialx/common.py +++ b/serialx/common.py @@ -13,6 +13,7 @@ from enum import Enum import functools import io +import logging import os.path from pathlib import Path import time @@ -23,6 +24,8 @@ from typing_extensions import Buffer, Self, TypedDict, Unpack +LOGGER = logging.getLogger(__name__) + class Platform(str, Enum): """Built-in platform name.""" @@ -52,6 +55,7 @@ class RegisteredUriHandler: list_serial_ports_func: Callable[..., list[SerialPortInfo]] async_list_serial_ports_func: Callable[..., Awaitable[list[SerialPortInfo]]] strip_uri_scheme: bool + connect_kwargs: frozenset[str] = frozenset() class _RegistryEntry(NamedTuple): @@ -87,6 +91,7 @@ def register_uri_handler( ] = async_empty_port_list, weight: int = 1, strip_uri_scheme: bool = False, + connect_kwargs: frozenset[str] = frozenset(), ) -> Callable[[], None]: """Register a URI handler. @@ -110,6 +115,7 @@ def register_uri_handler( strip_uri_scheme: If ``True``, the leading ``scheme`` / ``unique_scheme`` is removed before the URL is passed to the sync class. Set this when the underlying class expects a bare device path rather than a URL. + connect_kwargs: Names of backend-specific connect kwargs this handler accepts. Returns: A callable that unregisters the handler. @@ -140,6 +146,7 @@ def register_uri_handler( list_serial_ports_func=list_serial_ports_func, async_list_serial_ports_func=async_list_serial_ports_func, strip_uri_scheme=strip_uri_scheme, + connect_kwargs=connect_kwargs, ), ) bisect.insort_right(_REGISTERED_URI_HANDLERS[scheme], item) @@ -165,6 +172,36 @@ def get_uri_handler(uri: str) -> RegisteredUriHandler: return handlers[-1].handler +def route_backend_kwargs( + handler: RegisteredUriHandler, kwargs: dict[str, Any] +) -> dict[str, Any]: + """Drop kwargs that belong to a different backend before dispatch.""" + all_backend_specific_kwargs: set[str] = set() + + for extras in BACKEND_CONNECT_KWARGS.values(): + all_backend_specific_kwargs |= extras + + for entries in _REGISTERED_URI_HANDLERS.values(): + for entry in entries: + all_backend_specific_kwargs |= entry.handler.connect_kwargs + + backend_specific_kwargs = ( + BACKEND_CONNECT_KWARGS.get(handler.unique_scheme, set()) + | handler.connect_kwargs + ) + + other_kwargs = all_backend_specific_kwargs - backend_specific_kwargs + dropped = other_kwargs & kwargs.keys() + + if not dropped: + return dict(kwargs) + + LOGGER.debug( + "Ignoring kwarg not accepted by %r backend: %s", handler.unique_scheme, dropped + ) + return {key: value for key, value in kwargs.items() if key not in dropped} + + class SerialException(Exception): """Base serial exception.""" @@ -195,18 +232,78 @@ class Parity(str, Enum): SPACE = "S" -class ConnectKwargs( # type: ignore[call-arg] # PEP 728 not in mypy yet - TypedDict, total=False, extra_items=Any -): - """Kwargs forwarded to BaseSerialTransport.connect / _connect.""" +class _CommonConnectKwargs(TypedDict, total=False): + """Connect kwargs accepted by every backend (see `BaseSerial.__init__`).""" baudrate: int - parity: Parity - stopbits: StopBits + parity: Parity | str | None + stopbits: StopBits | int | float xonxoff: bool rtscts: bool - exclusive: bool + dsrdtr: bool byte_size: int + read_timeout: float | None + write_timeout: float | None + rtsdtr_on_open: PinState + rtsdtr_on_close: PinState + exclusive: bool + + # pyserial compatibility kwargs + port: str | None + timeout: float | None + bytesize: int | None + writeTimeout: float | None + do_not_open: bool | None + + +class ConnectKwargs( # type: ignore[call-arg] # PEP 728 not in mypy yet + _CommonConnectKwargs, total=False, extra_items=Any +): + """Connect kwargs plumbed internally to `BaseSerialTransport.connect`.""" + + +class AllConnectKwargs(_CommonConnectKwargs, total=False): + """Every connect kwarg any built-in backend accepts, for typing.""" + + # linux:// + low_latency: bool + + # socket:// + tcp:// + rfc2217:// + esphome:// + connect_timeout: float | None + + # rfc2217:// + receive_buffer_size: int + + # windows:// + read_buffer_size: int + write_buffer_size: int + + # esphome:// + port_name: str | None + port_instance: int | None + key: str | None + password: str | None + noise_psk: str | None + + +# Backend-specific connect kwargs per unique URI scheme +BACKEND_CONNECT_KWARGS: dict[str, frozenset[str]] = { + "linux://": frozenset({"low_latency"}), + "windows://": frozenset({"read_buffer_size", "write_buffer_size"}), + "socket://": frozenset({"connect_timeout"}), + "tcp://": frozenset({"connect_timeout"}), + "rfc2217://": frozenset({"connect_timeout", "receive_buffer_size"}), + "esphome://": frozenset( + { + "connect_timeout", + "port_name", + "port_instance", + "key", + "password", + "noise_psk", + } + ), +} class PinState(Enum): @@ -397,7 +494,9 @@ def _check_broken(self) -> None: raise self._broken @classmethod - def from_url(cls, url: str, *args: Any, **kwargs: Any) -> BaseSerial: + def from_url( + cls, url: str, *args: Any, **kwargs: Unpack[AllConnectKwargs] + ) -> BaseSerial: """Create the appropriate serial port subclass for the given URL.""" handler = get_uri_handler(url) target = url @@ -405,7 +504,8 @@ def from_url(cls, url: str, *args: Any, **kwargs: Any) -> BaseSerial: target = url.removeprefix(handler.scheme).removeprefix( handler.unique_scheme ) - return handler.sync_cls(target, *args, **kwargs) + routed = route_backend_kwargs(handler, dict(kwargs)) + return handler.sync_cls(target, *args, **routed) @maybe_wrap_exceptions def open(self) -> None: @@ -1165,6 +1265,8 @@ async def async_list_serial_ports( return await handler.async_list_serial_ports_func(**kwargs) -def serial_for_url(url: str, *args: Any, **kwargs: Any) -> BaseSerial: +def serial_for_url( + url: str, *args: Any, **kwargs: Unpack[AllConnectKwargs] +) -> BaseSerial: """Create the appropriate serial port subclass for the given URL.""" return BaseSerial.from_url(url, *args, **kwargs) diff --git a/tests/test_common.py b/tests/test_common.py index b3d356b..67841b3 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -3,6 +3,8 @@ from __future__ import annotations from collections.abc import Generator +import logging +from typing import Any from unittest.mock import Mock import pytest @@ -16,11 +18,14 @@ ) from serialx.common import ( _REGISTERED_URI_HANDLERS, + BACKEND_CONNECT_KWARGS, + AllConnectKwargs, BaseSerial, BaseSerialTransport, SerialPortInfo, UnknownUriScheme, get_uri_handler, + route_backend_kwargs, ) @@ -128,6 +133,86 @@ def test_register_uri_handler_dispatch_and_unregister() -> None: get_uri_handler("test-shared-2://") +def test_backend_connect_kwargs_match_typed_dict() -> None: + """Every kwarg in the runtime table is also declared on `AllConnectKwargs`.""" + declared: set[str] = set() + + for extras in BACKEND_CONNECT_KWARGS.values(): + declared |= extras + + assert declared <= set(AllConnectKwargs.__optional_keys__) + + +class _StubSerial(BaseSerial): + """A concrete `BaseSerial` whose abstract methods are inert stubs.""" + + def _open(self) -> None: ... + def _close(self) -> None: ... + def _configure_port(self) -> None: ... + def _flush(self) -> None: ... + def _readinto(self, buf: Any) -> int: ... # type:ignore[empty-body, override] + def _write(self, data: Any) -> int: ... # type:ignore[empty-body, override] + def _reset_read_buffer(self) -> None: ... + def _reset_write_buffer(self) -> None: ... + def _get_modem_pins(self) -> Any: ... + def _set_modem_pins(self, modem_pins: Any) -> None: ... + + @property + def is_open(self) -> bool: ... # type:ignore[empty-body] + + @property + def num_unread_bytes(self) -> int: ... # type:ignore[empty-body, override] + + @property + def num_unwritten_bytes(self) -> int: ... # type:ignore[empty-body, override] + + +def test_serial_kwarg_forwarding(caplog: pytest.LogCaptureFixture) -> None: + """.""" + + class TestSerialBackend(_StubSerial): + def __init__(self, *args: Any, test_opt: int = 1, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + unregister = register_uri_handler( + scheme="test-backend://", + unique_scheme="test-backend://", + sync_cls=TestSerialBackend, + async_transport_cls=BaseSerialTransport, # type: ignore[type-abstract] + connect_kwargs=frozenset({"test_opt"}), + ) + + try: + handler = get_uri_handler("test-backend://") + + # The chosen backend's own kwarg and common kwargs are forwarded + assert route_backend_kwargs(handler, {"test_opt": 5, "baudrate": 115200}) == { + "test_opt": 5, + "baudrate": 115200, + } + + # A built-in backend-specific kwarg is dropped (not raised) for other backends + with caplog.at_level(logging.DEBUG, logger="serialx.common"): + assert route_backend_kwargs(handler, {"low_latency": False}) == {} + + assert "low_latency" in caplog.text + + # A typo isn't backend-specific, so it's forwarded for the backend to reject + assert route_backend_kwargs(handler, {"test_opt": 1}) == {"test_opt": 1} + + # `from_url` drops other-backend kwargs; a real typo raises from the backend + instance = BaseSerial.from_url( + "test-backend://dev", low_latency=False, baudrate=4800 + ) + assert isinstance(instance, TestSerialBackend) + assert instance.baudrate == 4800 + + with pytest.raises(TypeError, match="nonsense"): + BaseSerial.from_url("test-backend://dev", nonsense=1) # type: ignore[call-arg] + finally: + unregister() + + @pytest.mark.parametrize( ("port", "expected"), [ From 21fec5ad6bf2eaa223029f0ff223218d61bc2d33 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 21 Jun 2026 20:43:07 -0400 Subject: [PATCH 2/3] Fix typing --- serialx/common.py | 3 +++ serialx_compat/serial/__init__.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/serialx/common.py b/serialx/common.py index d280e91..de30f9e 100644 --- a/serialx/common.py +++ b/serialx/common.py @@ -22,6 +22,7 @@ import urllib.parse import warnings +from aioesphomeapi.client import APIClient from typing_extensions import Buffer, Self, TypedDict, Unpack LOGGER = logging.getLogger(__name__) @@ -279,6 +280,7 @@ class AllConnectKwargs(_CommonConnectKwargs, total=False): write_buffer_size: int # esphome:// + api: APIClient | None port_name: str | None port_instance: int | None key: str | None @@ -295,6 +297,7 @@ class AllConnectKwargs(_CommonConnectKwargs, total=False): "rfc2217://": frozenset({"connect_timeout", "receive_buffer_size"}), "esphome://": frozenset( { + "api", "connect_timeout", "port_name", "port_instance", diff --git a/serialx_compat/serial/__init__.py b/serialx_compat/serial/__init__.py index f8bb302..d1198ae 100644 --- a/serialx_compat/serial/__init__.py +++ b/serialx_compat/serial/__init__.py @@ -90,7 +90,7 @@ def __init__( @classmethod def from_url(cls, url: str, *args: Any, **kwargs: Any) -> BaseSerial: """Create the appropriate serial port subclass for the given URL.""" - return super().from_url(url, *args, _wrap_exceptions=True, **kwargs) + return super().from_url(url, *args, _wrap_exceptions=True, **kwargs) # type:ignore[call-arg] Serial = CompatSerial From faa1afd0137ad2a921e270393b1696dc4bd44b6f Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 21 Jun 2026 20:45:38 -0400 Subject: [PATCH 3/3] Clean up 3.10 typing --- serialx/common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/serialx/common.py b/serialx/common.py index de30f9e..3ed0d93 100644 --- a/serialx/common.py +++ b/serialx/common.py @@ -18,13 +18,15 @@ from pathlib import Path import time from types import TracebackType -from typing import Any, Concatenate, NamedTuple, ParamSpec, TypeVar, cast +from typing import TYPE_CHECKING, Any, Concatenate, NamedTuple, ParamSpec, TypeVar, cast import urllib.parse import warnings -from aioesphomeapi.client import APIClient from typing_extensions import Buffer, Self, TypedDict, Unpack +if TYPE_CHECKING: + from aioesphomeapi.client import APIClient + LOGGER = logging.getLogger(__name__)