diff --git a/bec_server/bec_server/scan_server/scans/scan_base.py b/bec_server/bec_server/scan_server/scans/scan_base.py index 453d83d18..42fb00d3e 100644 --- a/bec_server/bec_server/scan_server/scans/scan_base.py +++ b/bec_server/bec_server/scan_server/scans/scan_base.py @@ -8,7 +8,7 @@ import enum import threading from collections.abc import Sequence -from typing import TYPE_CHECKING, Annotated, Type +from typing import Annotated, Callable, Type import numpy as np import pint @@ -201,6 +201,7 @@ def __init__( self._premove_motor_status = None self.positions = np.array([]) self.start_positions = [] + self._scan_original_hooks = self._collect_original_scan_hooks() self._scan_modifier_hooks = ( get_scan_hooks_impl(scan_modifier) if scan_modifier is not None else {} ) @@ -263,3 +264,20 @@ def update_scan_info( setattr(self.scan_info, key, value) else: self.scan_info.additional_scan_parameters[key] = value + + def _collect_original_scan_hooks(self) -> dict[str, Callable]: + """ + Bind the undecorated scan hook implementations to this scan instance. + + Returns: + dict[str, Callable]: Mapping from hook name to the original bound method. + """ + original_hooks = {} + for attr_name in dir(type(self)): + attr = getattr(type(self), attr_name) + hook_info = getattr(attr, "_scan_hook_info", None) + original_func = getattr(attr, "_scan_hook_original", None) + if hook_info is None or original_func is None: + continue + original_hooks[hook_info["method_name"]] = original_func.__get__(self, type(self)) + return original_hooks diff --git a/bec_server/bec_server/scan_server/scans/scan_modifier.py b/bec_server/bec_server/scan_server/scans/scan_modifier.py index 406a9cfb4..f63367af0 100644 --- a/bec_server/bec_server/scan_server/scans/scan_modifier.py +++ b/bec_server/bec_server/scan_server/scans/scan_modifier.py @@ -1,5 +1,6 @@ from __future__ import annotations +from fnmatch import fnmatchcase from functools import wraps from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, get_args @@ -27,6 +28,44 @@ VALID_SCAN_HOOKS = set(get_args(ScanHookName)) +def _matches_scan_name(scan_name: str | None, patterns: list[str] | None) -> bool: + if not patterns: + return True + if scan_name is None: + return False + return any(fnmatchcase(scan_name, pattern) for pattern in patterns) + + +def _get_hook_method_name( + hook_name: str, + hook_info: dict[str, str | dict[str, str | list[str]] | list[str | dict[str, str | list[str]]]], + hook_type: str, + scan_name: str | None, +) -> str | None: + hook_config = hook_info.get(hook_type) + if hook_config is None: + return None + if isinstance(hook_config, list): + matched_method_names = [] + for config in hook_config: + if isinstance(config, str): + matched_method_names.append(config) + continue + if _matches_scan_name(scan_name, config.get("scan_names")): + matched_method_names.append(config["method_name"]) + if len(matched_method_names) > 1: + raise ValueError( + f"Multiple scan modifier implementations matched hook '{hook_name}' " + f"for lifecycle '{hook_type}' and scan '{scan_name}'" + ) + return matched_method_names[0] if matched_method_names else None + if isinstance(hook_config, str): + return hook_config + if not _matches_scan_name(scan_name, hook_config.get("scan_names")): + return None + return hook_config["method_name"] + + def scan_hook(func): """ Decorator for scan hooks. It registers the decorated method as a scan hook and thus allows @@ -46,41 +85,55 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) hook_info = self._scan_modifier_hooks[func.__name__] - if "before" in hook_info: - before_method = getattr(self._scan_modifier, hook_info["before"]) + scan_name = getattr( + getattr(self, "scan_info", None), "scan_name", getattr(self, "scan_name", None) + ) + + before_method_name = _get_hook_method_name(func.__name__, hook_info, "before", scan_name) + if before_method_name is not None: + before_method = getattr(self._scan_modifier, before_method_name) before_method(*args, **kwargs) - if "replace" in hook_info: - replace_method = getattr(self._scan_modifier, hook_info["replace"]) + replace_method_name = _get_hook_method_name(func.__name__, hook_info, "replace", scan_name) + if replace_method_name is not None: + replace_method = getattr(self._scan_modifier, replace_method_name) replace_method(*args, **kwargs) else: func(self, *args, **kwargs) - if "after" in hook_info: - after_method = getattr(self._scan_modifier, hook_info["after"]) + after_method_name = _get_hook_method_name(func.__name__, hook_info, "after", scan_name) + if after_method_name is not None: + after_method = getattr(self._scan_modifier, after_method_name) after_method(*args, **kwargs) return # pylint: disable=protected-access wrapper._scan_hook_info = {"method_name": func.__name__} # type: ignore + wrapper._scan_hook_original = func # type: ignore[attr-defined] return wrapper def scan_hook_impl( - hook_name: ScanHookName, hook_type: Literal["before", "after", "replace"] = "before" + hook_name: ScanHookName, + hook_type: Literal["before", "after", "replace"] = "before", + scan_names: list[str] | None = None, ): """ Decorator for scan hook implementations. It registers the decorated method as an implementation of the specified scan hook. The hook_name must refer to an existing scan hook. The hook_type should be one of the following: "before", "after" or "replace". + The optional scan_names list can be used to restrict the implementation to matching scan names. + Wildcards are supported using shell-style patterns such as ``*_line_scan``. This allows the scan modifier to specify whether the decorated method should be executed before, after or instead of the original scan hook method. """ if hook_name not in VALID_SCAN_HOOKS: raise ValueError(f"Invalid scan hook: {hook_name}") if hook_type not in {"before", "after", "replace"}: raise ValueError(f"Invalid scan hook type: {hook_type}") + if scan_names is not None and not isinstance(scan_names, list): + raise ValueError("scan_names must be a list of scan name patterns") def decorator(func): @wraps(func) @@ -88,20 +141,26 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) # pylint: disable=protected-access - wrapper._scan_hook_impl_info = {"hook_name": hook_name, "hook_type": hook_type} # type: ignore + wrapper._scan_hook_impl_info = { + "hook_name": hook_name, + "hook_type": hook_type, + "scan_names": scan_names, + } # type: ignore return wrapper return decorator -def get_scan_hooks_impl(cls) -> dict[str, dict[str, str]]: +def get_scan_hooks_impl( + cls, +) -> dict[ + str, dict[str, str | dict[str, str | list[str]] | list[str | dict[str, str | list[str]]]] +]: """ Get the scan hooks implemented by the given class. It returns a dictionary mapping the original scan hook names to the corresponding method names and hook types in the scan modifier. - Raises: - ValueError: If the class implements multiple hooks for the same hook_type (before, after, replace) for the same scan hook. """ hooks = {} for attr_name in dir(cls): @@ -112,11 +171,19 @@ def get_scan_hooks_impl(cls) -> dict[str, dict[str, str]]: hook_type = info["hook_type"] if hook_name not in hooks: hooks[hook_name] = {} - if hook_type in hooks[hook_name]: - raise ValueError( - f"Multiple implementations for the same hook type '{hook_type}' for the scan hook '{hook_name}' in class '{cls.__name__}'" - ) - hooks[hook_name][hook_type] = attr_name + scan_names = info.get("scan_names") + hook_config: str | dict[str, str | list[str]] + if scan_names is None: + hook_config = attr_name + else: + hook_config = {"method_name": attr_name, "scan_names": scan_names} + existing_hook_config = hooks[hook_name].get(hook_type) + if existing_hook_config is None: + hooks[hook_name][hook_type] = hook_config + elif isinstance(existing_hook_config, list): + existing_hook_config.append(hook_config) + else: + hooks[hook_name][hook_type] = [existing_hook_config, hook_config] return hooks @@ -225,3 +292,27 @@ def device_is_available(self, device: list[str] | str, check_enabled: bool = Tru if check_enabled and not self.dev[dev_name].enabled: return False return True + + def call_original(self, hook_name: ScanHookName, *args, **kwargs): + """ + Call the scan's original hook implementation directly, bypassing scan modifier dispatch. + + Args: + hook_name (ScanHookName): Name of the original scan hook to call. + *args: Positional arguments forwarded to the original hook. + **kwargs: Keyword arguments forwarded to the original hook. + + Returns: + Any: The return value of the original hook implementation. + + Raises: + AttributeError: If the scan does not expose an original implementation for the hook. + """ + original_hooks = getattr(self.scan, "_scan_original_hooks", {}) + try: + original_hook = original_hooks[hook_name] + except KeyError as exc: + raise AttributeError( + f"Scan {type(self.scan).__name__!r} does not expose an original hook for {hook_name!r}" + ) from exc + return original_hook(*args, **kwargs) diff --git a/bec_server/tests/tests_scan_server/test_scan_modifier.py b/bec_server/tests/tests_scan_server/test_scan_modifier.py index 46e300e4f..080acc6ac 100644 --- a/bec_server/tests/tests_scan_server/test_scan_modifier.py +++ b/bec_server/tests/tests_scan_server/test_scan_modifier.py @@ -1,9 +1,10 @@ from __future__ import annotations -from types import SimpleNamespace +from unittest import mock import pytest +from bec_server.scan_server.scans.scan_base import ScanBase from bec_server.scan_server.scans.scan_modifier import ( ScanModifier, get_scan_hooks_impl, @@ -22,14 +23,98 @@ def after_close_scan(self): pass +class _ScopedModifier: + @scan_hook_impl("at_each_point", "replace", ["_v4_test_scan_modifier"]) + def replace_exact_match(self, ind, pos): + return self.call_original("at_each_point", ind, pos) + + @scan_hook_impl("post_scan", "after", ["*_scan_modifier"]) + def after_wildcard_match(self): + pass + + +class _MultiScopedModifier(ScanModifier): + @scan_hook_impl("post_scan", "after", ["*_scan_modifier"]) + def after_first_wildcard_match(self): + pass + + @scan_hook_impl("post_scan", "after", ["_v4_grid_scan"]) + def after_grid_match(self): + pass + + +class _AmbiguousScanNameFilteringModifier(ScanModifier): + @scan_hook_impl("post_scan", "after", ["*_scan_modifier"]) + def after_wildcard_match(self): + self.scan.modifier_calls.append("after_wildcard_match") + + @scan_hook_impl("post_scan", "after", ["_v4_test_scan_modifier"]) + def after_exact_match(self): + self.scan.modifier_calls.append("after_exact_match") + + +class _TestScan(ScanBase): + scan_name = "_v4_test_scan_modifier" + scan_type = None + + def __init__(self, *args, **kwargs): + self.original_hook_calls = [] + self.after_hook_calls = [] + self.modifier_calls = [] + super().__init__(*args, **kwargs) + + @scan_hook + def at_each_point(self, ind, pos): + self.original_hook_calls.append((ind, pos)) + return "original-result" + + @scan_hook + def post_scan(self): + self.after_hook_calls.append("post_scan") + + +class _OriginalCallingModifier(ScanModifier): + @scan_hook_impl("at_each_point", "replace") + def replace_at_each_point(self, ind, pos): + return self.call_original("at_each_point", ind, pos) + + +class _ScanNameFilteringModifier(ScanModifier): + @scan_hook_impl("at_each_point", "replace", ["_v4_test_scan_modifier"]) + def replace_exact_match(self, ind, pos): + self.scan.modifier_calls.append(("replace_exact_match", ind, pos)) + return self.call_original("at_each_point", ind, pos) + + @scan_hook_impl("post_scan", "after", ["*_scan_modifier"]) + def after_wildcard_match(self): + self.scan.modifier_calls.append("after_wildcard_match") + self.scan.after_hook_calls.append("modifier:after_post_scan") + + +@pytest.fixture +def test_scan(device_manager, connected_connector): + return _TestScan( + scan_id="scan-id", + redis_connector=connected_connector, + device_manager=device_manager, + instruction_handler=mock.MagicMock(), + request_inputs={}, + system_config={}, + scan_modifier=None, + ) + + def test_scan_hook_marks_method_with_hook_info(): @scan_hook def prepare_scan(self): return "ok" - scan = SimpleNamespace(_scan_modifier_hooks={}, _scan_modifier=None) + scan = mock.MagicMock() + scan._scan_modifier_hooks = {} + scan._scan_modifier = None assert prepare_scan._scan_hook_info == {"method_name": "prepare_scan"} # type: ignore[attr-defined] + assert prepare_scan._scan_hook_original is prepare_scan.__wrapped__ # type: ignore[attr-defined] assert prepare_scan(scan) == "ok" @@ -42,6 +127,35 @@ def test_scan_hook_impl_registers_hook_metadata(): } +def test_scan_hook_impl_registers_scan_name_filters(): + hooks = get_scan_hooks_impl(_ScopedModifier) + + assert hooks == { + "at_each_point": { + "replace": { + "method_name": "replace_exact_match", + "scan_names": ["_v4_test_scan_modifier"], + } + }, + "post_scan": { + "after": {"method_name": "after_wildcard_match", "scan_names": ["*_scan_modifier"]} + }, + } + + +def test_scan_hook_impl_registers_multiple_same_lifecycle_filters(): + hooks = get_scan_hooks_impl(_MultiScopedModifier) + + assert hooks == { + "post_scan": { + "after": [ + {"method_name": "after_first_wildcard_match", "scan_names": ["*_scan_modifier"]}, + {"method_name": "after_grid_match", "scan_names": ["_v4_grid_scan"]}, + ] + } + } + + def test_scan_hook_impl_rejects_invalid_hook_name(): with pytest.raises(ValueError, match="Invalid scan hook"): scan_hook_impl("not_a_hook") # type: ignore[arg-type] @@ -52,13 +166,17 @@ def test_scan_hook_impl_rejects_invalid_hook_type(): scan_hook_impl("stage", "during") # type: ignore[arg-type] +def test_scan_hook_impl_rejects_non_list_scan_names(): + with pytest.raises(ValueError, match="scan_names must be a list"): + scan_hook_impl("stage", "before", "line_scan") # type: ignore[arg-type] + + def test_scan_modifier_device_is_available_checks_presence_and_enabled_state(): - scan = SimpleNamespace( - dev={"samx": SimpleNamespace(enabled=True), "samy": SimpleNamespace(enabled=False)}, - actions=None, - components=None, - scan_info=None, - ) + scan = mock.MagicMock() + scan.dev = {"samx": mock.MagicMock(enabled=True), "samy": mock.MagicMock(enabled=False)} + scan.actions = None + scan.components = None + scan.scan_info = None modifier = ScanModifier(scan) assert modifier.device_is_available("samx") is True @@ -66,3 +184,80 @@ def test_scan_modifier_device_is_available_checks_presence_and_enabled_state(): assert modifier.device_is_available("samy") is False assert modifier.device_is_available("samy", check_enabled=False) is True assert modifier.device_is_available(["samx", "samy"], check_enabled=False) is True + + +def test_scan_modifier_call_original_calls_bound_original_hook(test_scan): + test_scan._scan_modifier_hooks = get_scan_hooks_impl(_OriginalCallingModifier) + test_scan._scan_modifier = _OriginalCallingModifier(test_scan) + + result = test_scan.at_each_point(3, [1, 2]) + + assert result is None + assert test_scan.original_hook_calls == [(3, [1, 2])] + + +def test_scan_modifier_call_original_raises_for_missing_hook(test_scan): + modifier = ScanModifier(test_scan) + + with pytest.raises(AttributeError, match="does not expose an original hook"): + modifier.call_original("stage") + + +def test_scan_modifier_hook_filters_apply_for_exact_scan_name_match(test_scan): + test_scan._scan_modifier_hooks = get_scan_hooks_impl(_ScanNameFilteringModifier) + test_scan._scan_modifier = _ScanNameFilteringModifier(test_scan) + + test_scan.at_each_point(3, [1, 2]) + + assert test_scan.modifier_calls == [("replace_exact_match", 3, [1, 2])] + assert test_scan.original_hook_calls == [(3, [1, 2])] + + +def test_scan_modifier_hook_filters_skip_non_matching_scan_name(test_scan): + test_scan.scan_info.scan_name = "_v4_other_scan" + test_scan._scan_modifier_hooks = get_scan_hooks_impl(_ScanNameFilteringModifier) + test_scan._scan_modifier = _ScanNameFilteringModifier(test_scan) + + test_scan.at_each_point(3, [1, 2]) + + assert test_scan.modifier_calls == [] + assert test_scan.original_hook_calls == [(3, [1, 2])] + + +def test_scan_modifier_hook_filters_apply_wildcard_scan_name_match(test_scan): + test_scan._scan_modifier_hooks = get_scan_hooks_impl(_ScanNameFilteringModifier) + test_scan._scan_modifier = _ScanNameFilteringModifier(test_scan) + + test_scan.post_scan() + + assert test_scan.modifier_calls == ["after_wildcard_match"] + assert test_scan.after_hook_calls == ["post_scan", "modifier:after_post_scan"] + + +def test_scan_modifier_hook_filters_skip_non_matching_wildcard_scan_name(test_scan): + test_scan.scan_info.scan_name = "_v4_grid_scan" + test_scan._scan_modifier_hooks = get_scan_hooks_impl(_ScanNameFilteringModifier) + test_scan._scan_modifier = _ScanNameFilteringModifier(test_scan) + + test_scan.post_scan() + + assert test_scan.modifier_calls == [] + assert test_scan.after_hook_calls == ["post_scan"] + + +def test_scan_modifier_hook_filters_allow_disjoint_same_lifecycle_matches(test_scan): + test_scan.scan_info.scan_name = "_v4_grid_scan" + test_scan._scan_modifier_hooks = get_scan_hooks_impl(_MultiScopedModifier) + test_scan._scan_modifier = _MultiScopedModifier(test_scan) + + test_scan.post_scan() + + assert test_scan.after_hook_calls == ["post_scan"] + + +def test_scan_modifier_hook_filters_raise_on_ambiguous_match(test_scan): + test_scan._scan_modifier_hooks = get_scan_hooks_impl(_AmbiguousScanNameFilteringModifier) + test_scan._scan_modifier = _AmbiguousScanNameFilteringModifier(test_scan) + + with pytest.raises(ValueError, match="Multiple scan modifier implementations matched hook"): + test_scan.post_scan()