diff --git a/src/aieng_bot/bookstack/agent.py b/src/aieng_bot/bookstack/agent.py index 2126197..05ef081 100644 --- a/src/aieng_bot/bookstack/agent.py +++ b/src/aieng_bot/bookstack/agent.py @@ -21,6 +21,23 @@ # MessageParam[content] constraints; cast to list[MessageParam] at call sites. MessageHistory = list[Any] +# Thinking models (Qwen3, DeepSeek-R1, …) embed chain-of-thought reasoning in +# the text stream before this marker. Text before it is buffered silently; +# only what follows is forwarded as answer content. +_THINK_END = "" + + +class _TurnState: + """Mutable container for the results of one LLM streaming turn.""" + + __slots__ = ("accumulated_text", "thinking_done", "text_streamed", "final_response") + + def __init__(self) -> None: + self.accumulated_text: str = "" + self.thinking_done: bool = False + self.text_streamed: bool = False + self.final_response: Any = None + class BookstackQAAgent: """Answer questions from the BookStack wiki using Claude with tool use. @@ -201,6 +218,100 @@ def ask( # Async streaming path (API) # ------------------------------------------------------------------ + async def _stream_llm_turn( + self, + messages: MessageHistory, + state: _TurnState, + ) -> AsyncGenerator[dict[str, Any], None]: + """Run one LLM streaming call; yield ``text_chunk``/``text_clear`` events. + + Populates *state* with the accumulated answer text, whether ```` + was seen, and the final message object. + """ + skip_leading_nl = False + + async with self._async_client.messages.stream( + model=self.model, + max_tokens=8192, + system=SYSTEM_PROMPT, + tools=ALL_TOOLS, + messages=cast(list[MessageParam], messages), + ) as stream: + async for event in stream: + event_type = getattr(event, "type", None) + + if ( + event_type == "content_block_delta" + and getattr(getattr(event, "delta", None), "type", None) + == "text_delta" + ): + chunk: str = event.delta.text # type: ignore[union-attr] + if not state.thinking_done: + state.accumulated_text += chunk + idx = state.accumulated_text.find(_THINK_END) + if idx >= 0: + state.thinking_done = True + skip_leading_nl = True + post = state.accumulated_text[ + idx + len(_THINK_END) : + ].lstrip("\n") + state.accumulated_text = post + if post: + skip_leading_nl = False + yield {"type": "text_chunk", "chunk": post} + state.text_streamed = True + else: + if skip_leading_nl: + chunk = chunk.lstrip("\n") + if not chunk: + continue + skip_leading_nl = False + state.accumulated_text += chunk + yield {"type": "text_chunk", "chunk": chunk} + state.text_streamed = True + + elif event_type == "content_block_start": + block = getattr(event, "content_block", None) + if getattr(block, "type", None) == "tool_use": + if state.text_streamed: + yield {"type": "text_clear"} + state.text_streamed = False + state.accumulated_text = "" + state.thinking_done = False + skip_leading_nl = False + + state.final_response = await stream.get_final_message() + + async def _execute_tool_calls( + self, + tool_uses: list[Any], + ) -> AsyncGenerator[dict[str, Any], None]: + """Execute tool calls sequentially; yield ``tool_use`` / ``tool_resolve`` events. + + After iterating, read ``self._tool_results`` for the list of + ``tool_result`` dicts to append to the message history. + """ + self._tool_results: list[dict[str, Any]] = [] + for tu in tool_uses: + ti = dict(tu.input) if isinstance(tu.input, dict) else {} + yield {"type": "tool_use", "tool": tu.name, "input": ti} + result = await asyncio.to_thread(execute_tool, tu.name, ti, self.bookstack) + if tu.name == "get_page": + try: + page_data = json.loads(result) + page_title = str(page_data.get("name") or "") + if page_title: + yield { + "type": "tool_resolve", + "page_id": ti.get("page_id"), + "page_title": page_title, + } + except (json.JSONDecodeError, KeyError, ValueError): + pass + self._tool_results.append( + {"type": "tool_result", "tool_use_id": tu.id, "content": result} + ) + async def ask_stream( self, question: str, @@ -208,21 +319,24 @@ async def ask_stream( ) -> AsyncGenerator[dict[str, Any], None]: """Answer a question, yielding structured SSE events as they occur. - Uses the Anthropic streaming API so text tokens are forwarded to the - client as they are generated. + Handles thinking models (e.g. Qwen3) that embed reasoning inside + ```` tags in the regular text stream. Text before ```` + is silently discarded; text after it streams token-by-token. Models + that never emit ```` (Claude, GPT) have their text buffered + and emitted as fast chunks once the response is complete. Event types (dict with ``type`` key): - ``{"type": "text_chunk", "chunk": ""}`` - — incremental text token streamed in real time. + — incremental text token (post-think, or burst-emit for non-thinking + models). - ``{"type": "text_clear"}`` - — the text streamed so far was reasoning/planning text that preceded - a tool call; the UI should discard it. + — discard streamed text; only emitted if post-think text was already + streamed and a tool call follows. - ``{"type": "tool_use", "tool": "", "input": {...}}`` — emitted before each tool call. - ``{"type": "answer", "text": "", "history": [...]}`` - — emitted once at the end confirming the complete answer and updated - history. The caller must persist ``history`` for the next turn. + — final answer; caller must persist ``history`` for the next turn. - ``{"type": "error", "message": ""}`` Parameters @@ -243,88 +357,44 @@ async def ask_stream( try: for _ in range(self.MAX_TURNS): - accumulated_text = "" - text_streamed = False - final_response: Any = None - - async with self._async_client.messages.stream( - model=self.model, - max_tokens=8192, - system=SYSTEM_PROMPT, - tools=ALL_TOOLS, - messages=cast(list[MessageParam], messages), - ) as stream: - async for event in stream: - event_type = getattr(event, "type", None) - - if ( - event_type == "content_block_delta" - and getattr(getattr(event, "delta", None), "type", None) - == "text_delta" - ): - chunk: str = event.delta.text # type: ignore[union-attr] - accumulated_text += chunk - yield {"type": "text_chunk", "chunk": chunk} - text_streamed = True - - elif event_type == "content_block_start": - block = getattr(event, "content_block", None) - if ( - getattr(block, "type", None) == "tool_use" - and text_streamed - ): - # Reasoning/planning text preceded this tool call — discard it - yield {"type": "text_clear"} - accumulated_text = "" - text_streamed = False - - final_response = await stream.get_final_message() + state = _TurnState() + async for event in self._stream_llm_turn(messages, state): + yield event + final_response = state.final_response tool_uses = [b for b in final_response.content if b.type == "tool_use"] if not tool_uses: - # Final answer — text was already streamed via text_chunk events. - answer = accumulated_text.strip() or self._extract_text( - final_response - ) + if state.thinking_done: + answer = state.accumulated_text.strip() or self._extract_text( + final_response + ) + else: + # Non-thinking model: burst-emit buffer in small chunks. + raw = state.accumulated_text or self._extract_text( + final_response + ) + answer = raw.strip() + chunk_size = 20 + for i in range(0, len(answer), chunk_size): + yield { + "type": "text_chunk", + "chunk": answer[i : i + chunk_size], + } + await asyncio.sleep(0) messages.append({"role": "assistant", "content": answer}) yield {"type": "answer", "text": answer, "history": messages} return - # Tool-use turn: persist content and execute tools messages.append( { "role": "assistant", "content": self._content_from_response(final_response), } ) - - tool_results: list[dict[str, Any]] = [] - for tu in tool_uses: - ti = dict(tu.input) if isinstance(tu.input, dict) else {} - yield {"type": "tool_use", "tool": tu.name, "input": ti} - result = await asyncio.to_thread( - execute_tool, tu.name, ti, self.bookstack - ) - # For get_page, emit the resolved page title so the UI can - # display it instead of the raw numeric ID. - if tu.name == "get_page": - try: - page_data = json.loads(result) - page_title = str(page_data.get("name") or "") - if page_title: - yield { - "type": "tool_resolve", - "page_id": ti.get("page_id"), - "page_title": page_title, - } - except (json.JSONDecodeError, KeyError, ValueError): - pass - tool_results.append( - {"type": "tool_result", "tool_use_id": tu.id, "content": result} - ) - - messages.append({"role": "user", "content": tool_results}) + async for event in self._execute_tool_calls(tool_uses): + yield event + messages.append({"role": "user", "content": self._tool_results}) yield { "type": "error", diff --git a/tests/bookstack/test_agent.py b/tests/bookstack/test_agent.py index 5fe1588..ae6400a 100644 --- a/tests/bookstack/test_agent.py +++ b/tests/bookstack/test_agent.py @@ -332,19 +332,19 @@ async def test_stream_no_history_starts_fresh( assert history[0] == {"role": "user", "content": "Fresh start?"} @pytest.mark.asyncio - async def test_stream_text_clear_emitted_when_text_precedes_tool_call( + async def test_stream_thinking_text_buffered_not_emitted_before_tool_call( self, agent: BookstackQAAgent ) -> None: - """text_clear is emitted when reasoning text appears before a tool call.""" - # Turn 1: reasoning text streamed, then a tool_use block starts - text_event = _make_text_delta_event("Let me search for that.") + """Text before is buffered silently — no text_chunk or text_clear emitted.""" + # Qwen3 pattern: turn 1 has thinking text + + (empty) + tool_use block + think_event = _make_text_delta_event("Let me search.\n\n\n") tool_start_event = _make_tool_use_block_start_event("search_bookstack") tool_final = _make_sync_response( [_make_tool_use_block("search_bookstack", "tu_1", {"query": "policy"})] ) - ctx1 = _make_stream_ctx([text_event, tool_start_event], tool_final) + ctx1 = _make_stream_ctx([think_event, tool_start_event], tool_final) - # Turn 2: actual answer + # Turn 2: actual answer (no thinking) answer_final = _make_sync_response([_make_text_block("The policy says…")]) ctx2 = _make_stream_ctx([], answer_final) @@ -359,11 +359,65 @@ async def test_stream_text_clear_emitted_when_text_precedes_tool_call( events.append(evt) types = [e["type"] for e in events] - # text_chunk from the reasoning text - assert types[0] == "text_chunk" - # text_clear follows to discard the reasoning text - assert "text_clear" in types - text_clear_idx = types.index("text_clear") - # tool_use comes after text_clear - assert "tool_use" in types[text_clear_idx:] + # Thinking text must never appear as a text_chunk + assert "text_clear" not in types, ( + "no text was streamed so text_clear is unnecessary" + ) + assert "tool_use" in types + assert types[-1] == "answer" + + @pytest.mark.asyncio + async def test_stream_thinking_model_final_answer_streams_post_think( + self, agent: BookstackQAAgent + ) -> None: + """For a thinking model, only text after is emitted as text_chunk.""" + # splits from thinking; answer follows in real time + think_chunk = _make_text_delta_event("Thinking...\n\n\n") + answer_chunk = _make_text_delta_event("Paris.") + final_msg = _make_sync_response([_make_text_block("Paris.")]) + ctx = _make_stream_ctx([think_chunk, answer_chunk], final_msg) + agent._async_client.messages.stream.return_value = ctx # type: ignore[attr-defined] + + events = [] + async for evt in agent.ask_stream("Capital of France?"): + events.append(evt) + + types = [e["type"] for e in events] + chunks = [e["chunk"] for e in events if e["type"] == "text_chunk"] + # Only post-think text should appear + assert all("think" not in c.lower() for c in chunks) + assert "Paris." in chunks + assert types[-1] == "answer" + assert events[-1]["text"] == "Paris." + + @pytest.mark.asyncio + async def test_stream_text_clear_emitted_when_post_think_text_precedes_tool_call( + self, agent: BookstackQAAgent + ) -> None: + """text_clear fires only when text *after* was already streamed.""" + # Model generates thinking + + "Let me search." + tool_use + think_event = _make_text_delta_event("Thinking...\n\n\n") + bridge_event = _make_text_delta_event("Let me search.") + tool_start = _make_tool_use_block_start_event("search_bookstack") + tool_final = _make_sync_response( + [_make_tool_use_block("search_bookstack", "tu_1", {"query": "policy"})] + ) + ctx1 = _make_stream_ctx([think_event, bridge_event, tool_start], tool_final) + + answer_final = _make_sync_response([_make_text_block("The policy says…")]) + ctx2 = _make_stream_ctx([], answer_final) + agent._async_client.messages.stream.side_effect = [ctx1, ctx2] # type: ignore[attr-defined] + + with patch( + "aieng_bot.bookstack.agent.execute_tool", + return_value=json.dumps({"data": [], "total": 0}), + ): + events = [] + async for evt in agent.ask_stream("What is the leave policy?"): + events.append(evt) + + types = [e["type"] for e in events] + assert "text_chunk" in types # "Let me search." was streamed + assert "text_clear" in types # then discarded when tool_use detected + assert "tool_use" in types assert types[-1] == "answer"