From 32bf321ac81c694e90c3a5716e825235eb305894 Mon Sep 17 00:00:00 2001 From: Michael Porter Date: Fri, 6 Feb 2026 12:50:47 +0000 Subject: [PATCH 1/2] feat(plugins): wire on_state_change_callback into plugin framework Add plumbing so that plugins are notified when an event carries session state changes (non-empty state_delta). This closes a gap where BasePlugin had no default method, PluginManager had no dispatcher, and the runner never triggered the callback. Fixes https://github.com/google/adk-python/issues/4393 --- src/google/adk/plugins/base_plugin.py | 22 +++++++ src/google/adk/plugins/plugin_manager.py | 14 +++++ src/google/adk/runners.py | 14 ++++- tests/unittests/plugins/test_base_plugin.py | 17 +++++ .../unittests/plugins/test_plugin_manager.py | 62 +++++++++++++++++++ 5 files changed, 127 insertions(+), 2 deletions(-) diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 3639f61aa2..b8e5ce383e 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -370,3 +370,25 @@ async def on_tool_error_callback( allows the original error to be raised. """ pass + + async def on_state_change_callback( + self, + *, + callback_context: CallbackContext, + state_delta: dict[str, Any], + ) -> None: + """Callback executed when an event carries state changes. + + This callback is invoked after an event with a non-empty + ``state_delta`` is yielded from the runner. It is observational: + returning a value has no effect on execution flow. + + Args: + callback_context: The context for the current invocation. + state_delta: A copy of the state changes carried by the event. + Mutating this dict does not affect the original state. + + Returns: + None + """ + pass diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index c781e8fa4e..b5de924cd4 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -52,6 +52,7 @@ "after_model_callback", "on_tool_error_callback", "on_model_error_callback", + "on_state_change_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -257,6 +258,19 @@ async def run_on_tool_error_callback( error=error, ) + async def run_on_state_change_callback( + self, + *, + callback_context: CallbackContext, + state_delta: dict[str, Any], + ) -> Optional[None]: + """Runs the `on_state_change_callback` for all plugins.""" + return await self._run_callbacks( + "on_state_change_callback", + callback_context=callback_context, + state_delta=state_delta, + ) + async def _run_callbacks( self, callback_name: PluginCallbackName, **kwargs: Any ) -> Optional[Any]: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 545a0e83e6..906fc92448 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -839,9 +839,19 @@ async def _exec_with_plugin( _apply_run_config_custom_metadata( modified_event, invocation_context.run_config ) - yield modified_event + final_event = modified_event else: - yield event + final_event = event + yield final_event + + # Step 3b: Notify plugins of state changes, if any. + if final_event.actions.state_delta: + from .agents.callback_context import CallbackContext + + await plugin_manager.run_on_state_change_callback( + callback_context=CallbackContext(invocation_context), + state_delta=dict(final_event.actions.state_delta), + ) # Step 4: Run the after_run callbacks to perform global cleanup tasks or # finalizing logs and metrics data. diff --git a/tests/unittests/plugins/test_base_plugin.py b/tests/unittests/plugins/test_base_plugin.py index aa7c17fb01..fbe98b71df 100644 --- a/tests/unittests/plugins/test_base_plugin.py +++ b/tests/unittests/plugins/test_base_plugin.py @@ -79,6 +79,9 @@ async def after_model_callback(self, **kwargs) -> str: async def on_model_error_callback(self, **kwargs) -> str: return "overridden_on_model_error" + async def on_state_change_callback(self, **kwargs) -> str: + return "overridden_on_state_change" + def test_base_plugin_initialization(): """Tests that a plugin is initialized with the correct name.""" @@ -172,6 +175,13 @@ async def test_base_plugin_default_callbacks_return_none(): ) is None ) + assert ( + await plugin.on_state_change_callback( + callback_context=mock_context, + state_delta={}, + ) + is None + ) @pytest.mark.asyncio @@ -278,3 +288,10 @@ async def test_base_plugin_all_callbacks_can_be_overridden(): ) == "overridden_on_model_error" ) + assert ( + await plugin.on_state_change_callback( + callback_context=mock_callback_context, + state_delta={"key": "value"}, + ) + == "overridden_on_state_change" + ) diff --git a/tests/unittests/plugins/test_plugin_manager.py b/tests/unittests/plugins/test_plugin_manager.py index ba070ea8f3..fe47ee47a1 100644 --- a/tests/unittests/plugins/test_plugin_manager.py +++ b/tests/unittests/plugins/test_plugin_manager.py @@ -91,6 +91,9 @@ async def after_model_callback(self, **kwargs): async def on_model_error_callback(self, **kwargs): return await self._handle_callback("on_model_error_callback") + async def on_state_change_callback(self, **kwargs): + return await self._handle_callback("on_state_change_callback") + @pytest.fixture def service() -> PluginManager: @@ -252,6 +255,10 @@ async def test_all_callbacks_are_supported( llm_request=mock_context, error=mock_context, ) + await service.run_on_state_change_callback( + callback_context=mock_context, + state_delta={"key": "value"}, + ) # Verify all callbacks were logged expected_callbacks = [ @@ -267,6 +274,7 @@ async def test_all_callbacks_are_supported( "before_model_callback", "after_model_callback", "on_model_error_callback", + "on_state_change_callback", ] assert set(plugin1.call_log) == set(expected_callbacks) @@ -317,3 +325,57 @@ async def slow_close(): assert "Failed to close plugins: 'plugin1': TimeoutError" in str( excinfo.value ) + + +# --- on_state_change_callback tests --- + + +@pytest.mark.asyncio +async def test_run_on_state_change_callback( + service: PluginManager, plugin1: TestPlugin +): + """Tests that run_on_state_change_callback invokes the callback and returns None.""" + service.register_plugin(plugin1) + result = await service.run_on_state_change_callback( + callback_context=Mock(), + state_delta={"key": "value"}, + ) + assert result is None + assert "on_state_change_callback" in plugin1.call_log + + +@pytest.mark.asyncio +async def test_run_on_state_change_callback_calls_all_plugins( + service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin +): + """Tests that on_state_change_callback is called on all plugins.""" + service.register_plugin(plugin1) + service.register_plugin(plugin2) + + await service.run_on_state_change_callback( + callback_context=Mock(), + state_delta={"key": "value"}, + ) + + assert "on_state_change_callback" in plugin1.call_log + assert "on_state_change_callback" in plugin2.call_log + + +@pytest.mark.asyncio +async def test_run_on_state_change_callback_wraps_exceptions( + service: PluginManager, plugin1: TestPlugin +): + """Tests that exceptions in on_state_change_callback are wrapped in RuntimeError.""" + original_exception = ValueError("state change error") + plugin1.exceptions_to_raise["on_state_change_callback"] = original_exception + service.register_plugin(plugin1) + + with pytest.raises(RuntimeError) as excinfo: + await service.run_on_state_change_callback( + callback_context=Mock(), + state_delta={"key": "value"}, + ) + + assert "Error in plugin 'plugin1'" in str(excinfo.value) + assert "on_state_change_callback" in str(excinfo.value) + assert excinfo.value.__cause__ is original_exception From 1375419aa57431932425a55543ae489d663419fc Mon Sep 17 00:00:00 2001 From: Michael Porter Date: Fri, 6 Feb 2026 13:03:50 +0000 Subject: [PATCH 2/2] fix: address review comments from gemini-code-assist - Clarify docstring: non-None return short-circuits subsequent plugins - Fix return type: Optional[None] -> Optional[Any] to match _run_callbacks --- src/google/adk/plugins/base_plugin.py | 4 ++-- src/google/adk/plugins/plugin_manager.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index b8e5ce383e..4d5a016173 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -380,8 +380,8 @@ async def on_state_change_callback( """Callback executed when an event carries state changes. This callback is invoked after an event with a non-empty - ``state_delta`` is yielded from the runner. It is observational: - returning a value has no effect on execution flow. + ``state_delta`` is yielded from the runner. It is observational, but + returning a non-`None` value will short-circuit subsequent plugins. Args: callback_context: The context for the current invocation. diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index b5de924cd4..46954d7706 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -263,7 +263,7 @@ async def run_on_state_change_callback( *, callback_context: CallbackContext, state_delta: dict[str, Any], - ) -> Optional[None]: + ) -> Optional[Any]: """Runs the `on_state_change_callback` for all plugins.""" return await self._run_callbacks( "on_state_change_callback",