Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,7 @@ async def _postprocess_live(
and not llm_response.usage_metadata
and not llm_response.live_session_resumption_update
and not llm_response.grounding_metadata
and not llm_response.live_setup_complete
):
return

Expand All @@ -1146,6 +1147,14 @@ async def _postprocess_live(
yield model_response_event
return

# Handle setup completion signal from the Live model.
if llm_response.live_setup_complete:
model_response_event.live_setup_complete = (
llm_response.live_setup_complete
)
yield model_response_event
return

# Handle transcription events ONCE per llm_response, outside the event loop
if llm_response.input_transcription:
model_response_event.input_transcription = (
Expand Down
9 changes: 9 additions & 0 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,15 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
model_version=self._model_version,
live_session_id=live_session_id,
)
if message.setup_complete:
logger.debug(
'Received setup complete message: %s', message.setup_complete
)
yield LlmResponse(
live_setup_complete=message.setup_complete,
model_version=self._model_version,
live_session_id=live_session_id,
)

if tool_call_parts:
logger.debug('Exited loop with pending tool_call_parts')
Expand Down
5 changes: 5 additions & 0 deletions src/google/adk/models/llm_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class LlmResponse(BaseModel):
custom_metadata: The custom metadata of the LlmResponse.
input_transcription: Audio transcription of user input.
output_transcription: Audio transcription of model output.
live_setup_complete: The setup completion signal from the Live model,
indicating the model is ready to receive user input.
avg_logprobs: Average log probability of the generated tokens.
logprobs_result: Detailed log probabilities for chosen and top candidate tokens.
"""
Expand Down Expand Up @@ -123,6 +125,9 @@ class LlmResponse(BaseModel):
go_away: Optional[types.LiveServerGoAway] = None
"""The GoAway signal from the Live model."""

live_setup_complete: Optional[types.LiveServerSetupComplete] = None
"""The setup completion signal from the Live model."""

input_transcription: Optional[types.Transcription] = None
"""Audio transcription of user input."""

Expand Down
62 changes: 62 additions & 0 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ async def test_receive_transcript_finished(
msg.usage_metadata = None
msg.session_resumption_update = None
msg.go_away = None
msg.setup_complete = None
msg.server_content.model_turn = None
msg.server_content.interrupted = False
msg.server_content.turn_complete = False
Expand Down Expand Up @@ -232,6 +233,7 @@ async def test_receive_usage_metadata_and_server_content(
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None

async def mock_receive_generator():
yield mock_message
Expand Down Expand Up @@ -299,6 +301,7 @@ async def test_receive_usage_metadata_remaps_output_tokens(
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None

async def mock_receive_generator():
yield mock_message
Expand Down Expand Up @@ -341,6 +344,7 @@ async def test_receive_populates_live_session_id(
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None

mock_server_content = mock.Mock()
mock_server_content.model_turn = types.Content(
Expand Down Expand Up @@ -394,6 +398,7 @@ async def test_receive_transcript_finished_on_interrupt(
message1.tool_call = None
message1.session_resumption_update = None
message1.go_away = None
message1.setup_complete = None

message2 = mock.Mock()
message2.usage_metadata = None
Expand All @@ -410,6 +415,7 @@ async def test_receive_transcript_finished_on_interrupt(
message2.tool_call = None
message2.session_resumption_update = None
message2.go_away = None
message2.setup_complete = None

message3 = mock.Mock()
message3.usage_metadata = None
Expand All @@ -424,6 +430,7 @@ async def test_receive_transcript_finished_on_interrupt(
message3.tool_call = None
message3.session_resumption_update = None
message3.go_away = None
message3.setup_complete = None

async def mock_receive_generator():
yield message1
Expand Down Expand Up @@ -480,6 +487,7 @@ async def test_receive_transcript_finished_on_generation_complete(
message1.tool_call = None
message1.session_resumption_update = None
message1.go_away = None
message1.setup_complete = None

message2 = mock.Mock()
message2.usage_metadata = None
Expand All @@ -496,6 +504,7 @@ async def test_receive_transcript_finished_on_generation_complete(
message2.tool_call = None
message2.session_resumption_update = None
message2.go_away = None
message2.setup_complete = None

message3 = mock.Mock()
message3.usage_metadata = None
Expand All @@ -510,6 +519,7 @@ async def test_receive_transcript_finished_on_generation_complete(
message3.tool_call = None
message3.session_resumption_update = None
message3.go_away = None
message3.setup_complete = None

async def mock_receive_generator():
yield message1
Expand Down Expand Up @@ -565,6 +575,7 @@ async def test_receive_transcript_finished_on_turn_complete(
message1.tool_call = None
message1.session_resumption_update = None
message1.go_away = None
message1.setup_complete = None

message2 = mock.Mock()
message2.usage_metadata = None
Expand All @@ -581,6 +592,7 @@ async def test_receive_transcript_finished_on_turn_complete(
message2.tool_call = None
message2.session_resumption_update = None
message2.go_away = None
message2.setup_complete = None

message3 = mock.Mock()
message3.usage_metadata = None
Expand All @@ -595,6 +607,7 @@ async def test_receive_transcript_finished_on_turn_complete(
message3.tool_call = None
message3.session_resumption_update = None
message3.go_away = None
message3.setup_complete = None

async def mock_receive_generator():
yield message1
Expand Down Expand Up @@ -643,6 +656,7 @@ async def test_receive_handles_input_transcription_fragments(
message1.tool_call = None
message1.session_resumption_update = None
message1.go_away = None
message1.setup_complete = None

message2 = mock.Mock()
message2.usage_metadata = None
Expand All @@ -659,6 +673,7 @@ async def test_receive_handles_input_transcription_fragments(
message2.tool_call = None
message2.session_resumption_update = None
message2.go_away = None
message2.setup_complete = None

message3 = mock.Mock()
message3.usage_metadata = None
Expand All @@ -675,6 +690,7 @@ async def test_receive_handles_input_transcription_fragments(
message3.tool_call = None
message3.session_resumption_update = None
message3.go_away = None
message3.setup_complete = None

async def mock_receive_generator():
yield message1
Expand Down Expand Up @@ -718,6 +734,7 @@ async def test_receive_handles_output_transcription_fragments(
message1.tool_call = None
message1.session_resumption_update = None
message1.go_away = None
message1.setup_complete = None

message2 = mock.Mock()
message2.usage_metadata = None
Expand All @@ -734,6 +751,7 @@ async def test_receive_handles_output_transcription_fragments(
message2.tool_call = None
message2.session_resumption_update = None
message2.go_away = None
message2.setup_complete = None

message3 = mock.Mock()
message3.usage_metadata = None
Expand All @@ -750,6 +768,7 @@ async def test_receive_handles_output_transcription_fragments(
message3.tool_call = None
message3.session_resumption_update = None
message3.go_away = None
message3.setup_complete = None

async def mock_receive_generator():
yield message1
Expand Down Expand Up @@ -1028,6 +1047,7 @@ async def test_receive_grounding_metadata_standalone(
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None

async def mock_receive_generator():
yield mock_message
Expand Down Expand Up @@ -1073,6 +1093,7 @@ async def test_receive_grounding_metadata_with_content(
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None

async def mock_receive_generator():
yield mock_message
Expand Down Expand Up @@ -1106,6 +1127,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio(
mock_tool_call_msg.server_content = None
mock_tool_call_msg.session_resumption_update = None
mock_tool_call_msg.go_away = None
mock_tool_call_msg.setup_complete = None

function_call = types.FunctionCall(
name='enterprise_web_search',
Expand Down Expand Up @@ -1146,6 +1168,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio(
mock_metadata_msg.tool_call = None
mock_metadata_msg.session_resumption_update = None
mock_metadata_msg.go_away = None
mock_metadata_msg.setup_complete = None

# 3. Message with turn_complete
mock_turn_complete_content = mock.create_autospec(
Expand All @@ -1167,6 +1190,7 @@ async def test_receive_tool_call_and_grounding_metadata_with_native_audio(
mock_turn_complete_msg.tool_call = None
mock_turn_complete_msg.session_resumption_update = None
mock_turn_complete_msg.go_away = None
mock_turn_complete_msg.setup_complete = None

async def mock_receive_generator():
yield mock_tool_call_msg
Expand Down Expand Up @@ -1217,6 +1241,7 @@ async def test_receive_multiple_tool_calls_buffered_until_turn_complete(
mock_tool_call_msg1.server_content = None
mock_tool_call_msg1.session_resumption_update = None
mock_tool_call_msg1.go_away = None
mock_tool_call_msg1.setup_complete = None

function_call1 = types.FunctionCall(
name='tool_1',
Expand All @@ -1236,6 +1261,7 @@ async def test_receive_multiple_tool_calls_buffered_until_turn_complete(
mock_tool_call_msg2.server_content = None
mock_tool_call_msg2.session_resumption_update = None
mock_tool_call_msg2.go_away = None
mock_tool_call_msg2.setup_complete = None

function_call2 = types.FunctionCall(
name='tool_2',
Expand Down Expand Up @@ -1266,6 +1292,7 @@ async def test_receive_multiple_tool_calls_buffered_until_turn_complete(
mock_turn_complete_msg.tool_call = None
mock_turn_complete_msg.session_resumption_update = None
mock_turn_complete_msg.go_away = None
mock_turn_complete_msg.setup_complete = None

async def mock_receive_generator():
yield mock_tool_call_msg1
Expand Down Expand Up @@ -1311,6 +1338,7 @@ async def test_receive_tool_calls_yielded_immediately_for_gemini_3_1(
mock_tool_call_msg.server_content = None
mock_tool_call_msg.session_resumption_update = None
mock_tool_call_msg.go_away = None
mock_tool_call_msg.setup_complete = None

function_call = types.FunctionCall(
name='test_tool',
Expand Down Expand Up @@ -1346,6 +1374,7 @@ async def test_receive_go_away(gemini_connection, mock_gemini_session):
mock_msg.tool_call = None
mock_msg.session_resumption_update = None
mock_msg.go_away = mock_go_away
mock_msg.setup_complete = None

async def mock_receive_generator():
yield mock_msg
Expand All @@ -1359,6 +1388,30 @@ async def mock_receive_generator():
assert responses[0].go_away == mock_go_away


@pytest.mark.asyncio
async def test_receive_setup_complete(gemini_connection, mock_gemini_session):
"""Test receive yields setup_complete message."""
mock_setup_complete = types.LiveServerSetupComplete()
mock_msg = mock.MagicMock()
mock_msg.usage_metadata = None
mock_msg.server_content = None
mock_msg.tool_call = None
mock_msg.session_resumption_update = None
mock_msg.go_away = None
mock_msg.setup_complete = mock_setup_complete

async def mock_receive_generator():
yield mock_msg

receive_mock = mock.Mock(return_value=mock_receive_generator())
mock_gemini_session.receive = receive_mock

responses = [resp async for resp in gemini_connection.receive()]

assert len(responses) == 1
assert responses[0].live_setup_complete == mock_setup_complete


@pytest.mark.asyncio
async def test_receive_aggregates_thoughts_separately(
gemini_connection, mock_gemini_session
Expand Down Expand Up @@ -1465,6 +1518,7 @@ async def test_receive_video_content(gemini_connection, mock_gemini_session):
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None

async def mock_receive_generator():
yield mock_message
Expand Down Expand Up @@ -1499,6 +1553,7 @@ def make_msg(
tool_call=None,
session_resumption_update=None,
go_away=None,
setup_complete=None,
)
msg.server_content = mock.Mock(
interrupted=False,
Expand Down Expand Up @@ -1576,6 +1631,7 @@ async def test_receive_populates_turn_complete_reason(
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None

async def mock_receive_generator():
yield mock_message
Expand Down Expand Up @@ -1617,6 +1673,7 @@ async def test_receive_populates_turn_complete_reason_standalone_grounding(
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None

async def mock_receive_generator():
yield mock_message
Expand Down Expand Up @@ -1663,6 +1720,7 @@ async def test_receive_populates_turn_complete_reason_with_content(
mock_message.tool_call = None
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None

async def mock_receive_generator():
yield mock_message
Expand Down Expand Up @@ -1699,6 +1757,7 @@ def make_msg(
msg.tool_call = tool_call
msg.session_resumption_update = None
msg.go_away = None
msg.setup_complete = None
msg.server_content = mock.Mock()
msg.server_content.interrupted = False
msg.server_content.input_transcription = None
Expand Down Expand Up @@ -1770,6 +1829,7 @@ def make_msg(text: str | None = None, tc: bool = False) -> mock.Mock:
msg.tool_call = None
msg.session_resumption_update = None
msg.go_away = None
msg.setup_complete = None
msg.server_content = mock.Mock()
msg.server_content.interrupted = False
msg.server_content.input_transcription = None
Expand Down Expand Up @@ -1827,6 +1887,7 @@ def make_msg(
msg.tool_call = None
msg.session_resumption_update = None
msg.go_away = None
msg.setup_complete = None
msg.server_content = mock.Mock()
msg.server_content.interrupted = False
msg.server_content.input_transcription = (
Expand Down Expand Up @@ -1898,6 +1959,7 @@ def _create_mock_receive_message(
mock_message.tool_call = tool_call
mock_message.session_resumption_update = None
mock_message.go_away = None
mock_message.setup_complete = None
return mock_message


Expand Down
Loading
Loading