Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bec_lib/bec_lib/bl_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,18 @@ def __init__(self, client: BECClient) -> None:
self._connector.register(
MessageEndpoints.available_beamline_states(), cb=self._on_state_update, from_start=True
)
self._ready = False

@property
def ready(self) -> bool:
"""Returns true after beamline states have been loaded from Redis."""
return self._ready

def _on_state_update(self, msg_dict: dict, **_kwargs) -> None:
# type: ignore ; we know it's an AvailableBeamlineStatesMessage
msg: messages.AvailableBeamlineStatesMessage = msg_dict["data"]
self._update_states(msg.states) # pylint: disable=protected-access
self._ready = True

def _update_state(self, state: BeamlineStateConfig) -> None:
if state.name in self._states:
Expand Down
4 changes: 2 additions & 2 deletions bec_lib/bec_lib/builtin_actor_hli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def states_watched(self) -> dict[str, BlStateStatus]:
return msg.states_watched
return {}

def add_state_to_interlock(self, state_name: str, required_value: BlStateStatus):
def add_state_to_interlock(self, state_name: str, required_value: BlStateStatus = "valid"):
"""
Add a beamline state and its status to watch to the ScanInterlockActor. If the state no
longer has this status, an interlock will be placed on the primary scan queue.
Args:
state_name (str): the state to watch
status (Literal["valid","invalid","warning","unknown"]): the status to watch for.
status (Literal["valid","invalid","warning","unknown"]): the status to watch for. Defaults to "valid".
"""
self._client.connector.xadd(
MessageEndpoints.modify_interlock_table(),
Expand Down
18 changes: 14 additions & 4 deletions bec_server/bec_server/actors/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def evaluate(self, *_, **__):
return
if (now := time.monotonic()) < self.last_evaluated + self.min_delay_s:
return
logger.info(f"{self.__class__.__name__} triggered")
logger.debug(f"{self.__class__.__name__} evaluated")
self.last_evaluated = now
return super().evaluate(*_, **__)

Expand All @@ -99,6 +99,7 @@ def run(self):
self.push_status(ProcedureWorkerStatus.IDLE)

def stop(self, *_):
"""Stop the actor and cleanup subscriptions."""
self._stopped = True
for endpoint in self._endpoints:
for cb in self.default_monitor_callbacks():
Expand All @@ -108,6 +109,7 @@ def stop(self, *_):
logger.error(
f"{self.__class__} {self.__qualname__} failed to unregister {cb} from {endpoint}: {e}"
)
self.stop_event.set()


class BlStateActor(SubscriptionActor):
Expand All @@ -128,8 +130,6 @@ def __init__(self, client: BECClient, name: str, exec_id: str):
}
super().__init__(client, name, exec_id)
self.state_cache: dict[str, BlStateStatus] = {}
self._update_cache()
self.evaluate()

def _update_cache(self):
with self.state_table_lock:
Expand All @@ -142,7 +142,9 @@ def _update_cache(self):
continue
self.state_cache[state] = status
for state in to_remove:
logger.warning(f"Removing {state} from watched states.")
logger.warning(
f"Removing {state} from watched states because it no longer seems to exist."
)
del self.state_table[state]

def all_states_match(self, client: BECClient):
Expand All @@ -162,6 +164,14 @@ def all_match_action(self, client: BECClient):
def some_mismatch_action(self, client: BECClient):
pass

def run(self):
while not self.client.beamline_states.ready and not self.stop_event.set():
logger.warning(f"{self.__class__.__name__} waiting for beamline states to become ready")
time.sleep(0.1)
self._update_cache()
self.evaluate()
return super().run()

def default_monitor_endpoints(self) -> set[EndpointInfo]:
return {MessageEndpoints.beamline_state(state) for state in self.state_table}

