diff --git a/oocana/oocana/mainframe.py b/oocana/oocana/mainframe.py index 3f7083e..b0d7165 100644 --- a/oocana/oocana/mainframe.py +++ b/oocana/oocana/mainframe.py @@ -128,47 +128,86 @@ def on_message_once(_client, _userdata, message): self._logger.info("notify ready success in {} {}".format(session_id, job_id)) return replay - def add_request_response_callback(self, session_id: str, request_id: str, callback: Callable[[Any], Any]): - """Add a callback to be called when an error occurs while running a block.""" + def _add_callback( + self, + callbacks_dict: dict[str, list[Callable]], + key: str, + topic: str, + callback: Callable[[Any], Any] + ) -> None: + """Generic method to add a callback with subscription management. + + Args: + callbacks_dict: The dictionary storing callbacks (keyed by identifier) + key: The key to use in the callbacks dict + topic: The MQTT topic to subscribe to + callback: The callback function to add + """ if not callable(callback): raise ValueError("Callback must be callable") - - if request_id not in self.__request_response_callbacks: - self.__request_response_callbacks[request_id] = [] - self.subscribe(f"session/{session_id}/request/{request_id}/response", lambda payload: [cb(payload) for cb in self.__request_response_callbacks[request_id].copy()]) - self.__request_response_callbacks[request_id].append(callback) + if key not in callbacks_dict: + callbacks_dict[key] = [] + self.subscribe(topic, lambda payload: [cb(payload) for cb in callbacks_dict[key].copy()]) + + callbacks_dict[key].append(callback) + + def _remove_callback( + self, + callbacks_dict: dict[str, list[Callable]], + key: str, + topic: str, + callback: Callable[[Any], Any], + error_context: str + ) -> None: + """Generic method to remove a callback with subscription cleanup. + + Args: + callbacks_dict: The dictionary storing callbacks + key: The key in the callbacks dict + topic: The MQTT topic to unsubscribe from + callback: The callback function to remove + error_context: Context string for warning message + """ + if key in callbacks_dict and callback in callbacks_dict[key]: + callbacks_dict[key].remove(callback) + if len(callbacks_dict[key]) == 0: + del callbacks_dict[key] + self.unsubscribe(topic) + else: + self._logger.warning(f"Callback not found in {error_context}") + + def add_request_response_callback(self, session_id: str, request_id: str, callback: Callable[[Any], Any]): + """Add a callback to be called when a request response is received.""" + topic = f"session/{session_id}/request/{request_id}/response" + self._add_callback(self.__request_response_callbacks, request_id, topic, callback) def remove_request_response_callback(self, session_id: str, request_id: str, callback: Callable[[Any], Any]): - """Remove a previously added run block error callback.""" - if request_id in self.__request_response_callbacks and callback in self.__request_response_callbacks[request_id]: - self.__request_response_callbacks[request_id].remove(callback) - if len(self.__request_response_callbacks[request_id]) == 0: - del self.__request_response_callbacks[request_id] - self.unsubscribe(f"session/{session_id}/request/{request_id}/response") - else: - self._logger.warning("Callback not found in request/response callbacks for session {} and request {}.".format(session_id, request_id)) + """Remove a previously added request response callback.""" + topic = f"session/{session_id}/request/{request_id}/response" + self._remove_callback( + self.__request_response_callbacks, + request_id, + topic, + callback, + f"request/response callbacks for session {session_id} and request {request_id}" + ) def add_session_callback(self, session_id: str, callback: Callable[[dict], Any]): """Add a callback to be called when a session message is received.""" - if not callable(callback): - raise ValueError("Callback must be callable") - - if session_id not in self.__session_callbacks: - self.__session_callbacks[session_id] = [] - self.subscribe(f"session/{session_id}", lambda payload: [cb(payload) for cb in self.__session_callbacks[session_id].copy()]) - - self.__session_callbacks[session_id].append(callback) + topic = f"session/{session_id}" + self._add_callback(self.__session_callbacks, session_id, topic, callback) def remove_session_callback(self, session_id: str, callback: Callable[[dict], Any]): """Remove a previously added session callback.""" - if session_id in self.__session_callbacks and callback in self.__session_callbacks[session_id]: - self.__session_callbacks[session_id].remove(callback) - if len(self.__session_callbacks[session_id]) == 0: - del self.__session_callbacks[session_id] - self.unsubscribe(f"session/{session_id}") - else: - self._logger.warning("Callback not found in session callbacks for session: {}".format(session_id)) + topic = f"session/{session_id}" + self._remove_callback( + self.__session_callbacks, + session_id, + topic, + callback, + f"session callbacks for session: {session_id}" + ) def add_report_callback(self, fn): diff --git a/oocana/tests/test_mainframe_callbacks.py b/oocana/tests/test_mainframe_callbacks.py new file mode 100644 index 0000000..d321100 --- /dev/null +++ b/oocana/tests/test_mainframe_callbacks.py @@ -0,0 +1,157 @@ +import unittest +from unittest.mock import MagicMock, patch + + +class TestCallbackManagement(unittest.TestCase): + """Test cases for callback management methods in Mainframe.""" + + def setUp(self): + # Patch the mqtt client to avoid real network connections + self.mock_client_patcher = patch('paho.mqtt.client.Client') + self.mock_client_class = self.mock_client_patcher.start() + self.mock_client = MagicMock() + self.mock_client_class.return_value = self.mock_client + self.mock_client.is_connected.return_value = True + + from oocana import Mainframe + self.mainframe = Mainframe('mqtt://localhost:1883') + self.mainframe.client = self.mock_client + + def tearDown(self): + self.mock_client_patcher.stop() + + def test_add_request_response_callback(self): + """Test adding a request response callback.""" + session_id = 'test-session' + request_id = 'test-request' + callback = MagicMock() + + self.mainframe.add_request_response_callback(session_id, request_id, callback) + + # Verify subscribe was called with correct topic + expected_topic = f"session/{session_id}/request/{request_id}/response" + self.mock_client.message_callback_add.assert_called() + + def test_add_session_callback(self): + """Test adding a session callback.""" + session_id = 'test-session' + callback = MagicMock() + + self.mainframe.add_session_callback(session_id, callback) + + # Verify subscribe was called + self.mock_client.message_callback_add.assert_called() + + def test_add_callback_requires_callable(self): + """Test that non-callable raises ValueError.""" + session_id = 'test-session' + + with self.assertRaises(ValueError) as context: + self.mainframe.add_session_callback(session_id, "not a callable") + + self.assertIn("callable", str(context.exception)) + + def test_add_request_response_callback_requires_callable(self): + """Test that non-callable raises ValueError for request response callback.""" + with self.assertRaises(ValueError) as context: + self.mainframe.add_request_response_callback("session", "request", "not a callable") + + self.assertIn("callable", str(context.exception)) + + def test_remove_session_callback(self): + """Test removing a session callback.""" + session_id = 'test-session' + callback = MagicMock() + + # Add then remove + self.mainframe.add_session_callback(session_id, callback) + self.mainframe.remove_session_callback(session_id, callback) + + # Verify unsubscribe was called + self.mock_client.unsubscribe.assert_called() + + def test_remove_request_response_callback(self): + """Test removing a request response callback.""" + session_id = 'test-session' + request_id = 'test-request' + callback = MagicMock() + + # Add then remove + self.mainframe.add_request_response_callback(session_id, request_id, callback) + self.mainframe.remove_request_response_callback(session_id, request_id, callback) + + # Verify unsubscribe was called + self.mock_client.unsubscribe.assert_called() + + def test_remove_nonexistent_callback_logs_warning(self): + """Test that removing a nonexistent callback logs a warning.""" + session_id = 'test-session' + callback = MagicMock() + + # Create a mock logger + mock_logger = MagicMock() + self.mainframe._logger = mock_logger + + # Try to remove callback that was never added + self.mainframe.remove_session_callback(session_id, callback) + + # Verify warning was logged + mock_logger.warning.assert_called_once() + + def test_multiple_callbacks_for_same_session(self): + """Test that multiple callbacks can be added for the same session.""" + session_id = 'test-session' + callback1 = MagicMock() + callback2 = MagicMock() + + self.mainframe.add_session_callback(session_id, callback1) + self.mainframe.add_session_callback(session_id, callback2) + + # Remove first callback, should not unsubscribe yet + self.mainframe.remove_session_callback(session_id, callback1) + + # Subscribe should have been called only once (for first add) + call_count_before = self.mock_client.message_callback_add.call_count + + # Remove second callback, should unsubscribe now + self.mainframe.remove_session_callback(session_id, callback2) + + self.mock_client.unsubscribe.assert_called() + + def test_add_report_callback(self): + """Test adding a report callback.""" + callback = MagicMock() + + self.mainframe.add_report_callback(callback) + + # No error should occur + + def test_add_report_callback_requires_callable(self): + """Test that non-callable raises ValueError for report callback.""" + with self.assertRaises(ValueError) as context: + self.mainframe.add_report_callback("not a callable") + + self.assertIn("callable", str(context.exception)) + + def test_remove_report_callback(self): + """Test removing a report callback.""" + callback = MagicMock() + + self.mainframe.add_report_callback(callback) + self.mainframe.remove_report_callback(callback) + + # No error should occur + + def test_remove_nonexistent_report_callback_logs_warning(self): + """Test that removing a nonexistent report callback logs a warning.""" + callback = MagicMock() + mock_logger = MagicMock() + self.mainframe._logger = mock_logger + + self.mainframe.remove_report_callback(callback) + + mock_logger.warning.assert_called_once() + + +if __name__ == '__main__': + unittest.main()