Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,22 @@ async def on_tool_error_callback(
allows the original error to be raised.
"""
pass

async def on_pipeline_error_callback(
self,
*,
invocation_context: InvocationContext,
error: Exception,
) -> Exception:
"""Callback executed when the runner pipeline encounters an error.

This callback provides an opportunity to handle pipeline errors globally.

Args:
invocation_context: The context for the entire invocation.
error: The exception that was raised during runner execution.

Returns:
An Exception to be raised (either the original error or a new/modified one).
"""
return error
36 changes: 35 additions & 1 deletion src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
"on_pipeline_error_callback",
]

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -272,6 +273,33 @@ async def run_on_tool_error_callback(
error=error,
)

async def run_on_pipeline_error_callback(
self,
*,
invocation_context: InvocationContext,
error: Exception,
) -> Exception:
"""Runs the `on_pipeline_error_callback` for all plugins sequentially, chaining the error."""
for plugin in self.plugins:
try:
error = await plugin.on_pipeline_error_callback(
invocation_context=invocation_context, error=error
)
except Exception as e:
error_message = (
f"Error in plugin '{plugin.name}' during "
f"'on_pipeline_error_callback' callback: {e}"
)
logger.error(
"Error in plugin '%s' during 'on_pipeline_error_callback'"
" callback: %s",
plugin.name,
e,
exc_info=True,
)
raise RuntimeError(error_message) from e
return error

async def _run_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> Optional[Any]:
Expand Down Expand Up @@ -316,7 +344,13 @@ async def _run_callbacks(
f"Error in plugin '{plugin.name}' during '{callback_name}'"
f" callback: {e}"
)
logger.error(error_message, exc_info=True)
logger.error(
"Error in plugin '%s' during '%s' callback: %s",
plugin.name,
callback_name,
e,
exc_info=True,
)
raise RuntimeError(error_message) from e

return None
Expand Down
117 changes: 62 additions & 55 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,66 +1355,73 @@ async def _exec_with_plugin(

plugin_manager = invocation_context.plugin_manager

# Step 1: Run the before_run callbacks to see if we should early exit.
early_exit_result = await plugin_manager.run_before_run_callback(
invocation_context=invocation_context
)
if isinstance(early_exit_result, types.Content):
early_exit_event = Event(
invocation_id=invocation_context.invocation_id,
author='model',
content=early_exit_result,
)
_apply_run_config_custom_metadata(
early_exit_event, invocation_context.run_config
try:
# Step 1: Run the before_run callbacks to see if we should early exit.
early_exit_result = await plugin_manager.run_before_run_callback(
invocation_context=invocation_context
)
if self._should_append_event(early_exit_event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session,
event=early_exit_event,
if isinstance(early_exit_result, types.Content):
early_exit_event = Event(
invocation_id=invocation_context.invocation_id,
author='model',
content=early_exit_result,
)
yield early_exit_event
else:
# Step 2: Otherwise continue with normal execution
async with aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
_apply_run_config_custom_metadata(
event, invocation_context.run_config
)
# Step 3: Run the on_event callbacks before persisting so callback
# changes are stored in the session and match the streamed event.
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
)
output_event = self._get_output_event(
original_event=event,
modified_event=modified_event,
run_config=invocation_context.run_config,
_apply_run_config_custom_metadata(
early_exit_event, invocation_context.run_config
)
if self._should_append_event(early_exit_event, is_live_call):
await self.session_service.append_event(
session=invocation_context.session,
event=early_exit_event,
)
yield early_exit_event
else:
# Step 2: Otherwise continue with normal execution
async with aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
_apply_run_config_custom_metadata(
event, invocation_context.run_config
)
# Step 3: Run the on_event callbacks before persisting so callback
# changes are stored in the session and match the streamed event.
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
)
output_event = self._get_output_event(
original_event=event,
modified_event=modified_event,
run_config=invocation_context.run_config,
)

if is_live_call:
# Skip partial transcriptions for Live
if event.partial is not True and self._should_append_event(
event, is_live_call
):
logger.debug('Appending live event: %s', output_event)
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)
else:
if event.partial is not True:
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)

yield output_event
if is_live_call:
# Skip partial transcriptions for Live
if event.partial is not True and self._should_append_event(
event, is_live_call
):
logger.debug('Appending live event: %s', output_event)
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)
else:
if event.partial is not True:
await self.session_service.append_event(
session=invocation_context.session, event=output_event
)

# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
# This does NOT emit any event.
await plugin_manager.run_after_run_callback(
invocation_context=invocation_context
)
yield output_event
except Exception as e:
if plugin_manager:
e = await plugin_manager.run_on_pipeline_error_callback(
invocation_context=invocation_context, error=e
)
raise e
finally:
# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
# This does NOT emit any event.
await plugin_manager.run_after_run_callback(
invocation_context=invocation_context
)

async def _append_new_message_to_session(
self,
Expand Down
51 changes: 51 additions & 0 deletions tests/unittests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ async def after_model_callback(self, **kwargs):
async def on_model_error_callback(self, **kwargs):
return await self._handle_callback("on_model_error_callback")

async def on_pipeline_error_callback(self, error: Exception, **kwargs):
self.call_log.append("on_pipeline_error_callback")
if "on_pipeline_error_callback" in self.exceptions_to_raise:
raise self.exceptions_to_raise["on_pipeline_error_callback"]
return self.return_values.get("on_pipeline_error_callback", error)


@pytest.fixture
def service() -> PluginManager:
Expand Down Expand Up @@ -252,6 +258,10 @@ async def test_all_callbacks_are_supported(
llm_request=mock_context,
error=mock_context,
)
await service.run_on_pipeline_error_callback(
invocation_context=mock_context,
error=ValueError("err"),
)

# Verify all callbacks were logged
expected_callbacks = [
Expand All @@ -267,6 +277,7 @@ async def test_all_callbacks_are_supported(
"before_model_callback",
"after_model_callback",
"on_model_error_callback",
"on_pipeline_error_callback",
]
assert set(plugin1.call_log) == set(expected_callbacks)

Expand Down Expand Up @@ -363,3 +374,43 @@ async def test_set_skip_closing_plugins_false_reverts_to_closing(
await service.close()

plugin1.close.assert_awaited_once()


@pytest.mark.asyncio
async def test_pipeline_error_callback_chaining(
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
):
"""Tests that on_pipeline_error_callback is called and errors are chained."""
error1 = ValueError("Original error")
error2 = RuntimeError("Chained error")
plugin1.return_values["on_pipeline_error_callback"] = error2

service.register_plugin(plugin1)
service.register_plugin(plugin2)

result_err = await service.run_on_pipeline_error_callback(
invocation_context=Mock(), error=error1
)

assert result_err is error2
assert "on_pipeline_error_callback" in plugin1.call_log
assert "on_pipeline_error_callback" in plugin2.call_log


@pytest.mark.asyncio
async def test_pipeline_error_callback_exception_wrap(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that if on_pipeline_error_callback raises, it wraps in RuntimeError."""
plugin1.exceptions_to_raise["on_pipeline_error_callback"] = ValueError(
"Callback crashed"
)
service.register_plugin(plugin1)

with pytest.raises(RuntimeError) as excinfo:
await service.run_on_pipeline_error_callback(
invocation_context=Mock(), error=ValueError("Original")
)

assert "Error in plugin 'plugin1'" in str(excinfo.value)
assert "on_pipeline_error_callback" in str(excinfo.value)