Expand Down
52 changes: 35 additions & 17 deletions bec_server/bec_server/actors/builtin_actor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from bec_lib.messages import (
AvailableBeamlineStatesMessage,
BlStateStatus,
BuiltinActorStateChangeNotification,
BuiltinActorStateUpdatedNotification,
ScanInterlockModifyStateTableMessage,
Expand Down Expand Up @@ -52,6 +54,9 @@ def __init__(self, bootstrap_server: str) -> None:
self._client.connector.register(
MessageEndpoints.modify_interlock_table(), cb=self._modify_interlock_table
)
self._client.connector.register(
MessageEndpoints.available_beamline_states(), cb=self._handle_state_update
)

def _ping_clients(self, actor_name: str):
self._client.connector.send(
Expand Down Expand Up @@ -105,6 +110,31 @@ def shutdown(self):
self._client.shutdown()

# Actor specific management methods:
def _set_interlock_states_in_redis(self, states: dict[str, BlStateStatus]):
self._client.connector.set(
MessageEndpoints.scan_interlock_states(),
ScanInterlockStateTableContent(states_watched=states),
)

def _current_watched_states(self) -> dict[str, BlStateStatus]:
states: ScanInterlockStateTableContent | None = self._client.connector.get(
MessageEndpoints.scan_interlock_states()
)
return states.states_watched if states is not None else {}

def _handle_state_update(self, msg_dict: dict):
msg: AvailableBeamlineStatesMessage = msg_dict["data"]
state_names = [state.name for state in msg.states]
for watched_state in self._current_watched_states():
if watched_state not in state_names:
self._modify_interlock_table(
{
"data": ScanInterlockModifyStateTableMessage(
action="remove", state_name=watched_state
)
}
)

def _modify_interlock_table(self, msg_dict):
"""Update the watched states for ScanInterlockActor - handled by the actor itself if it is
active, otherwise just the config in redis is updated."""
Expand All @@ -113,27 +143,15 @@ def _modify_interlock_table(self, msg_dict):
actor, _, _ = ats
actor._on_state_modification(msg)
else:
states: ScanInterlockStateTableContent | None = self._client.connector.get(
MessageEndpoints.scan_interlock_states()
)
current_watched = states.states_watched if states is not None else {}
if msg.action == "add":
current_watched = self._current_watched_states()
if msg.action == "add" and msg.state_name not in current_watched:
logger.info(f"Adding {msg.state_name} to the scan interlock actor")
current_watched[msg.state_name] = msg.status
self._client.connector.set(
MessageEndpoints.scan_interlock_states(),
ScanInterlockStateTableContent(states_watched=current_watched),
)
self._set_interlock_states_in_redis(current_watched)
elif msg.action == "remove_all":
self._client.connector.set(
MessageEndpoints.scan_interlock_states(),
ScanInterlockStateTableContent(states_watched={}),
)
self._set_interlock_states_in_redis({})
else:
logger.info(f"Removing {msg.state_name} from the scan interlock actor")
current_watched.pop(msg.state_name, None)
self._client.connector.set(
MessageEndpoints.scan_interlock_states(),
ScanInterlockStateTableContent(states_watched=current_watched),
)
self._set_interlock_states_in_redis(current_watched)
self._ping_clients("ScanInterlockActor")
7 changes: 5 additions & 2 deletions bec_server/bec_server/actors/scan_interlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, client: BECClient, name: str, exec_id: str):
super().__init__(client, name, exec_id)

def _ping_clients(self):
logger.warning(self.name)
logger.debug(f"{self.name} pinging clients that it was updated")
self.client.connector.send(
MessageEndpoints.builtin_actor_update_notif(self.name),
BuiltinActorStateUpdatedNotification(actor_name=self.name),
Expand Down Expand Up @@ -97,7 +97,10 @@ def all_match_action(self, client: BECClient):
def _unlock(self):
if self.client.queue is None:
return
self.client.queue.remove_queue_lock(queue="primary", lock_id=self._LOCK_ID)
if (q := self.client.queue) is not None:
if (curr_q := q.queue_storage.current_scan_queue) is not None:
if (primary := curr_q.get("primary")) is not None and primary.locks != []:
self.client.queue.remove_queue_lock(queue="primary", lock_id=self._LOCK_ID)

def run(self):
super().run()
Expand Down
41 changes: 40 additions & 1 deletion bec_server/tests/tests_actors/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bec_lib.endpoints import MessageEndpoints
from bec_lib.messages import ActorStartRequestMessage, ProcedureWorkerStatus, RawMessage
from bec_lib.redis_connector import MessageObject, RedisConnector
from bec_server.actors.actor import ActorBase
from bec_server.actors.actor import ActorBase, BlStateActor
from bec_server.actors.manager import ActorManager
from bec_server.actors.worker import actor_procedure
from bec_server.procedures.constants import BecClientType
Expand Down Expand Up @@ -173,3 +173,42 @@ def test_actor_procedure_logs_error_not_actor():
with patch("bec_server.actors.worker.logger") as logger:
actor_procedure("bec_server.test.actor_test_utils", "EndpointInfo", "test", MagicMock())
assert "is not a valid Actor!" in logger.error.call_args.args[0]


class BlStateTestActor(BlStateActor):
state_table = {"test_state": "valid", "test_state_2": "valid"}


def test_blstateactor_init_table_and_cache():
mock_client = MagicMock()

def get_status_by_name(name: str):
if name == "test_state":
return "valid"

mock_client.beamline_states.get_status_by_name.side_effect = get_status_by_name
actor = BlStateTestActor(mock_client, "Test", "Test")
actor.stop_event.set()
actor.run()

assert actor.state_table == {"test_state": "valid"}
assert actor.state_cache == {"test_state": "valid"}


def test_bl_state_actor_waits_for_states():
mock_client = MagicMock()

mock_client.beamline_states.ready = False
actor = BlStateTestActor(mock_client, "Test", "Test")
actor.evaluate = MagicMock()
with patch("bec_server.actors.actor.logger") as mock_logger:
t = Thread(target=actor.run)
t.start()
sleep(0.1)
mock_logger.warning.assert_called()
actor.evaluate.assert_not_called()
mock_client.beamline_states.ready = True
sleep(0.2)
actor.stop_event.set()
t.join()
actor.evaluate.assert_called()
46 changes: 45 additions & 1 deletion bec_server/tests/tests_scan_server/test_builtin_actor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

import pytest

from bec_lib.messages import (
AvailableBeamlineStatesMessage,
ScanInterlockModifyStateTableMessage,
ScanInterlockStateTableContent,
)
from bec_server.actors.builtin_actor_manager import BuiltinActorManager


Expand Down Expand Up @@ -38,7 +43,7 @@ def test_init_registers_callback(mocked_manager):

mock_client.start.assert_called_once()

assert mock_client.connector.register.call_count == 2
assert mock_client.connector.register.call_count == 3

kwargs = mock_client.connector.register.call_args_list[0].kwargs
assert "cb" in kwargs
Expand Down Expand Up @@ -148,3 +153,42 @@ def test_shutdown_stops_all_and_shuts_down_client(mocked_manager):
assert mock_stop.call_count == 2

mock_client.shutdown.assert_called_once()


def _req_msg(action, state_name, status):
return {
"data": ScanInterlockModifyStateTableMessage(
action=action, state_name=state_name, status=status
)
}


INITIAL_STATES = {"initial_state_1": "valid", "initial_state_2": "valid"}


@pytest.mark.parametrize(
["modification_request", "new_watched_states"],
[
(_req_msg("add", "test_state", "valid"), {**INITIAL_STATES, "test_state": "valid"}),
(_req_msg("remove", "initial_state_1", None), {"initial_state_2": "valid"}),
(_req_msg("remove_all", None, None), {}),
(_req_msg("remove", "missing_state", None), INITIAL_STATES),
],
)
def test_modify_interlock_table(mocked_manager, modification_request, new_watched_states):
manager, mock_client = mocked_manager
mock_client.connector.get.return_value = ScanInterlockStateTableContent(
states_watched=INITIAL_STATES
)
manager._modify_interlock_table(modification_request)
assert mock_client.connector.set.call_args.args[1].states_watched == new_watched_states


def test_handle_state_update(mocked_manager):
manager, _ = mocked_manager
manager._current_watched_states = lambda: {"test_state": "valid"}
manager._modify_interlock_table = MagicMock()
manager._handle_state_update({"data": AvailableBeamlineStatesMessage(states=[])})
manager._modify_interlock_table.assert_called_with(
{"data": ScanInterlockModifyStateTableMessage(action="remove", state_name="test_state")}
)
Loading