Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/strands/experimental/bidi/_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
funcs: Stop functions to call in sequence.
Raises:
ExceptionGroup: If any stop function raises an exception.
RuntimeError: If any stop function raises an exception.
"""
exceptions = []
for func in funcs:
Expand All @@ -26,4 +26,8 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
exceptions.append(exception)

if exceptions:
raise ExceptionGroup("failed stop sequence", exceptions)
exceptions.append(RuntimeError("failed stop sequence"))
for i in range(1, len(exceptions)):
exceptions[i].__cause__ = exceptions[i - 1]
Copy link
Member Author

@pgrayy pgrayy Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception traceback will show all the causes chained together. We have this unit tested down below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slightly misleading right? e.g. in reality they didn't really all cause each other, we're just presenting it that way?

Should we indicate that somehow? (Like via the message or something)

Can we also document this in the code

Copy link
Member Author

@pgrayy pgrayy Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah shoot, I actually meant to get a traceback message that looks something like:

Traceback (most recent call last):
  File "/Users/pgrayy/Projects/Strands/sdk-python/src/strands/experimental/bidi/_async/__init__.py", line 24, in stop_all
    await func()
  File "/Users/pgrayy/Library/Application Support/hatch/env/virtual/.pythons/3.12/python/lib/python3.12/unittest/mock.py", line 2291, in _execute_mock_call
    raise effect
ValueError: stop 2 failed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/pgrayy/Projects/Strands/sdk-python/src/strands/experimental/bidi/_async/__init__.py", line 24, in stop_all
    await func()
  File "/Users/pgrayy/Library/Application Support/hatch/env/virtual/.pythons/3.12/python/lib/python3.12/unittest/mock.py", line 2291, in _execute_mock_call
    raise effect
ValueError: stop 4 failed

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/pgrayy/Projects/Strands/sdk-python/tests/strands/experimental/bidi/_async/test__init__.py", line 17, in test_stop_exception
    await stop_all(func1, func2, func3, func4)
  File "/Users/pgrayy/Projects/Strands/sdk-python/src/strands/experimental/bidi/_async/__init__.py", line 33, in stop_all
    raise exceptions[-1]
RuntimeError: failed stop sequence

Note, this says During handling of the above exception, another exception occurred instead of The above exception was the direct cause of the following exception. I would just need to switch __cause__ to __context__.

What would you think of this?


raise exceptions[-1]
19 changes: 16 additions & 3 deletions src/strands/experimental/bidi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,22 @@ async def run_outputs(inputs_task: asyncio.Task) -> None:
for start in [*input_starts, *output_starts]:
await start(self)

async with asyncio.TaskGroup() as task_group:
inputs_task = task_group.create_task(run_inputs())
task_group.create_task(run_outputs(inputs_task))
inputs_task = asyncio.create_task(run_inputs())
outputs_task = asyncio.create_task(run_outputs(inputs_task))

try:
await asyncio.gather(inputs_task, outputs_task)
except (Exception, asyncio.CancelledError) as error:
inputs_task.cancel()
outputs_task.cancel()
await asyncio.gather(inputs_task, outputs_task, return_exceptions=True)

if not isinstance(error, asyncio.CancelledError):
raise

run_task = asyncio.current_task()
if run_task and run_task.cancelling() > 0: # externally cancelled
raise
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an example here, if the user is running run through asyncio.create_task and calls task.cancel() themselves, we want that CancelledError propagating. Similarly for KeyboardInterrupt. That is what TaskGroup does.

Internal cancellations however, like with the inputs_task.cancel() inside run_outputs, we don't want that to reraise to the user.


finally:
input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)]
Expand Down
13 changes: 8 additions & 5 deletions tests/strands/experimental/bidi/_async/test__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
Copy link
Member Author

@pgrayy pgrayy Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some lint fixes in these 2 test files. They weren't caught previously because there is a bug in our lint command used specifically for bidi. I am addressing this as part of #1299 which resolves the issue by checking bidi using the existing hatch scripts we have configured for the rest of strands.

from unittest.mock import AsyncMock

import pytest
Expand All @@ -10,17 +11,19 @@ async def test_stop_exception():
func1 = AsyncMock()
func2 = AsyncMock(side_effect=ValueError("stop 2 failed"))
func3 = AsyncMock()
func4 = AsyncMock(side_effect=ValueError("stop 4 failed"))

with pytest.raises(ExceptionGroup) as exc_info:
await stop_all(func1, func2, func3)
with pytest.raises(RuntimeError, match=r"failed stop sequence") as exc_info:
await stop_all(func1, func2, func3, func4)

func1.assert_called_once()
func2.assert_called_once()
func3.assert_called_once()
func4.assert_called_once()

assert len(exc_info.value.exceptions) == 1
with pytest.raises(ValueError, match=r"stop 2 failed"):
raise exc_info.value.exceptions[0]
tru_tb = "".join(traceback.format_exception(RuntimeError, exc_info.value, exc_info.tb))
assert "ValueError: stop 2 failed" in tru_tb
assert "ValueError: stop 4 failed" in tru_tb


@pytest.mark.asyncio
Expand Down
9 changes: 4 additions & 5 deletions tests/strands/experimental/bidi/models/test_gemini_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import pytest
from google.genai import types as genai_types

from strands.experimental.bidi.models.model import BidiModelTimeoutError
from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel
from strands.experimental.bidi.models.model import BidiModelTimeoutError
from strands.experimental.bidi.types.events import (
BidiAudioInputEvent,
BidiAudioStreamEvent,
Expand Down Expand Up @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id):
model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key})
await model4.start()
mock_live_session_cm.__aexit__.side_effect = Exception("Close failed")
with pytest.raises(ExceptionGroup):
with pytest.raises(RuntimeError, match=r"failed stop sequence"):
await model4.stop()


Expand Down Expand Up @@ -572,7 +572,6 @@ def test_tool_formatting(model, tool_spec):
assert formatted_empty == []



# Tool Result Content Tests


Expand Down Expand Up @@ -601,7 +600,7 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key
assert isinstance(audio_event, BidiAudioStreamEvent)
# Should use configured rates, not constants
assert audio_event.sample_rate == 48000 # Custom config
assert audio_event.channels == 2 # Custom config
assert audio_event.channels == 2 # Custom config
assert audio_event.format == "pcm"

await model.stop()
Expand Down Expand Up @@ -631,7 +630,7 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke
assert isinstance(audio_event, BidiAudioStreamEvent)
# Should use default rates
assert audio_event.sample_rate == 24000 # Default output rate
assert audio_event.channels == 1 # Default channels
assert audio_event.channels == 1 # Default channels
assert audio_event.format == "pcm"

await model.stop()
Expand Down
12 changes: 8 additions & 4 deletions tests/strands/experimental/bidi/models/test_openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def test_audio_config_defaults(api_key, model_name):
def test_audio_config_partial_override(api_key, model_name):
"""Test partial audio configuration override."""
provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}}
model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config)
model = BidiOpenAIRealtimeModel(
model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config
)

# Overridden values
assert model.config["audio"]["output_rate"] == 48000
Expand All @@ -154,7 +156,9 @@ def test_audio_config_full_override(api_key, model_name):
"voice": "shimmer",
}
}
model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config)
model = BidiOpenAIRealtimeModel(
model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config
)

assert model.config["audio"]["input_rate"] == 48000
assert model.config["audio"]["output_rate"] == 48000
Expand Down Expand Up @@ -349,7 +353,7 @@ async def async_connect(*args, **kwargs):
model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key})
await model4.start()
mock_ws.close.side_effect = Exception("Close failed")
with pytest.raises(ExceptionGroup):
with pytest.raises(RuntimeError, match=r"failed stop sequence"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have guidelines on when we use a custom exception vs general?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have guidelines and probably should. In this particular case, users wouldn't have reason to act specifically on this error which is why I made it a RuntimeError. They would just likely catch this generally with except Exception because it is an unrecoverable internal error. That is in comparison to something like BidiModelTimeoutError, which we created so that we could trigger the restart connection workflow internally.

So if we are to develop guidelines, I would probably start with something like this. Would the exception be used to drive a specific action/workflow. If not and the user would just end up catching it with except Exception, then maybe we don't require a custom error. I would have to think about this some more, but this is where my head was at for this particular case.

await model4.stop()


Expand Down Expand Up @@ -510,7 +514,7 @@ async def test_receive_lifecycle_events(mock_websocket, model):
format="pcm",
sample_rate=24000,
channels=1,
)
),
]
assert tru_events == exp_events

Expand Down
Loading