diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index 5b0100818c..243e5d15b2 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -17,7 +17,9 @@ import asyncio import copy import importlib +import json import logging +import time from typing import Any from typing import AsyncGenerator from typing import Optional @@ -26,12 +28,14 @@ from google.genai import errors from google.genai import types from google.genai.types import Content +import opentelemetry.context as context_api +from opentelemetry.trace import set_span_in_context +from opentelemetry.trace import Span from pydantic import BaseModel from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosedOK from ..agents.callback_context import CallbackContext -from ..agents.invocation_context import InvocationContext from ..agents.live_request_queue import LiveRequestQueue from ..agents.llm_agent import Agent from ..agents.run_config import RunConfig @@ -39,7 +43,6 @@ from ..artifacts.base_artifact_service import BaseArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..events.event import Event -from ..flows.llm_flows.functions import handle_function_calls_live from ..memory.base_memory_service import BaseMemoryService from ..memory.in_memory_memory_service import InMemoryMemoryService from ..models.llm_request import LlmRequest @@ -47,6 +50,8 @@ from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session +from ..telemetry import tracing as _telemetry +from ..telemetry._token_usage import TokenUsage from ..utils.context_utils import Aclosing from ._retry_options_utils import EnsureRetryOptionsPlugin from .app_details import AgentDetails @@ -59,6 +64,7 @@ from .eval_case import SessionInput from .eval_set import EvalSet from .request_intercepter_plugin import _RequestIntercepterPlugin +from .simulation.user_simulator import BaseUserSimulatorConfig from .simulation.user_simulator import Status as UserSimulatorStatus from .simulation.user_simulator import UserSimulator from .simulation.user_simulator_provider import UserSimulatorProvider @@ -68,6 +74,121 @@ _USER_AUTHOR = "user" _DEFAULT_AUTHOR = "agent" +# Fallback idle window for draining a turn's events when no finished=True +# transcription arrives. Each received event resets the window. +_TRANSCRIPTION_TAIL_GRACE_SECONDS = 2.0 + +# `Event.custom_metadata` key marking the turn-final `turn_complete` (as +# opposed to the intermediate one emitted when the model issues a tool call). +_FINAL_TURN_COMPLETE_METADATA_KEY = "final_turn_complete" + + +def _extract_content_text(content: Optional[Content]) -> Optional[str]: + """Joins the text parts of a `Content` into a single string, or None.""" + if not content or not content.parts: + return None + text = " ".join(p.text for p in content.parts if p.text) + return text or None + + +def _record_live_turn_telemetry( + span: Span, + events: list[Event], + invocation_id: str, + user_message: Optional[Content], +) -> None: + """Records the turn's request/response and chronology on a `live_turn` span. + + Utterances are the connection's consolidated (non-partial) transcriptions, + one per speech segment between tool calls. `llm_response` carries one part + per utterance, and each utterance / tool call / tool response is also + recorded as a timestamped span event. Token usage reported by the live API + (one `usage_metadata` event per generation) is summed onto the span using + the same `gen_ai.usage.*` attributes as non-live `call_llm` spans. + """ + user_text = _extract_content_text(user_message) + if user_text: + span.set_attribute( + "gcp.vertex.agent.llm_request", + json.dumps({ + "contents": [{ + "role": "user", + "parts": [{"text": user_text}], + }], + }), + ) + + utterances: list[str] = [] + usage_metadatas: list[types.GenerateContentResponseUsageMetadata] = [] + + for evt in events: + if evt.invocation_id != invocation_id or evt.author == _USER_AUTHOR: + continue + if evt.usage_metadata: + usage_metadatas.append(evt.usage_metadata) + if evt.get_function_calls(): + for call in evt.get_function_calls(): + span.add_event( + "tool_call", + attributes={ + "tool": call.name or "", + "args": json.dumps(call.args, default=str), + }, + timestamp=int(evt.timestamp * 1e9), + ) + elif evt.get_function_responses(): + for response in evt.get_function_responses(): + span.add_event( + "tool_response", + attributes={ + "tool": response.name or "", + "response": json.dumps(response.response, default=str), + }, + timestamp=int(evt.timestamp * 1e9), + ) + elif ( + evt.output_transcription + and evt.output_transcription.text + and not evt.partial + ): + utterances.append(evt.output_transcription.text) + span.add_event( + "model_utterance", + attributes={"text": evt.output_transcription.text}, + timestamp=int(evt.timestamp * 1e9), + ) + + if usage_metadatas: + + def _sum_tokens(field: str) -> Optional[int]: + values = [ + getattr(usage, field) + for usage in usage_metadatas + if getattr(usage, field) is not None + ] + return sum(values) if values else None + + aggregated_usage = types.GenerateContentResponseUsageMetadata( + prompt_token_count=_sum_tokens("prompt_token_count"), + candidates_token_count=_sum_tokens("candidates_token_count"), + thoughts_token_count=_sum_tokens("thoughts_token_count"), + tool_use_prompt_token_count=_sum_tokens("tool_use_prompt_token_count"), + cached_content_token_count=_sum_tokens("cached_content_token_count"), + total_token_count=_sum_tokens("total_token_count"), + ) + span.set_attributes(TokenUsage(aggregated_usage).to_attributes()) + + if utterances: + span.set_attribute( + "gcp.vertex.agent.llm_response", + json.dumps({ + "content": { + "role": "model", + "parts": [{"text": text} for text in utterances], + }, + }), + ) + class EvalCaseResponses(BaseModel): """Contains multiple responses associated with an EvalCase. @@ -99,6 +220,13 @@ def __init__( self.live_finished = asyncio.Event() self.current_invocation_id = Event.new_id() self.consume_task = None + # OTel context whose current span is the per-turn `live_turn`. Set by + # the main task before sending a user message and cleared after the + # turn completes. The consume task attaches it around + # `handle_function_calls_live` so tool spans are parented under + # `live_turn` (i.e. live in the same trace as their turn) instead of + # under whatever ambient context this task happens to have. + self.current_turn_context: Optional[context_api.Context] = None async def __aenter__(self) -> _LiveSession: """Starts the background task.""" @@ -124,6 +252,25 @@ async def _consume_events(self) -> None: self.session, self.runner.agent ) + # Run before_agent_callback before any instruction preprocessing. + # `agent.run_live` (bypassed below) would normally fire this. Without it, + # an agent that seeds session state in before_agent_callback raises + # KeyError when `_preprocess_async` renders a `{state_var}` referenced by + # its instruction template. The callback writes through to + # `session.state` (State.__setitem__), and we append the resulting event + # so the state delta is persisted for non-in-memory session services too. + before_agent_event = ( + await invocation_context.agent._handle_before_agent_callback( + invocation_context + ) + ) + if before_agent_event: + await self.runner.session_service.append_event( + session=self.session, event=before_agent_event + ) + if invocation_context.end_invocation: + return + callback_context = None llm_request = LlmRequest() @@ -146,82 +293,85 @@ async def _consume_events(self) -> None: ) in_function_call_loop = False + # Bypass `agent.run_live`: it wraps the flow in `record_agent_invocation` + # which opens a single long-lived `invoke_agent` span covering the + # entire session. That collapses every turn into one trace and adds an + # empty/erroring trace to session views (the WebSocket close at + # session-end gets recorded as a span exception). Call the impl + # directly; per-turn `live_turn` spans (opened by the main task) take + # over the role of session-grouped invocation spans for eval purposes. + # `before_agent_callback` is run explicitly above (it must fire before + # `_preprocess_async`, and `run_live` would fire it a second time); + # `after_agent_callback` is still skipped here — fine for evals of + # agents that don't rely on it. async with Aclosing( - invocation_context.agent.run_live(invocation_context) + invocation_context.agent._run_live_impl(invocation_context) ) as agen: - async for event in agen: - assert event is not None - event.invocation_id = self.current_invocation_id - if callback_context: - await invocation_context.plugin_manager.run_after_model_callback( - callback_context=callback_context, - llm_response=event, - ) - await self.event_queue.put(event) - if not event.partial: - await self.runner.session_service.append_event( - session=self.session, event=event - ) - function_calls = event.get_function_calls() - if function_calls: - in_function_call_loop = True - inv_context = InvocationContext( - session_service=self.runner.session_service, - invocation_id=event.invocation_id, - agent=self.runner.agent, - session=self.session, - run_config=run_config, - ) - - if isinstance(self.runner.agent, Agent): - resolved_tools = await self.runner.agent.canonical_tools( - inv_context - ) - tools_dict = {t.name: t for t in resolved_tools} - else: - tools_dict = {} - + # Drive the generator manually so we can attach + # `current_turn_context` around BOTH the generator's internal work + # (which also runs `handle_function_calls_live` — see + # base_llm_flow.py:_receive_from_model) and our own body below. + # Without this, the first invocation of the tool — the one inside + # `_run_live_impl` — runs without `live_turn` as parent and lands + # as an orphan root span. + while True: + token = ( + context_api.attach(self.current_turn_context) + if self.current_turn_context is not None + else None + ) + try: try: - response_event = await handle_function_calls_live( - invocation_context=inv_context, - function_call_event=event, - tools_dict=tools_dict, + event = await agen.__anext__() + except StopAsyncIteration: + break + assert event is not None + event.invocation_id = self.current_invocation_id + if callback_context: + await invocation_context.plugin_manager.run_after_model_callback( + callback_context=callback_context, + llm_response=event, ) - - if ( - response_event - and response_event.content - and response_event.content.parts - ): - for part in response_event.content.parts: - if part.function_response: - tool_content = types.Content( - role="tool", - parts=[part], - ) - self.live_request_queue.send_content(tool_content) - except (ValueError, RuntimeError, KeyError, TypeError) as e: - logger.error( - "Failed to handle function calls: %s", - e, - exc_info=True, + # Tag the turn-final `turn_complete` before queueing so the + # drain can recognize the end of the turn. + if ( + event.turn_complete + and event.author != _USER_AUTHOR + and not in_function_call_loop + ): + event.custom_metadata = { + **(event.custom_metadata or {}), + _FINAL_TURN_COMPLETE_METADATA_KEY: True, + } + await self.event_queue.put(event) + if not event.partial: + await self.runner.session_service.append_event( + session=self.session, event=event ) - for fc in function_calls: - response_content = types.FunctionResponse( - name=fc.name, - id=fc.id, - response={"error": str(e)}, - ) - tool_content = types.Content( - role="tool", - parts=[types.Part(function_response=response_content)], - ) - self.live_request_queue.send_content(tool_content) - if event.turn_complete and event.author != _USER_AUTHOR: - if not in_function_call_loop: - self.turn_complete_event.set() - else: - in_function_call_loop = False + # Track the "function-call → tool-response → final answer" + # interlude so the first `turn_complete` (which signals "I've + # issued my tool call") doesn't release the main task — the + # turn isn't really done until the post-tool reply arrives. + if event.get_function_calls(): + in_function_call_loop = True + + # The flow handles the whole tool loop by itself: + # `_receive_from_model` runs `handle_function_calls_live` and + # yields the function_response event, and `run_live` + # (base_llm_flow.py, "send back the function response to + # models") forwards it into the live request queue. Executing + # the tool or forwarding the response here would do either a + # second time — the model would receive the tool result twice + # and answer the same question twice. + + if event.turn_complete and event.author != _USER_AUTHOR: + if not in_function_call_loop: + self.turn_complete_event.set() + else: + in_function_call_loop = False + finally: + if token is not None: + context_api.detach(token) finally: self.live_finished.set() self.turn_complete_event.set() # Unblock any waiters @@ -263,6 +413,7 @@ async def generate_responses( agent_module_path: str, repeat_num: int = 3, agent_name: str = None, + user_simulator_config: Optional[BaseUserSimulatorConfig] = None, ) -> list[EvalCaseResponses]: """Returns evaluation responses for the given dataset and agent. @@ -273,12 +424,19 @@ async def generate_responses( usually done to remove uncertainty that a single run may bring. agent_name: The name of the agent that should be evaluated. This is usually the sub-agent. + user_simulator_config: Optional configuration for the user simulator. + Only relevant for eval cases that use a `conversation_scenario` (which + are driven by `LlmBackedUserSimulator`); ignored for static + conversations. Pass an `LlmBackedUserSimulatorConfig` to override the + user-simulation model, max invocations, or custom instructions. """ results = [] for eval_case in eval_set.eval_cases: - # assume only static conversations are needed - user_simulator = UserSimulatorProvider().provide(eval_case) + user_simulator = UserSimulatorProvider( + user_simulator_config=user_simulator_config + ).provide(eval_case) + responses = [] for _ in range(repeat_num): response_invocations = await EvaluationGenerator._process_query( @@ -321,6 +479,11 @@ def generate_responses_from_session(session_path, eval_dataset): return results + @staticmethod + def _is_live_api_model(name: str) -> bool: + """Detects Gemini Live API models by name (e.g. `gemini-live-...`).""" + return "live" in name + @staticmethod async def _process_query( module_name: str, @@ -340,12 +503,23 @@ async def _process_query( agent_to_evaluate = root_agent.find_agent(agent_name) assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found." - return await EvaluationGenerator._generate_inferences_from_root_agent( - agent_to_evaluate, - user_simulator=user_simulator, - reset_func=reset_func, - initial_session=initial_session, - ) + if EvaluationGenerator._is_live_api_model(agent_to_evaluate.model): + return ( + await EvaluationGenerator._generate_inferences_from_root_agent_live( + root_agent=agent_to_evaluate, + user_simulator=user_simulator, + reset_func=reset_func, + initial_session=initial_session, + ) + ) + + else: + return await EvaluationGenerator._generate_inferences_from_root_agent( + agent_to_evaluate, + user_simulator=user_simulator, + reset_func=reset_func, + initial_session=initial_session, + ) @staticmethod async def _generate_inferences_for_single_user_invocation( @@ -405,26 +579,51 @@ async def _generate_inferences_for_single_user_invocation_live( ) raise - while not event_queue.empty(): - event = await event_queue.get() - if event.invocation_id == current_invocation_id: - yield event - # Emit a synthetic text event for each transcription, preserving - # the order in which events are received. - if ( - event.author != _USER_AUTHOR - and event.output_transcription - and event.output_transcription.text - and event.partial - ): - yield Event( - content=Content( - role="model", - parts=[types.Part(text=event.output_transcription.text)], - ), - author=agent_name, - invocation_id=current_invocation_id, - ) + # `turn_complete` only ends generation; transcription can trail it by + # seconds. Stop once the turn-final `turn_complete` has been seen and + # the transcription closed with a finished=True consolidation (the + # common case — no idle wait); otherwise fall back to the grace window. + saw_final_turn_complete = False + transcription_settled = False + while True: + try: + event = await asyncio.wait_for( + event_queue.get(), timeout=_TRANSCRIPTION_TAIL_GRACE_SECONDS + ) + except asyncio.TimeoutError: + break + if event.invocation_id != current_invocation_id: + logger.debug( + "Dropped straggler event from invocation %s while draining %s.", + event.invocation_id, + current_invocation_id, + ) + continue + yield event + # Emit one synthetic text event per utterance, sourced from the + # connection's consolidated (non-partial) transcription. + if ( + event.author != _USER_AUTHOR + and event.output_transcription + and event.output_transcription.text + and not event.partial + ): + yield Event( + content=Content( + role="model", + parts=[types.Part(text=event.output_transcription.text)], + ), + author=agent_name, + invocation_id=current_invocation_id, + ) + if event.custom_metadata and event.custom_metadata.get( + _FINAL_TURN_COMPLETE_METADATA_KEY + ): + saw_final_turn_complete = True + if event.output_transcription: + transcription_settled = not event.partial + if saw_final_turn_complete and transcription_settled: + break @staticmethod async def _generate_inferences_from_root_agent_live( @@ -503,18 +702,50 @@ async def _generate_inferences_from_root_agent_live( logger.info("Waiting for model to complete turn %d...", turn_idx) - async for ( - event - ) in EvaluationGenerator._generate_inferences_for_single_user_invocation_live( - live_request_queue=live_session.live_request_queue, - event_queue=live_session.event_queue, - user_message=next_user_message.user_message, - current_invocation_id=live_session.current_invocation_id, - turn_complete_event=live_session.turn_complete_event, - live_timeout_seconds=live_timeout_seconds, - agent_name=runner.agent.name, - ): - events.append(event) + # Open a per-turn root span. By using an empty parent context it + # becomes the root of its own trace, which session-grouping + # tracing pipelines (e.g. MLflow Sessions) treat as one chat-turn + # entry. Tool calls executed during the turn are re-parented + # under this span by `_LiveSession._consume_events` via + # `current_turn_context`, so the whole turn lives in one trace. + live_turn_span = _telemetry.tracer.start_span( + "live_turn", + context=context_api.Context(), + start_time=time.time_ns(), + ) + live_turn_span.set_attribute("gen_ai.conversation.id", session_id) + live_turn_span.set_attribute("gen_ai.agent.name", runner.agent.name) + live_turn_span.set_attribute("gen_ai.operation.name", "chat") + live_session.current_turn_context = set_span_in_context( + live_turn_span + ) + try: + async for ( + event + ) in EvaluationGenerator._generate_inferences_for_single_user_invocation_live( + live_request_queue=live_session.live_request_queue, + event_queue=live_session.event_queue, + user_message=next_user_message.user_message, + current_invocation_id=live_session.current_invocation_id, + turn_complete_event=live_session.turn_complete_event, + live_timeout_seconds=live_timeout_seconds, + agent_name=runner.agent.name, + ): + events.append(event) + + # The synthetic text events for the eval trajectory are emitted + # per transcription chunk (in arrival order) by + # `_generate_inferences_for_single_user_invocation_live`; the + # span gets per-utterance attributes and timestamped events. + _record_live_turn_telemetry( + live_turn_span, + events, + live_session.current_invocation_id, + next_user_message.user_message, + ) + finally: + live_session.current_turn_context = None + live_turn_span.end() if live_session.live_finished.is_set(): logger.info("Live session finished signal detected.") diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 61fd8bbdf6..cfbadb1a7f 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -389,6 +389,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: live_session_id=live_session_id, ) if message.server_content.output_transcription.finished: + logger.info( + 'live-transcription: finished=True SERVER-SENT, text=%r', + self._output_transcription_text, + ) yield LlmResponse( output_transcription=types.Transcription( text=self._output_transcription_text, @@ -419,6 +423,15 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) self._input_transcription_text = '' if self._output_transcription_text: + logger.info( + 'live-transcription: finished=True FABRICATED' + ' (interrupted=%s turn_complete=%s generation_complete=%s),' + ' text=%r', + message.server_content.interrupted, + message.server_content.turn_complete, + message.server_content.generation_complete, + self._output_transcription_text, + ) yield LlmResponse( output_transcription=types.Transcription( text=self._output_transcription_text, diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index 05ab25cc72..ff188179ca 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -15,12 +15,17 @@ from __future__ import annotations import asyncio +import json from google.adk.evaluation.app_details import AgentDetails from google.adk.evaluation.app_details import AppDetails +from google.adk.evaluation.eval_case import EvalCase +from google.adk.evaluation.eval_set import EvalSet from google.adk.evaluation.evaluation_generator import _LiveSession +from google.adk.evaluation.evaluation_generator import _record_live_turn_telemetry from google.adk.evaluation.evaluation_generator import EvaluationGenerator from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig from google.adk.evaluation.simulation.user_simulator import NextUserMessage from google.adk.evaluation.simulation.user_simulator import Status as UserSimulatorStatus from google.adk.evaluation.simulation.user_simulator import UserSimulator @@ -463,14 +468,22 @@ async def test_generate_inferences_live_with_synthetic_events(self, mocker): user_content = types.Content(parts=[types.Part(text="User query")]) invocation_id = "inv1" - transcription = types.Transcription(text="Partial transcription") partial_event = Event( author="agent", content=types.Content(parts=[]), invocation_id=invocation_id, - output_transcription=transcription, + output_transcription=types.Transcription(text="Partial "), partial=True, ) + consolidated_event = Event( + author="agent", + content=types.Content(parts=[]), + invocation_id=invocation_id, + output_transcription=types.Transcription( + text="Partial transcription", finished=True + ), + partial=False, + ) gen = EvaluationGenerator._generate_inferences_for_single_user_invocation_live( live_request_queue=mock_live_request_queue, @@ -491,24 +504,134 @@ async def test_generate_inferences_live_with_synthetic_events(self, mocker): # Mock turn_complete_event.wait to avoid blocking turn_complete_event.wait = mocker.AsyncMock() - # Put the partial event in the queue await event_queue.put(partial_event) + await event_queue.put(consolidated_event) - # Now advance + # Partial events are passed through without a synthetic companion. second_event = await gen.__anext__() assert second_event == partial_event - # Next should be the synthetic event third_event = await gen.__anext__() - assert third_event.author == "custom_agent_name" - assert third_event.invocation_id == invocation_id - assert third_event.content.role == "model" - assert third_event.content.parts[0].text == "Partial transcription" + assert third_event == consolidated_event + + # The consolidated transcription yields one synthetic text event. + fourth_event = await gen.__anext__() + assert fourth_event.author == "custom_agent_name" + assert fourth_event.invocation_id == invocation_id + assert fourth_event.content.role == "model" + assert fourth_event.content.parts[0].text == "Partial transcription" # The generator should be exhausted now with pytest.raises(StopAsyncIteration): await gen.__anext__() + @pytest.mark.asyncio + async def test_generate_inferences_live_waits_for_transcription_tail( + self, mocker + ): + """The drain captures an ASR tail that trails turn_complete by ~1s.""" + mock_live_request_queue = mocker.MagicMock() + event_queue = asyncio.Queue() + turn_complete_event = asyncio.Event() + invocation_id = "inv1" + + flushed_fragment = Event( + author="agent", + invocation_id=invocation_id, + partial=False, + output_transcription=types.Transcription( + text="I can provide weather for", finished=True + ), + ) + tail_chunk = Event( + author="agent", + invocation_id=invocation_id, + partial=True, + output_transcription=types.Transcription( + text=" London and Berlin.", finished=False + ), + ) + + await event_queue.put(flushed_fragment) + turn_complete_event.set() + + async def put_tail_late(): + await asyncio.sleep(0.8) + await event_queue.put(tail_chunk) + + tail_task = asyncio.create_task(put_tail_late()) + + gen = EvaluationGenerator._generate_inferences_for_single_user_invocation_live( + live_request_queue=mock_live_request_queue, + event_queue=event_queue, + user_message=types.Content(parts=[types.Part(text="Which cities?")]), + current_invocation_id=invocation_id, + turn_complete_event=turn_complete_event, + live_timeout_seconds=300, + ) + + async def collect(): + return [event async for event in gen] + + events = await collect() + await tail_task + + assert tail_chunk in events + + @pytest.mark.asyncio + async def test_generate_inferences_live_stops_without_idle_wait(self, mocker): + """A server-sent finished + final turn_complete ends the drain at once.""" + mock_live_request_queue = mocker.MagicMock() + event_queue = asyncio.Queue() + turn_complete_event = asyncio.Event() + invocation_id = "inv1" + + chunk = Event( + author="agent", + invocation_id=invocation_id, + partial=True, + output_transcription=types.Transcription( + text="It is sunny.", finished=False + ), + ) + server_finished = Event( + author="agent", + invocation_id=invocation_id, + partial=False, + output_transcription=types.Transcription( + text="It is sunny.", finished=True + ), + ) + final_turn_complete = Event( + author="agent", + invocation_id=invocation_id, + turn_complete=True, + custom_metadata={"final_turn_complete": True}, + ) + + await event_queue.put(chunk) + await event_queue.put(server_finished) + await event_queue.put(final_turn_complete) + turn_complete_event.set() + + gen = EvaluationGenerator._generate_inferences_for_single_user_invocation_live( + live_request_queue=mock_live_request_queue, + event_queue=event_queue, + user_message=types.Content(parts=[types.Part(text="Weather?")]), + current_invocation_id=invocation_id, + turn_complete_event=turn_complete_event, + live_timeout_seconds=300, + ) + + async def collect(): + return [event async for event in gen] + + # Must finish well under the grace window — no idle wait. + events = await asyncio.wait_for(collect(), timeout=1.0) + + assert server_finished in events + assert final_turn_complete in events + @pytest.fixture def mock_runner(mocker): @@ -713,7 +836,7 @@ async def mock_preprocess_async(invocation_context, llm_request): mock_flow._preprocess_async = mock_preprocess_async mock_agent._llm_flow = mock_flow - # Mock run_live stream yielding one event + # Mock the _run_live_impl stream (bypassing run_live) yielding one event mock_event = Event( author="agent", content=types.Content(parts=[types.Part(text="Hello")]), @@ -723,7 +846,10 @@ async def mock_preprocess_async(invocation_context, llm_request): async def mock_run_live(*args, **kwargs): yield mock_event - mock_agent.run_live.return_value = mock_run_live() + mock_agent._run_live_impl.return_value = mock_run_live() + mock_agent._handle_before_agent_callback = mocker.AsyncMock( + return_value=None + ) # Mock plugin_manager on invocation context mock_plugin_manager = mocker.MagicMock() @@ -733,6 +859,9 @@ async def mock_run_live(*args, **kwargs): mock_plugin_manager ) mock_runner._new_invocation_context_for_live.return_value.agent = mock_agent + mock_runner._new_invocation_context_for_live.return_value.end_invocation = ( + False + ) # 2. Instantiate and enter _LiveSession live_session = _LiveSession( @@ -803,7 +932,7 @@ async def mock_preprocess_async(invocation_context, llm_request): mock_flow._preprocess_async = mock_preprocess_async mock_agent._llm_flow = mock_flow - # Mock run_live stream yielding one event + # Mock the _run_live_impl stream (bypassing run_live) yielding one event mock_event = Event( author="agent", content=types.Content(parts=[types.Part(text="Hello")]), @@ -813,7 +942,10 @@ async def mock_preprocess_async(invocation_context, llm_request): async def mock_run_live(*args, **kwargs): yield mock_event - mock_agent.run_live.return_value = mock_run_live() + mock_agent._run_live_impl.return_value = mock_run_live() + mock_agent._handle_before_agent_callback = mocker.AsyncMock( + return_value=None + ) # Mock plugin_manager on invocation context mock_plugin_manager = mocker.MagicMock() @@ -823,6 +955,9 @@ async def mock_run_live(*args, **kwargs): mock_plugin_manager ) mock_runner._new_invocation_context_for_live.return_value.agent = mock_agent + mock_runner._new_invocation_context_for_live.return_value.end_invocation = ( + False + ) # 2. Instantiate and enter _LiveSession live_session = _LiveSession( @@ -860,3 +995,400 @@ async def mock_run_live(*args, **kwargs): ) assert isinstance(called_after_args.kwargs["llm_response"], Event) assert called_after_args.kwargs["llm_response"] == mock_event + + +class TestLiveSessionFunctionResponses: + """_LiveSession must not re-send tool responses to the live model.""" + + @pytest.mark.asyncio + async def test_consume_events_does_not_resend_function_responses( + self, mocker + ): + """The flow's run_live already forwards tool responses to the model. + + Forwarding them again from _consume_events makes the model receive the + tool result twice and answer the same question twice. + """ + from google.adk.agents.llm_agent import Agent + + mock_runner = mocker.MagicMock() + mock_runner.session_service.append_event = mocker.AsyncMock() + mock_agent = mocker.MagicMock(spec=Agent) + mock_runner.agent = mock_agent + mock_runner._find_agent_to_run.return_value = mock_agent + mock_agent.name = "test_agent" + + async def mock_preprocess_async(invocation_context, llm_request): + return + yield + + mock_flow = mocker.MagicMock() + mock_flow._preprocess_async = mock_preprocess_async + mock_agent._llm_flow = mock_flow + mock_agent._handle_before_agent_callback = mocker.AsyncMock( + return_value=None + ) + + function_response_event = Event( + author="agent", + invocation_id="inv1", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name="get_temperature", response={"temp": 8.5} + ) + ) + ] + ), + ) + + async def mock_run_live(*args, **kwargs): + yield function_response_event + + mock_agent._run_live_impl.return_value = mock_run_live() + mock_plugin_manager = mocker.MagicMock() + mock_plugin_manager.run_before_model_callback = mocker.AsyncMock() + mock_plugin_manager.run_after_model_callback = mocker.AsyncMock() + mock_runner._new_invocation_context_for_live.return_value.plugin_manager = ( + mock_plugin_manager + ) + mock_runner._new_invocation_context_for_live.return_value.agent = mock_agent + mock_runner._new_invocation_context_for_live.return_value.end_invocation = ( + False + ) + + live_session = _LiveSession( + runner=mock_runner, + session=mocker.MagicMock(), + user_id="test_user", + session_id="test_session", + ) + send_content = mocker.patch.object( + live_session.live_request_queue, "send_content" + ) + + await live_session._consume_events() + + send_content.assert_not_called() + + @pytest.mark.asyncio + async def test_consume_events_tags_only_final_turn_complete(self, mocker): + """The post-tool turn_complete gets tagged; the intermediate one not.""" + from google.adk.agents.llm_agent import Agent + + mock_runner = mocker.MagicMock() + mock_runner.session_service.append_event = mocker.AsyncMock() + mock_agent = mocker.MagicMock(spec=Agent) + mock_runner.agent = mock_agent + mock_runner._find_agent_to_run.return_value = mock_agent + mock_agent.name = "test_agent" + + async def mock_preprocess_async(invocation_context, llm_request): + return + yield + + mock_flow = mocker.MagicMock() + mock_flow._preprocess_async = mock_preprocess_async + mock_agent._llm_flow = mock_flow + mock_agent._handle_before_agent_callback = mocker.AsyncMock( + return_value=None + ) + + function_call_event = Event( + author="agent", + invocation_id="inv1", + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + name="get_temperature", args={} + ) + ) + ] + ), + ) + intermediate_turn_complete = Event( + author="agent", invocation_id="inv1", turn_complete=True + ) + final_turn_complete = Event( + author="agent", invocation_id="inv1", turn_complete=True + ) + + async def mock_run_live(*args, **kwargs): + yield function_call_event + yield intermediate_turn_complete + yield final_turn_complete + + mock_agent._run_live_impl.return_value = mock_run_live() + mock_plugin_manager = mocker.MagicMock() + mock_plugin_manager.run_before_model_callback = mocker.AsyncMock() + mock_plugin_manager.run_after_model_callback = mocker.AsyncMock() + mock_runner._new_invocation_context_for_live.return_value.plugin_manager = ( + mock_plugin_manager + ) + mock_runner._new_invocation_context_for_live.return_value.agent = mock_agent + mock_runner._new_invocation_context_for_live.return_value.end_invocation = ( + False + ) + + live_session = _LiveSession( + runner=mock_runner, + session=mocker.MagicMock(), + user_id="test_user", + session_id="test_session", + ) + + await live_session._consume_events() + + assert not (intermediate_turn_complete.custom_metadata or {}).get( + "final_turn_complete" + ) + assert final_turn_complete.custom_metadata == {"final_turn_complete": True} + + +class TestRecordLiveTurnTelemetry: + """Test cases for _record_live_turn_telemetry.""" + + def test_record_live_turn_telemetry_splits_utterances_at_tool_calls( + self, mocker + ): + """Speech before and after a tool call is recorded in real order.""" + span = mocker.MagicMock() + events = [ + Event( + author="agent", + invocation_id="inv1", + partial=False, + output_transcription=types.Transcription( + text="Let me check", finished=True + ), + timestamp=1.0, + ), + Event( + author="agent", + invocation_id="inv1", + content=types.Content( + parts=[ + types.Part( + function_call=types.FunctionCall( + name="get_temperature", args={"city": "berlin"} + ) + ) + ] + ), + timestamp=2.0, + ), + Event( + author="agent", + invocation_id="inv1", + content=types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name="get_temperature", response={"temp": 8.5} + ) + ) + ] + ), + timestamp=3.0, + ), + Event( + author="agent", + invocation_id="inv1", + partial=False, + output_transcription=types.Transcription( + text="The temperature in Berlin is 8.5 degrees Celsius.", + finished=True, + ), + timestamp=4.0, + ), + ] + + _record_live_turn_telemetry( + span, + events, + "inv1", + types.Content( + parts=[types.Part(text="What's the temperature in Berlin?")] + ), + ) + + span_event_names = [c.args[0] for c in span.add_event.call_args_list] + assert span_event_names == [ + "model_utterance", + "tool_call", + "tool_response", + "model_utterance", + ] + + attributes = { + c.args[0]: c.args[1] for c in span.set_attribute.call_args_list + } + assert json.loads(attributes["gcp.vertex.agent.llm_response"]) == { + "content": { + "role": "model", + "parts": [ + {"text": "Let me check"}, + {"text": "The temperature in Berlin is 8.5 degrees Celsius."}, + ], + }, + } + + def test_record_live_turn_telemetry_ignores_audio_and_partial_events( + self, mocker + ): + """Only consolidated transcriptions produce utterances.""" + span = mocker.MagicMock() + events = [ + Event( + author="agent", + invocation_id="inv1", + partial=True, + output_transcription=types.Transcription(text="It is"), + timestamp=1.0, + ), + Event( + author="agent", + invocation_id="inv1", + content=types.Content( + parts=[ + types.Part( + inline_data=types.Blob( + data=b"pcm", mime_type="audio/pcm" + ) + ) + ] + ), + timestamp=1.5, + ), + Event( + author="agent", + invocation_id="inv1", + partial=False, + output_transcription=types.Transcription( + text="It is sunny.", finished=True + ), + timestamp=2.0, + ), + ] + + _record_live_turn_telemetry(span, events, "inv1", None) + + span_event_names = [c.args[0] for c in span.add_event.call_args_list] + assert span_event_names == ["model_utterance"] + assert ( + span.add_event.call_args_list[0].kwargs["attributes"]["text"] + == "It is sunny." + ) + + def test_record_live_turn_telemetry_aggregates_token_usage(self, mocker): + """Usage metadata from all of the turn's events is summed onto the span.""" + span = mocker.MagicMock() + events = [ + Event( + author="agent", + invocation_id="inv1", + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=100, + candidates_token_count=20, + ), + ), + Event( + author="agent", + invocation_id="inv1", + usage_metadata=types.GenerateContentResponseUsageMetadata( + prompt_token_count=150, + candidates_token_count=30, + thoughts_token_count=5, + cached_content_token_count=10, + ), + ), + ] + + _record_live_turn_telemetry(span, events, "inv1", None) + + span.set_attributes.assert_called_once_with({ + "gen_ai.usage.input_tokens": 250, + "gen_ai.usage.output_tokens": 55, + "gen_ai.usage.cache_read.input_tokens": 10, + "gen_ai.usage.reasoning.output_tokens": 5, + }) + + def test_record_live_turn_telemetry_ignores_other_invocations(self, mocker): + """Events from other invocations do not leak into the turn's telemetry.""" + span = mocker.MagicMock() + events = [ + Event( + author="agent", + invocation_id="other_inv", + partial=False, + output_transcription=types.Transcription( + text="Old turn text.", finished=True + ), + timestamp=1.0, + ), + Event( + author="agent", + invocation_id="inv1", + partial=False, + output_transcription=types.Transcription( + text="New turn text.", finished=True + ), + timestamp=2.0, + ), + ] + + _record_live_turn_telemetry(span, events, "inv1", None) + + attributes = { + c.args[0]: c.args[1] for c in span.set_attribute.call_args_list + } + assert json.loads(attributes["gcp.vertex.agent.llm_response"]) == { + "content": { + "role": "model", + "parts": [{"text": "New turn text."}], + }, + } + + +class TestGenerateResponses: + """Test cases for EvaluationGenerator.generate_responses method.""" + + @pytest.mark.asyncio + async def test_generate_responses_forwards_llm_backed_user_simulator_config( + self, mocker + ): + """Tests that an LlmBackedUserSimulatorConfig is forwarded to the provider verbatim.""" + mock_provider_cls = mocker.patch( + "google.adk.evaluation.evaluation_generator.UserSimulatorProvider" + ) + mocker.patch( + "google.adk.evaluation.evaluation_generator.EvaluationGenerator._process_query", + new_callable=mocker.AsyncMock, + return_value=[], + ) + + user_simulator_config = LlmBackedUserSimulatorConfig( + model="test-model", + max_allowed_invocations=5, + ) + eval_set = EvalSet( + eval_set_id="test_set", + eval_cases=[EvalCase(eval_id="case_0", conversation=[])], + ) + + await EvaluationGenerator.generate_responses( + eval_set=eval_set, + agent_module_path="some.agent.module", + repeat_num=1, + user_simulator_config=user_simulator_config, + ) + + mock_provider_cls.assert_called_once_with( + user_simulator_config=user_simulator_config + ) + assert ( + mock_provider_cls.call_args.kwargs["user_simulator_config"] + is user_simulator_config + )