Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 145 additions & 75 deletions src/aieng_bot/bookstack/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@
# MessageParam[content] constraints; cast to list[MessageParam] at call sites.
MessageHistory = list[Any]

# Thinking models (Qwen3, DeepSeek-R1, …) embed chain-of-thought reasoning in
# the text stream before this marker. Text before it is buffered silently;
# only what follows is forwarded as answer content.
_THINK_END = "</think>"


class _TurnState:
"""Mutable container for the results of one LLM streaming turn."""

__slots__ = ("accumulated_text", "thinking_done", "text_streamed", "final_response")

def __init__(self) -> None:
self.accumulated_text: str = ""
self.thinking_done: bool = False
self.text_streamed: bool = False
self.final_response: Any = None


class BookstackQAAgent:
"""Answer questions from the BookStack wiki using Claude with tool use.
Expand Down Expand Up @@ -201,28 +218,125 @@ def ask(
# Async streaming path (API)
# ------------------------------------------------------------------

async def _stream_llm_turn(
self,
messages: MessageHistory,
state: _TurnState,
) -> AsyncGenerator[dict[str, Any], None]:
"""Run one LLM streaming call; yield ``text_chunk``/``text_clear`` events.

Populates *state* with the accumulated answer text, whether ``</think>``
was seen, and the final message object.
"""
skip_leading_nl = False

async with self._async_client.messages.stream(
model=self.model,
max_tokens=8192,
system=SYSTEM_PROMPT,
tools=ALL_TOOLS,
messages=cast(list[MessageParam], messages),
) as stream:
async for event in stream:
event_type = getattr(event, "type", None)

if (
event_type == "content_block_delta"
and getattr(getattr(event, "delta", None), "type", None)
== "text_delta"
):
chunk: str = event.delta.text # type: ignore[union-attr]
if not state.thinking_done:
state.accumulated_text += chunk
idx = state.accumulated_text.find(_THINK_END)
if idx >= 0:
state.thinking_done = True
skip_leading_nl = True
post = state.accumulated_text[
idx + len(_THINK_END) :
].lstrip("\n")
state.accumulated_text = post
if post:
skip_leading_nl = False
yield {"type": "text_chunk", "chunk": post}
state.text_streamed = True
else:
if skip_leading_nl:
chunk = chunk.lstrip("\n")
if not chunk:
continue
skip_leading_nl = False
state.accumulated_text += chunk
yield {"type": "text_chunk", "chunk": chunk}
state.text_streamed = True

elif event_type == "content_block_start":
block = getattr(event, "content_block", None)
if getattr(block, "type", None) == "tool_use":
if state.text_streamed:
yield {"type": "text_clear"}
state.text_streamed = False
state.accumulated_text = ""
state.thinking_done = False
skip_leading_nl = False

state.final_response = await stream.get_final_message()

async def _execute_tool_calls(
self,
tool_uses: list[Any],
) -> AsyncGenerator[dict[str, Any], None]:
"""Execute tool calls sequentially; yield ``tool_use`` / ``tool_resolve`` events.

After iterating, read ``self._tool_results`` for the list of
``tool_result`` dicts to append to the message history.
"""
self._tool_results: list[dict[str, Any]] = []
for tu in tool_uses:
ti = dict(tu.input) if isinstance(tu.input, dict) else {}
yield {"type": "tool_use", "tool": tu.name, "input": ti}
result = await asyncio.to_thread(execute_tool, tu.name, ti, self.bookstack)
if tu.name == "get_page":
try:
page_data = json.loads(result)
page_title = str(page_data.get("name") or "")
if page_title:
yield {
"type": "tool_resolve",
"page_id": ti.get("page_id"),
"page_title": page_title,
}
except (json.JSONDecodeError, KeyError, ValueError):
pass
self._tool_results.append(
{"type": "tool_result", "tool_use_id": tu.id, "content": result}
)

async def ask_stream(
self,
question: str,
history: MessageHistory | None = None,
) -> AsyncGenerator[dict[str, Any], None]:
"""Answer a question, yielding structured SSE events as they occur.

Uses the Anthropic streaming API so text tokens are forwarded to the
client as they are generated.
Handles thinking models (e.g. Qwen3) that embed reasoning inside
``</think>`` tags in the regular text stream. Text before ``</think>``
is silently discarded; text after it streams token-by-token. Models
that never emit ``</think>`` (Claude, GPT) have their text buffered
and emitted as fast chunks once the response is complete.

