diff --git a/bec_lib/bec_lib/bl_state_manager.py b/bec_lib/bec_lib/bl_state_manager.py index ae2a5d364..0d6684c4b 100644 --- a/bec_lib/bec_lib/bl_state_manager.py +++ b/bec_lib/bec_lib/bl_state_manager.py @@ -114,11 +114,18 @@ def __init__(self, client: BECClient) -> None: self._connector.register( MessageEndpoints.available_beamline_states(), cb=self._on_state_update, from_start=True ) + self._ready = False + + @property + def ready(self) -> bool: + """Returns true after beamline states have been loaded from Redis.""" + return self._ready def _on_state_update(self, msg_dict: dict, **_kwargs) -> None: # type: ignore ; we know it's an AvailableBeamlineStatesMessage msg: messages.AvailableBeamlineStatesMessage = msg_dict["data"] self._update_states(msg.states) # pylint: disable=protected-access + self._ready = True def _update_state(self, state: BeamlineStateConfig) -> None: if state.name in self._states: diff --git a/bec_lib/bec_lib/builtin_actor_hli.py b/bec_lib/bec_lib/builtin_actor_hli.py index 2a558723d..d73ec2eeb 100644 --- a/bec_lib/bec_lib/builtin_actor_hli.py +++ b/bec_lib/bec_lib/builtin_actor_hli.py @@ -41,13 +41,13 @@ def states_watched(self) -> dict[str, BlStateStatus]: return msg.states_watched return {} - def add_state_to_interlock(self, state_name: str, required_value: BlStateStatus): + def add_state_to_interlock(self, state_name: str, required_value: BlStateStatus = "valid"): """ Add a beamline state and its status to watch to the ScanInterlockActor. If the state no longer has this status, an interlock will be placed on the primary scan queue. Args: state_name (str): the state to watch - status (Literal["valid","invalid","warning","unknown"]): the status to watch for. + status (Literal["valid","invalid","warning","unknown"]): the status to watch for. Defaults to "valid". """ self._client.connector.xadd( MessageEndpoints.modify_interlock_table(), diff --git a/bec_server/bec_server/actors/actor.py b/bec_server/bec_server/actors/actor.py index 896cda9e1..92178c8c6 100644 --- a/bec_server/bec_server/actors/actor.py +++ b/bec_server/bec_server/actors/actor.py @@ -86,7 +86,7 @@ def evaluate(self, *_, **__): return if (now := time.monotonic()) < self.last_evaluated + self.min_delay_s: return - logger.info(f"{self.__class__.__name__} triggered") + logger.debug(f"{self.__class__.__name__} evaluated") self.last_evaluated = now return super().evaluate(*_, **__) @@ -99,6 +99,7 @@ def run(self): self.push_status(ProcedureWorkerStatus.IDLE) def stop(self, *_): + """Stop the actor and cleanup subscriptions.""" self._stopped = True for endpoint in self._endpoints: for cb in self.default_monitor_callbacks(): @@ -108,6 +109,7 @@ def stop(self, *_): logger.error( f"{self.__class__} {self.__qualname__} failed to unregister {cb} from {endpoint}: {e}" ) + self.stop_event.set() class BlStateActor(SubscriptionActor): @@ -128,8 +130,6 @@ def __init__(self, client: BECClient, name: str, exec_id: str): } super().__init__(client, name, exec_id) self.state_cache: dict[str, BlStateStatus] = {} - self._update_cache() - self.evaluate() def _update_cache(self): with self.state_table_lock: @@ -142,7 +142,9 @@ def _update_cache(self): continue self.state_cache[state] = status for state in to_remove: - logger.warning(f"Removing {state} from watched states.") + logger.warning( + f"Removing {state} from watched states because it no longer seems to exist." + ) del self.state_table[state] def all_states_match(self, client: BECClient): @@ -162,6 +164,14 @@ def all_match_action(self, client: BECClient): def some_mismatch_action(self, client: BECClient): pass + def run(self): + while not self.client.beamline_states.ready and not self.stop_event.set(): + logger.warning(f"{self.__class__.__name__} waiting for beamline states to become ready") + time.sleep(0.1) + self._update_cache() + self.evaluate() + return super().run() + def default_monitor_endpoints(self) -> set[EndpointInfo]: return {MessageEndpoints.beamline_state(state) for state in self.state_table} diff --git a/bec_server/bec_server/actors/builtin_actor_manager.py b/bec_server/bec_server/actors/builtin_actor_manager.py index 5a5a12389..73fd8de84 100644 --- a/bec_server/bec_server/actors/builtin_actor_manager.py +++ b/bec_server/bec_server/actors/builtin_actor_manager.py @@ -6,6 +6,8 @@ from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from bec_lib.messages import ( + AvailableBeamlineStatesMessage, + BlStateStatus, BuiltinActorStateChangeNotification, BuiltinActorStateUpdatedNotification, ScanInterlockModifyStateTableMessage, @@ -52,6 +54,9 @@ def __init__(self, bootstrap_server: str) -> None: self._client.connector.register( MessageEndpoints.modify_interlock_table(), cb=self._modify_interlock_table ) + self._client.connector.register( + MessageEndpoints.available_beamline_states(), cb=self._handle_state_update + ) def _ping_clients(self, actor_name: str): self._client.connector.send( @@ -105,6 +110,31 @@ def shutdown(self): self._client.shutdown() # Actor specific management methods: + def _set_interlock_states_in_redis(self, states: dict[str, BlStateStatus]): + self._client.connector.set( + MessageEndpoints.scan_interlock_states(), + ScanInterlockStateTableContent(states_watched=states), + ) + + def _current_watched_states(self) -> dict[str, BlStateStatus]: + states: ScanInterlockStateTableContent | None = self._client.connector.get( + MessageEndpoints.scan_interlock_states() + ) + return states.states_watched if states is not None else {} + + def _handle_state_update(self, msg_dict: dict): + msg: AvailableBeamlineStatesMessage = msg_dict["data"] + state_names = [state.name for state in msg.states] + for watched_state in self._current_watched_states(): + if watched_state not in state_names: + self._modify_interlock_table( + { + "data": ScanInterlockModifyStateTableMessage( + action="remove", state_name=watched_state + ) + } + ) + def _modify_interlock_table(self, msg_dict): """Update the watched states for ScanInterlockActor - handled by the actor itself if it is active, otherwise just the config in redis is updated.""" @@ -113,27 +143,15 @@ def _modify_interlock_table(self, msg_dict): actor, _, _ = ats actor._on_state_modification(msg) else: - states: ScanInterlockStateTableContent | None = self._client.connector.get( - MessageEndpoints.scan_interlock_states() - ) - current_watched = states.states_watched if states is not None else {} - if msg.action == "add": + current_watched = self._current_watched_states() + if msg.action == "add" and msg.state_name not in current_watched: logger.info(f"Adding {msg.state_name} to the scan interlock actor") current_watched[msg.state_name] = msg.status - self._client.connector.set( - MessageEndpoints.scan_interlock_states(), - ScanInterlockStateTableContent(states_watched=current_watched), - ) + self._set_interlock_states_in_redis(current_watched) elif msg.action == "remove_all": - self._client.connector.set( - MessageEndpoints.scan_interlock_states(), - ScanInterlockStateTableContent(states_watched={}), - ) + self._set_interlock_states_in_redis({}) else: logger.info(f"Removing {msg.state_name} from the scan interlock actor") current_watched.pop(msg.state_name, None) - self._client.connector.set( - MessageEndpoints.scan_interlock_states(), - ScanInterlockStateTableContent(states_watched=current_watched), - ) + self._set_interlock_states_in_redis(current_watched) self._ping_clients("ScanInterlockActor") diff --git a/bec_server/bec_server/actors/scan_interlock.py b/bec_server/bec_server/actors/scan_interlock.py index e48d71c0f..86a0952e4 100644 --- a/bec_server/bec_server/actors/scan_interlock.py +++ b/bec_server/bec_server/actors/scan_interlock.py @@ -28,7 +28,7 @@ def __init__(self, client: BECClient, name: str, exec_id: str): super().__init__(client, name, exec_id) def _ping_clients(self): - logger.warning(self.name) + logger.debug(f"{self.name} pinging clients that it was updated") self.client.connector.send( MessageEndpoints.builtin_actor_update_notif(self.name), BuiltinActorStateUpdatedNotification(actor_name=self.name), @@ -97,7 +97,10 @@ def all_match_action(self, client: BECClient): def _unlock(self): if self.client.queue is None: return - self.client.queue.remove_queue_lock(queue="primary", lock_id=self._LOCK_ID) + if (q := self.client.queue) is not None: + if (curr_q := q.queue_storage.current_scan_queue) is not None: + if (primary := curr_q.get("primary")) is not None and primary.locks != []: + self.client.queue.remove_queue_lock(queue="primary", lock_id=self._LOCK_ID) def run(self): super().run() diff --git a/bec_server/tests/tests_actors/test_actors.py b/bec_server/tests/tests_actors/test_actors.py index a91d5b8b4..361816d2f 100644 --- a/bec_server/tests/tests_actors/test_actors.py +++ b/bec_server/tests/tests_actors/test_actors.py @@ -9,7 +9,7 @@ from bec_lib.endpoints import MessageEndpoints from bec_lib.messages import ActorStartRequestMessage, ProcedureWorkerStatus, RawMessage from bec_lib.redis_connector import MessageObject, RedisConnector -from bec_server.actors.actor import ActorBase +from bec_server.actors.actor import ActorBase, BlStateActor from bec_server.actors.manager import ActorManager from bec_server.actors.worker import actor_procedure from bec_server.procedures.constants import BecClientType @@ -173,3 +173,42 @@ def test_actor_procedure_logs_error_not_actor(): with patch("bec_server.actors.worker.logger") as logger: actor_procedure("bec_server.test.actor_test_utils", "EndpointInfo", "test", MagicMock()) assert "is not a valid Actor!" in logger.error.call_args.args[0] + + +class BlStateTestActor(BlStateActor): + state_table = {"test_state": "valid", "test_state_2": "valid"} + + +def test_blstateactor_init_table_and_cache(): + mock_client = MagicMock() + + def get_status_by_name(name: str): + if name == "test_state": + return "valid" + + mock_client.beamline_states.get_status_by_name.side_effect = get_status_by_name + actor = BlStateTestActor(mock_client, "Test", "Test") + actor.stop_event.set() + actor.run() + + assert actor.state_table == {"test_state": "valid"} + assert actor.state_cache == {"test_state": "valid"} + + +def test_bl_state_actor_waits_for_states(): + mock_client = MagicMock() + + mock_client.beamline_states.ready = False + actor = BlStateTestActor(mock_client, "Test", "Test") + actor.evaluate = MagicMock() + with patch("bec_server.actors.actor.logger") as mock_logger: + t = Thread(target=actor.run) + t.start() + sleep(0.1) + mock_logger.warning.assert_called() + actor.evaluate.assert_not_called() + mock_client.beamline_states.ready = True + sleep(0.2) + actor.stop_event.set() + t.join() + actor.evaluate.assert_called() diff --git a/bec_server/tests/tests_scan_server/test_builtin_actor_manager.py b/bec_server/tests/tests_scan_server/test_builtin_actor_manager.py index 204a23257..f9d2fe4c2 100644 --- a/bec_server/tests/tests_scan_server/test_builtin_actor_manager.py +++ b/bec_server/tests/tests_scan_server/test_builtin_actor_manager.py @@ -3,6 +3,11 @@ import pytest +from bec_lib.messages import ( + AvailableBeamlineStatesMessage, + ScanInterlockModifyStateTableMessage, + ScanInterlockStateTableContent, +) from bec_server.actors.builtin_actor_manager import BuiltinActorManager @@ -38,7 +43,7 @@ def test_init_registers_callback(mocked_manager): mock_client.start.assert_called_once() - assert mock_client.connector.register.call_count == 2 + assert mock_client.connector.register.call_count == 3 kwargs = mock_client.connector.register.call_args_list[0].kwargs assert "cb" in kwargs @@ -148,3 +153,42 @@ def test_shutdown_stops_all_and_shuts_down_client(mocked_manager): assert mock_stop.call_count == 2 mock_client.shutdown.assert_called_once() + + +def _req_msg(action, state_name, status): + return { + "data": ScanInterlockModifyStateTableMessage( + action=action, state_name=state_name, status=status + ) + } + + +INITIAL_STATES = {"initial_state_1": "valid", "initial_state_2": "valid"} + + +@pytest.mark.parametrize( + ["modification_request", "new_watched_states"], + [ + (_req_msg("add", "test_state", "valid"), {**INITIAL_STATES, "test_state": "valid"}), + (_req_msg("remove", "initial_state_1", None), {"initial_state_2": "valid"}), + (_req_msg("remove_all", None, None), {}), + (_req_msg("remove", "missing_state", None), INITIAL_STATES), + ], +) +def test_modify_interlock_table(mocked_manager, modification_request, new_watched_states): + manager, mock_client = mocked_manager + mock_client.connector.get.return_value = ScanInterlockStateTableContent( + states_watched=INITIAL_STATES + ) + manager._modify_interlock_table(modification_request) + assert mock_client.connector.set.call_args.args[1].states_watched == new_watched_states + + +def test_handle_state_update(mocked_manager): + manager, _ = mocked_manager + manager._current_watched_states = lambda: {"test_state": "valid"} + manager._modify_interlock_table = MagicMock() + manager._handle_state_update({"data": AvailableBeamlineStatesMessage(states=[])}) + manager._modify_interlock_table.assert_called_with( + {"data": ScanInterlockModifyStateTableMessage(action="remove", state_name="test_state")} + )