Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/google/adk/agents/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class RunConfig(BaseModel):
speech_config: Optional[types.SpeechConfig] = None
"""Speech configuration for the live agent."""

response_modalities: Optional[list[str]] = None
response_modalities: Optional[list[types.Modality]] = None
"""The output modalities. If not set, it's default to AUDIO."""

avatar_config: Optional[types.AvatarConfig] = None
Expand Down
4 changes: 2 additions & 2 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,7 @@ async def run_live(
# Some native audio models requires the modality to be set. So we set it to
# AUDIO by default.
if run_config.response_modalities is None:
run_config.response_modalities = ['AUDIO']
run_config.response_modalities = [types.Modality.AUDIO]
if session is None and (user_id is None or session_id is None):
raise ValueError(
'Either session or user_id and session_id must be provided.'
Expand Down Expand Up @@ -2021,7 +2021,7 @@ def _new_invocation_context_for_live(
# For live multi-agents system, we need model's text transcription as
# context for the transferred agent.
if hasattr(self.agent, 'sub_agents') and self.agent.sub_agents:
if 'AUDIO' in run_config.response_modalities:
if types.Modality.AUDIO in run_config.response_modalities
if not run_config.output_audio_transcription:
run_config.output_audio_transcription = (
types.AudioTranscriptionConfig()
Expand Down
24 changes: 24 additions & 0 deletions tests/unittests/agents/test_run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
from unittest.mock import ANY
from unittest.mock import patch
import warnings

from google.adk.agents.run_config import RunConfig
from google.genai import types
Expand Down Expand Up @@ -67,6 +68,29 @@ def test_audio_transcription_configs_are_not_shared_between_instances():
)


def test_response_modalities_accepts_enum():
config = RunConfig(response_modalities=[types.Modality.AUDIO])
assert config.response_modalities == [types.Modality.AUDIO]
assert isinstance(config.response_modalities[0], types.Modality)


def test_response_modalities_coerces_string_to_enum():
config = RunConfig(response_modalities=["AUDIO"])
assert config.response_modalities == [types.Modality.AUDIO]
assert isinstance(config.response_modalities[0], types.Modality)


def test_response_modalities_serialization_no_warning():
config = RunConfig(response_modalities=[types.Modality.AUDIO])
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
config.model_dump()
pydantic_warnings = [
x for x in w if "PydanticSerializationUnexpectedValue" in str(x.message)
]
assert len(pydantic_warnings) == 0


def test_avatar_config_initialization():
custom_avatar = types.CustomizedAvatar(
image_mime_type="image/jpeg", image_data=b"image_bytes"
Expand Down
Loading