From c06b5c1eae89cc7c93c50e169d0fd8d90fb7c407 Mon Sep 17 00:00:00 2001 From: Rob Berwick Date: Sat, 15 Feb 2025 11:07:59 +0000 Subject: [PATCH] feat: implement backend connection checks and add no_backend_required decorator --- src/blinkstick/clients/blinkstick.py | 17 ++++++++++++++++- src/blinkstick/decorators.py | 15 +++++++++++++++ src/blinkstick/exceptions.py | 7 +++++++ tests/clients/test_blinkstick.py | 18 ++++++++++++++++++ 4 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 src/blinkstick/decorators.py diff --git a/src/blinkstick/clients/blinkstick.py b/src/blinkstick/clients/blinkstick.py index 14fedf5..ca3d4e8 100644 --- a/src/blinkstick/clients/blinkstick.py +++ b/src/blinkstick/clients/blinkstick.py @@ -12,8 +12,10 @@ remap_rgb_value_reverse, ColorFormat, ) -from blinkstick.enums import BlinkStickVariant +from blinkstick.decorators import no_backend_required from blinkstick.devices import BlinkStickDevice +from blinkstick.enums import BlinkStickVariant +from blinkstick.exceptions import NotConnected from blinkstick.utilities import string_to_info_block_data if sys.platform == "win32": @@ -63,6 +65,19 @@ def __init__( self.backend = USBBackend(device) self.bs_serial = self.get_serial() + def __getattribute__(self, name): + """Default all callables to require a backend unless they have the no_backend_required attribute""" + attr = object.__getattribute__(self, name) + if callable(attr) and not getattr(attr, "no_backend_required", False): + + def wrapper(*args, **kwargs): + if self.backend is None: + raise NotConnected("No backend set") + return attr(*args, **kwargs) + + return wrapper + return attr + def get_serial(self) -> str: """ Returns the serial number of backend.:: diff --git a/src/blinkstick/decorators.py b/src/blinkstick/decorators.py new file mode 100644 index 0000000..beefd8e --- /dev/null +++ b/src/blinkstick/decorators.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from functools import wraps + + +def no_backend_required(func): + """no-op decorator to mark a function as requiring a backend. See BlinkStick.__getattribute__ for usage.""" + + func.no_backend_required = True + + @wraps(func) + def wrapper(self, *args, **kwargs): + return func(self, *args, **kwargs) + + return wrapper diff --git a/src/blinkstick/exceptions.py b/src/blinkstick/exceptions.py index e1f04c6..86f6f51 100644 --- a/src/blinkstick/exceptions.py +++ b/src/blinkstick/exceptions.py @@ -1,2 +1,9 @@ +from __future__ import annotations + + class BlinkStickException(Exception): pass + + +class NotConnected(BlinkStickException): + pass diff --git a/tests/clients/test_blinkstick.py b/tests/clients/test_blinkstick.py index eb7d007..6ac420f 100644 --- a/tests/clients/test_blinkstick.py +++ b/tests/clients/test_blinkstick.py @@ -7,14 +7,32 @@ from blinkstick.clients.blinkstick import BlinkStick from pytest_mock import MockFixture +from blinkstick.exceptions import NotConnected from tests.conftest import make_blinkstick def test_instantiate(): + """Test that we can instantiate a BlinkStick object.""" bs = BlinkStick() assert bs is not None +def test_all_methods_require_backend(make_blinkstick): + """Test that all methods require a backend.""" + bs = make_blinkstick() + bs.backend = None # noqa + + class_methods = ( + method + for method in dir(BlinkStick) + if callable(getattr(bs, method)) and not method.startswith("__") + ) + for method_name in class_methods: + method = getattr(bs, method_name) + with pytest.raises(NotConnected): + method() + + @pytest.mark.parametrize( "serial, version_attribute, expected_variant, expected_variant_value", [