Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,25 @@ async def on_tool_error_callback(
allows the original error to be raised.
"""
pass

async def on_state_change_callback(
self,
*,
callback_context: CallbackContext,
state_delta: dict[str, Any],
) -> None:
"""Callback executed when an event carries state changes.
This callback is invoked after an event with a non-empty
``state_delta`` is yielded from the runner. It is observational, but
returning a non-`None` value will short-circuit subsequent plugins.
Args:
callback_context: The context for the current invocation.
state_delta: A copy of the state changes carried by the event.
Mutating this dict does not affect the original state.
Returns:
None
"""
pass
14 changes: 14 additions & 0 deletions src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
"on_state_change_callback",
]

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -257,6 +258,19 @@ async def run_on_tool_error_callback(
error=error,
)

async def run_on_state_change_callback(
self,
*,
callback_context: CallbackContext,
state_delta: dict[str, Any],
) -> Optional[Any]:
"""Runs the `on_state_change_callback` for all plugins."""
return await self._run_callbacks(
"on_state_change_callback",
callback_context=callback_context,
state_delta=state_delta,
)

async def _run_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> Optional[Any]:
Expand Down
14 changes: 12 additions & 2 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,19 @@ async def _exec_with_plugin(
_apply_run_config_custom_metadata(
modified_event, invocation_context.run_config
)
yield modified_event
final_event = modified_event
else:
yield event
final_event = event
yield final_event

# Step 3b: Notify plugins of state changes, if any.
if final_event.actions.state_delta:
from .agents.callback_context import CallbackContext

await plugin_manager.run_on_state_change_callback(
callback_context=CallbackContext(invocation_context),
state_delta=dict(final_event.actions.state_delta),
)

# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/plugins/test_base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ async def after_model_callback(self, **kwargs) -> str:
async def on_model_error_callback(self, **kwargs) -> str:
return "overridden_on_model_error"

async def on_state_change_callback(self, **kwargs) -> str:
return "overridden_on_state_change"


def test_base_plugin_initialization():
"""Tests that a plugin is initialized with the correct name."""
Expand Down Expand Up @@ -172,6 +175,13 @@ async def test_base_plugin_default_callbacks_return_none():
)
is None
)
assert (
await plugin.on_state_change_callback(
callback_context=mock_context,
state_delta={},
)
is None
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -278,3 +288,10 @@ async def test_base_plugin_all_callbacks_can_be_overridden():
)
== "overridden_on_model_error"
)
assert (
await plugin.on_state_change_callback(
callback_context=mock_callback_context,
state_delta={"key": "value"},
)
== "overridden_on_state_change"
)
62 changes: 62 additions & 0 deletions tests/unittests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ async def after_model_callback(self, **kwargs):
async def on_model_error_callback(self, **kwargs):
return await self._handle_callback("on_model_error_callback")

async def on_state_change_callback(self, **kwargs):
return await self._handle_callback("on_state_change_callback")


@pytest.fixture
def service() -> PluginManager:
Expand Down Expand Up @@ -252,6 +255,10 @@ async def test_all_callbacks_are_supported(
llm_request=mock_context,
error=mock_context,
)
await service.run_on_state_change_callback(
callback_context=mock_context,
state_delta={"key": "value"},
)

# Verify all callbacks were logged
expected_callbacks = [
Expand All @@ -267,6 +274,7 @@ async def test_all_callbacks_are_supported(
"before_model_callback",
"after_model_callback",
"on_model_error_callback",
"on_state_change_callback",
]
assert set(plugin1.call_log) == set(expected_callbacks)

Expand Down Expand Up @@ -317,3 +325,57 @@ async def slow_close():
assert "Failed to close plugins: 'plugin1': TimeoutError" in str(
excinfo.value
)


# --- on_state_change_callback tests ---


@pytest.mark.asyncio
async def test_run_on_state_change_callback(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that run_on_state_change_callback invokes the callback and returns None."""
service.register_plugin(plugin1)
result = await service.run_on_state_change_callback(
callback_context=Mock(),
state_delta={"key": "value"},
)
assert result is None
assert "on_state_change_callback" in plugin1.call_log


@pytest.mark.asyncio
async def test_run_on_state_change_callback_calls_all_plugins(
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
):
"""Tests that on_state_change_callback is called on all plugins."""
service.register_plugin(plugin1)
service.register_plugin(plugin2)

await service.run_on_state_change_callback(
callback_context=Mock(),
state_delta={"key": "value"},
)

assert "on_state_change_callback" in plugin1.call_log
assert "on_state_change_callback" in plugin2.call_log


@pytest.mark.asyncio
async def test_run_on_state_change_callback_wraps_exceptions(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that exceptions in on_state_change_callback are wrapped in RuntimeError."""
original_exception = ValueError("state change error")
plugin1.exceptions_to_raise["on_state_change_callback"] = original_exception
service.register_plugin(plugin1)

with pytest.raises(RuntimeError) as excinfo:
await service.run_on_state_change_callback(
callback_context=Mock(),
state_delta={"key": "value"},
)

assert "Error in plugin 'plugin1'" in str(excinfo.value)
assert "on_state_change_callback" in str(excinfo.value)
assert excinfo.value.__cause__ is original_exception