From c59e89e7b032db54f90097d09a4c0e17f7015a56 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 5 Dec 2025 12:25:24 -0500 Subject: [PATCH] build - bidi - isolate nova provider --- pyproject.toml | 68 +------- src/strands/agent/agent.py | 14 +- src/strands/event_loop/event_loop.py | 2 +- src/strands/experimental/__init__.py | 4 +- src/strands/experimental/bidi/__init__.py | 8 - .../experimental/bidi/_async/__init__.py | 8 +- src/strands/experimental/bidi/agent/agent.py | 49 ++++-- src/strands/experimental/bidi/agent/loop.py | 23 +-- .../experimental/bidi/models/__init__.py | 2 - src/strands/experimental/bidi/models/model.py | 3 +- .../experimental/bidi/models/nova_sonic.py | 30 +++- src/strands/tools/_caller.py | 9 +- src/strands/tools/executors/_executor.py | 65 +++++--- src/strands/tools/executors/concurrent.py | 9 +- src/strands/tools/executors/sequential.py | 5 +- .../test_event_loop_structured_output.py | 6 +- .../experimental/bidi/_async/test__init__.py | 13 +- .../experimental/bidi/agent/__init__.py | 2 +- .../experimental/bidi/agent/test_agent.py | 154 +++++++++--------- .../experimental/bidi/agent/test_loop.py | 50 +++--- .../experimental/bidi/io/test_audio.py | 3 +- .../strands/experimental/bidi/io/test_text.py | 2 +- .../bidi/models/test_gemini_live.py | 9 +- .../bidi/models/test_nova_sonic.py | 18 +- .../bidi/models/test_openai_realtime.py | 12 +- 25 files changed, 279 insertions(+), 289 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c2a6b260..16be677e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,17 +70,18 @@ a2a = [ "starlette>=0.46.2,<1.0.0", ] -bidi = [ - "aws_sdk_bedrock_runtime; python_version>='3.12'", +bidi-io = [ "prompt_toolkit>=3.0.0,<4.0.0", "pyaudio>=0.2.13,<1.0.0", - "smithy-aws-core>=0.0.1; python_version>='3.12'", ] bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] +bidi-nova = [ + "aws_sdk_bedrock_runtime; python_version>='3.12'", + "smithy-aws-core>=0.0.1; python_version>='3.12'", +] bidi-openai = ["websockets>=15.0.0,<16.0.0"] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] -bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"] +all = ["strands-agents[a2a,anthropic,bidi-io,bidi-gemini,bidi-openai,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", @@ -130,7 +131,7 @@ format-fix = [ ] lint-check = [ "ruff check", - "mypy ./src" + "mypy -p src" ] lint-fix = [ "ruff check --fix" @@ -204,16 +205,10 @@ warn_no_return = true warn_unreachable = true follow_untyped_imports = true ignore_missing_imports = false -exclude = ["src/strands/experimental/bidi"] - -[[tool.mypy.overrides]] -module = ["strands.experimental.bidi.*"] -follow_imports = "skip" [tool.ruff] line-length = 120 include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"] -exclude = ["src/strands/experimental/bidi/**/*.py", "tests/strands/experimental/bidi/**/*.py", "tests_integ/bidi/**/*.py"] [tool.ruff.lint] select = [ @@ -236,8 +231,7 @@ convention = "google" [tool.pytest.ini_options] testpaths = ["tests"] asyncio_default_fixture_loop_scope = "function" -addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi" - +addopts = "--ignore=tests/strands/experimental/bidi/models/test_nova_sonic.py --ignore=tests_integ/bidi" [tool.coverage.run] branch = true @@ -245,7 +239,6 @@ source = ["src"] context = "thread" parallel = true concurrency = ["thread", "multiprocessing"] -omit = ["src/strands/experimental/bidi/*"] [tool.coverage.report] show_missing = true @@ -275,48 +268,3 @@ style = [ ["text", ""], ["disabled", "fg:#858585 italic"] ] - -# ========================= -# Bidi development configs -# ========================= - -[tool.hatch.envs.bidi] -dev-mode = true -features = ["dev", "bidi-all"] -installer = "uv" - -[tool.hatch.envs.bidi.scripts] -prepare = [ - "hatch run bidi-lint:format-fix", - "hatch run bidi-lint:quality-fix", - "hatch run bidi-lint:type-check", - "hatch run bidi-test:test-cov", -] - -[tools.hatch.envs.bidi-lint] -template = "bidi" - -[tool.hatch.envs.bidi-lint.scripts] -format-check = "format-fix --check" -format-fix = "ruff format {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py" -quality-check = "ruff check {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py" -quality-fix = "quality-check --fix" -type-check = "mypy {args} --python-version 3.12 ./src/strands/experimental/bidi/**/*.py" - -[tool.hatch.envs.bidi-test] -template = "bidi" - -[tool.hatch.envs.bidi-test.scripts] -test = "pytest {args} tests/strands/experimental/bidi" -test-cov = """ -test \ - --cov=strands.experimental.bidi \ - --cov-config= \ - --cov-branch \ - --cov-report=term-missing \ - --cov-report=xml:build/coverage/bidi-coverage.xml \ - --cov-report=html:build/coverage/bidi-html -""" - -[[tool.hatch.envs.bidi-test.matrix]] -python = ["3.13", "3.12"] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 232e2ca2a..ff0a1c3c3 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -624,8 +624,7 @@ async def _run_loop( try: yield InitEventLoopEvent() - for message in messages: - await self._append_message(message) + await self._append_messages(*messages) structured_output_context = StructuredOutputContext( structured_output_model or self._default_structured_output_model @@ -715,7 +714,7 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: tool_use_ids = [ content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content ] - await self._append_message( + await self._append_messages( { "role": "user", "content": generate_missing_tool_result_content(tool_use_ids), @@ -811,10 +810,11 @@ def _initialize_system_prompt( else: return None, None - async def _append_message(self, message: Message) -> None: - """Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent.""" - self.messages.append(message) - await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message)) + async def _append_messages(self, *messages: Message) -> None: + """Appends messages to history and invoke the callbacks for the MessageAddedEvent.""" + for message in messages: + self.messages.append(message) + await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message)) def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]: """Redact user content preserving toolResult blocks. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 186ead708..f25057e4d 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -230,7 +230,7 @@ async def event_loop_cycle( ) structured_output_context.set_forced_mode() logger.debug("Forcing structured output tool") - await agent._append_message( + await agent._append_messages( {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} ) diff --git a/src/strands/experimental/__init__.py b/src/strands/experimental/__init__.py index 3c1d0ee46..7a24ecd38 100644 --- a/src/strands/experimental/__init__.py +++ b/src/strands/experimental/__init__.py @@ -3,7 +3,7 @@ This module implements experimental features that are subject to change in future revisions without notice. """ -from . import steering, tools +from . import bidi, steering, tools from .agent_config import config_to_agent -__all__ = ["config_to_agent", "tools", "steering"] +__all__ = ["bidi", "config_to_agent", "tools", "steering"] diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 57986062e..1c0e74aae 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -1,10 +1,5 @@ """Bidirectional streaming package.""" -import sys - -if sys.version_info < (3, 12): - raise ImportError("bidi only supported for >= Python 3.12") - # Main components - Primary user interface # Re-export standard agent events for tool handling from ...types._events import ( @@ -19,7 +14,6 @@ # Model interface (for custom implementations) from .models.model import BidiModel -from .models.nova_sonic import BidiNovaSonicModel # Built-in tools from .tools import stop_conversation @@ -48,8 +42,6 @@ "BidiAgent", # IO channels "BidiAudioIO", - # Model providers - "BidiNovaSonicModel", # Built-in tools "stop_conversation", # Input Event types diff --git a/src/strands/experimental/bidi/_async/__init__.py b/src/strands/experimental/bidi/_async/__init__.py index 6cee3264d..2b97ab1fc 100644 --- a/src/strands/experimental/bidi/_async/__init__.py +++ b/src/strands/experimental/bidi/_async/__init__.py @@ -16,7 +16,7 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: funcs: Stop functions to call in sequence. Raises: - ExceptionGroup: If any stop function raises an exception. + RuntimeError: If any stop function raises an exception. """ exceptions = [] for func in funcs: @@ -26,4 +26,8 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None: exceptions.append(exception) if exceptions: - raise ExceptionGroup("failed stop sequence", exceptions) + exceptions.append(RuntimeError("failed stop sequence")) + for i in range(1, len(exceptions)): + exceptions[i].__cause__ = exceptions[i - 1] + + raise exceptions[-1] diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 360dfe707..c41bf2ba4 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -26,13 +26,12 @@ from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry from ....tools.watcher import ToolWatcher -from ....types.content import Messages +from ....types.content import Message, Messages from ....types.tools import AgentTool -from ...hooks.events import BidiAgentInitializedEvent +from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent from ...tools import ToolProvider from .._async import stop_all from ..models.model import BidiModel -from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput from ..types.events import ( BidiAudioInputEvent, @@ -100,13 +99,13 @@ def __init__( ValueError: If model configuration is invalid or state is invalid type. TypeError: If model type is unsupported. """ - self.model = ( - BidiNovaSonicModel() - if not model - else BidiNovaSonicModel(model_id=model) - if isinstance(model, str) - else model - ) + if isinstance(model, BidiModel): + self.model = model + else: + from ..models.nova_sonic import BidiNovaSonicModel + + self.model = BidiNovaSonicModel(model_id=model) if isinstance(model, str) else BidiNovaSonicModel() + self.system_prompt = system_prompt self.messages = messages or [] @@ -167,6 +166,9 @@ def __init__( # TODO: Determine if full support is required self._interrupt_state = _InterruptState() + # Lock to ensure that paired messages are added to history in sequence without interference. + self._message_lock = asyncio.Lock() + self._started = False @property @@ -387,12 +389,33 @@ async def run_outputs(inputs_task: asyncio.Task) -> None: for start in [*input_starts, *output_starts]: await start(self) - async with asyncio.TaskGroup() as task_group: - inputs_task = task_group.create_task(run_inputs()) - task_group.create_task(run_outputs(inputs_task)) + inputs_task = asyncio.create_task(run_inputs()) + outputs_task = asyncio.create_task(run_outputs(inputs_task)) + + try: + await asyncio.gather(inputs_task, outputs_task) + except (Exception, asyncio.CancelledError): + inputs_task.cancel() + outputs_task.cancel() + await asyncio.gather(inputs_task, outputs_task, return_exceptions=True) + raise finally: input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)] output_stops = [output.stop for output in outputs if isinstance(output, BidiOutput)] await stop_all(*input_stops, *output_stops, self.stop) + + async def _append_messages(self, *messages: Message) -> None: + """Append messages to history in sequence without interference. + + The message lock ensures that paired messages are added to history in sequence without interference. For + example, tool use and tool result messages must be added adjacent to each other. + + Args: + *messages: List of messages to add into history. + """ + async with self._message_lock: + for message in messages: + self.messages.append(message) + await self.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self, message=message)) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 13b7033a4..2b883cf73 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -15,7 +15,6 @@ BidiAfterInvocationEvent, BidiBeforeConnectionRestartEvent, BidiBeforeInvocationEvent, - BidiMessageAddedEvent, ) from ...hooks.events import ( BidiInterruptionEvent as BidiInterruptionHookEvent, @@ -51,8 +50,6 @@ class _BidiAgentLoop: that tools can access via their invocation_state parameter. _send_gate: Gate the sending of events to the model. Blocks when agent is reseting the model connection after timeout. - _message_lock: Lock to ensure that paired messages are added to history in sequence without interference. - For example, tool use and tool result messages must be added adjacent to each other. """ def __init__(self, agent: "BidiAgent") -> None: @@ -70,7 +67,6 @@ def __init__(self, agent: "BidiAgent") -> None: self._invocation_state: dict[str, Any] self._send_gate = asyncio.Event() - self._message_lock = asyncio.Lock() async def start(self, invocation_state: dict[str, Any] | None = None) -> None: """Start the agent loop. @@ -145,7 +141,7 @@ async def send(self, event: BidiInputEvent | ToolResultEvent) -> None: if isinstance(event, BidiTextInputEvent): message: Message = {"role": "user", "content": [{"text": event.text}]} - await self._add_messages(message) + await self._agent._append_messages(message) await self._agent.model.send(event) @@ -224,7 +220,7 @@ async def _run_model(self) -> None: if isinstance(event, BidiTranscriptStreamEvent): if event["is_final"]: message: Message = {"role": event["role"], "content": [{"text": event["text"]}]} - await self._add_messages(message) + await self._agent._append_messages(message) elif isinstance(event, ToolUseStreamEvent): tool_use = event["current_tool_use"] @@ -282,7 +278,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None: tool_use_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} tool_result_message: Message = {"role": "user", "content": [{"toolResult": tool_result_event.tool_result}]} - await self._add_messages(tool_use_message, tool_result_message) + await self._agent._append_messages(tool_use_message, tool_result_message) await self._event_queue.put(ToolResultMessageEvent(tool_result_message)) @@ -300,16 +296,3 @@ async def _run_tool(self, tool_use: ToolUse) -> None: except Exception as error: await self._event_queue.put(error) - - async def _add_messages(self, *messages: Message) -> None: - """Add messages to history in sequence without interference. - - Args: - *messages: List of messages to add into history. - """ - async with self._message_lock: - for message in messages: - self._agent.messages.append(message) - await self._agent.hooks.invoke_callbacks_async( - BidiMessageAddedEvent(agent=self._agent, message=message) - ) diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index cc62c9987..4b8b22e2d 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,10 +1,8 @@ """Bidirectional model interfaces and implementations.""" from .model import BidiModel, BidiModelTimeoutError -from .nova_sonic import BidiNovaSonicModel __all__ = [ "BidiModel", "BidiModelTimeoutError", - "BidiNovaSonicModel", ] diff --git a/src/strands/experimental/bidi/models/model.py b/src/strands/experimental/bidi/models/model.py index f5e34aa50..5941d7e41 100644 --- a/src/strands/experimental/bidi/models/model.py +++ b/src/strands/experimental/bidi/models/model.py @@ -14,7 +14,7 @@ """ import logging -from typing import Any, AsyncIterable, Protocol +from typing import Any, AsyncIterable, Protocol, runtime_checkable from ....types._events import ToolResultEvent from ....types.content import Messages @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) +@runtime_checkable class BidiModel(Protocol): """Protocol for bidirectional streaming models. diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 6a2477e22..966b571db 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -11,9 +11,16 @@ - Tool execution with content containers and identifier tracking - 8-minute connection limits with proper cleanup sequences - Interruption detection through stopReason events + +Note, BidiNovaSonicModel is only supported for Python 3.12+ """ -import asyncio +import sys + +if sys.version_info < (3, 12): + raise ImportError("BidiNovaSonicModel is only supported for Python 3.12+") + +import asyncio # type: ignore[unreachable] import base64 import json import logging @@ -21,17 +28,24 @@ from typing import Any, AsyncGenerator, cast import boto3 -from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput -from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme -from aws_sdk_bedrock_runtime.models import ( +from aws_sdk_bedrock_runtime.client import ( # type: ignore[import-not-found] + BedrockRuntimeClient, + InvokeModelWithBidirectionalStreamOperationInput, +) +from aws_sdk_bedrock_runtime.config import ( # type: ignore[import-not-found] + Config, + HTTPAuthSchemeResolver, + SigV4AuthScheme, +) +from aws_sdk_bedrock_runtime.models import ( # type: ignore[import-not-found] BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, ModelTimeoutException, ValidationException, ) -from smithy_aws_core.identity.static import StaticCredentialsResolver -from smithy_core.aio.eventstream import DuplexEventStream -from smithy_core.shapes import ShapeID +from smithy_aws_core.identity.static import StaticCredentialsResolver # type: ignore[import-not-found] +from smithy_core.aio.eventstream import DuplexEventStream # type: ignore[import-not-found] +from smithy_core.shapes import ShapeID # type: ignore[import-not-found] from ....types._events import ToolResultEvent, ToolUseStreamEvent from ....types.content import Messages @@ -93,6 +107,8 @@ class BidiNovaSonicModel(BidiModel): Manages Nova Sonic's complex event sequencing, audio format conversion, and tool execution patterns while providing the standard BidiModel interface. + Note, BidiNovaSonicModel is only supported for Python 3.12+. + Attributes: _stream: open bedrock stream to nova sonic. """ diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 3ab576947..6871ce814 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -9,7 +9,7 @@ import json import random -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, cast from .._async import run_async from ..tools.executors._executor import ToolExecutor @@ -108,7 +108,7 @@ async def acall() -> ToolResult: # Apply conversation management if agent supports it (traditional agents) if hasattr(self._agent, "conversation_manager"): - self._agent.conversation_manager.apply_management(self._agent) + self._agent.conversation_manager.apply_management(cast("Agent", self._agent)) return tool_result @@ -195,10 +195,7 @@ async def _record_tool_execution( } # Add to message history - await self._agent._append_message(user_msg) - await self._agent._append_message(tool_use_msg) - await self._agent._append_message(tool_result_msg) - await self._agent._append_message(assistant_msg) + await self._agent._append_messages(user_msg, tool_use_msg, tool_result_msg, assistant_msg) def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: """Filter input parameters to only include those defined in the tool specification. diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index a4f9e7e1f..72a533505 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -49,15 +49,24 @@ async def _invoke_before_tool_call_hook( invocation_state: dict[str, Any], ) -> tuple[BeforeToolCallEvent | BidiBeforeToolCallEvent, list[Interrupt]]: """Invoke the appropriate before tool call hook based on agent type.""" - event_cls = BeforeToolCallEvent if ToolExecutor._is_agent(agent) else BidiBeforeToolCallEvent - return await agent.hooks.invoke_callbacks_async( - event_cls( - agent=agent, - selected_tool=tool_func, - tool_use=tool_use, - invocation_state=invocation_state, + if ToolExecutor._is_agent(agent): + return await agent.hooks.invoke_callbacks_async( + BeforeToolCallEvent( + agent=cast("Agent", agent), + selected_tool=tool_func, + tool_use=tool_use, + invocation_state=invocation_state, + ) + ) + else: + return await agent.hooks.invoke_callbacks_async( + BidiBeforeToolCallEvent( + agent=cast("BidiAgent", agent), + selected_tool=tool_func, + tool_use=tool_use, + invocation_state=invocation_state, + ) ) - ) @staticmethod async def _invoke_after_tool_call_hook( @@ -70,18 +79,30 @@ async def _invoke_after_tool_call_hook( cancel_message: str | None = None, ) -> tuple[AfterToolCallEvent | BidiAfterToolCallEvent, list[Interrupt]]: """Invoke the appropriate after tool call hook based on agent type.""" - event_cls = AfterToolCallEvent if ToolExecutor._is_agent(agent) else BidiAfterToolCallEvent - return await agent.hooks.invoke_callbacks_async( - event_cls( - agent=agent, - selected_tool=selected_tool, - tool_use=tool_use, - invocation_state=invocation_state, - result=result, - exception=exception, - cancel_message=cancel_message, + if ToolExecutor._is_agent(agent): + return await agent.hooks.invoke_callbacks_async( + AfterToolCallEvent( + agent=cast("Agent", agent), + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + exception=exception, + cancel_message=cancel_message, + ) + ) + else: + return await agent.hooks.invoke_callbacks_async( + BidiAfterToolCallEvent( + agent=cast("BidiAgent", agent), + selected_tool=selected_tool, + tool_use=tool_use, + invocation_state=invocation_state, + result=result, + exception=exception, + cancel_message=cancel_message, + ) ) - ) @staticmethod async def _stream( @@ -247,7 +268,7 @@ async def _stream( @staticmethod async def _stream_with_trace( - agent: "Agent | BidiAgent", + agent: "Agent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -259,7 +280,7 @@ async def _stream_with_trace( """Execute tool with tracing and metrics collection. Args: - agent: The agent (Agent or BidiAgent) for which the tool is being executed. + agent: The agent for which the tool is being executed. tool_use: Metadata and inputs for the tool to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -308,7 +329,7 @@ async def _stream_with_trace( # pragma: no cover def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index da5c1ff10..216eee379 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -22,7 +21,7 @@ class ConcurrentToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -33,7 +32,7 @@ async def _execute( """Execute tools concurrently. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. @@ -79,7 +78,7 @@ async def _execute( async def _task( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_use: ToolUse, tool_results: list[ToolResult], cycle_trace: Trace, @@ -94,7 +93,7 @@ async def _task( """Execute a single tool and put results in the task queue. Args: - agent: The agent (Agent or BidiAgent) executing the tool. + agent: The agent executing the tool. tool_use: Tool use metadata and inputs. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index 6163fc195..f78e60872 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: # pragma: no cover from ...agent import Agent - from ...experimental.bidi import BidiAgent from ..structured_output._structured_output_context import StructuredOutputContext @@ -21,7 +20,7 @@ class SequentialToolExecutor(ToolExecutor): @override async def _execute( self, - agent: "Agent | BidiAgent", + agent: "Agent", tool_uses: list[ToolUse], tool_results: list[ToolResult], cycle_trace: Trace, @@ -34,7 +33,7 @@ async def _execute( Breaks early if an interrupt is raised by the user. Args: - agent: The agent (Agent or BidiAgent) for which tools are being executed. + agent: The agent for which tools are being executed. tool_uses: Metadata and inputs for the tools to be executed. tool_results: List of tool results from each tool execution. cycle_trace: Trace object for the current event loop cycle. diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 30a25312b..508042af0 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -42,7 +42,7 @@ def mock_agent(): agent.trace_span = None agent.trace_attributes = {} agent.tool_executor = Mock() - agent._append_message = AsyncMock() + agent._append_messages = AsyncMock() # Set up _interrupt_state properly agent._interrupt_state = Mock() @@ -186,8 +186,8 @@ async def test_event_loop_forces_structured_output_on_end_turn( await alist(stream) # Should have appended a message to force structured output - mock_agent._append_message.assert_called_once() - args = mock_agent._append_message.call_args[0][0] + mock_agent._append_messages.assert_called_once() + args = mock_agent._append_messages.call_args[0][0] assert args["role"] == "user" # Should have called recurse_event_loop with the context diff --git a/tests/strands/experimental/bidi/_async/test__init__.py b/tests/strands/experimental/bidi/_async/test__init__.py index f8df25e14..3c5339d08 100644 --- a/tests/strands/experimental/bidi/_async/test__init__.py +++ b/tests/strands/experimental/bidi/_async/test__init__.py @@ -1,3 +1,4 @@ +import traceback from unittest.mock import AsyncMock import pytest @@ -10,17 +11,19 @@ async def test_stop_exception(): func1 = AsyncMock() func2 = AsyncMock(side_effect=ValueError("stop 2 failed")) func3 = AsyncMock() + func4 = AsyncMock(side_effect=ValueError("stop 4 failed")) - with pytest.raises(ExceptionGroup) as exc_info: - await stop_all(func1, func2, func3) + with pytest.raises(RuntimeError, match=r"failed stop sequence") as exc_info: + await stop_all(func1, func2, func3, func4) func1.assert_called_once() func2.assert_called_once() func3.assert_called_once() + func4.assert_called_once() - assert len(exc_info.value.exceptions) == 1 - with pytest.raises(ValueError, match=r"stop 2 failed"): - raise exc_info.value.exceptions[0] + tru_tb = "".join(traceback.format_exception(RuntimeError, exc_info.value, exc_info.tb)) + assert "ValueError: stop 2 failed" in tru_tb + assert "ValueError: stop 4 failed" in tru_tb @pytest.mark.asyncio diff --git a/tests/strands/experimental/bidi/agent/__init__.py b/tests/strands/experimental/bidi/agent/__init__.py index 3359c6565..dd401a83d 100644 --- a/tests/strands/experimental/bidi/agent/__init__.py +++ b/tests/strands/experimental/bidi/agent/__init__.py @@ -1 +1 @@ -"""Bidirectional streaming agent tests.""" \ No newline at end of file +"""Bidirectional streaming agent tests.""" diff --git a/tests/strands/experimental/bidi/agent/test_agent.py b/tests/strands/experimental/bidi/agent/test_agent.py index 19d3525d7..29e5b7303 100644 --- a/tests/strands/experimental/bidi/agent/test_agent.py +++ b/tests/strands/experimental/bidi/agent/test_agent.py @@ -1,22 +1,24 @@ """Unit tests for BidiAgent.""" -import unittest.mock import asyncio -import pytest +import unittest.mock from uuid import uuid4 +import pytest + from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.nova_sonic import BidiNovaSonicModel +from strands.experimental.bidi.models import BidiModel from strands.experimental.bidi.types.events import ( - BidiTextInputEvent, BidiAudioInputEvent, BidiAudioStreamEvent, - BidiTranscriptStreamEvent, - BidiConnectionStartEvent, BidiConnectionCloseEvent, + BidiConnectionStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, ) -class MockBidiModel: + +class MockBidiModel(BidiModel): """Mock bidirectional model for testing.""" def __init__(self, config=None, model_id="mock-model"): @@ -46,14 +48,14 @@ async def receive(self): """Async generator yielding mock events.""" if not self._started: raise RuntimeError("model not started | call start before sending/receiving") - + # Yield connection start event yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) - + # Yield any configured events for event in self._events_to_yield: yield event - + # Yield connection end event yield BidiConnectionCloseEvent(connection_id=self._connection_id, reason="complete") @@ -61,11 +63,13 @@ def set_events(self, events): """Helper to set events this mock model will yield.""" self._events_to_yield = events + @pytest.fixture def mock_model(): """Create a mock BidiModel instance.""" return MockBidiModel() + @pytest.fixture def mock_tool_registry(): """Mock tool registry with some basic tools.""" @@ -73,15 +77,15 @@ def mock_tool_registry(): registry.get_all_tool_specs.return_value = [ { "name": "calculator", - "description": "Perform calculations", - "inputSchema": {"json": {"type": "object", "properties": {}}} + "description": "Perform calculations", + "inputSchema": {"json": {"type": "object", "properties": {}}}, } ] registry.get_all_tools_config.return_value = {"calculator": {}} return registry -@pytest.fixture +@pytest.fixture def mock_tool_caller(): """Mock tool caller for testing tool execution.""" caller = unittest.mock.AsyncMock() @@ -94,203 +98,194 @@ def agent(mock_model, mock_tool_registry, mock_tool_caller): """Create a BidiAgent instance for testing.""" with unittest.mock.patch("strands.experimental.bidi.agent.agent.ToolRegistry") as mock_registry_class: mock_registry_class.return_value = mock_tool_registry - + with unittest.mock.patch("strands.experimental.bidi.agent.agent._ToolCaller") as mock_caller_class: mock_caller_class.return_value = mock_tool_caller - + # Don't pass tools to avoid real tool loading agent = BidiAgent(model=mock_model) return agent + def test_bidi_agent_init_with_various_configurations(): """Test agent initialization with various configurations.""" # Test default initialization mock_model = MockBidiModel() agent = BidiAgent(model=mock_model) - + assert agent.model == mock_model assert agent.system_prompt is None assert not agent._started assert agent.model._connection_id is None - + # Test with configuration system_prompt = "You are a helpful assistant." - agent_with_config = BidiAgent( - model=mock_model, - system_prompt=system_prompt, - agent_id="test_agent" - ) - + agent_with_config = BidiAgent(model=mock_model, system_prompt=system_prompt, agent_id="test_agent") + assert agent_with_config.system_prompt == system_prompt assert agent_with_config.agent_id == "test_agent" - + # Test with string model ID - model_id = "amazon.nova-sonic-v1:0" - agent_with_string = BidiAgent(model=model_id) - - assert isinstance(agent_with_string.model, BidiNovaSonicModel) - assert agent_with_string.model.model_id == model_id - + # model_id = "amazon.nova-sonic-v1:0" + # agent_with_string = BidiAgent(model=model_id) + + # assert isinstance(agent_with_string.model, BidiNovaSonicModel) + # assert agent_with_string.model.model_id == model_id + # Test model config access config = agent.model.config assert config["audio"]["input_rate"] == 16000 assert config["audio"]["output_rate"] == 24000 assert config["audio"]["channels"] == 1 + @pytest.mark.asyncio async def test_bidi_agent_start_stop_lifecycle(agent): """Test agent start/stop lifecycle and state management.""" # Initial state assert not agent._started assert agent.model._connection_id is None - + # Start agent await agent.start() assert agent._started assert agent.model._connection_id is not None connection_id = agent.model._connection_id - + # Double start should error with pytest.raises(RuntimeError, match="agent already started"): await agent.start() - + # Stop agent await agent.stop() assert not agent._started assert agent.model._connection_id is None - + # Multiple stops should be safe await agent.stop() await agent.stop() - + # Restart should work with new connection ID await agent.start() assert agent._started assert agent.model._connection_id != connection_id + @pytest.mark.asyncio async def test_bidi_agent_send_with_input_types(agent): """Test sending various input types through agent.send().""" await agent.start() - + # Test text input with TypedEvent text_input = BidiTextInputEvent(text="Hello", role="user") await agent.send(text_input) assert len(agent.messages) == 1 assert agent.messages[0]["content"][0]["text"] == "Hello" - + # Test string input (shorthand) await agent.send("World") assert len(agent.messages) == 2 assert agent.messages[1]["content"][0]["text"] == "World" - + # Test audio input (doesn't add to messages) audio_input = BidiAudioInputEvent( audio="dGVzdA==", # base64 "test" format="pcm", sample_rate=16000, - channels=1 + channels=1, ) await agent.send(audio_input) assert len(agent.messages) == 2 # Still 2, audio doesn't add - + # Test concurrent sends - sends = [ - agent.send(BidiTextInputEvent(text=f"Message {i}", role="user")) - for i in range(3) - ] + sends = [agent.send(BidiTextInputEvent(text=f"Message {i}", role="user")) for i in range(3)] await asyncio.gather(*sends) assert len(agent.messages) == 5 # 2 + 3 new messages + @pytest.mark.asyncio async def test_bidi_agent_receive_events_from_model(agent): """Test receiving events from model.""" # Configure mock model to yield events events = [ - BidiAudioStreamEvent( - audio="dGVzdA==", - format="pcm", - sample_rate=24000, - channels=1 - ), + BidiAudioStreamEvent(audio="dGVzdA==", format="pcm", sample_rate=24000, channels=1), BidiTranscriptStreamEvent( text="Hello world", role="assistant", is_final=True, delta={"text": "Hello world"}, - current_transcript="Hello world" - ) + current_transcript="Hello world", + ), ] agent.model.set_events(events) - + await agent.start() - + received_events = [] async for event in agent.receive(): received_events.append(event) if len(received_events) >= 4: # Stop after getting expected events break - + # Verify event types and order assert len(received_events) >= 3 assert isinstance(received_events[0], BidiConnectionStartEvent) assert isinstance(received_events[1], BidiAudioStreamEvent) assert isinstance(received_events[2], BidiTranscriptStreamEvent) - + # Test empty events agent.model.set_events([]) await agent.stop() await agent.start() - + empty_events = [] async for event in agent.receive(): empty_events.append(event) if len(empty_events) >= 2: break - + assert len(empty_events) >= 1 assert isinstance(empty_events[0], BidiConnectionStartEvent) + def test_bidi_agent_tool_integration(agent, mock_tool_registry): """Test agent tool integration and properties.""" # Test tool property access - assert hasattr(agent, 'tool') + assert hasattr(agent, "tool") assert agent.tool is not None assert agent.tool == agent._tool_caller - + # Test tool names property - mock_tool_registry.get_all_tools_config.return_value = { - "calculator": {}, - "weather": {} - } - + mock_tool_registry.get_all_tools_config.return_value = {"calculator": {}, "weather": {}} + tool_names = agent.tool_names assert isinstance(tool_names, list) assert len(tool_names) == 2 assert "calculator" in tool_names assert "weather" in tool_names + @pytest.mark.asyncio async def test_bidi_agent_send_receive_error_before_start(agent): """Test error handling in various scenarios.""" # Test send before start with pytest.raises(RuntimeError, match="call start before"): await agent.send(BidiTextInputEvent(text="Hello", role="user")) - + # Test receive before start with pytest.raises(RuntimeError, match="call start before"): - async for event in agent.receive(): + async for _ in agent.receive(): pass - + # Test send after stop await agent.start() await agent.stop() with pytest.raises(RuntimeError, match="call start before"): await agent.send(BidiTextInputEvent(text="Hello", role="user")) - + # Test receive after stop with pytest.raises(RuntimeError, match="call start before"): - async for event in agent.receive(): + async for _ in agent.receive(): pass @@ -301,43 +296,44 @@ async def test_bidi_agent_start_receive_propagates_model_errors(): mock_model = MockBidiModel() mock_model.start = unittest.mock.AsyncMock(side_effect=Exception("Connection failed")) error_agent = BidiAgent(model=mock_model) - + with pytest.raises(Exception, match="Connection failed"): await error_agent.start() - + # Test model receive error mock_model2 = MockBidiModel() agent2 = BidiAgent(model=mock_model2) await agent2.start() - + async def failing_receive(): yield BidiConnectionStartEvent(connection_id="test", model="test-model") raise Exception("Receive failed") - + agent2.model.receive = failing_receive with pytest.raises(Exception, match="Receive failed"): - async for event in agent2.receive(): + async for _ in agent2.receive(): pass + @pytest.mark.asyncio async def test_bidi_agent_state_consistency(agent): """Test that agent state remains consistent across operations.""" # Initial state assert not agent._started assert agent.model._connection_id is None - + # Start await agent.start() assert agent._started assert agent.model._connection_id is not None connection_id = agent.model._connection_id - + # Send operations shouldn't change connection state await agent.send(BidiTextInputEvent(text="Hello", role="user")) assert agent._started assert agent.model._connection_id == connection_id - + # Stop await agent.stop() assert not agent._started - assert agent.model._connection_id is None \ No newline at end of file + assert agent.model._connection_id is None diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index d19cada60..bd735a84c 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -4,12 +4,15 @@ import pytest_asyncio from strands import tool -from strands.experimental.bidi.agent.loop import _BidiAgentLoop -from strands.experimental.bidi.models import BidiModelTimeoutError +from strands.experimental.bidi import BidiAgent + +# from strands.experimental.bidi.agent.loop import _BidiAgentLoop +from strands.experimental.bidi.models import BidiModel, BidiModelTimeoutError from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent -from strands.hooks import HookRegistry -from strands.tools.executors import SequentialToolExecutor -from strands.tools.registry import ToolRegistry + +# from strands.hooks import HookRegistry +# from strands.tools.executors import SequentialToolExecutor +# from strands.tools.registry import ToolRegistry from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent @@ -24,20 +27,24 @@ async def func(): @pytest.fixture def agent(time_tool): - mock = unittest.mock.Mock() - mock.hooks = HookRegistry() - mock.messages = [] - mock.model = unittest.mock.AsyncMock() - mock.tool_executor = SequentialToolExecutor() - mock.tool_registry = ToolRegistry() - mock.tool_registry.process_tools([time_tool]) - - return mock + # mock = unittest.mock.Mock() + # mock.hooks = HookRegistry() + # mock.messages = [] + # mock.model = unittest.mock.AsyncMock() + # mock.tool_executor = SequentialToolExecutor() + # mock.tool_registry = ToolRegistry() + # mock.tool_registry.process_tools([time_tool]) + # mock._append_messages = unittest.mock.AsyncMock() + # mock._message_lock = asyncio.Lock() + return BidiAgent(model=unittest.mock.AsyncMock(spec=BidiModel), tools=[time_tool]) + + # return mock @pytest_asyncio.fixture async def loop(agent): - return _BidiAgentLoop(agent) + return agent._loop + # return _BidiAgentLoop(agent) @pytest.mark.asyncio @@ -48,19 +55,19 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato agent.model.receive = unittest.mock.Mock(side_effect=[timeout_error, agenerator([text_event])]) await loop.start() - + tru_events = [] async for event in loop.receive(): tru_events.append(event) if len(tru_events) >= 2: break - + exp_events = [ BidiConnectionRestartEvent(timeout_error), text_event, ] assert tru_events == exp_events - + agent.model.stop.assert_called_once() assert agent.model.start.call_count == 2 agent.model.start.assert_called_with( @@ -73,7 +80,6 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato @pytest.mark.asyncio async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): - tool_use = {"toolUseId": "t1", "name": "time_tool", "input": {}} tool_result = {"toolUseId": "t1", "status": "success", "content": [{"text": "12:00"}]} @@ -81,9 +87,9 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): tool_result_event = ToolResultEvent(tool_result) agent.model.receive = unittest.mock.Mock(return_value=agenerator([tool_use_event])) - + await loop.start() - + tru_events = [] async for event in loop.receive(): tru_events.append(event) @@ -96,7 +102,7 @@ async def test_bidi_agent_loop_receive_tool_use(loop, agent, agenerator): ToolResultMessageEvent({"role": "user", "content": [{"toolResult": tool_result}]}), ] assert tru_events == exp_events - + tru_messages = agent.messages exp_messages = [ {"role": "assistant", "content": [{"toolUse": tool_use}]}, diff --git a/tests/strands/experimental/bidi/io/test_audio.py b/tests/strands/experimental/bidi/io/test_audio.py index 459faa78a..9b502700b 100644 --- a/tests/strands/experimental/bidi/io/test_audio.py +++ b/tests/strands/experimental/bidi/io/test_audio.py @@ -29,7 +29,7 @@ def agent(): "voice": "test-voice", }, } - return mock + return mock @pytest.fixture @@ -49,6 +49,7 @@ def config(): "output_frames_per_buffer": 2048, } + @pytest.fixture def audio_io(py_audio, config): _ = py_audio diff --git a/tests/strands/experimental/bidi/io/test_text.py b/tests/strands/experimental/bidi/io/test_text.py index 5507a8c0f..e21e149bd 100644 --- a/tests/strands/experimental/bidi/io/test_text.py +++ b/tests/strands/experimental/bidi/io/test_text.py @@ -42,7 +42,7 @@ async def test_bidi_text_io_input(prompt_session, text_input): (BidiInterruptionEvent(reason="user_speech"), "interrupted"), (BidiTranscriptStreamEvent(text="test text", delta="", is_final=False, role="user"), "Preview: test text"), (BidiTranscriptStreamEvent(text="test text", delta="", is_final=True, role="user"), "test text"), - ] + ], ) @pytest.mark.asyncio async def test_bidi_text_io_output(event, exp_print, text_output, capsys): diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index da516d4a0..6543dc4f2 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,8 +13,8 @@ import pytest from google.genai import types as genai_types -from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -185,7 +185,7 @@ async def test_connection_edge_cases(mock_genai_client, api_key, model_id): model4 = BidiGeminiLiveModel(model_id=model_id, client_config={"api_key": api_key}) await model4.start() mock_live_session_cm.__aexit__.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(RuntimeError, match=r"failed stop sequence"): await model4.stop() @@ -572,7 +572,6 @@ def test_tool_formatting(model, tool_spec): assert formatted_empty == [] - # Tool Result Content Tests @@ -601,7 +600,7 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key assert isinstance(audio_event, BidiAudioStreamEvent) # Should use configured rates, not constants assert audio_event.sample_rate == 48000 # Custom config - assert audio_event.channels == 2 # Custom config + assert audio_event.channels == 2 # Custom config assert audio_event.format == "pcm" await model.stop() @@ -631,7 +630,7 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke assert isinstance(audio_event, BidiAudioStreamEvent) # Should use default rates assert audio_event.sample_rate == 24000 # Default output rate - assert audio_event.channels == 1 # Default channels + assert audio_event.channels == 1 # Default channels assert audio_event.format == "pcm" await model.stop() diff --git a/tests/strands/experimental/bidi/models/test_nova_sonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py index 04f8043be..933fd2088 100644 --- a/tests/strands/experimental/bidi/models/test_nova_sonic.py +++ b/tests/strands/experimental/bidi/models/test_nova_sonic.py @@ -13,10 +13,10 @@ import pytest_asyncio from aws_sdk_bedrock_runtime.models import ModelTimeoutException, ValidationException +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.nova_sonic import ( BidiNovaSonicModel, ) -from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, @@ -538,12 +538,12 @@ async def test_custom_audio_rates_in_events(model_id, region): audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} result = model._convert_nova_event(nova_event) - + assert result is not None assert isinstance(result, BidiAudioStreamEvent) # Should use configured rates, not constants assert result.sample_rate == 48000 # Custom config - assert result.channels == 2 # Custom config + assert result.channels == 2 # Custom config assert result.format == "pcm" @@ -558,12 +558,12 @@ async def test_default_audio_rates_in_events(model_id, region): audio_base64 = base64.b64encode(audio_bytes).decode("utf-8") nova_event = {"audioOutput": {"content": audio_base64}} result = model._convert_nova_event(nova_event) - + assert result is not None assert isinstance(result, BidiAudioStreamEvent) # Should use default rates assert result.sample_rate == 16000 # Default output rate - assert result.channels == 1 # Default channels + assert result.channels == 1 # Default channels assert result.format == "pcm" @@ -573,9 +573,9 @@ async def test_bidi_nova_sonic_model_receive_timeout(nova_model, mock_stream): mock_output = AsyncMock() mock_output.receive.side_effect = ModelTimeoutException("Connection timeout") mock_stream.await_output.return_value = (None, mock_output) - + await nova_model.start() - + with pytest.raises(BidiModelTimeoutError, match=r"Connection timeout"): async for _ in nova_model.receive(): pass @@ -586,9 +586,9 @@ async def test_bidi_nova_sonic_model_receive_timeout_validation(nova_model, mock mock_output = AsyncMock() mock_output.receive.side_effect = ValidationException("InternalErrorCode=531: Request timeout") mock_stream.await_output.return_value = (None, mock_output) - + await nova_model.start() - + with pytest.raises(BidiModelTimeoutError, match=r"InternalErrorCode=531"): async for _ in nova_model.receive(): pass diff --git a/tests/strands/experimental/bidi/models/test_openai_realtime.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py index 5c9c0900d..5ab183da2 100644 --- a/tests/strands/experimental/bidi/models/test_openai_realtime.py +++ b/tests/strands/experimental/bidi/models/test_openai_realtime.py @@ -131,7 +131,9 @@ def test_audio_config_defaults(api_key, model_name): def test_audio_config_partial_override(api_key, model_name): """Test partial audio configuration override.""" provider_config = {"audio": {"output_rate": 48000, "voice": "echo"}} - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) # Overridden values assert model.config["audio"]["output_rate"] == 48000 @@ -154,7 +156,9 @@ def test_audio_config_full_override(api_key, model_name): "voice": "shimmer", } } - model = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config) + model = BidiOpenAIRealtimeModel( + model_id=model_name, client_config={"api_key": api_key}, provider_config=provider_config + ) assert model.config["audio"]["input_rate"] == 48000 assert model.config["audio"]["output_rate"] == 48000 @@ -349,7 +353,7 @@ async def async_connect(*args, **kwargs): model4 = BidiOpenAIRealtimeModel(model_id=model_name, client_config={"api_key": api_key}) await model4.start() mock_ws.close.side_effect = Exception("Close failed") - with pytest.raises(ExceptionGroup): + with pytest.raises(RuntimeError, match=r"failed stop sequence"): await model4.stop() @@ -510,7 +514,7 @@ async def test_receive_lifecycle_events(mock_websocket, model): format="pcm", sample_rate=24000, channels=1, - ) + ), ] assert tru_events == exp_events