Event types (dict with ``type`` key):

- ``{"type": "text_chunk", "chunk": "<text>"}``
— incremental text token streamed in real time.
— incremental text token (post-think, or burst-emit for non-thinking
models).
- ``{"type": "text_clear"}``
the text streamed so far was reasoning/planning text that preceded
a tool call; the UI should discard it.
discard streamed text; only emitted if post-think text was already
streamed and a tool call follows.
- ``{"type": "tool_use", "tool": "<name>", "input": {...}}``
— emitted before each tool call.
- ``{"type": "answer", "text": "<markdown>", "history": [...]}``
— emitted once at the end confirming the complete answer and updated
history. The caller must persist ``history`` for the next turn.
— final answer; caller must persist ``history`` for the next turn.
- ``{"type": "error", "message": "<msg>"}``

Parameters
Expand All @@ -243,88 +357,44 @@ async def ask_stream(

try:
for _ in range(self.MAX_TURNS):
accumulated_text = ""
text_streamed = False
final_response: Any = None

async with self._async_client.messages.stream(
model=self.model,
max_tokens=8192,
system=SYSTEM_PROMPT,
tools=ALL_TOOLS,
messages=cast(list[MessageParam], messages),
) as stream:
async for event in stream:
event_type = getattr(event, "type", None)

if (
event_type == "content_block_delta"
and getattr(getattr(event, "delta", None), "type", None)
== "text_delta"
):
chunk: str = event.delta.text # type: ignore[union-attr]
accumulated_text += chunk
yield {"type": "text_chunk", "chunk": chunk}
text_streamed = True

elif event_type == "content_block_start":
block = getattr(event, "content_block", None)
if (
getattr(block, "type", None) == "tool_use"
and text_streamed
):
# Reasoning/planning text preceded this tool call — discard it
yield {"type": "text_clear"}
accumulated_text = ""
text_streamed = False

final_response = await stream.get_final_message()
state = _TurnState()
async for event in self._stream_llm_turn(messages, state):
yield event

final_response = state.final_response
tool_uses = [b for b in final_response.content if b.type == "tool_use"]

if not tool_uses:
# Final answer — text was already streamed via text_chunk events.
answer = accumulated_text.strip() or self._extract_text(
final_response
)
if state.thinking_done:
answer = state.accumulated_text.strip() or self._extract_text(
final_response
)
else:
# Non-thinking model: burst-emit buffer in small chunks.
raw = state.accumulated_text or self._extract_text(
final_response
)
answer = raw.strip()
chunk_size = 20
for i in range(0, len(answer), chunk_size):
yield {
"type": "text_chunk",
"chunk": answer[i : i + chunk_size],
}
await asyncio.sleep(0)
messages.append({"role": "assistant", "content": answer})
yield {"type": "answer", "text": answer, "history": messages}
return

# Tool-use turn: persist content and execute tools
messages.append(
{
"role": "assistant",
"content": self._content_from_response(final_response),
}
)

tool_results: list[dict[str, Any]] = []
for tu in tool_uses:
ti = dict(tu.input) if isinstance(tu.input, dict) else {}
yield {"type": "tool_use", "tool": tu.name, "input": ti}
result = await asyncio.to_thread(
execute_tool, tu.name, ti, self.bookstack
)
# For get_page, emit the resolved page title so the UI can
# display it instead of the raw numeric ID.
if tu.name == "get_page":
try:
page_data = json.loads(result)
page_title = str(page_data.get("name") or "")
if page_title:
yield {
"type": "tool_resolve",
"page_id": ti.get("page_id"),
"page_title": page_title,
}
except (json.JSONDecodeError, KeyError, ValueError):
pass
tool_results.append(
{"type": "tool_result", "tool_use_id": tu.id, "content": result}
)

messages.append({"role": "user", "content": tool_results})
async for event in self._execute_tool_calls(tool_uses):
yield event
messages.append({"role": "user", "content": self._tool_results})

yield {
"type": "error",
Expand Down
80 changes: 67 additions & 13 deletions tests/bookstack/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,19 +332,19 @@ async def test_stream_no_history_starts_fresh(
assert history[0] == {"role": "user", "content": "Fresh start?"}

@pytest.mark.asyncio
async def test_stream_text_clear_emitted_when_text_precedes_tool_call(
async def test_stream_thinking_text_buffered_not_emitted_before_tool_call(
self, agent: BookstackQAAgent
) -> None:
"""text_clear is emitted when reasoning text appears before a tool call."""
# Turn 1: reasoning text streamed, then a tool_use block starts
text_event = _make_text_delta_event("Let me search for that.")
"""Text before </think> is buffered silently — no text_chunk or text_clear emitted."""
# Qwen3 pattern: turn 1 has thinking text + </think> + (empty) + tool_use block
think_event = _make_text_delta_event("Let me search.\n</think>\n\n")
tool_start_event = _make_tool_use_block_start_event("search_bookstack")
tool_final = _make_sync_response(
[_make_tool_use_block("search_bookstack", "tu_1", {"query": "policy"})]
)
ctx1 = _make_stream_ctx([text_event, tool_start_event], tool_final)
ctx1 = _make_stream_ctx([think_event, tool_start_event], tool_final)

# Turn 2: actual answer
# Turn 2: actual answer (no thinking)
answer_final = _make_sync_response([_make_text_block("The policy says…")])
ctx2 = _make_stream_ctx([], answer_final)

Expand All @@ -359,11 +359,65 @@ async def test_stream_text_clear_emitted_when_text_precedes_tool_call(
events.append(evt)

types = [e["type"] for e in events]
# text_chunk from the reasoning text
assert types[0] == "text_chunk"
# text_clear follows to discard the reasoning text
assert "text_clear" in types
text_clear_idx = types.index("text_clear")
# tool_use comes after text_clear
assert "tool_use" in types[text_clear_idx:]
# Thinking text must never appear as a text_chunk
assert "text_clear" not in types, (
"no text was streamed so text_clear is unnecessary"
)
assert "tool_use" in types
assert types[-1] == "answer"

@pytest.mark.asyncio
async def test_stream_thinking_model_final_answer_streams_post_think(
self, agent: BookstackQAAgent
) -> None:
"""For a thinking model, only text after </think> is emitted as text_chunk."""
# </think> splits from thinking; answer follows in real time
think_chunk = _make_text_delta_event("Thinking...\n</think>\n\n")
answer_chunk = _make_text_delta_event("Paris.")
final_msg = _make_sync_response([_make_text_block("Paris.")])
ctx = _make_stream_ctx([think_chunk, answer_chunk], final_msg)
agent._async_client.messages.stream.return_value = ctx # type: ignore[attr-defined]

events = []
async for evt in agent.ask_stream("Capital of France?"):
events.append(evt)

types = [e["type"] for e in events]
chunks = [e["chunk"] for e in events if e["type"] == "text_chunk"]
# Only post-think text should appear
assert all("think" not in c.lower() for c in chunks)
assert "Paris." in chunks
assert types[-1] == "answer"
assert events[-1]["text"] == "Paris."

@pytest.mark.asyncio
async def test_stream_text_clear_emitted_when_post_think_text_precedes_tool_call(
self, agent: BookstackQAAgent
) -> None:
"""text_clear fires only when text *after* </think> was already streamed."""
# Model generates thinking + </think> + "Let me search." + tool_use
think_event = _make_text_delta_event("Thinking...\n</think>\n\n")
bridge_event = _make_text_delta_event("Let me search.")
tool_start = _make_tool_use_block_start_event("search_bookstack")
tool_final = _make_sync_response(
[_make_tool_use_block("search_bookstack", "tu_1", {"query": "policy"})]
)
ctx1 = _make_stream_ctx([think_event, bridge_event, tool_start], tool_final)

answer_final = _make_sync_response([_make_text_block("The policy says…")])
ctx2 = _make_stream_ctx([], answer_final)
agent._async_client.messages.stream.side_effect = [ctx1, ctx2] # type: ignore[attr-defined]

with patch(
"aieng_bot.bookstack.agent.execute_tool",
return_value=json.dumps({"data": [], "total": 0}),
):
events = []
async for evt in agent.ask_stream("What is the leave policy?"):
events.append(evt)

types = [e["type"] for e in events]
assert "text_chunk" in types # "Let me search." was streamed
assert "text_clear" in types # then discarded when tool_use detected
assert "tool_use" in types
assert types[-1] == "answer"