-
Notifications
You must be signed in to change notification settings - Fork 1
fix(oocana): extract common callback management logic (DRY) #453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -128,47 +128,86 @@ def on_message_once(_client, _userdata, message): | |||||||||||||||||||||
| self._logger.info("notify ready success in {} {}".format(session_id, job_id)) | ||||||||||||||||||||||
| return replay | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def add_request_response_callback(self, session_id: str, request_id: str, callback: Callable[[Any], Any]): | ||||||||||||||||||||||
| """Add a callback to be called when an error occurs while running a block.""" | ||||||||||||||||||||||
| def _add_callback( | ||||||||||||||||||||||
| self, | ||||||||||||||||||||||
| callbacks_dict: dict[str, list[Callable]], | ||||||||||||||||||||||
| key: str, | ||||||||||||||||||||||
| topic: str, | ||||||||||||||||||||||
| callback: Callable[[Any], Any] | ||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||
| """Generic method to add a callback with subscription management. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Args: | ||||||||||||||||||||||
| callbacks_dict: The dictionary storing callbacks (keyed by identifier) | ||||||||||||||||||||||
| key: The key to use in the callbacks dict | ||||||||||||||||||||||
| topic: The MQTT topic to subscribe to | ||||||||||||||||||||||
| callback: The callback function to add | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| if not callable(callback): | ||||||||||||||||||||||
| raise ValueError("Callback must be callable") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if request_id not in self.__request_response_callbacks: | ||||||||||||||||||||||
| self.__request_response_callbacks[request_id] = [] | ||||||||||||||||||||||
| self.subscribe(f"session/{session_id}/request/{request_id}/response", lambda payload: [cb(payload) for cb in self.__request_response_callbacks[request_id].copy()]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| self.__request_response_callbacks[request_id].append(callback) | ||||||||||||||||||||||
| if key not in callbacks_dict: | ||||||||||||||||||||||
| callbacks_dict[key] = [] | ||||||||||||||||||||||
| self.subscribe(topic, lambda payload: [cb(payload) for cb in callbacks_dict[key].copy()]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
Comment on lines
+151
to
+152
|
||||||||||||||||||||||
| self.subscribe(topic, lambda payload: [cb(payload) for cb in callbacks_dict[key].copy()]) | |
| def _dispatch_payload(payload, _callbacks_dict=callbacks_dict, _key=key): | |
| callbacks = _callbacks_dict.get(_key) | |
| if not callbacks: | |
| return | |
| for cb in callbacks.copy(): | |
| cb(payload) | |
| self.subscribe(topic, _dispatch_payload) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
避免回调移除后触发 KeyError
当最后一个回调被移除后会 del callbacks_dict[key],若此时仍有滞后消息进入订阅回调,callbacks_dict[key] 会抛出 KeyError,导致 MQTT 回调异常。建议用 get 安全读取并用显式循环执行回调。
🛠️ 建议修复
- self.subscribe(topic, lambda payload: [cb(payload) for cb in callbacks_dict[key].copy()])
+ def _dispatch(payload):
+ for cb in list(callbacks_dict.get(key, [])):
+ cb(payload)
+ self.subscribe(topic, _dispatch)🤖 Prompt for AI Agents
In `@oocana/oocana/mainframe.py` around lines 149 - 176, The current subscription
handler closes over callbacks_dict[key] and can raise KeyError if the last
callback is removed concurrently; update the subscribe call (the lambda created
where callbacks_dict[key] is set) to safely capture the callbacks list via
callbacks_dict.get(key, []) and copy it into a local variable, then iterate with
an explicit for-loop calling each cb(payload). Also ensure the removal logic in
_remove_callback remains unchanged except to rely on this safe access pattern so
late-arriving messages won't access callbacks_dict[key] after del; reference the
subscribe call, the lambda using callbacks_dict[key].copy(), and the
_remove_callback method when making this change.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| import unittest | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
|
|
||
| class TestCallbackManagement(unittest.TestCase): | ||
| """Test cases for callback management methods in Mainframe.""" | ||
|
|
||
| def setUp(self): | ||
| # Patch the mqtt client to avoid real network connections | ||
| self.mock_client_patcher = patch('paho.mqtt.client.Client') | ||
| self.mock_client_class = self.mock_client_patcher.start() | ||
| self.mock_client = MagicMock() | ||
| self.mock_client_class.return_value = self.mock_client | ||
| self.mock_client.is_connected.return_value = True | ||
|
|
||
| from oocana import Mainframe | ||
| self.mainframe = Mainframe('mqtt://localhost:1883') | ||
| self.mainframe.client = self.mock_client | ||
|
|
||
| def tearDown(self): | ||
| self.mock_client_patcher.stop() | ||
|
|
||
| def test_add_request_response_callback(self): | ||
| """Test adding a request response callback.""" | ||
| session_id = 'test-session' | ||
| request_id = 'test-request' | ||
| callback = MagicMock() | ||
|
|
||
| self.mainframe.add_request_response_callback(session_id, request_id, callback) | ||
|
|
||
| # Verify subscribe was called with correct topic | ||
| expected_topic = f"session/{session_id}/request/{request_id}/response" | ||
| self.mock_client.message_callback_add.assert_called() | ||
|
|
||
|
Comment on lines
+31
to
+34
|
||
| def test_add_session_callback(self): | ||
| """Test adding a session callback.""" | ||
| session_id = 'test-session' | ||
| callback = MagicMock() | ||
|
|
||
| self.mainframe.add_session_callback(session_id, callback) | ||
|
|
||
| # Verify subscribe was called | ||
| self.mock_client.message_callback_add.assert_called() | ||
|
|
||
| def test_add_callback_requires_callable(self): | ||
| """Test that non-callable raises ValueError.""" | ||
| session_id = 'test-session' | ||
|
|
||
| with self.assertRaises(ValueError) as context: | ||
| self.mainframe.add_session_callback(session_id, "not a callable") | ||
|
|
||
| self.assertIn("callable", str(context.exception)) | ||
|
|
||
| def test_add_request_response_callback_requires_callable(self): | ||
| """Test that non-callable raises ValueError for request response callback.""" | ||
| with self.assertRaises(ValueError) as context: | ||
| self.mainframe.add_request_response_callback("session", "request", "not a callable") | ||
|
|
||
| self.assertIn("callable", str(context.exception)) | ||
|
Comment on lines
+45
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 类型检查失败:传入非可调用值需显式忽略 CI 已报 🧾 建议修复- self.mainframe.add_session_callback(session_id, "not a callable")
+ self.mainframe.add_session_callback(session_id, "not a callable") # type: ignore[arg-type]
@@
- self.mainframe.add_request_response_callback("session", "request", "not a callable")
+ self.mainframe.add_request_response_callback("session", "request", "not a callable") # type: ignore[arg-type]🧰 Tools🪛 GitHub Actions: layer[error] 50-50: Argument of type "Literal['not a callable']" cannot be assigned to parameter "callback" of type "(dict[Unknown, Unknown]) -> Any" in function "add_session_callback". Type "Literal['not a callable']" is not assignable to type "(dict[Unknown, Unknown]) -> Any" (reportArgumentType) [error] 57-57: Argument of type "Literal['not a callable']" cannot be assigned to parameter "callback" of type "(Any) -> Any" in function "add_request_response_callback". Type "Literal['not a callable']" is not assignable to type "(Any) -> Any" (reportArgumentType) 🪛 GitHub Actions: pr[error] 50-50: Argument of type "Literal['not a callable']" cannot be assigned to parameter "callback" of type "(dict[Unknown, Unknown]) -> Any" in function "add_session_callback" (reportArgumentType) [error] 57-57: Argument of type "Literal['not a callable']" cannot be assigned to parameter "callback" of type "(Any) -> Any" in function "add_request_response_callback" (reportArgumentType) 🤖 Prompt for AI Agents |
||
|
|
||
| def test_remove_session_callback(self): | ||
| """Test removing a session callback.""" | ||
| session_id = 'test-session' | ||
| callback = MagicMock() | ||
|
|
||
| # Add then remove | ||
| self.mainframe.add_session_callback(session_id, callback) | ||
| self.mainframe.remove_session_callback(session_id, callback) | ||
|
|
||
| # Verify unsubscribe was called | ||
| self.mock_client.unsubscribe.assert_called() | ||
|
|
||
| def test_remove_request_response_callback(self): | ||
| """Test removing a request response callback.""" | ||
| session_id = 'test-session' | ||
| request_id = 'test-request' | ||
| callback = MagicMock() | ||
|
|
||
| # Add then remove | ||
| self.mainframe.add_request_response_callback(session_id, request_id, callback) | ||
| self.mainframe.remove_request_response_callback(session_id, request_id, callback) | ||
|
|
||
| # Verify unsubscribe was called | ||
| self.mock_client.unsubscribe.assert_called() | ||
|
|
||
| def test_remove_nonexistent_callback_logs_warning(self): | ||
| """Test that removing a nonexistent callback logs a warning.""" | ||
| session_id = 'test-session' | ||
| callback = MagicMock() | ||
|
|
||
| # Create a mock logger | ||
| mock_logger = MagicMock() | ||
| self.mainframe._logger = mock_logger | ||
|
|
||
| # Try to remove callback that was never added | ||
| self.mainframe.remove_session_callback(session_id, callback) | ||
|
|
||
| # Verify warning was logged | ||
| mock_logger.warning.assert_called_once() | ||
|
|
||
| def test_multiple_callbacks_for_same_session(self): | ||
| """Test that multiple callbacks can be added for the same session.""" | ||
| session_id = 'test-session' | ||
| callback1 = MagicMock() | ||
| callback2 = MagicMock() | ||
|
|
||
| self.mainframe.add_session_callback(session_id, callback1) | ||
| self.mainframe.add_session_callback(session_id, callback2) | ||
|
|
||
| # Remove first callback, should not unsubscribe yet | ||
| self.mainframe.remove_session_callback(session_id, callback1) | ||
|
|
||
| # Subscribe should have been called only once (for first add) | ||
| call_count_before = self.mock_client.message_callback_add.call_count | ||
|
|
||
| # Remove second callback, should unsubscribe now | ||
| self.mainframe.remove_session_callback(session_id, callback2) | ||
|
|
||
| self.mock_client.unsubscribe.assert_called() | ||
|
Comment on lines
+113
to
+119
|
||
|
|
||
| def test_add_report_callback(self): | ||
| """Test adding a report callback.""" | ||
| callback = MagicMock() | ||
|
|
||
| self.mainframe.add_report_callback(callback) | ||
|
|
||
| # No error should occur | ||
|
|
||
| def test_add_report_callback_requires_callable(self): | ||
| """Test that non-callable raises ValueError for report callback.""" | ||
| with self.assertRaises(ValueError) as context: | ||
| self.mainframe.add_report_callback("not a callable") | ||
|
|
||
| self.assertIn("callable", str(context.exception)) | ||
|
|
||
| def test_remove_report_callback(self): | ||
| """Test removing a report callback.""" | ||
| callback = MagicMock() | ||
|
|
||
| self.mainframe.add_report_callback(callback) | ||
| self.mainframe.remove_report_callback(callback) | ||
|
|
||
| # No error should occur | ||
|
|
||
| def test_remove_nonexistent_report_callback_logs_warning(self): | ||
| """Test that removing a nonexistent report callback logs a warning.""" | ||
| callback = MagicMock() | ||
| mock_logger = MagicMock() | ||
| self.mainframe._logger = mock_logger | ||
|
|
||
| self.mainframe.remove_report_callback(callback) | ||
|
|
||
| mock_logger.warning.assert_called_once() | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR description references
_add_callback_with_subscription()/_remove_callback()helper names, but the implementation introduces_add_callback()/_remove_callback(). Please align the PR description (or rename the helper) so reviewers/users aren’t looking for the wrong method names.