Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions src/aieng_bot/bookstack/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,18 @@ async def ask_stream(
) -> AsyncGenerator[dict[str, Any], None]:
"""Answer a question, yielding structured SSE events as they occur.

Uses the Anthropic streaming API so final-answer text tokens are
forwarded to the client as they are generated.
Uses the Anthropic streaming API so text tokens are forwarded to the
client as they are generated.

Event types (dict with ``type`` key):

- ``{"type": "tool_use", "tool": "<name>", "input": {...}}``
— emitted before each tool call (clears any in-progress text in UI).
- ``{"type": "text_chunk", "chunk": "<text>"}``
— incremental text token from the current turn's response.
When a ``tool_use`` event follows, the UI should discard these
(they were planning/thinking text, not the final answer).
— incremental text token streamed in real time.
- ``{"type": "text_clear"}``
— the text streamed so far was reasoning/planning text that preceded
a tool call; the UI should discard it.
- ``{"type": "tool_use", "tool": "<name>", "input": {...}}``
— emitted before each tool call.
- ``{"type": "answer", "text": "<markdown>", "history": [...]}``
— emitted once at the end confirming the complete answer and updated
history. The caller must persist ``history`` for the next turn.
Expand All @@ -243,6 +244,7 @@ async def ask_stream(
try:
for _ in range(self.MAX_TURNS):
accumulated_text = ""
text_streamed = False
final_response: Any = None

async with self._async_client.messages.stream(
Expand All @@ -253,33 +255,38 @@ async def ask_stream(
messages=cast(list[MessageParam], messages),
) as stream:
async for event in stream:
# Accumulate text silently — we only forward it to the
# UI once we know this is a final-answer turn (no tool
# use). On-prem models like Qwen emit reasoning text
# before tool calls; streaming it and then clearing it
# is not reliable because the gateway may flush all
# text deltas before the tool-use block start event.
event_type = getattr(event, "type", None)

if (
getattr(event, "type", None) == "content_block_delta"
event_type == "content_block_delta"
and getattr(getattr(event, "delta", None), "type", None)
== "text_delta"
):
chunk: str = event.delta.text # type: ignore[union-attr]
accumulated_text += chunk
yield {"type": "text_chunk", "chunk": chunk}
text_streamed = True

elif event_type == "content_block_start":
block = getattr(event, "content_block", None)
if (
getattr(block, "type", None) == "tool_use"
and text_streamed
):
# Reasoning/planning text preceded this tool call — discard it
yield {"type": "text_clear"}
accumulated_text = ""
text_streamed = False

final_response = await stream.get_final_message()

tool_uses = [b for b in final_response.content if b.type == "tool_use"]

if not tool_uses:
# Final answer — no tool use, so it is safe to surface the
# accumulated text. Stream it chunk-by-chunk so the UI
# still renders progressively.
# Final answer — text was already streamed via text_chunk events.
answer = accumulated_text.strip() or self._extract_text(
final_response
)
for char in answer:
yield {"type": "text_chunk", "chunk": char}
messages.append({"role": "assistant", "content": answer})
yield {"type": "answer", "text": answer, "history": messages}
return
Expand All @@ -295,7 +302,6 @@ async def ask_stream(
tool_results: list[dict[str, Any]] = []
for tu in tool_uses:
ti = dict(tu.input) if isinstance(tu.input, dict) else {}
# Signal UI to clear any in-progress text and show tool status
yield {"type": "tool_use", "tool": tu.name, "input": ti}
result = await asyncio.to_thread(
execute_tool, tu.name, ti, self.bookstack
Expand Down
49 changes: 49 additions & 0 deletions tests/bookstack/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ def _make_text_delta_event(text: str) -> MagicMock:
return event


def _make_tool_use_block_start_event(name: str) -> MagicMock:
"""Build a fake content_block_start event with a tool_use block."""
block = MagicMock()
block.type = "tool_use"
block.name = name

event = MagicMock()
event.type = "content_block_start"
event.content_block = block
return event


def _make_stream_ctx(
events: list[MagicMock],
final_message: MagicMock,
Expand Down Expand Up @@ -318,3 +330,40 @@ async def test_stream_no_history_starts_fresh(
answer_event = next(e for e in events if e["type"] == "answer")
history = answer_event["history"]
assert history[0] == {"role": "user", "content": "Fresh start?"}

@pytest.mark.asyncio
async def test_stream_text_clear_emitted_when_text_precedes_tool_call(
self, agent: BookstackQAAgent
) -> None:
"""text_clear is emitted when reasoning text appears before a tool call."""
# Turn 1: reasoning text streamed, then a tool_use block starts
text_event = _make_text_delta_event("Let me search for that.")
tool_start_event = _make_tool_use_block_start_event("search_bookstack")
tool_final = _make_sync_response(
[_make_tool_use_block("search_bookstack", "tu_1", {"query": "policy"})]
)
ctx1 = _make_stream_ctx([text_event, tool_start_event], tool_final)

# Turn 2: actual answer
answer_final = _make_sync_response([_make_text_block("The policy says…")])
ctx2 = _make_stream_ctx([], answer_final)

agent._async_client.messages.stream.side_effect = [ctx1, ctx2] # type: ignore[attr-defined]

with patch(
"aieng_bot.bookstack.agent.execute_tool",
return_value=json.dumps({"data": [], "total": 0}),
):
events = []
async for evt in agent.ask_stream("What is the leave policy?"):
events.append(evt)

types = [e["type"] for e in events]
# text_chunk from the reasoning text
assert types[0] == "text_chunk"
# text_clear follows to discard the reasoning text
assert "text_clear" in types
text_clear_idx = types.index("text_clear")
# tool_use comes after text_clear
assert "tool_use" in types[text_clear_idx:]
assert types[-1] == "answer"