Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 8 additions & 60 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,18 @@ a2a = [
"starlette>=0.46.2,<1.0.0",
]

bidi = [
"aws_sdk_bedrock_runtime; python_version>='3.12'",
bidi-io = [
"prompt_toolkit>=3.0.0,<4.0.0",
"pyaudio>=0.2.13,<1.0.0",
"smithy-aws-core>=0.0.1; python_version>='3.12'",
]
bidi-gemini = ["google-genai>=1.32.0,<2.0.0"]
bidi-nova = [
"aws_sdk_bedrock_runtime; python_version>='3.12'",
"smithy-aws-core>=0.0.1; python_version>='3.12'",
]
bidi-openai = ["websockets>=15.0.0,<16.0.0"]

all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"]
bidi-all = ["strands-agents[a2a,bidi,bidi-gemini,bidi-openai,docs,otel]"]
all = ["strands-agents[a2a,anthropic,bidi-io,bidi-gemini,bidi-openai,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"]

dev = [
"commitizen>=4.4.0,<5.0.0",
Expand Down Expand Up @@ -130,7 +131,7 @@ format-fix = [
]
lint-check = [
"ruff check",
"mypy ./src"
"mypy -p src"
]
lint-fix = [
"ruff check --fix"
Expand Down Expand Up @@ -204,16 +205,10 @@ warn_no_return = true
warn_unreachable = true
follow_untyped_imports = true
ignore_missing_imports = false
exclude = ["src/strands/experimental/bidi"]

[[tool.mypy.overrides]]
module = ["strands.experimental.bidi.*"]
follow_imports = "skip"

[tool.ruff]
line-length = 120
include = ["examples/**/*.py", "src/**/*.py", "tests/**/*.py", "tests_integ/**/*.py"]
exclude = ["src/strands/experimental/bidi/**/*.py", "tests/strands/experimental/bidi/**/*.py", "tests_integ/bidi/**/*.py"]

[tool.ruff.lint]
select = [
Expand All @@ -236,16 +231,14 @@ convention = "google"
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_default_fixture_loop_scope = "function"
addopts = "--ignore=tests/strands/experimental/bidi --ignore=tests_integ/bidi"

addopts = "--ignore=tests/strands/experimental/bidi/models/test_nova_sonic.py --ignore=tests_integ/bidi"

[tool.coverage.run]
branch = true
source = ["src"]
context = "thread"
parallel = true
concurrency = ["thread", "multiprocessing"]
omit = ["src/strands/experimental/bidi/*"]

[tool.coverage.report]
show_missing = true
Expand Down Expand Up @@ -275,48 +268,3 @@ style = [
["text", ""],
["disabled", "fg:#858585 italic"]
]

# =========================
# Bidi development configs
# =========================

[tool.hatch.envs.bidi]
dev-mode = true
features = ["dev", "bidi-all"]
installer = "uv"

[tool.hatch.envs.bidi.scripts]
prepare = [
"hatch run bidi-lint:format-fix",
"hatch run bidi-lint:quality-fix",
"hatch run bidi-lint:type-check",
"hatch run bidi-test:test-cov",
]

[tools.hatch.envs.bidi-lint]
template = "bidi"

[tool.hatch.envs.bidi-lint.scripts]
format-check = "format-fix --check"
format-fix = "ruff format {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py"
quality-check = "ruff check {args} --target-version py312 ./src/strands/experimental/bidi/**/*.py"
quality-fix = "quality-check --fix"
type-check = "mypy {args} --python-version 3.12 ./src/strands/experimental/bidi/**/*.py"

[tool.hatch.envs.bidi-test]
template = "bidi"

[tool.hatch.envs.bidi-test.scripts]
test = "pytest {args} tests/strands/experimental/bidi"
test-cov = """
test \
--cov=strands.experimental.bidi \
--cov-config= \
--cov-branch \
--cov-report=term-missing \
--cov-report=xml:build/coverage/bidi-coverage.xml \
--cov-report=html:build/coverage/bidi-html
"""

[[tool.hatch.envs.bidi-test.matrix]]
python = ["3.13", "3.12"]
14 changes: 7 additions & 7 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,7 @@ async def _run_loop(
try:
yield InitEventLoopEvent()

for message in messages:
await self._append_message(message)
await self._append_messages(*messages)

structured_output_context = StructuredOutputContext(
structured_output_model or self._default_structured_output_model
Expand Down Expand Up @@ -715,7 +714,7 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
tool_use_ids = [
content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content
]
await self._append_message(
await self._append_messages(
{
"role": "user",
"content": generate_missing_tool_result_content(tool_use_ids),
Expand Down Expand Up @@ -811,10 +810,11 @@ def _initialize_system_prompt(
else:
return None, None

async def _append_message(self, message: Message) -> None:
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
self.messages.append(message)
await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message))
async def _append_messages(self, *messages: Message) -> None:
"""Appends messages to history and invoke the callbacks for the MessageAddedEvent."""
for message in messages:
self.messages.append(message)
await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message))

def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]:
"""Redact user content preserving toolResult blocks.
Expand Down
2 changes: 1 addition & 1 deletion src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def event_loop_cycle(
)
structured_output_context.set_forced_mode()
logger.debug("Forcing structured output tool")
await agent._append_message(
await agent._append_messages(
{"role": "user", "content": [{"text": "You must format the previous response as structured output."}]}
)

Expand Down
4 changes: 2 additions & 2 deletions src/strands/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module implements experimental features that are subject to change in future revisions without notice.
"""

from . import steering, tools
from . import bidi, steering, tools
from .agent_config import config_to_agent

__all__ = ["config_to_agent", "tools", "steering"]
__all__ = ["bidi", "config_to_agent", "tools", "steering"]
8 changes: 0 additions & 8 deletions src/strands/experimental/bidi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
"""Bidirectional streaming package."""

import sys

if sys.version_info < (3, 12):
raise ImportError("bidi only supported for >= Python 3.12")

# Main components - Primary user interface
# Re-export standard agent events for tool handling
from ...types._events import (
Expand All @@ -19,7 +14,6 @@

# Model interface (for custom implementations)
from .models.model import BidiModel
from .models.nova_sonic import BidiNovaSonicModel

# Built-in tools
from .tools import stop_conversation
Expand Down Expand Up @@ -48,8 +42,6 @@
"BidiAgent",
# IO channels
"BidiAudioIO",
# Model providers
"BidiNovaSonicModel",
# Built-in tools
"stop_conversation",
# Input Event types
Expand Down
8 changes: 6 additions & 2 deletions src/strands/experimental/bidi/_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
funcs: Stop functions to call in sequence.

Raises:
ExceptionGroup: If any stop function raises an exception.
RuntimeError: If any stop function raises an exception.
"""
exceptions = []
for func in funcs:
Expand All @@ -26,4 +26,8 @@ async def stop_all(*funcs: Callable[..., Awaitable[None]]) -> None:
exceptions.append(exception)

if exceptions:
raise ExceptionGroup("failed stop sequence", exceptions)
exceptions.append(RuntimeError("failed stop sequence"))
for i in range(1, len(exceptions)):
exceptions[i].__cause__ = exceptions[i - 1]

raise exceptions[-1]
49 changes: 36 additions & 13 deletions src/strands/experimental/bidi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@
from ....tools.executors._executor import ToolExecutor
from ....tools.registry import ToolRegistry
from ....tools.watcher import ToolWatcher
from ....types.content import Messages
from ....types.content import Message, Messages
from ....types.tools import AgentTool
from ...hooks.events import BidiAgentInitializedEvent
from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent
from ...tools import ToolProvider
from .._async import stop_all
from ..models.model import BidiModel
from ..models.nova_sonic import BidiNovaSonicModel
from ..types.agent import BidiAgentInput
from ..types.events import (
BidiAudioInputEvent,
Expand Down Expand Up @@ -100,13 +99,13 @@ def __init__(
ValueError: If model configuration is invalid or state is invalid type.
TypeError: If model type is unsupported.
"""
self.model = (
BidiNovaSonicModel()
if not model
else BidiNovaSonicModel(model_id=model)
if isinstance(model, str)
else model
)
if isinstance(model, BidiModel):
self.model = model
else:
from ..models.nova_sonic import BidiNovaSonicModel

self.model = BidiNovaSonicModel(model_id=model) if isinstance(model, str) else BidiNovaSonicModel()

self.system_prompt = system_prompt
self.messages = messages or []

Expand Down Expand Up @@ -167,6 +166,9 @@ def __init__(
# TODO: Determine if full support is required
self._interrupt_state = _InterruptState()

# Lock to ensure that paired messages are added to history in sequence without interference.
self._message_lock = asyncio.Lock()

self._started = False

@property
Expand Down Expand Up @@ -387,12 +389,33 @@ async def run_outputs(inputs_task: asyncio.Task) -> None:
for start in [*input_starts, *output_starts]:
await start(self)

async with asyncio.TaskGroup() as task_group:
inputs_task = task_group.create_task(run_inputs())
task_group.create_task(run_outputs(inputs_task))
inputs_task = asyncio.create_task(run_inputs())
outputs_task = asyncio.create_task(run_outputs(inputs_task))

try:
await asyncio.gather(inputs_task, outputs_task)
except (Exception, asyncio.CancelledError):
inputs_task.cancel()
outputs_task.cancel()
await asyncio.gather(inputs_task, outputs_task, return_exceptions=True)
raise

finally:
input_stops = [input_.stop for input_ in inputs if isinstance(input_, BidiInput)]
output_stops = [output.stop for output in outputs if isinstance(output, BidiOutput)]

await stop_all(*input_stops, *output_stops, self.stop)

async def _append_messages(self, *messages: Message) -> None:
"""Append messages to history in sequence without interference.

The message lock ensures that paired messages are added to history in sequence without interference. For
example, tool use and tool result messages must be added adjacent to each other.

Args:
*messages: List of messages to add into history.
"""
async with self._message_lock:
for message in messages:
self.messages.append(message)
await self.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self, message=message))
23 changes: 3 additions & 20 deletions src/strands/experimental/bidi/agent/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
BidiAfterInvocationEvent,
BidiBeforeConnectionRestartEvent,
BidiBeforeInvocationEvent,
BidiMessageAddedEvent,
)
from ...hooks.events import (
BidiInterruptionEvent as BidiInterruptionHookEvent,
Expand Down Expand Up @@ -51,8 +50,6 @@ class _BidiAgentLoop:
that tools can access via their invocation_state parameter.
_send_gate: Gate the sending of events to the model.
Blocks when agent is reseting the model connection after timeout.
_message_lock: Lock to ensure that paired messages are added to history in sequence without interference.
For example, tool use and tool result messages must be added adjacent to each other.
"""

def __init__(self, agent: "BidiAgent") -> None:
Expand All @@ -70,7 +67,6 @@ def __init__(self, agent: "BidiAgent") -> None:
self._invocation_state: dict[str, Any]

self._send_gate = asyncio.Event()
self._message_lock = asyncio.Lock()

async def start(self, invocation_state: dict[str, Any] | None = None) -> None:
"""Start the agent loop.
Expand Down Expand Up @@ -145,7 +141,7 @@ async def send(self, event: BidiInputEvent | ToolResultEvent) -> None:

if isinstance(event, BidiTextInputEvent):
message: Message = {"role": "user", "content": [{"text": event.text}]}
await self._add_messages(message)
await self._agent._append_messages(message)

await self._agent.model.send(event)

Expand Down Expand Up @@ -224,7 +220,7 @@ async def _run_model(self) -> None:
if isinstance(event, BidiTranscriptStreamEvent):
if event["is_final"]:
message: Message = {"role": event["role"], "content": [{"text": event["text"]}]}
await self._add_messages(message)
await self._agent._append_messages(message)

elif isinstance(event, ToolUseStreamEvent):
tool_use = event["current_tool_use"]
Expand Down Expand Up @@ -282,7 +278,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None:

tool_use_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]}
tool_result_message: Message = {"role": "user", "content": [{"toolResult": tool_result_event.tool_result}]}
await self._add_messages(tool_use_message, tool_result_message)
await self._agent._append_messages(tool_use_message, tool_result_message)

await self._event_queue.put(ToolResultMessageEvent(tool_result_message))

Expand All @@ -300,16 +296,3 @@ async def _run_tool(self, tool_use: ToolUse) -> None:

except Exception as error:
await self._event_queue.put(error)

async def _add_messages(self, *messages: Message) -> None:
"""Add messages to history in sequence without interference.

Args:
*messages: List of messages to add into history.
"""
async with self._message_lock:
for message in messages:
self._agent.messages.append(message)
await self._agent.hooks.invoke_callbacks_async(
BidiMessageAddedEvent(agent=self._agent, message=message)
)
2 changes: 0 additions & 2 deletions src/strands/experimental/bidi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Bidirectional model interfaces and implementations."""

from .model import BidiModel, BidiModelTimeoutError
from .nova_sonic import BidiNovaSonicModel

__all__ = [
"BidiModel",
"BidiModelTimeoutError",
"BidiNovaSonicModel",
]
3 changes: 2 additions & 1 deletion src/strands/experimental/bidi/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

import logging
from typing import Any, AsyncIterable, Protocol
from typing import Any, AsyncIterable, Protocol, runtime_checkable

from ....types._events import ToolResultEvent
from ....types.content import Messages
Expand All @@ -27,6 +27,7 @@
logger = logging.getLogger(__name__)


@runtime_checkable
class BidiModel(Protocol):
"""Protocol for bidirectional streaming models.

Expand Down
Loading
Loading