diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 6cee3264d..2b97ab1fc 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -16,7 +16,7 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: funcs: Stop functions to call in sequence. Raises: - ExceptionGroup: If any stop function raises an exception. + RuntimeError: If any stop function raises an exception. """ exceptions = [] for func in funcs: @@ -26,4 +26,8 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: exceptions.append(exception) if exceptions: - raise ExceptionGroup("failed stop sequence", exceptions) + exceptions.append(RuntimeError("failed stop sequence")) + for i in range(1, len(exceptions)): + exceptions[i].__cause__ = exceptions[i - 1] + + raise exceptions[-1] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 4012d5e2d..7cd48c466 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -390,9 +390,22 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: for start in [*input_starts, *output_starts]: await start(self) - async with asyncio.TaskGroup() as task_group: - inputs_task = task_group.create_task(run_inputs()) - task_group.create_task(run_outputs(inputs_task)) + inputs_task = asyncio.create_task(run_inputs()) + outputs_task = asyncio.create_task(run_outputs(inputs_task)) + + try: + await asyncio.gather(inputs_task, outputs_task) + except (Exception, asyncio.CancelledError) as error: + inputs_task.cancel() + outputs_task.cancel() + await asyncio.gather(inputs_task, outputs_task, return_exceptions=True) + + if not isinstance(error, asyncio.CancelledError): + raise + + run_task = asyncio.current_task() + if run_task and run_task.cancelling() > 0: # externally cancelled + raise finally: input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)] diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py index f8df25e14..3c5339d08 100644 --- a/tests/strands/experimental/bidi/_async/test__init__.py +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -1,3 +1,4 @@ +import traceback from unittest.mock import AsyncMock import pytest @@ -10,17 +11,19 @@ async def test_stop_exception(): func1 = AsyncMock() func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) func3 = AsyncMock() + func4 = AsyncMock(side_effect=ValueError("stop 4 failed")) - with pytest.raises(ExceptionGroup) as exc_info: - await stop_all(func1, func2, func3) + with pytest.raises(RuntimeError, match=r"failed stop sequence") as exc_info: + await stop_all(func1, func2, func3, func4) func1.assert_called_once() func2.assert_called_once() func3.assert_called_once() + func4.assert_called_once() - assert len(exc_info.value.exceptions) == 1 - with pytest.raises(ValueError, match=r"stop 2 failed"): - raise exc_info.value.exceptions[0] + tru_tb = "".join(traceback.format_exception(RuntimeError, exc_info.value, exc_info.tb)) + assert "ValueError: stop 2 failed" in tru_tb + assert "ValueError: stop 4 failed" in tru_tb @pytest.mark.asyncio diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index da516d4a0..6543dc4f2 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,8 +13,8 @@ import pytest from google.genai import types as genai_types -from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(RuntimeError, match=r"failed stop sequence"): await model4.stop() @@ -572,7 +572,6 @@ def test_tool_formatting(model, tool_spec): assert formatted_empty == [] - # Tool Result Content Tests @@ -601,7 +600,7 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key assert isinstance(audio_event, BidiAudioStreamEvent) # Should use configured rates, not constants assert audio_event.sample_rate == 48000 # Custom config - assert audio_event.channels == 2 # Custom config + assert audio_event.channels == 2 # Custom config assert audio_event.format == "pcm" await model.stop() @@ -631,7 +630,7 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke assert isinstance(audio_event, BidiAudioStreamEvent) # Should use default rates assert audio_event.sample_rate == 24000 # Default output rate - assert audio_event.channels == 1 # Default channels + assert audio_event.channels == 1 # Default channels assert audio_event.format == "pcm" await model.stop() diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 5c9c0900d..5ab183da2 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -131,7 +131,9 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -154,7 +156,9 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -349,7 +353,7 @@ async def async_connect(*args, **kwargs): model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(RuntimeError, match=r"failed stop sequence"): await model4.stop() @@ -510,7 +514,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): format="pcm", sample_rate=24000, channels=1, - ) + ), ] assert tru_events == exp_events