Skip to content

Commit b693903

Browse files
committed
Merge branch 'main' of https://github.com/strands-agents/sdk-python into bidi-remove-311-features
2 parents 5421b56 + d1b523c commit b693903

File tree

9 files changed

+111
-50
lines changed

9 files changed

+111
-50
lines changed

src/strands/agent/agent.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -624,8 +624,7 @@ async def _run_loop(
624624
try:
625625
yield InitEventLoopEvent()
626626

627-
for message in messages:
628-
await self._append_message(message)
627+
await self._append_messages(*messages)
629628

630629
structured_output_context = StructuredOutputContext(
631630
structured_output_model or self._default_structured_output_model
@@ -715,7 +714,7 @@ async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
715714
tool_use_ids = [
716715
content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content
717716
]
718-
await self._append_message(
717+
await self._append_messages(
719718
{
720719
"role": "user",
721720
"content": generate_missing_tool_result_content(tool_use_ids),
@@ -811,10 +810,11 @@ def _initialize_system_prompt(
811810
else:
812811
return None, None
813812

814-
async def _append_message(self, message: Message) -> None:
815-
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
816-
self.messages.append(message)
817-
await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message))
813+
async def _append_messages(self, *messages: Message) -> None:
814+
"""Appends messages to history and invoke the callbacks for the MessageAddedEvent."""
815+
for message in messages:
816+
self.messages.append(message)
817+
await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message))
818818

819819
def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]:
820820
"""Redact user content preserving toolResult blocks.

src/strands/event_loop/event_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ async def event_loop_cycle(
230230
)
231231
structured_output_context.set_forced_mode()
232232
logger.debug("Forcing structured output tool")
233-
await agent._append_message(
233+
await agent._append_messages(
234234
{"role": "user", "content": [{"text": "You must format the previous response as structured output."}]}
235235
)
236236

src/strands/experimental/bidi/agent/agent.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from ....tools.executors._executor import ToolExecutor
2727
from ....tools.registry import ToolRegistry
2828
from ....tools.watcher import ToolWatcher
29-
from ....types.content import Messages
29+
from ....types.content import Message, Messages
3030
from ....types.tools import AgentTool
31-
from ...hooks.events import BidiAgentInitializedEvent
31+
from ...hooks.events import BidiAgentInitializedEvent, BidiMessageAddedEvent
3232
from ...tools import ToolProvider
3333
from .._async import stop_all
3434
from ..models.model import BidiModel
@@ -167,6 +167,9 @@ def __init__(
167167
# TODO: Determine if full support is required
168168
self._interrupt_state = _InterruptState()
169169

170+
# Lock to ensure that paired messages are added to history in sequence without interference
171+
self._message_lock = asyncio.Lock()
172+
170173
self._started = False
171174

172175
@property
@@ -403,3 +406,17 @@ async def run_outputs(inputs_task: asyncio.Task) -> None:
403406
output_stops = [output.stop for output in outputs if isinstance(output, BidiOutput)]
404407

405408
await stop_all(*input_stops, *output_stops, self.stop)
409+
410+
async def _append_messages(self, *messages: Message) -> None:
411+
"""Append messages to history in sequence without interference.
412+
413+
The message lock ensures that paired messages are added to history in sequence without interference. For
414+
example, tool use and tool result messages must be added adjacent to each other.
415+
416+
Args:
417+
*messages: List of messages to add into history.
418+
"""
419+
async with self._message_lock:
420+
for message in messages:
421+
self.messages.append(message)
422+
await self.hooks.invoke_callbacks_async(BidiMessageAddedEvent(agent=self, message=message))

src/strands/experimental/bidi/agent/loop.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
BidiAfterInvocationEvent,
1616
BidiBeforeConnectionRestartEvent,
1717
BidiBeforeInvocationEvent,
18-
BidiMessageAddedEvent,
1918
)
2019
from ...hooks.events import (
2120
BidiInterruptionEvent as BidiInterruptionHookEvent,
@@ -51,8 +50,6 @@ class _BidiAgentLoop:
5150
that tools can access via their invocation_state parameter.
5251
_send_gate: Gate the sending of events to the model.
5352
Blocks when agent is reseting the model connection after timeout.
54-
_message_lock: Lock to ensure that paired messages are added to history in sequence without interference.
55-
For example, tool use and tool result messages must be added adjacent to each other.
5653
"""
5754

5855
def __init__(self, agent: "BidiAgent") -> None:
@@ -70,7 +67,6 @@ def __init__(self, agent: "BidiAgent") -> None:
7067
self._invocation_state: dict[str, Any]
7168

7269
self._send_gate = asyncio.Event()
73-
self._message_lock = asyncio.Lock()
7470

7571
async def start(self, invocation_state: dict[str, Any] | None = None) -> None:
7672
"""Start the agent loop.
@@ -145,7 +141,7 @@ async def send(self, event: BidiInputEvent | ToolResultEvent) -> None:
145141

146142
if isinstance(event, BidiTextInputEvent):
147143
message: Message = {"role": "user", "content": [{"text": event.text}]}
148-
await self._add_messages(message)
144+
await self._agent._append_messages(message)
149145

150146
await self._agent.model.send(event)
151147

@@ -224,7 +220,7 @@ async def _run_model(self) -> None:
224220
if isinstance(event, BidiTranscriptStreamEvent):
225221
if event["is_final"]:
226222
message: Message = {"role": event["role"], "content": [{"text": event["text"]}]}
227-
await self._add_messages(message)
223+
await self._agent._append_messages(message)
228224

229225
elif isinstance(event, ToolUseStreamEvent):
230226
tool_use = event["current_tool_use"]
@@ -282,7 +278,7 @@ async def _run_tool(self, tool_use: ToolUse) -> None:
282278

283279
tool_use_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]}
284280
tool_result_message: Message = {"role": "user", "content": [{"toolResult": tool_result_event.tool_result}]}
285-
await self._add_messages(tool_use_message, tool_result_message)
281+
await self._agent._append_messages(tool_use_message, tool_result_message)
286282

287283
await self._event_queue.put(ToolResultMessageEvent(tool_result_message))
288284

@@ -300,16 +296,3 @@ async def _run_tool(self, tool_use: ToolUse) -> None:
300296

301297
except Exception as error:
302298
await self._event_queue.put(error)
303-
304-
async def _add_messages(self, *messages: Message) -> None:
305-
"""Add messages to history in sequence without interference.
306-
307-
Args:
308-
*messages: List of messages to add into history.
309-
"""
310-
async with self._message_lock:
311-
for message in messages:
312-
self._agent.messages.append(message)
313-
await self._agent.hooks.invoke_callbacks_async(
314-
BidiMessageAddedEvent(agent=self._agent, message=message)
315-
)

src/strands/tools/_caller.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,7 @@ async def _record_tool_execution(
195195
}
196196

197197
# Add to message history
198-
await self._agent._append_message(user_msg)
199-
await self._agent._append_message(tool_use_msg)
200-
await self._agent._append_message(tool_result_msg)
201-
await self._agent._append_message(assistant_msg)
198+
await self._agent._append_messages(user_msg, tool_use_msg, tool_result_msg, assistant_msg)
202199

203200
def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]:
204201
"""Filter input parameters to only include those defined in the tool specification.

tests/strands/event_loop/test_event_loop_structured_output.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def mock_agent():
4242
agent.trace_span = None
4343
agent.trace_attributes = {}
4444
agent.tool_executor = Mock()
45-
agent._append_message = AsyncMock()
45+
agent._append_messages = AsyncMock()
4646

4747
# Set up _interrupt_state properly
4848
agent._interrupt_state = Mock()
@@ -186,8 +186,8 @@ async def test_event_loop_forces_structured_output_on_end_turn(
186186
await alist(stream)
187187

188188
# Should have appended a message to force structured output
189-
mock_agent._append_message.assert_called_once()
190-
args = mock_agent._append_message.call_args[0][0]
189+
mock_agent._append_messages.assert_called_once()
190+
args = mock_agent._append_messages.call_args[0][0]
191191
assert args["role"] == "user"
192192

193193
# Should have called recurse_event_loop with the context

tests/strands/experimental/bidi/agent/test_loop.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
import pytest_asyncio
55

66
from strands import tool
7+
from strands.experimental.bidi import BidiAgent
78
from strands.experimental.bidi.agent.loop import _BidiAgentLoop
89
from strands.experimental.bidi.models import BidiModelTimeoutError
910
from strands.experimental.bidi.types.events import BidiConnectionRestartEvent, BidiTextInputEvent
10-
from strands.hooks import HookRegistry
11-
from strands.tools.executors import SequentialToolExecutor
12-
from strands.tools.registry import ToolRegistry
1311
from strands.types._events import ToolResultEvent, ToolResultMessageEvent, ToolUseStreamEvent
1412

1513

@@ -24,20 +22,12 @@ async def func():
2422

2523
@pytest.fixture
2624
def agent(time_tool):
27-
mock = unittest.mock.Mock()
28-
mock.hooks = HookRegistry()
29-
mock.messages = []
30-
mock.model = unittest.mock.AsyncMock()
31-
mock.tool_executor = SequentialToolExecutor()
32-
mock.tool_registry = ToolRegistry()
33-
mock.tool_registry.process_tools([time_tool])
34-
35-
return mock
25+
return BidiAgent(model=unittest.mock.AsyncMock(), tools=[time_tool])
3626

3727

3828
@pytest_asyncio.fixture
3929
async def loop(agent):
40-
return _BidiAgentLoop(agent)
30+
return agent._loop
4131

4232

4333
@pytest.mark.asyncio

tests_integ/bidi/tools/__init__.py

Whitespace-only changes.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import unittest.mock
2+
3+
import pytest
4+
5+
from strands import tool
6+
from strands.experimental.bidi.agent import BidiAgent
7+
8+
9+
@pytest.fixture
10+
def weather_tool():
11+
@tool(name="weather_tool")
12+
def func(city_name: str) -> str:
13+
return f"city_name=<{city_name}> | sunny"
14+
15+
return func
16+
17+
18+
@pytest.fixture
19+
def agent(weather_tool):
20+
return BidiAgent(record_direct_tool_call=True, tools=[weather_tool])
21+
22+
23+
def test_bidi_agent_tool_direct_call(agent):
24+
tru_result = agent.tool.weather_tool(city_name="new york")
25+
exp_result = {
26+
"content": [{"text": "city_name=<new york> | sunny"}],
27+
"status": "success",
28+
"toolUseId": unittest.mock.ANY,
29+
}
30+
assert tru_result == exp_result
31+
32+
tru_messages = agent.messages
33+
exp_messages = [
34+
{
35+
"content": [
36+
{
37+
"text": (
38+
"agent.tool.weather_tool direct tool call.\n"
39+
'Input parameters: {"city_name": "new york"}\n'
40+
),
41+
},
42+
],
43+
"role": "user",
44+
},
45+
{
46+
"content": [
47+
{
48+
"toolUse": {
49+
"input": {"city_name": "new york"},
50+
"name": "weather_tool",
51+
"toolUseId": unittest.mock.ANY,
52+
},
53+
},
54+
],
55+
"role": "assistant",
56+
},
57+
{
58+
"content": [
59+
{
60+
"toolResult": {
61+
"content": [{"text": "city_name=<new york> | sunny"}],
62+
"status": "success",
63+
"toolUseId": unittest.mock.ANY,
64+
},
65+
},
66+
],
67+
"role": "user",
68+
},
69+
{
70+
"content": [{"text": "agent.tool.weather_tool was called."}],
71+
"role": "assistant",
72+
},
73+
]
74+
assert tru_messages == exp_messages

0 commit comments

Comments
 (0)