diff --git a/src/google/adk/agents/loop_agent.py b/src/google/adk/agents/loop_agent.py index 2980f68ac0..432cddb846 100644 --- a/src/google/adk/agents/loop_agent.py +++ b/src/google/adk/agents/loop_agent.py @@ -26,8 +26,10 @@ from typing_extensions import override from ..events.event import Event +from ..events.event_actions import EventActions from ..features import experimental from ..features import FeatureName +from ..termination.termination_condition import TerminationCondition from ..utils.context_utils import Aclosing from .base_agent import BaseAgent from .base_agent import BaseAgentState @@ -66,6 +68,18 @@ class LoopAgent(BaseAgent): escalates. """ + termination_condition: Optional[TerminationCondition] = None + """An optional termination condition that controls when the loop stops. + + The condition is evaluated after each event emitted by a sub-agent. When + it fires, the loop yields a final event with + ``actions.termination_reason`` set and ``actions.escalate`` set to + ``True``, then stops. + + The condition is automatically reset at the start of each ``_run_async_impl`` + call, so the same instance can be reused across multiple runs. + """ + @override async def _run_async_impl( self, ctx: InvocationContext @@ -73,6 +87,9 @@ async def _run_async_impl( if not self.sub_agents: return + if self.termination_condition: + await self.termination_condition.reset() + agent_state = self._load_agent_state(ctx, LoopAgentState) is_resuming_at_current_agent = agent_state is not None times_looped, start_index = self._get_start_state(agent_state) @@ -102,6 +119,21 @@ async def _run_async_impl( yield event if event.actions.escalate: should_exit = True + + if self.termination_condition and not should_exit: + result = await self.termination_condition.check([event]) + if result: + termination_event = Event( + invocation_id=ctx.invocation_id, + author=self.name, + actions=EventActions( + escalate=True, + termination_reason=result.reason, + ), + ) + yield termination_event + return + if ctx.should_pause_invocation(event): pause_invocation = True diff --git a/src/google/adk/events/event_actions.py b/src/google/adk/events/event_actions.py index cfa73324b5..62ba9a5175 100644 --- a/src/google/adk/events/event_actions.py +++ b/src/google/adk/events/event_actions.py @@ -112,3 +112,7 @@ class EventActions(BaseModel): render_ui_widgets: Optional[list[UiWidget]] = None """List of UI widgets to be rendered by the UI.""" + + termination_reason: Optional[str] = None + """The human-readable reason the conversation was terminated by a + TerminationCondition. Only set on synthetic termination events.""" diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 8e352794a4..e6388cfacc 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -61,6 +61,7 @@ from .sessions.in_memory_session_service import InMemorySessionService from .sessions.session import Session from .telemetry.tracing import tracer +from .termination.termination_condition import TerminationCondition from .tools.base_toolset import BaseToolset from .utils._debug_output import print_event from .utils.context_utils import Aclosing @@ -509,6 +510,7 @@ async def run_async( new_message: Optional[types.Content] = None, state_delta: Optional[dict[str, Any]] = None, run_config: Optional[RunConfig] = None, + termination_condition: Optional[TerminationCondition] = None, ) -> AsyncGenerator[Event, None]: """Main entry method to run the agent in this runner. @@ -526,6 +528,8 @@ async def run_async( new_message: A new message to append to the session. state_delta: Optional state changes to apply to the session. run_config: The run config for the agent. + termination_condition: An optional condition that stops the run when + triggered. Reset automatically before the run begins. Yields: The events generated by the agent. @@ -602,9 +606,24 @@ async def _run_with_trace( return async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: + if termination_condition: + await termination_condition.reset() async with Aclosing(ctx.agent.run_async(ctx)) as agen: async for event in agen: yield event + if termination_condition: + termination_result = await termination_condition.check([event]) + if termination_result: + termination_event = Event( + invocation_id=ctx.invocation_id, + author=ctx.agent.name, + actions=EventActions( + escalate=True, + termination_reason=termination_result.reason, + ), + ) + yield termination_event + return async with Aclosing( self._exec_with_plugin( diff --git a/src/google/adk/termination/__init__.py b/src/google/adk/termination/__init__.py new file mode 100644 index 0000000000..9049106bb2 --- /dev/null +++ b/src/google/adk/termination/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .external_termination import ExternalTermination +from .function_call_termination import FunctionCallTermination +from .max_iterations_termination import MaxIterationsTermination +from .termination_condition import AndTerminationCondition +from .termination_condition import OrTerminationCondition +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult +from .text_mention_termination import TextMentionTermination +from .timeout_termination import TimeoutTermination +from .token_usage_termination import TokenUsageTermination + +__all__ = [ + 'AndTerminationCondition', + 'ExternalTermination', + 'FunctionCallTermination', + 'MaxIterationsTermination', + 'OrTerminationCondition', + 'TerminationCondition', + 'TerminationResult', + 'TextMentionTermination', + 'TimeoutTermination', + 'TokenUsageTermination', +] diff --git a/src/google/adk/termination/external_termination.py b/src/google/adk/termination/external_termination.py new file mode 100644 index 0000000000..85b2314c31 --- /dev/null +++ b/src/google/adk/termination/external_termination.py @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A termination condition controlled programmatically via ``set()``.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from ..events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class ExternalTermination(TerminationCondition): + """A termination condition that is controlled externally by calling ``set()``. + + Useful for integrating external stop signals such as a UI "Stop" button + or application-level logic. + + Example:: + + stop_button = ExternalTermination() + + agent = LoopAgent( + name='my_loop', + sub_agents=[...], + termination_condition=stop_button, + ) + + # Elsewhere (e.g. from a UI event handler): + stop_button.set() + """ + + def __init__(self) -> None: + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + def set(self) -> None: + """Signals that the conversation should terminate at the next check.""" + self._terminated = True + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return TerminationResult(reason='Externally terminated') + return None + + async def reset(self) -> None: + self._terminated = False diff --git a/src/google/adk/termination/function_call_termination.py b/src/google/adk/termination/function_call_termination.py new file mode 100644 index 0000000000..1d43df8a6f --- /dev/null +++ b/src/google/adk/termination/function_call_termination.py @@ -0,0 +1,60 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates when a specific function (tool) has been executed.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from ..events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class FunctionCallTermination(TerminationCondition): + """Terminates when a tool with a specific name has been executed. + + The condition checks ``FunctionResponse`` parts in events. + + Example:: + + # Stop when the "approve" tool is called + condition = FunctionCallTermination('approve') + """ + + def __init__(self, function_name: str) -> None: + self._function_name = function_name + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + for event in events: + for response in event.get_function_responses(): + if response.name == self._function_name: + self._terminated = True + return TerminationResult( + reason=f"Function '{self._function_name}' was executed" + ) + return None + + async def reset(self) -> None: + self._terminated = False diff --git a/src/google/adk/termination/max_iterations_termination.py b/src/google/adk/termination/max_iterations_termination.py new file mode 100644 index 0000000000..17b832a293 --- /dev/null +++ b/src/google/adk/termination/max_iterations_termination.py @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates after a maximum number of events have been processed.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from ..events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class MaxIterationsTermination(TerminationCondition): + """Terminates the conversation after a maximum number of events. + + Example:: + + # Stop after 10 events + condition = MaxIterationsTermination(10) + """ + + def __init__(self, max_iterations: int) -> None: + if max_iterations <= 0: + raise ValueError('max_iterations must be a positive integer.') + self._max_iterations = max_iterations + self._count = 0 + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + self._count += len(events) + + if self._count >= self._max_iterations: + self._terminated = True + return TerminationResult( + reason=( + f'Maximum iterations of {self._max_iterations} reached,' + f' current count: {self._count}' + ) + ) + return None + + async def reset(self) -> None: + self._terminated = False + self._count = 0 diff --git a/src/google/adk/termination/termination_condition.py b/src/google/adk/termination/termination_condition.py new file mode 100644 index 0000000000..cbe51ff9a2 --- /dev/null +++ b/src/google/adk/termination/termination_condition.py @@ -0,0 +1,171 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base termination condition and compound combinators.""" + +from __future__ import annotations + +import abc +from dataclasses import dataclass +from typing import Optional +from typing import Sequence + +from ..events.event import Event + + +@dataclass +class TerminationResult: + """The result returned by a termination condition when the conversation should stop.""" + + reason: str + """A human-readable description of why the conversation was terminated.""" + + +class TerminationCondition(abc.ABC): + """Abstract base class for all termination conditions. + + A termination condition is evaluated after each event in the agent loop. + When ``check()`` returns a ``TerminationResult``, the loop stops and the + ``reason`` is surfaced in the final event's ``actions.termination_reason``. + + Conditions are stateful but reset automatically at the start of each run. + They can be combined with ``.and_()`` and ``.or_()`` to create compound + logic. + + Example:: + + condition = MaxIterationsTermination(10).or_( + TextMentionTermination('TERMINATE') + ) + """ + + @property + @abc.abstractmethod + def terminated(self) -> bool: + """Whether this termination condition has been reached.""" + + @abc.abstractmethod + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + """Checks whether the termination condition is met. + + Called after each event emitted by the agent. Returns a + ``TerminationResult`` if the loop should stop, or ``None`` to continue. + + Args: + events: The delta sequence of events since the last check. + """ + + @abc.abstractmethod + async def reset(self) -> None: + """Resets this condition to its initial state. + + Called automatically at the start of each run so the same instance can + be reused across multiple runs. + """ + + def and_(self, other: TerminationCondition) -> TerminationCondition: + """Returns a new condition that terminates only when BOTH conditions are met. + + Args: + other: The other termination condition. + """ + return AndTerminationCondition(self, other) + + def or_(self, other: TerminationCondition) -> TerminationCondition: + """Returns a new condition that terminates when EITHER condition is met. + + Args: + other: The other termination condition. + """ + return OrTerminationCondition(self, other) + + def __and__(self, other: TerminationCondition) -> TerminationCondition: + """Supports ``condition_a & condition_b`` syntax.""" + return self.and_(other) + + def __or__(self, other: TerminationCondition) -> TerminationCondition: + """Supports ``condition_a | condition_b`` syntax.""" + return self.or_(other) + + +class AndTerminationCondition(TerminationCondition): + """A compound condition that terminates only when ALL children have fired.""" + + def __init__( + self, + left: TerminationCondition, + right: TerminationCondition, + ) -> None: + self._left = left + self._right = right + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + # Forward to both children so each accumulates its own state. + await self._left.check(events) + await self._right.check(events) + + if self._left.terminated and self._right.terminated: + self._terminated = True + return TerminationResult(reason='All termination conditions met') + return None + + async def reset(self) -> None: + self._terminated = False + await self._left.reset() + await self._right.reset() + + +class OrTerminationCondition(TerminationCondition): + """A compound condition that terminates when ANY child fires first.""" + + def __init__( + self, + left: TerminationCondition, + right: TerminationCondition, + ) -> None: + self._left = left + self._right = right + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + left_result = await self._left.check(events) + if left_result: + self._terminated = True + return left_result + + right_result = await self._right.check(events) + if right_result: + self._terminated = True + return right_result + + return None + + async def reset(self) -> None: + self._terminated = False + await self._left.reset() + await self._right.reset() diff --git a/src/google/adk/termination/text_mention_termination.py b/src/google/adk/termination/text_mention_termination.py new file mode 100644 index 0000000000..4f52da7600 --- /dev/null +++ b/src/google/adk/termination/text_mention_termination.py @@ -0,0 +1,77 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates when a specific text string is found in event content.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from ..events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +def _stringify_event_content(event: Event) -> str: + """Extracts a text representation from an event's content.""" + if not event.content or not event.content.parts: + return '' + texts = [] + for part in event.content.parts: + if part.text: + texts.append(part.text) + return ' '.join(texts) + + +class TextMentionTermination(TerminationCondition): + """Terminates the conversation when a specific text is found in event content. + + Example:: + + # Stop when any agent says "TERMINATE" + condition = TextMentionTermination('TERMINATE') + + # Stop only when the "critic" agent says "APPROVE" + condition = TextMentionTermination('APPROVE', sources=['critic']) + """ + + def __init__( + self, + text: str, + sources: Optional[Sequence[str]] = None, + ) -> None: + self._text = text + self._sources = list(sources) if sources else None + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + for event in events: + if self._sources and (event.author or '') not in self._sources: + continue + + if self._text in _stringify_event_content(event): + self._terminated = True + return TerminationResult(reason=f"Text '{self._text}' mentioned") + return None + + async def reset(self) -> None: + self._terminated = False diff --git a/src/google/adk/termination/timeout_termination.py b/src/google/adk/termination/timeout_termination.py new file mode 100644 index 0000000000..9cc7dac358 --- /dev/null +++ b/src/google/adk/termination/timeout_termination.py @@ -0,0 +1,70 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates after a specified duration has elapsed.""" + +from __future__ import annotations + +import time +from typing import Optional +from typing import Sequence + +from ..events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class TimeoutTermination(TerminationCondition): + """Terminates the conversation after a specified duration has elapsed. + + The timer starts on the first ``check()`` call. + + Example:: + + # Stop after 30 seconds + condition = TimeoutTermination(30) + """ + + def __init__(self, timeout_seconds: float) -> None: + if timeout_seconds <= 0: + raise ValueError('timeout_seconds must be a positive number.') + self._timeout_seconds = timeout_seconds + self._start_time: Optional[float] = None + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + if self._start_time is None: + self._start_time = time.monotonic() + + elapsed = time.monotonic() - self._start_time + if elapsed >= self._timeout_seconds: + self._terminated = True + return TerminationResult( + reason=( + f'Timeout of {self._timeout_seconds}s reached' + f' (elapsed: {elapsed:.2f}s)' + ) + ) + return None + + async def reset(self) -> None: + self._terminated = False + self._start_time = None diff --git a/src/google/adk/termination/token_usage_termination.py b/src/google/adk/termination/token_usage_termination.py new file mode 100644 index 0000000000..845b9acac1 --- /dev/null +++ b/src/google/adk/termination/token_usage_termination.py @@ -0,0 +1,129 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Terminates when cumulative token usage exceeds a limit.""" + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from ..events.event import Event +from .termination_condition import TerminationCondition +from .termination_condition import TerminationResult + + +class TokenUsageTermination(TerminationCondition): + """Terminates when cumulative token usage exceeds configured limits. + + At least one of the token limits must be provided. + + Example:: + + # Stop after 10000 total tokens + condition = TokenUsageTermination(max_total_tokens=10_000) + + # Stop after 5000 prompt tokens OR 2000 completion tokens + condition = TokenUsageTermination( + max_prompt_tokens=5_000, + max_completion_tokens=2_000, + ) + """ + + def __init__( + self, + *, + max_total_tokens: Optional[int] = None, + max_prompt_tokens: Optional[int] = None, + max_completion_tokens: Optional[int] = None, + ) -> None: + if ( + max_total_tokens is None + and max_prompt_tokens is None + and max_completion_tokens is None + ): + raise ValueError( + 'At least one of max_total_tokens, max_prompt_tokens, or' + ' max_completion_tokens must be provided.' + ) + self._max_total_tokens = max_total_tokens + self._max_prompt_tokens = max_prompt_tokens + self._max_completion_tokens = max_completion_tokens + self._total_tokens = 0 + self._prompt_tokens = 0 + self._completion_tokens = 0 + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def check(self, events: Sequence[Event]) -> Optional[TerminationResult]: + if self._terminated: + return None + + for event in events: + if not event.usage_metadata: + continue + + self._total_tokens += event.usage_metadata.total_token_count or 0 + self._prompt_tokens += event.usage_metadata.prompt_token_count or 0 + self._completion_tokens += ( + event.usage_metadata.candidates_token_count or 0 + ) + + if ( + self._max_total_tokens is not None + and self._total_tokens >= self._max_total_tokens + ): + self._terminated = True + return TerminationResult( + reason=( + f'Token limit exceeded: total_tokens={self._total_tokens}' + f' >= max_total_tokens={self._max_total_tokens}' + ) + ) + + if ( + self._max_prompt_tokens is not None + and self._prompt_tokens >= self._max_prompt_tokens + ): + self._terminated = True + return TerminationResult( + reason=( + f'Token limit exceeded: prompt_tokens={self._prompt_tokens}' + f' >= max_prompt_tokens={self._max_prompt_tokens}' + ) + ) + + if ( + self._max_completion_tokens is not None + and self._completion_tokens >= self._max_completion_tokens + ): + self._terminated = True + return TerminationResult( + reason=( + 'Token limit exceeded:' + f' completion_tokens={self._completion_tokens}' + f' >= max_completion_tokens={self._max_completion_tokens}' + ) + ) + + return None + + async def reset(self) -> None: + self._terminated = False + self._total_tokens = 0 + self._prompt_tokens = 0 + self._completion_tokens = 0 diff --git a/tests/unittests/agents/test_loop_agent.py b/tests/unittests/agents/test_loop_agent.py index 68f5d963d3..a5b4042402 100644 --- a/tests/unittests/agents/test_loop_agent.py +++ b/tests/unittests/agents/test_loop_agent.py @@ -24,6 +24,8 @@ from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.termination.max_iterations_termination import MaxIterationsTermination +from google.adk.termination.text_mention_termination import TextMentionTermination from google.genai import types import pytest from typing_extensions import override @@ -249,3 +251,84 @@ async def test_run_async_with_escalate_action( ), ] assert simplified_events == expected_events + + +@pytest.mark.asyncio +async def test_run_async_with_termination_condition_stops_loop( + request: pytest.FixtureRequest, +): + """Termination condition fires after first event and stops the loop.""" + agent = _TestingAgent(name=f'{request.function.__name__}_agent') + loop_agent = LoopAgent( + name=f'{request.function.__name__}_loop', + max_iterations=5, + sub_agents=[agent], + termination_condition=MaxIterationsTermination(1), + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, loop_agent + ) + + events = [e async for e in loop_agent.run_async(parent_ctx)] + + # The sub-agent emits one text event, then the loop emits a termination event. + assert len(events) == 2 + text_event, termination_event = events + assert text_event.author == agent.name + assert termination_event.author == loop_agent.name + assert termination_event.actions.escalate is True + assert termination_event.actions.termination_reason is not None + + +@pytest.mark.asyncio +async def test_run_async_with_termination_condition_text_mention( + request: pytest.FixtureRequest, +): + """TextMentionTermination fires when the keyword appears in an event.""" + agent = _TestingAgent(name=f'{request.function.__name__}_agent') + loop_agent = LoopAgent( + name=f'{request.function.__name__}_loop', + max_iterations=10, + sub_agents=[agent], + # 'Hello' appears in the sub-agent's response, so it should fire at once. + termination_condition=TextMentionTermination('Hello'), + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, loop_agent + ) + + events = [e async for e in loop_agent.run_async(parent_ctx)] + + termination_event = events[-1] + assert termination_event.author == loop_agent.name + assert termination_event.actions.escalate is True + assert termination_event.actions.termination_reason is not None + + +@pytest.mark.asyncio +async def test_run_async_termination_condition_resets_between_runs( + request: pytest.FixtureRequest, +): + """The termination condition is reset at the start of each run.""" + agent = _TestingAgent(name=f'{request.function.__name__}_agent') + condition = MaxIterationsTermination(1) + loop_agent = LoopAgent( + name=f'{request.function.__name__}_loop', + max_iterations=5, + sub_agents=[agent], + termination_condition=condition, + ) + + # First run – condition fires and gets reset automatically. + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, loop_agent + ) + events_first = [e async for e in loop_agent.run_async(parent_ctx)] + assert events_first[-1].actions.termination_reason is not None + + # Second run – condition must have been reset; should fire again identically. + parent_ctx2 = await _create_parent_invocation_context( + request.function.__name__ + '_2', loop_agent + ) + events_second = [e async for e in loop_agent.run_async(parent_ctx2)] + assert events_second[-1].actions.termination_reason is not None diff --git a/tests/unittests/runners/test_runner_termination.py b/tests/unittests/runners/test_runner_termination.py new file mode 100644 index 0000000000..af541d1b7f --- /dev/null +++ b/tests/unittests/runners/test_runner_termination.py @@ -0,0 +1,160 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Runner.run_async with termination_condition.""" + +from __future__ import annotations + +from typing import AsyncGenerator + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.termination.max_iterations_termination import MaxIterationsTermination +from google.adk.termination.text_mention_termination import TextMentionTermination +from google.genai import types +import pytest +from typing_extensions import override + + +class _TextAgent(BaseAgent): + """A simple agent that yields a fixed text event and then stops.""" + + text: str = 'hello' + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + invocation_id=ctx.invocation_id, + author=self.name, + content=types.Content( + role='model', + parts=[types.Part(text=self.text)], + ), + ) + + +async def _run_with_termination( + agent: BaseAgent, + termination_condition=None, +) -> list[Event]: + """Creates a runner with an in-memory session and collects all events.""" + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + runner = Runner( + app_name='test_app', + agent=agent, + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + events = [] + async for event in runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=types.Content(role='user', parts=[types.Part(text='go')]), + termination_condition=termination_condition, + ): + events.append(event) + return events + + +@pytest.mark.asyncio +async def test_run_async_without_termination_condition(): + """Baseline: runner emits the agent event with no condition attached.""" + agent = _TextAgent(name='agent') + events = await _run_with_termination(agent) + assert any(e.author == 'agent' for e in events) + assert not any( + e.actions.termination_reason for e in events + ), 'No termination event should exist without a condition' + + +@pytest.mark.asyncio +async def test_run_async_termination_condition_stops_run(): + """Termination condition fires after the first event and stops the run.""" + agent = _TextAgent(name='agent', text='hello world') + condition = MaxIterationsTermination(1) + + events = await _run_with_termination(agent, termination_condition=condition) + + # The last event must be a synthetic termination event. + termination_event = events[-1] + assert termination_event.actions.escalate is True + assert termination_event.actions.termination_reason is not None + + +@pytest.mark.asyncio +async def test_run_async_text_mention_termination(): + """TextMentionTermination fires when the keyword is found in an event.""" + agent = _TextAgent(name='agent', text='STOP now') + condition = TextMentionTermination('STOP') + + events = await _run_with_termination(agent, termination_condition=condition) + + termination_event = events[-1] + assert termination_event.actions.escalate is True + assert 'STOP' in termination_event.actions.termination_reason + + +@pytest.mark.asyncio +async def test_run_async_text_mention_termination_keyword_absent(): + """Run completes normally when the keyword is not present in any event.""" + agent = _TextAgent(name='agent', text='everything is fine') + condition = TextMentionTermination('STOP') + + events = await _run_with_termination(agent, termination_condition=condition) + + assert not any( + e.actions.termination_reason for e in events + ), 'No termination event expected when keyword is absent' + + +@pytest.mark.asyncio +async def test_run_async_termination_condition_resets_between_runs(): + """The condition is reset at the start of each run_async call.""" + agent = _TextAgent(name='agent', text='hello') + condition = MaxIterationsTermination(1) + + session_service = InMemorySessionService() + runner = Runner( + app_name='test_app', + agent=agent, + session_service=session_service, + artifact_service=InMemoryArtifactService(), + ) + + for _ in range(2): + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + events = [] + async for event in runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=types.Content(role='user', parts=[types.Part(text='go')]), + termination_condition=condition, + ): + events.append(event) + + # Each run should emit a termination event, proving the reset happened. + assert ( + events[-1].actions.termination_reason is not None + ), 'Expected a termination event on every run' diff --git a/tests/unittests/termination/__init__.py b/tests/unittests/termination/__init__.py new file mode 100644 index 0000000000..58d482ea38 --- /dev/null +++ b/tests/unittests/termination/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/termination/test_termination_conditions.py b/tests/unittests/termination/test_termination_conditions.py new file mode 100644 index 0000000000..5af60af2b8 --- /dev/null +++ b/tests/unittests/termination/test_termination_conditions.py @@ -0,0 +1,472 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for termination conditions.""" + +from __future__ import annotations + +import asyncio +import time + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.termination.external_termination import ExternalTermination +from google.adk.termination.function_call_termination import FunctionCallTermination +from google.adk.termination.max_iterations_termination import MaxIterationsTermination +from google.adk.termination.termination_condition import AndTerminationCondition +from google.adk.termination.termination_condition import OrTerminationCondition +from google.adk.termination.text_mention_termination import TextMentionTermination +from google.adk.termination.timeout_termination import TimeoutTermination +from google.adk.termination.token_usage_termination import TokenUsageTermination +from google.genai import types +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_text_event(text: str, author: str = 'agent') -> Event: + return Event( + invocation_id='inv-1', + author=author, + actions=EventActions(), + content=types.Content( + role='model', + parts=[types.Part(text=text)], + ), + ) + + +def _make_token_event( + total_tokens: int, + prompt_tokens: int, + completion_tokens: int, +) -> Event: + return Event( + invocation_id='inv-1', + author='agent', + actions=EventActions(), + usage_metadata=types.GenerateContentResponseUsageMetadata( + total_token_count=total_tokens, + prompt_token_count=prompt_tokens, + candidates_token_count=completion_tokens, + ), + ) + + +def _make_function_response_event(function_name: str) -> Event: + return Event( + invocation_id='inv-1', + author='agent', + actions=EventActions(), + content=types.Content( + role='model', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=function_name, + response={'result': 'ok'}, + ) + ) + ], + ), + ) + + +# --------------------------------------------------------------------------- +# MaxIterationsTermination +# --------------------------------------------------------------------------- + + +class TestMaxIterationsTermination: + + def test_raises_if_not_positive(self): + with pytest.raises(ValueError): + MaxIterationsTermination(0) + with pytest.raises(ValueError): + MaxIterationsTermination(-1) + + @pytest.mark.asyncio + async def test_does_not_terminate_before_limit(self): + condition = MaxIterationsTermination(3) + result = await condition.check([_make_text_event('hello')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_terminates_at_limit(self): + condition = MaxIterationsTermination(3) + await condition.check([_make_text_event('a'), _make_text_event('b')]) + result = await condition.check([_make_text_event('c')]) + assert result is not None + assert '3' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_does_not_fire_again_after_termination(self): + condition = MaxIterationsTermination(1) + await condition.check([_make_text_event('first')]) + assert condition.terminated is True + second = await condition.check([_make_text_event('second')]) + assert second is None + + @pytest.mark.asyncio + async def test_reset(self): + condition = MaxIterationsTermination(1) + await condition.check([_make_text_event('first')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + result = await condition.check([_make_text_event('first again')]) + assert result is not None + + +# --------------------------------------------------------------------------- +# TextMentionTermination +# --------------------------------------------------------------------------- + + +class TestTextMentionTermination: + + @pytest.mark.asyncio + async def test_terminates_when_text_found(self): + condition = TextMentionTermination('TERMINATE') + result = await condition.check([_make_text_event('Please TERMINATE now.')]) + assert result is not None + assert 'TERMINATE' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_does_not_terminate_when_absent(self): + condition = TextMentionTermination('TERMINATE') + result = await condition.check([_make_text_event('Keep going!')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_respects_sources_filter(self): + condition = TextMentionTermination('APPROVE', sources=['critic']) + + # Wrong source — should NOT fire + no_fire = await condition.check([_make_text_event('APPROVE', 'primary')]) + assert no_fire is None + + # Correct source — should fire + fire = await condition.check([_make_text_event('APPROVE', 'critic')]) + assert fire is not None + + @pytest.mark.asyncio + async def test_reset(self): + condition = TextMentionTermination('DONE') + await condition.check([_make_text_event('DONE')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + result = await condition.check([_make_text_event('not done yet')]) + assert result is None + + +# --------------------------------------------------------------------------- +# TokenUsageTermination +# --------------------------------------------------------------------------- + + +class TestTokenUsageTermination: + + def test_raises_if_no_limit(self): + with pytest.raises(ValueError): + TokenUsageTermination() + + @pytest.mark.asyncio + async def test_terminates_on_total_tokens(self): + condition = TokenUsageTermination(max_total_tokens=100) + result = await condition.check([_make_token_event(101, 50, 51)]) + assert result is not None + assert 'total_tokens' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_terminates_on_prompt_tokens(self): + condition = TokenUsageTermination(max_prompt_tokens=50) + result = await condition.check([_make_token_event(60, 55, 5)]) + assert result is not None + assert 'prompt_tokens' in result.reason + + @pytest.mark.asyncio + async def test_terminates_on_completion_tokens(self): + condition = TokenUsageTermination(max_completion_tokens=30) + result = await condition.check([_make_token_event(40, 5, 35)]) + assert result is not None + assert 'completion_tokens' in result.reason + + @pytest.mark.asyncio + async def test_accumulates_across_events(self): + condition = TokenUsageTermination(max_total_tokens=100) + await condition.check([_make_token_event(60, 40, 20)]) + assert condition.terminated is False + + result = await condition.check([_make_token_event(50, 30, 20)]) + assert result is not None + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_ignores_events_without_usage(self): + condition = TokenUsageTermination(max_total_tokens=10) + result = await condition.check([_make_text_event('no tokens here')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_reset(self): + condition = TokenUsageTermination(max_total_tokens=100) + await condition.check([_make_token_event(200, 100, 100)]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + result = await condition.check([_make_token_event(50, 30, 20)]) + assert result is None + + +# --------------------------------------------------------------------------- +# TimeoutTermination +# --------------------------------------------------------------------------- + + +class TestTimeoutTermination: + + def test_raises_if_not_positive(self): + with pytest.raises(ValueError): + TimeoutTermination(0) + with pytest.raises(ValueError): + TimeoutTermination(-5) + + @pytest.mark.asyncio + async def test_does_not_terminate_before_timeout(self): + condition = TimeoutTermination(60) + result = await condition.check([_make_text_event('hello')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_terminates_after_timeout(self): + condition = TimeoutTermination(0.01) # 10ms + # Warm up the start time. + await condition.check([_make_text_event('trigger start')]) + # Wait slightly longer than the timeout. + await asyncio.sleep(0.02) + + result = await condition.check([_make_text_event('after timeout')]) + assert result is not None + assert 'Timeout' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_reset(self): + condition = TimeoutTermination(0.01) + await condition.check([_make_text_event('start')]) + await asyncio.sleep(0.02) + await condition.check([_make_text_event('fires')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + # After reset a fresh check starts a new timer. + result = await condition.check([_make_text_event('fresh start')]) + assert result is None + + +# --------------------------------------------------------------------------- +# FunctionCallTermination +# --------------------------------------------------------------------------- + + +class TestFunctionCallTermination: + + @pytest.mark.asyncio + async def test_terminates_on_matching_function(self): + condition = FunctionCallTermination('approve') + result = await condition.check([_make_function_response_event('approve')]) + assert result is not None + assert 'approve' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_does_not_terminate_for_different_function(self): + condition = FunctionCallTermination('approve') + result = await condition.check([_make_function_response_event('search')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_does_not_terminate_on_text_only(self): + condition = FunctionCallTermination('approve') + result = await condition.check([_make_text_event('approve this')]) + assert result is None + + @pytest.mark.asyncio + async def test_reset(self): + condition = FunctionCallTermination('approve') + await condition.check([_make_function_response_event('approve')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + +# --------------------------------------------------------------------------- +# ExternalTermination +# --------------------------------------------------------------------------- + + +class TestExternalTermination: + + @pytest.mark.asyncio + async def test_does_not_terminate_before_set(self): + condition = ExternalTermination() + result = await condition.check([_make_text_event('anything')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_terminates_after_set(self): + condition = ExternalTermination() + condition.set() + result = await condition.check([_make_text_event('anything')]) + assert result is not None + assert 'Externally terminated' in result.reason + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_reset(self): + condition = ExternalTermination() + condition.set() + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + result = await condition.check([_make_text_event('should not fire')]) + assert result is None + + +# --------------------------------------------------------------------------- +# OrTerminationCondition (.or_()) +# --------------------------------------------------------------------------- + + +class TestOrTerminationCondition: + + @pytest.mark.asyncio + async def test_terminates_on_first(self): + condition = MaxIterationsTermination(1).or_(TextMentionTermination('DONE')) + result = await condition.check([_make_text_event('any')]) + assert result is not None + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_terminates_on_second(self): + condition = MaxIterationsTermination(100).or_( + TextMentionTermination('DONE') + ) + result = await condition.check([_make_text_event('DONE')]) + assert result is not None + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_does_not_terminate_when_neither_fires(self): + condition = MaxIterationsTermination(100).or_( + TextMentionTermination('DONE') + ) + result = await condition.check([_make_text_event('keep going')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_reset_both_children(self): + condition = MaxIterationsTermination(1).or_(TextMentionTermination('DONE')) + await condition.check([_make_text_event('fires')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + def test_is_or_instance(self): + condition = MaxIterationsTermination(1).or_(TextMentionTermination('X')) + assert isinstance(condition, OrTerminationCondition) + + @pytest.mark.asyncio + async def test_pipe_operator(self): + condition = MaxIterationsTermination(1) | TextMentionTermination('DONE') + result = await condition.check([_make_text_event('any')]) + assert result is not None + assert isinstance(condition, OrTerminationCondition) + + +# --------------------------------------------------------------------------- +# AndTerminationCondition (.and_()) +# --------------------------------------------------------------------------- + + +class TestAndTerminationCondition: + + @pytest.mark.asyncio + async def test_does_not_terminate_when_only_first_fires(self): + condition = MaxIterationsTermination(1).and_(TextMentionTermination('DONE')) + result = await condition.check([_make_text_event('no keyword here')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_does_not_terminate_when_only_second_fires(self): + condition = MaxIterationsTermination(100).and_( + TextMentionTermination('DONE') + ) + result = await condition.check([_make_text_event('DONE')]) + assert result is None + assert condition.terminated is False + + @pytest.mark.asyncio + async def test_terminates_when_both_fire(self): + left = MaxIterationsTermination(1) + right = TextMentionTermination('DONE') + condition = left.and_(right) + + result = await condition.check([_make_text_event('DONE')]) + assert result is not None + assert condition.terminated is True + + @pytest.mark.asyncio + async def test_reset_both_children(self): + condition = MaxIterationsTermination(1).and_(TextMentionTermination('DONE')) + await condition.check([_make_text_event('DONE')]) + assert condition.terminated is True + + await condition.reset() + assert condition.terminated is False + + def test_is_and_instance(self): + condition = MaxIterationsTermination(1).and_(TextMentionTermination('X')) + assert isinstance(condition, AndTerminationCondition) + + @pytest.mark.asyncio + async def test_ampersand_operator(self): + condition = MaxIterationsTermination(1) & TextMentionTermination('DONE') + result = await condition.check([_make_text_event('DONE')]) + assert result is not None + assert isinstance(condition, AndTerminationCondition)