diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index 8fca4b39e0..0c3642fd0d 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -196,7 +196,7 @@ class RunConfig(BaseModel): speech_config: Optional[types.SpeechConfig] = None """Speech configuration for the live agent.""" - response_modalities: Optional[list[str]] = None + response_modalities: Optional[list[types.Modality]] = None """The output modalities. If not set, it's default to AUDIO.""" avatar_config: Optional[types.AvatarConfig] = None diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index ae1cc12a75..d96e3113b3 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -1554,7 +1554,7 @@ async def run_live( # Some native audio models requires the modality to be set. So we set it to # AUDIO by default. if run_config.response_modalities is None: - run_config.response_modalities = ['AUDIO'] + run_config.response_modalities = [types.Modality.AUDIO] if session is None and (user_id is None or session_id is None): raise ValueError( 'Either session or user_id and session_id must be provided.' @@ -2021,7 +2021,7 @@ def _new_invocation_context_for_live( # For live multi-agents system, we need model's text transcription as # context for the transferred agent. if hasattr(self.agent, 'sub_agents') and self.agent.sub_agents: - if 'AUDIO' in run_config.response_modalities: + if types.Modality.AUDIO in run_config.response_modalities if not run_config.output_audio_transcription: run_config.output_audio_transcription = ( types.AudioTranscriptionConfig() diff --git a/tests/unittests/agents/test_run_config.py b/tests/unittests/agents/test_run_config.py index cbb82af019..3d0b56b5c6 100644 --- a/tests/unittests/agents/test_run_config.py +++ b/tests/unittests/agents/test_run_config.py @@ -15,6 +15,7 @@ import sys from unittest.mock import ANY from unittest.mock import patch +import warnings from google.adk.agents.run_config import RunConfig from google.genai import types @@ -67,6 +68,29 @@ def test_audio_transcription_configs_are_not_shared_between_instances(): ) +def test_response_modalities_accepts_enum(): + config = RunConfig(response_modalities=[types.Modality.AUDIO]) + assert config.response_modalities == [types.Modality.AUDIO] + assert isinstance(config.response_modalities[0], types.Modality) + + +def test_response_modalities_coerces_string_to_enum(): + config = RunConfig(response_modalities=["AUDIO"]) + assert config.response_modalities == [types.Modality.AUDIO] + assert isinstance(config.response_modalities[0], types.Modality) + + +def test_response_modalities_serialization_no_warning(): + config = RunConfig(response_modalities=[types.Modality.AUDIO]) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + config.model_dump() + pydantic_warnings = [ + x for x in w if "PydanticSerializationUnexpectedValue" in str(x.message) + ] + assert len(pydantic_warnings) == 0 + + def test_avatar_config_initialization(): custom_avatar = types.CustomizedAvatar( image_mime_type="image/jpeg", image_data=b"image_bytes"