diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 3639f61aa2..4d5a016173 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -370,3 +370,25 @@ async def on_tool_error_callback( allows the original error to be raised. """ pass + + async def on_state_change_callback( + self, + *, + callback_context: CallbackContext, + state_delta: dict[str, Any], + ) -> None: + """Callback executed when an event carries state changes. + + This callback is invoked after an event with a non-empty + ``state_delta`` is yielded from the runner. It is observational, but + returning a non-`None` value will short-circuit subsequent plugins. + + Args: + callback_context: The context for the current invocation. + state_delta: A copy of the state changes carried by the event. + Mutating this dict does not affect the original state. + + Returns: + None + """ + pass diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index c781e8fa4e..46954d7706 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -52,6 +52,7 @@ "after_model_callback", "on_tool_error_callback", "on_model_error_callback", + "on_state_change_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -257,6 +258,19 @@ async def run_on_tool_error_callback( error=error, ) + async def run_on_state_change_callback( + self, + *, + callback_context: CallbackContext, + state_delta: dict[str, Any], + ) -> Optional[Any]: + """Runs the `on_state_change_callback` for all plugins.""" + return await self._run_callbacks( + "on_state_change_callback", + callback_context=callback_context, + state_delta=state_delta, + ) + async def _run_callbacks( self, callback_name: PluginCallbackName, **kwargs: Any ) -> Optional[Any]: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 545a0e83e6..906fc92448 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -839,9 +839,19 @@ async def _exec_with_plugin( _apply_run_config_custom_metadata( modified_event, invocation_context.run_config ) - yield modified_event + final_event = modified_event else: - yield event + final_event = event + yield final_event + + # Step 3b: Notify plugins of state changes, if any. + if final_event.actions.state_delta: + from .agents.callback_context import CallbackContext + + await plugin_manager.run_on_state_change_callback( + callback_context=CallbackContext(invocation_context), + state_delta=dict(final_event.actions.state_delta), + ) # Step 4: Run the after_run callbacks to perform global cleanup tasks or # finalizing logs and metrics data. diff --git a/tests/unittests/plugins/test_base_plugin.py b/tests/unittests/plugins/test_base_plugin.py index aa7c17fb01..fbe98b71df 100644 --- a/tests/unittests/plugins/test_base_plugin.py +++ b/tests/unittests/plugins/test_base_plugin.py @@ -79,6 +79,9 @@ async def after_model_callback(self, **kwargs) -> str: async def on_model_error_callback(self, **kwargs) -> str: return "overridden_on_model_error" + async def on_state_change_callback(self, **kwargs) -> str: + return "overridden_on_state_change" + def test_base_plugin_initialization(): """Tests that a plugin is initialized with the correct name.""" @@ -172,6 +175,13 @@ async def test_base_plugin_default_callbacks_return_none(): ) is None ) + assert ( + await plugin.on_state_change_callback( + callback_context=mock_context, + state_delta={}, + ) + is None + ) @pytest.mark.asyncio @@ -278,3 +288,10 @@ async def test_base_plugin_all_callbacks_can_be_overridden(): ) == "overridden_on_model_error" ) + assert ( + await plugin.on_state_change_callback( + callback_context=mock_callback_context, + state_delta={"key": "value"}, + ) + == "overridden_on_state_change" + ) diff --git a/tests/unittests/plugins/test_plugin_manager.py b/tests/unittests/plugins/test_plugin_manager.py index ba070ea8f3..fe47ee47a1 100644 --- a/tests/unittests/plugins/test_plugin_manager.py +++ b/tests/unittests/plugins/test_plugin_manager.py @@ -91,6 +91,9 @@ async def after_model_callback(self, **kwargs): async def on_model_error_callback(self, **kwargs): return await self._handle_callback("on_model_error_callback") + async def on_state_change_callback(self, **kwargs): + return await self._handle_callback("on_state_change_callback") + @pytest.fixture def service() -> PluginManager: @@ -252,6 +255,10 @@ async def test_all_callbacks_are_supported( llm_request=mock_context, error=mock_context, ) + await service.run_on_state_change_callback( + callback_context=mock_context, + state_delta={"key": "value"}, + ) # Verify all callbacks were logged expected_callbacks = [ @@ -267,6 +274,7 @@ async def test_all_callbacks_are_supported( "before_model_callback", "after_model_callback", "on_model_error_callback", + "on_state_change_callback", ] assert set(plugin1.call_log) == set(expected_callbacks) @@ -317,3 +325,57 @@ async def slow_close(): assert "Failed to close plugins: 'plugin1': TimeoutError" in str( excinfo.value ) + + +# --- on_state_change_callback tests --- + + +@pytest.mark.asyncio +async def test_run_on_state_change_callback( + service: PluginManager, plugin1: TestPlugin +): + """Tests that run_on_state_change_callback invokes the callback and returns None.""" + service.register_plugin(plugin1) + result = await service.run_on_state_change_callback( + callback_context=Mock(), + state_delta={"key": "value"}, + ) + assert result is None + assert "on_state_change_callback" in plugin1.call_log + + +@pytest.mark.asyncio +async def test_run_on_state_change_callback_calls_all_plugins( + service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin +): + """Tests that on_state_change_callback is called on all plugins.""" + service.register_plugin(plugin1) + service.register_plugin(plugin2) + + await service.run_on_state_change_callback( + callback_context=Mock(), + state_delta={"key": "value"}, + ) + + assert "on_state_change_callback" in plugin1.call_log + assert "on_state_change_callback" in plugin2.call_log + + +@pytest.mark.asyncio +async def test_run_on_state_change_callback_wraps_exceptions( + service: PluginManager, plugin1: TestPlugin +): + """Tests that exceptions in on_state_change_callback are wrapped in RuntimeError.""" + original_exception = ValueError("state change error") + plugin1.exceptions_to_raise["on_state_change_callback"] = original_exception + service.register_plugin(plugin1) + + with pytest.raises(RuntimeError) as excinfo: + await service.run_on_state_change_callback( + callback_context=Mock(), + state_delta={"key": "value"}, + ) + + assert "Error in plugin 'plugin1'" in str(excinfo.value) + assert "on_state_change_callback" in str(excinfo.value) + assert excinfo.value.__cause__ is original_exception