diff --git a/.gitignore b/.gitignore index 7e12f67569..dbb5d6b601 100644 --- a/.gitignore +++ b/.gitignore @@ -84,3 +84,4 @@ htmlcov/ # Memory2 autorecord recording*.db +MUJOCO_LOG.TXT diff --git a/dimos/agents/compaction_middleware.py b/dimos/agents/compaction_middleware.py new file mode 100644 index 0000000000..45a006ac8a --- /dev/null +++ b/dimos/agents/compaction_middleware.py @@ -0,0 +1,564 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""LangChain middleware that caps the input the agent sends to its model. + +The agent's `_history` would otherwise grow unbounded. This middleware runs +`before_model` on every LLM call and, if the projected input exceeds +`threshold_tokens`, compacts in two stages: + + 1. Drop image content blocks (replace with a small text placeholder). + 2. If still over `target_tokens`, summarize the older messages into a single + `SystemMessage`, keeping the most-recent tail verbatim. + +The leading `SystemMessage` (the agent's identity, set via +`create_agent(system_prompt=...)`) is preserved verbatim. Mid-list untagged +messages and prior compaction summaries are eligible for summarization. + +Token counting is a pessmistic approximation for now (3 chars/token, 1000 tokens/image) and +is memoized in `msg.additional_kwargs["dimos_tokens"]` so subsequent turns only +pay for newly-added messages. + +Design — why a middleware, and why `before_model` +------------------------------------------------- + +Compaction is an *invariant* of the prompt the LLM sees, not a feature of any +particular caller. Two consequences shape the design: + +1. **Middleware** + Putting compaction inside the state graph (via `create_agent(middleware=...)`) + means *every* path into the model is bounded — current callers, future + callers, alternate agents — without each one having to remember to call a + compact helper. The contract is enforced via a hook that is always called. + + It also handles a subtlety in the langgraph agent loop: a single user turn + can invoke the model multiple times (model → tool call → tool result → + model again → …). External pre-processing on `_history` would only run + once per user turn, leaving every intra-turn re-invocation unprotected. + The middleware fires *before each* model call, so the size bound is a true + invariant of the loop, not a "checked at user-message boundaries" property. + +2. **`before_model`, not `after_model` or `wrap_model_call`.** + `before_model` is the minimal-intervention hook that lets us transform the + state the model is about to receive. `after_model` runs too late — the + model has already been called and may have errored on context overflow. + `wrap_model_call` could work but means owning the entire request/response, + which conflates compaction with model-call concerns (retries, error + shaping, tool dispatch). `before_model` keeps the responsibility narrow: + adjust state in, return; everything else stays the agent loop's job. + +The current turn is treated as sacred (see `_current_turn_start`): even when +over threshold we never touch the latest `dimos_turn` group, because that's +the in-flight user query plus any tool calls/responses still being resolved. +Compressing those would either confuse the model mid-step or strip the very +context the user is asking about right now. + +Tool-call coherence is the harness's responsibility +--------------------------------------------------- + +The middleware never *introduces* orphan tool calls: `_split` aligns cuts to +`dimos_turn` boundaries, so an `AIMessage(tool_calls=...)` and its matching +`ToolMessage` — both stamped with the same turn — always travel together +into either the summary or the kept tail. But the middleware doesn't *fix* +orphans it inherits either. If the harness appends an +`AIMessage(tool_calls=...)` without its corresponding `ToolMessage`, the +orphan is passed through verbatim when it lives in the current turn (and the +LLM call will typically raise on the malformed conversation). The middleware +surfaces the issue; it doesn't paper over it. Proper turn-ordering on append +is the caller's job. + +Known limitations +----------------- + +1. **Summarizer context overflow on huge transcripts.** The transcript fed to + the summarizer can be arbitrarily large — bounded only by what + `before_model` decides to summarize. On the very first compaction event of + a long-running session, the transcript could exceed the summarizer model's + own context window, raising a provider error that `@retry` will dutifully + re-issue twice before propagating. Mitigation when it happens in practice: + either pre-truncate the transcript here, or chunk-and-fold-summarize + iteratively. Out of scope for the current placeholder-tokenizer phase. + +2. **`@retry(on_exception=Exception)` is intentionally broad.** Because the + summarizer is duck-typed (`Any`), we don't know which provider-specific + exception classes (httpx, openai, anthropic, …) signal transient vs. + permanent failures. Catching `Exception` means a permanent error (bad API + key, invalid schema, programming bug) costs up to 3 attempts + 1s of + sleeps before propagating. Acceptable trade-off vs. coupling the + middleware to a specific provider SDK; narrow the exception list if you + pin to one. +""" + +from __future__ import annotations + +import json +from typing import Any, cast + +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + RemoveMessage, + SystemMessage, +) +from langgraph.graph.message import REMOVE_ALL_MESSAGES + +from dimos.utils.decorators.decorators import retry +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# These are magic numbers, but they are used as pessimistic estimations of +# prompt size. It's better to be more cautious. +CHARS_PER_TOKEN = 3 +TOKENS_PER_IMAGE = 1000 +IMAGE_PLACEHOLDER = "[image removed during compaction]" + +DEFAULT_SUMMARY_PROMPT = """\ +You are compacting a conversation between a user and an AI agent controlling a robot. +Write a single concise paragraph in this exact style: + + "User asked X. Agent did A, B, C. User asked Y. Agent did P, Q." + +Rules: + - Preserve user goals, decisions made, tool calls (with results), and current state. + - Drop pleasantries, intermediate reasoning, and repeated content. + - Aim for at most {summary_size_tokens} tokens worth of text. + - Output ONLY the summary. No preamble, no headings. + +TRANSCRIPT: +{transcript} +""" + + +def count_tokens(text: str) -> int: + """Approximate token count for a string. Pessimistic heuristic: ceil(len/3).""" + if not text: + return 0 + return max(1, (len(text) + CHARS_PER_TOKEN - 1) // CHARS_PER_TOKEN) + + +def count_image_tokens() -> int: + """Cost of one image content block. Pessimistic heuristic.""" + return TOKENS_PER_IMAGE + + +def _count_content(content: Any) -> int: + """Count tokens for a `content` value (str or list of content blocks).""" + if isinstance(content, str): + return count_tokens(content) + if isinstance(content, list): + total = 0 + for block in content: + if isinstance(block, dict): + btype = block.get("type") + if btype == "text": + total += count_tokens(block.get("text", "")) + elif btype in ("image_url", "image"): + total += count_image_tokens() + else: + total += count_tokens(str(block)) + else: + total += count_tokens(str(block)) + return total + return count_tokens(str(content)) + + +def count_message_tokens(msg: BaseMessage) -> int: + """Token count for a message, memoized into `additional_kwargs["dimos_tokens"]`.""" + kwargs = getattr(msg, "additional_kwargs", None) + if isinstance(kwargs, dict): + cached = kwargs.get("dimos_tokens") + if isinstance(cached, int): + return cached + + total = _count_content(msg.content) + # AIMessage tool_calls are sent as JSON; count their serialized form. + tool_calls = getattr(msg, "tool_calls", None) + if tool_calls: + total += count_tokens(json.dumps(tool_calls, default=str)) + + if isinstance(kwargs, dict): + kwargs["dimos_tokens"] = total + return total + + +def _has_image(content: Any) -> bool: + if not isinstance(content, list): + return False + return any(isinstance(b, dict) and b.get("type") in ("image_url", "image") for b in content) + + +class DimosCompactionMiddleware(AgentMiddleware): # type: ignore[misc] + """`before_model` hook that compacts message history to a token budget. + + See module docstring for the algorithm. + """ + + def __init__( + self, + summarizer: Any, + *, + threshold_tokens: int, + target_tokens: int, + summary_size_tokens: int = 500, + system_prompt: str | None = None, + tool_schemas: list[Any] | None = None, + summary_prompt_template: str = DEFAULT_SUMMARY_PROMPT, + ) -> None: + """`summarizer` must duck-type to a langchain chat model: + `.invoke(messages)` returning an object with a `.content` str attribute. + """ + if target_tokens >= threshold_tokens: + raise ValueError("target_tokens must be < threshold_tokens") + if summary_size_tokens >= target_tokens: + raise ValueError("summary_size_tokens must be < target_tokens") + + # Hard-cap the summarizer's output to `summary_size_tokens` (provider tokens, + # measured slightly differently than our placeholder; close enough as a hard + # upper bound). `.bind()` is on Runnable so works for all langchain chat models; + # fake models accept and ignore the kwarg. + try: + self._summarizer = summarizer.bind(max_tokens=summary_size_tokens) + except Exception: + self._summarizer = summarizer + self._threshold = threshold_tokens + self._target = target_tokens + self._summary_size = summary_size_tokens + self._summary_prompt_template = summary_prompt_template + + self._system_prompt = system_prompt + self._tool_schemas = tool_schemas or [] + # Static-token cache: invalidated when (prompt, schemas) hash changes. + self._static_cache: tuple[int, int] | None = None + + # -- public -- + + def before_model(self, state: Any, runtime: Any) -> dict[str, Any] | None: + messages: list[BaseMessage] = list(state.get("messages") or []) + if not messages: + return None + + # The CURRENT TURN should remain untouched. Find its boundary using the last message's + # dimos_turn tag and protect everything from there to the end. + # This preserves in-flight context: the latest user query, any in-progress tool calls / tool responses, + # fresh images from perception, etc. + current_start = self._current_turn_start(messages) + compactable = messages[:current_start] + current_turn = messages[current_start:] + + static = self._static_tokens() + compactable_tokens = sum(count_message_tokens(m) for m in compactable) + current_turn_tokens = sum(count_message_tokens(m) for m in current_turn) + total = static + compactable_tokens + current_turn_tokens + + if total <= self._threshold: + return None + + if not compactable: + logger.warning( + "Compaction over threshold but everything is in the current turn; " + "passing through. Check compaction settings.", + total_tokens=total, + ) + return None + + # Stage 1: strip images ONLY in the compactable region. + stripped = self._strip_images(compactable) + stripped_tokens = sum(count_message_tokens(m) for m in stripped) + total_after_strip = static + stripped_tokens + current_turn_tokens + if total_after_strip <= self._target: + logger.info( + "Compaction fired (image-strip).", + tokens_before=total, + tokens_after=total_after_strip, + ) + return { + "messages": [ + RemoveMessage(id=REMOVE_ALL_MESSAGES), + *stripped, + *current_turn, + ] + } + + # Stage 2: split compactable into protected / to_summarize / keep_tail. + # Budget for the keep_tail accounts for the fixed cost of the current turn. + budget = max( + 0, + self._target - self._summary_size - current_turn_tokens, + ) + protected, to_summarize, keep = self._split(stripped, budget=budget) + if not to_summarize: + logger.warning( + "Compaction over threshold but nothing eligible to summarize; passing through.", + total_tokens=total_after_strip, + ) + return None + + summary_text = self._summarize(to_summarize) + summary_msg = self._build_summary_message(summary_text, to_summarize) + summary_tokens = count_message_tokens(summary_msg) + final_total = ( + static + + sum(count_message_tokens(m) for m in protected) + + summary_tokens + + sum(count_message_tokens(m) for m in keep) + + current_turn_tokens + ) + logger.info( + "Compaction fired (summarize).", + tokens_before=total, + tokens_after=final_total, + summarized_messages=len(to_summarize), + ) + return { + "messages": [ + RemoveMessage(id=REMOVE_ALL_MESSAGES), + *protected, + summary_msg, + *keep, + *current_turn, + ] + } + + # -- internals -- + + def _total_tokens(self, messages: list[BaseMessage]) -> int: + return self._static_tokens() + sum(count_message_tokens(m) for m in messages) + + def _static_tokens(self) -> int: + """Tokens for the system prompt + tool schemas. + + Computed once and cached forever — both inputs are bound at `__init__` + and never mutate, so there's no need to recompute (or even rehash) on + subsequent calls. + """ + if self._static_cache is not None: + return self._static_cache[1] + total = count_tokens(self._system_prompt or "") + if self._tool_schemas: + total += count_tokens(json.dumps(self._tool_schemas, default=str)) + self._static_cache = (0, total) # sentinel key; payload is immutable + return total + + def _strip_images(self, messages: list[BaseMessage]) -> list[BaseMessage]: + """Return a new list where image content blocks are replaced with text. + + Messages without images are reused by reference. Messages with images + are reconstructed via `model_copy` so every other field is preserved — + `id`, `name`, `tool_calls`, `tool_call_id`, `response_metadata`, etc. + (Plain `m.__class__(content=..., additional_kwargs=...)` would drop + them, which silently breaks AIMessages-with-tool_calls and ToolMessages + — the latter requires `tool_call_id` and would refuse to construct.) + """ + out: list[BaseMessage] = [] + for m in messages: + if not _has_image(m.content): + out.append(m) + continue + + new_blocks: list[Any] = [] + for block in m.content: # type: ignore[union-attr] + if isinstance(block, dict) and block.get("type") in ( + "image_url", + "image", + ): + new_blocks.append({"type": "text", "text": IMAGE_PLACEHOLDER}) + else: + new_blocks.append(block) + + new_kwargs = dict(m.additional_kwargs or {}) + new_kwargs.pop("dimos_tokens", None) + new_msg = m.model_copy(update={"content": new_blocks, "additional_kwargs": new_kwargs}) + out.append(new_msg) + return out + + def _current_turn_start(self, messages: list[BaseMessage]) -> int: + """Return the index where the 'current turn' begins. + + The current turn is the contiguous suffix of `messages` that all share + the highest `dimos_turn` value (plus any trailing untagged messages, + which are typically in-flight tool calls / responses not yet stamped). + Everything from this index to the end is preserved verbatim by the + middleware: no image stripping, no summarization. + + `dimos_turn` values are monotonically increasing within a history + (`McpClient._process_message` increments `self._turn` once per + user-facing turn before tagging). So we find the boundary in a single + backward pass: the first tagged message we see *is* the max turn, + and the first subsequent message tagged with a lower value marks + where the current-turn group ends. + + Untagged-history fallback: when no message carries a `dimos_turn` + tag at all (a caller wired the middleware in without going through + McpClient), we anchor on the latest `HumanMessage`. If there's no + `HumanMessage` either, we warn and fall through to full-list + compaction; `_split` still guarantees at least one message survives + in the kept tail. + """ + max_turn: int | None = None + for i in range(len(messages) - 1, -1, -1): + t = (messages[i].additional_kwargs or {}).get("dimos_turn") + if isinstance(t, int): + if max_turn is None: + max_turn = t + elif t < max_turn: + return i + 1 + # Untagged messages: in-flight in the current turn, keep walking. + + if max_turn is None: + # No tags anywhere. Anchor on the latest HumanMessage. + for i in range(len(messages) - 1, -1, -1): + if isinstance(messages[i], HumanMessage): + return i + logger.warning( + "Compaction: no `dimos_turn` tags and no HumanMessage found; " + "treating the entire history as compactable. " + "This should not happen — check whether the agent harness is " + "tagging messages or producing well-formed conversations.", + n_messages=len(messages), + ) + return len(messages) + + # All tagged messages share max_turn (no lower-turn boundary found). + return 0 + + def _split( + self, messages: list[BaseMessage], *, budget: int + ) -> tuple[list[BaseMessage], list[BaseMessage], list[BaseMessage]]: + """Partition messages into (protected_prefix, to_summarize, keep_tail). + + - protected_prefix: leading SystemMessages WITHOUT additional_kwargs[ + "dimos_compacted"]=True. Preserved verbatim. + - keep_tail: built back-to-front until adding the next message would + exceed (budget - protected_tokens). If the cut would split a + `dimos_turn` group, push it older so the entire turn falls on one side. + - to_summarize: everything in between. + """ + protected: list[BaseMessage] = [] + rest_start = 0 + for i, m in enumerate(messages): + if isinstance(m, SystemMessage) and not (m.additional_kwargs or {}).get( + "dimos_compacted" + ): + protected.append(m) + rest_start = i + 1 + else: + break + rest = messages[rest_start:] + + protected_tokens = sum(count_message_tokens(m) for m in protected) + budget = max(0, budget - protected_tokens) + + # Build keep_tail from the end, walking older until budget exhausted. + keep: list[BaseMessage] = [] + used = 0 + keep_start_idx = len(rest) # rest[keep_start_idx:] is what we keep + for i in range(len(rest) - 1, -1, -1): + m_tokens = count_message_tokens(rest[i]) + if used + m_tokens > budget and keep: + # Adding this would overflow; stop. (Always keep at least one + # message even if a single message exceeds budget, so the agent + # gets the latest user input.) + break + keep.append(rest[i]) + used += m_tokens + keep_start_idx = i + keep.reverse() + + # Align the cut to a dimos_turn boundary so tagged tool_call/tool_response + # pairs aren't split. The kept tail starts at keep_start_idx; if the + # message there is tagged AND a message just before it shares the same + # turn, push the cut older to include all messages of that turn. + if 0 < keep_start_idx < len(rest): + border_turn = (rest[keep_start_idx].additional_kwargs or {}).get("dimos_turn") + if border_turn is not None: + while keep_start_idx > 0: + prev_turn = (rest[keep_start_idx - 1].additional_kwargs or {}).get("dimos_turn") + if prev_turn != border_turn: + break + keep_start_idx -= 1 + keep = rest[keep_start_idx:] + + to_summarize = rest[:keep_start_idx] + return protected, to_summarize, keep + + def _summarize(self, messages: list[BaseMessage]) -> str: + transcript = _render_transcript(messages) + prompt = self._summary_prompt_template.format( + transcript=transcript, summary_size_tokens=self._summary_size + ) + # `cast` because @retry erases the return type to Any. + return cast("str", self._invoke_summarizer(prompt)) + + @retry(max_retries=2, on_exception=Exception, delay=0.5) # type: ignore[untyped-decorator] + def _invoke_summarizer(self, prompt: str) -> str: + """LLM call, isolated for retry. Raises on final failure (propagates).""" + response = self._summarizer.invoke([HumanMessage(content=prompt)]) + text = getattr(response, "content", None) + if isinstance(text, str): + stripped = text.strip() + if stripped: + return stripped + # Some fakes return raw strings or empty content; coerce. + if isinstance(text, list): + joined = " ".join(b.get("text", "") for b in text if isinstance(b, dict)) + if joined.strip(): + return joined.strip() + raise RuntimeError(f"Summarizer returned empty content: {response!r}") + + def _build_summary_message( + self, summary_text: str, summarized: list[BaseMessage] + ) -> SystemMessage: + max_turn: int | None = None + for m in summarized: + t = (m.additional_kwargs or {}).get("dimos_turn") + if isinstance(t, int) and (max_turn is None or t > max_turn): + max_turn = t + + kw: dict[str, Any] = { + "dimos_compacted": True, + "dimos_covers_count": len(summarized), + } + if max_turn is not None: + kw["dimos_turn"] = max_turn + + return SystemMessage( + content=f"[Prior conversation summary]\n{summary_text}", + additional_kwargs=kw, + ) + + +def _stringify_content(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, dict): + if block.get("type") == "text": + parts.append(block.get("text", "")) + else: + parts.append(f"<{block.get('type', 'unknown')}>") + else: + parts.append(str(block)) + return " ".join(parts) + return str(content) + + +def _render_transcript(messages: list[BaseMessage]) -> str: + lines: list[str] = [] + for m in messages: + role = type(m).__name__.replace("Message", "").lower() + turn = (m.additional_kwargs or {}).get("dimos_turn", "?") + content = _stringify_content(m.content) + if isinstance(m, AIMessage): + tool_calls = getattr(m, "tool_calls", None) or [] + if tool_calls: + tc_summary = "; ".join(f"{tc.get('name')}({tc.get('args')})" for tc in tool_calls) + content = (content + f"\n tool_calls: {tc_summary}").strip() + lines.append(f"[turn {turn} {role}] {content}") + return "\n".join(lines) diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index 75b532e9cc..e8e57b508f 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -13,6 +13,7 @@ # limitations under the License. from collections.abc import Callable +import os from queue import Empty, Queue from threading import Event, RLock, Thread import time @@ -21,12 +22,16 @@ import httpx from langchain.agents import create_agent -from langchain_core.messages import HumanMessage +from langchain.chat_models import init_chat_model +from langchain_core.messages import HumanMessage, RemoveMessage from langchain_core.messages.base import BaseMessage from langchain_core.tools import StructuredTool +from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.graph.state import CompiledStateGraph +from pydantic import Field from reactivex.disposable import Disposable +from dimos.agents.compaction_middleware import DimosCompactionMiddleware from dimos.agents.mcp import tool_stream from dimos.agents.system_prompt import SYSTEM_PROMPT from dimos.agents.utils import pretty_print_langchain_message @@ -41,12 +46,40 @@ logger = setup_logger() +def _env_int(name: str) -> int | None: + v = os.environ.get(name) + if not v: + return None + try: + return int(v) + except ValueError: + raise ValueError(f"Environment variable {name!r} must be an integer, got {v!r}") from None + + +def _env_str(name: str) -> str | None: + return os.environ.get(name) or None + + class McpClientConfig(ModuleConfig): system_prompt: str | None = SYSTEM_PROMPT model: str = "gpt-4o" model_fixture: str | None = None mcp_server_url: str = "http://localhost:9990/mcp" + # Compaction: env-driven, agent-scoped. On by default. + agent_compaction_threshold: int = Field( + default_factory=lambda: _env_int("AGENT_COMPACTION_THRESHOLD") or 40_000 + ) + agent_compaction_target: int = Field( + default_factory=lambda: _env_int("AGENT_COMPACTION_TARGET") or 3_000 + ) + agent_compaction_summary_size: int = Field( + default_factory=lambda: _env_int("AGENT_COMPACTION_SUMMARY_SIZE") or 1_000 + ) + agent_compaction_model: str | None = Field( + default_factory=lambda: _env_str("AGENT_COMPACTION_MODEL") + ) + class McpClient(Module): config: McpClientConfig @@ -64,6 +97,7 @@ class McpClient(Module): _http_client: httpx.Client _seq_ids: SequentialIds _tool_stream_cleanup: Callable[[], None] | None + _turn: int def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -81,6 +115,7 @@ def __init__(self, **kwargs: Any) -> None: self._http_client = httpx.Client(timeout=120.0) self._seq_ids = SequentialIds() self._tool_stream_cleanup = None + self._turn = 0 def __reduce__(self) -> Any: return (self.__class__, (), {}) @@ -218,11 +253,49 @@ def on_system_modules(self, _modules: list[RPCClient]) -> None: model = MockModel(json_path=self.config.model_fixture) + middleware: list[Any] = [] + if self.config.agent_compaction_threshold and self.config.agent_compaction_target: + if self.config.agent_compaction_model: + summarizer = init_chat_model(self.config.agent_compaction_model) + elif isinstance(model, str): + # `create_agent` accepts a model-name string and coerces internally, + # but the middleware needs an actual ChatModel object. + summarizer = init_chat_model(model) + else: + summarizer = model + + middleware.append( + DimosCompactionMiddleware( + summarizer=summarizer, + threshold_tokens=self.config.agent_compaction_threshold, + target_tokens=self.config.agent_compaction_target, + summary_size_tokens=self.config.agent_compaction_summary_size, + system_prompt=self.config.system_prompt, + # Pass JSON schemas (dicts), not pydantic class objects — + # otherwise json.dumps inside the middleware falls back to + # str() and produces a useless tiny string, leading to + # massive undercount of tool-definition tokens. + tool_schemas=[ + t.args_schema.model_json_schema() + for t in tools + if t.args_schema is not None and hasattr(t.args_schema, "model_json_schema") + ], + ) + ) + logger.info( + "Compaction middleware enabled.", + threshold=self.config.agent_compaction_threshold, + target=self.config.agent_compaction_target, + summary_size=self.config.agent_compaction_summary_size, + summarizer_model=self.config.agent_compaction_model or "(reuse agent)", + ) + with self._lock: self._state_graph = create_agent( model=model, tools=tools, system_prompt=self.config.system_prompt, + middleware=middleware, ) if not self._thread.is_alive(): self._thread.start() @@ -315,25 +388,89 @@ def _thread_loop(self) -> None: raise ValueError("No state graph initialized") self._process_message(self._state_graph, message) + def _apply_messages_update(self, node_messages: list[BaseMessage], turn: int) -> None: + """Merge a node's emitted messages into `self._history`, mirroring the + `add_messages` reducer langgraph uses internally. + + Honors `RemoveMessage(id=REMOVE_ALL_MESSAGES)` as "wipe history and use + what came after" so compaction-middleware replacements don't accrete + in our local history. Specific-id RemoveMessages prune matching entries. + Already-tagged messages (re-emitted by middleware) keep their tags; + new messages get the current turn id. + + Publish discipline: a message is printed and published on the `agent` + stream at most once per session. When compaction replays previously-seen + messages alongside a fresh summary, we publish only the genuinely-new + objects (identified by Python identity against the pre-wipe history), + so downstream subscribers don't see duplicates. + """ + wipe_idx: int | None = None + for i, m in enumerate(node_messages): + if isinstance(m, RemoveMessage) and m.id == REMOVE_ALL_MESSAGES: + wipe_idx = i + + if wipe_idx is not None: + pre_wipe_obj_ids = {id(h) for h in self._history} + self._history = [] + iter_msgs = node_messages[wipe_idx + 1 :] + is_replay = True + else: + pre_wipe_obj_ids = set() + iter_msgs = node_messages + is_replay = False + + for msg in iter_msgs: + if isinstance(msg, RemoveMessage): + # Specific-id removal: drop matching from history. + self._history = [h for h in self._history if getattr(h, "id", None) != msg.id] + continue + if not is_replay: + _tag_turn(msg, turn) + self._history.append(msg) + # Skip publish for messages already shown before a compaction wipe. + # The middleware emits its replacement as + # `[RemoveMessage, *protected, summary, *keep_tail, *current_turn]`; + # only `summary` (and any fresh AIMessages from later nodes in the + # same stream) are new — the rest are the same Python objects that + # were already published when they first arrived. + if is_replay and id(msg) in pre_wipe_obj_ids: + continue + pretty_print_langchain_message(msg) + self.agent.publish(msg) + def _process_message( self, state_graph: CompiledStateGraph[Any, Any, Any, Any], message: BaseMessage ) -> None: self.agent_idle.publish(False) + self._turn += 1 + turn = self._turn + _tag_turn(message, turn) self._history.append(message) pretty_print_langchain_message(message) self.agent.publish(message) for update in state_graph.stream({"messages": self._history}, stream_mode="updates"): for node_output in update.values(): - for msg in node_output.get("messages", []): - self._history.append(msg) - pretty_print_langchain_message(msg) - self.agent.publish(msg) + # Middleware hooks (e.g. compaction's before_model) may emit + # updates whose value is None when they made no change. + if not isinstance(node_output, dict): + continue + self._apply_messages_update(node_output.get("messages") or [], turn) if self._message_queue.empty(): self.agent_idle.publish(True) +def _tag_turn(message: BaseMessage, turn: int) -> None: + """Stamp a turn id into the message's additional_kwargs. + + Used by prompt-compaction to group/score messages by the turn that produced them. + """ + kwargs = getattr(message, "additional_kwargs", None) + if isinstance(kwargs, dict): + kwargs["dimos_turn"] = turn + + def _append_image_to_history( mcp_client: McpClient, func_name: str, uuid_: str, result: Any ) -> None: diff --git a/dimos/agents/test_compaction_middleware.py b/dimos/agents/test_compaction_middleware.py new file mode 100644 index 0000000000..4154530656 --- /dev/null +++ b/dimos/agents/test_compaction_middleware.py @@ -0,0 +1,649 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Tests for `DimosCompactionMiddleware`. + +Hermetic — uses langchain's `FakeListChatModel` / `FakeMessagesListChatModel` +(subclassed to record inputs) so no API key is ever needed. Covers both unit +tests of `before_model` and full-loop integration tests where the middleware +runs inside a real `create_agent` graph. +""" + +from __future__ import annotations + +from typing import Any, cast + +from langchain.agents import create_agent +from langchain_core.language_models.fake_chat_models import ( + FakeListChatModel, + FakeMessagesListChatModel, +) +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + RemoveMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.tools import tool +from langgraph.graph.message import REMOVE_ALL_MESSAGES +from pydantic import Field +import pytest + +from dimos.agents.compaction_middleware import ( + CHARS_PER_TOKEN, + IMAGE_PLACEHOLDER, + TOKENS_PER_IMAGE, + DimosCompactionMiddleware, + count_image_tokens, + count_message_tokens, + count_tokens, +) +from dimos.agents.mcp.mcp_client import _tag_turn + + +def make_human(text: str, turn: int) -> HumanMessage: + m = HumanMessage(content=text) + _tag_turn(m, turn) + return m + + +def make_ai(text: str, turn: int, *, tool_calls: list[dict[str, Any]] | None = None) -> AIMessage: + m = AIMessage(content=text, tool_calls=tool_calls or []) + _tag_turn(m, turn) + return m + + +def make_tool(text: str, tool_call_id: str, turn: int) -> ToolMessage: + m = ToolMessage(content=text, tool_call_id=tool_call_id) + _tag_turn(m, turn) + return m + + +def make_image_human(turn: int) -> HumanMessage: + m = HumanMessage( + content=[ + {"type": "text", "text": "see this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}, + ] + ) + _tag_turn(m, turn) + return m + + +def build_text_history(n_turns: int, text_per_turn: str = "x" * 60) -> list[BaseMessage]: + """SystemMessage prefix + n turns of (Human, AI) pairs, each tagged.""" + history: list[BaseMessage] = [SystemMessage(content="You are a test agent.")] + for i in range(1, n_turns + 1): + history.append(make_human(f"q{i}: {text_per_turn}", i)) + history.append(make_ai(f"a{i}: {text_per_turn}", i)) + return history + + +def state(messages: list[BaseMessage]) -> dict[str, Any]: + return {"messages": messages} + + +class CountingFake(FakeListChatModel): + """Langchain's FakeListChatModel + a side-list of every prompt it saw. + + Subclassing (rather than wrapping) is the only way to extend a pydantic v2 + model with new instance state. The mutable list-as-field is safe because + `append()` mutates in place — no attribute reassignment. + """ + + received: list[str] = Field(default_factory=list) + + def invoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: # type: ignore[override] + if isinstance(input, list): + text = "\n".join(str(getattr(m, "content", "")) for m in input) + else: + text = str(input) + self.received.append(text) + return super().invoke(input, *args, **kwargs) + + +def make_counting_fake(responses: list[str]) -> tuple[CountingFake, list[str]]: + m = CountingFake(responses=responses) + return m, m.received + + +def test_token_counter_text() -> None: + s = "x" * 30 + assert count_tokens(s) == 30 // CHARS_PER_TOKEN # 10 + m = HumanMessage(content=s) + assert count_message_tokens(m) == 10 + # memoized + assert m.additional_kwargs.get("dimos_tokens") == 10 + # second call uses the memo (we verify by mutating the memo and seeing it returned) + m.additional_kwargs["dimos_tokens"] = 999 + assert count_message_tokens(m) == 999 + + +def test_token_counter_image() -> None: + m = make_image_human(1) + n = count_message_tokens(m) + assert n == count_image_tokens() + count_tokens("see this") + assert n == TOKENS_PER_IMAGE + 3 # 8 chars / 3 = 3 (rounded up) + + +def test_static_tokens_cached() -> None: + mw = DimosCompactionMiddleware( + summarizer=CountingFake(responses=["UNUSED"]), + threshold_tokens=10_000, + target_tokens=5_000, + summary_size_tokens=500, + system_prompt="you are a test agent", + tool_schemas=[{"name": "echo", "args": {"text": "str"}}], + ) + a = mw._static_tokens() + b = mw._static_tokens() + assert a == b + assert mw._static_cache is not None # cache populated + + +def test_below_threshold_is_noop() -> None: + history = build_text_history(n_turns=2) + fake, received = make_counting_fake(["UNUSED"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=10_000, + target_tokens=5_000, + summary_size_tokens=500, + ) + result = mw.before_model(state(history), runtime=None) + assert result is None + assert received == [] # summarizer never called + + +def test_image_stripping_alone_suffices() -> None: + """An image in an OLDER turn pushes us over; stripping it brings us back under. + + The image must live outside the current turn — the current turn is sacred + and never gets compacted. + """ + history: list[BaseMessage] = [SystemMessage(content="sys")] + history.append(make_human("small msg 1", 1)) + history.append(make_image_human(1)) # ~1003 tokens, in OLD turn 1 + history.append(make_ai("reply 1", 1)) + # Current (latest) turn — protected. + history.append(make_human("small msg 2", 2)) + history.append(make_ai("reply 2", 2)) + + fake, received = make_counting_fake(["UNUSED"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=500, + target_tokens=300, + summary_size_tokens=50, + ) + result = mw.before_model(state(history), runtime=None) + assert result is not None + new_msgs = result["messages"] + assert isinstance(new_msgs[0], RemoveMessage) + assert new_msgs[0].id == REMOVE_ALL_MESSAGES + + # The image was replaced with the placeholder. + new_history = new_msgs[1:] + found_placeholder = False + for m in new_history: + if isinstance(m.content, list): + for block in m.content: + if isinstance(block, dict) and block.get("text") == IMAGE_PLACEHOLDER: + found_placeholder = True + assert found_placeholder + # Summarizer was NOT called — stage 1 alone was enough. + assert received == [] + + +def test_summarize_when_image_strip_insufficient() -> None: + history = build_text_history(n_turns=10, text_per_turn="y" * 200) + fake, received = make_counting_fake(["[summary]"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=400, + target_tokens=200, + summary_size_tokens=20, + ) + result = mw.before_model(state(history), runtime=None) + assert result is not None + new_history = result["messages"][1:] # skip RemoveMessage sentinel + + # Exactly one summary message present, marked dimos_compacted. + summaries = [ + m + for m in new_history + if isinstance(m, SystemMessage) and (m.additional_kwargs or {}).get("dimos_compacted") + ] + assert len(summaries) == 1 + assert "[summary]" in summaries[0].content + assert isinstance(summaries[0].additional_kwargs["dimos_turn"], int) + assert len(received) == 1 + + +def test_protected_prefix_preserved() -> None: + history = build_text_history(n_turns=10, text_per_turn="z" * 200) + sys_msg = history[0] + fake, received = make_counting_fake(["[summary]"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=400, + target_tokens=200, + summary_size_tokens=20, + ) + result = mw.before_model(state(history), runtime=None) + assert result is not None + new_history = result["messages"][1:] + # The original SystemMessage is at index 0 of the new history, unmodified. + assert new_history[0] is sys_msg + + +def test_untagged_midlist_is_summarized() -> None: + """A hand-injected, untagged message in the middle gets folded into the summary.""" + injected_marker = "MIDLIST-UNTAGGED-XYZ" + + history: list[BaseMessage] = [SystemMessage(content="sys")] + for i in range(1, 6): + history.append(make_human(f"q{i} " + "x" * 100, i)) + history.append(make_ai(f"a{i} " + "x" * 100, i)) + # Inject an untagged HumanMessage between turn 5 and 6. + history.append(HumanMessage(content=injected_marker + " " + "x" * 100)) + for i in range(6, 11): + history.append(make_human(f"q{i} " + "x" * 100, i)) + history.append(make_ai(f"a{i} " + "x" * 100, i)) + + fake, received = make_counting_fake(["[summary]"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=400, + target_tokens=200, + summary_size_tokens=20, + ) + result = mw.before_model(state(history), runtime=None) + assert result is not None + + # The injected marker appeared in what was sent to the summarizer. + assert any(injected_marker in p for p in received) + + # And the injected message is NOT in the final history (it was folded in). + new_history = result["messages"][1:] + for m in new_history: + if isinstance(m.content, str): + assert injected_marker not in m.content + + +def test_prior_summary_is_resummarized() -> None: + """A previous compaction's SystemMessage (dimos_compacted=True) folds into the next.""" + prior = SystemMessage( + content="[Prior conversation summary]\nUser asked older things.", + additional_kwargs={"dimos_compacted": True, "dimos_turn": 3}, + ) + + history: list[BaseMessage] = [SystemMessage(content="sys"), prior] + for i in range(4, 14): + history.append(make_human(f"q{i} " + "x" * 200, i)) + history.append(make_ai(f"a{i} " + "x" * 200, i)) + + fake, received = make_counting_fake(["[new summary]"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=400, + target_tokens=200, + summary_size_tokens=20, + ) + result = mw.before_model(state(history), runtime=None) + assert result is not None + + # The prior summary's content was passed to the summarizer. + assert any("older things" in p for p in received) + + # Exactly one compacted SystemMessage remains. + new_history = result["messages"][1:] + compacted = [ + m + for m in new_history + if isinstance(m, SystemMessage) and (m.additional_kwargs or {}).get("dimos_compacted") + ] + assert len(compacted) == 1 + assert "[new summary]" in compacted[0].content + + +def test_recent_turns_kept_verbatim() -> None: + history = build_text_history(n_turns=10, text_per_turn="w" * 200) + # Capture object identity of the last few messages. + tail_ids = {id(m) for m in history[-4:]} + + fake, received = make_counting_fake(["[summary]"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=400, + target_tokens=300, + summary_size_tokens=20, + ) + result = mw.before_model(state(history), runtime=None) + assert result is not None + new_history = result["messages"][1:] + # At least one of the recent messages should be in the new history by identity. + assert any(id(m) in tail_ids for m in new_history) + + +def test_tool_call_pair_coherence() -> None: + """Tool-call and tool-response with same dimos_turn are never split by the cut.""" + history: list[BaseMessage] = [SystemMessage(content="sys")] + # Many fluffy turns to push us well over. + for i in range(1, 8): + history.append(make_human(f"q{i} " + "x" * 200, i)) + history.append(make_ai(f"a{i} " + "x" * 200, i)) + # A target turn with the tool-call pair. + target_turn = 8 + history.append(make_human("invoke add", target_turn)) + history.append( + make_ai( + "", + target_turn, + tool_calls=[{"name": "add", "args": {"a": 1, "b": 2}, "id": "call_xyz"}], + ) + ) + history.append(make_tool("3", "call_xyz", target_turn)) + history.append(make_ai("1 + 2 = 3", target_turn)) + # And more after. + for i in range(9, 12): + history.append(make_human(f"q{i} " + "x" * 200, i)) + history.append(make_ai(f"a{i} " + "x" * 200, i)) + + fake, received = make_counting_fake(["[summary]"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=400, + target_tokens=250, + summary_size_tokens=20, + ) + result = mw.before_model(state(history), runtime=None) + assert result is not None + new_history = result["messages"][1:] + + # Find all messages of the target_turn that survived. + survivors_of_target_turn = [ + m for m in new_history if (m.additional_kwargs or {}).get("dimos_turn") == target_turn + ] + # Either: all 4 messages of that turn are in the kept tail, + # or: none of them are (they were summarized together). + assert len(survivors_of_target_turn) in (0, 4) + + +def test_untagged_history_anchors_current_turn_on_latest_human() -> None: + """If the input has no dimos_turn tags at all, the fallback treats the + latest HumanMessage as the start of the current turn — its content (and + anything emitted after it) is protected; older messages are compactable. + """ + # Manually-built history with NO turn tags anywhere. + history: list[BaseMessage] = [ + SystemMessage(content="You are a test agent."), + HumanMessage(content="old q " + "x" * 500), + AIMessage(content="old a " + "x" * 500), + HumanMessage(content="old q2 " + "x" * 500), + AIMessage(content="old a2 " + "x" * 500), + HumanMessage(content="LATEST_USER_INPUT_UNIQUE_MARKER"), + AIMessage( + content="", + tool_calls=[{"name": "echo", "args": {"text": "x"}, "id": "c1"}], + ), + ] + fake, received = make_counting_fake(["[summary]"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=400, + target_tokens=200, + summary_size_tokens=40, + ) + result = mw.before_model(state(history), runtime=None) + assert result is not None, "should compact: total is over threshold" + + new_history = result["messages"][1:] # skip RemoveMessage sentinel + # The latest HumanMessage is preserved verbatim in the kept tail. + assert any( + isinstance(m, HumanMessage) + and isinstance(m.content, str) + and "LATEST_USER_INPUT_UNIQUE_MARKER" in m.content + for m in new_history + ) + # The trailing AIMessage(tool_call) is also preserved. + assert any( + isinstance(m, AIMessage) + and (getattr(m, "tool_calls", None) or [{}])[0].get("id") == "c1" + for m in new_history + ) + # And the latest HumanMessage's content was NOT sent to the summarizer. + assert not any( + "LATEST_USER_INPUT_UNIQUE_MARKER" in p for p in received + ), "latest human input must not be summarized away" + + +def test_summarize_failure_propagates() -> None: + class BoomFake(FakeListChatModel): + def invoke(self, *args: Any, **kwargs: Any) -> Any: # type: ignore[override] + raise RuntimeError("boom") + + history = build_text_history(n_turns=10, text_per_turn="x" * 200) + mw = DimosCompactionMiddleware( + summarizer=BoomFake(responses=["never used"]), + threshold_tokens=400, + target_tokens=200, + summary_size_tokens=20, + ) + with pytest.raises(RuntimeError, match="boom"): + mw.before_model(state(history), runtime=None) + + +def test_recompaction_folds_prior_summary() -> None: + history = build_text_history(n_turns=10, text_per_turn="x" * 200) + fake, received = make_counting_fake(["[s1]", "[s2]"]) + mw = DimosCompactionMiddleware( + summarizer=fake, + threshold_tokens=400, + target_tokens=200, + summary_size_tokens=20, + ) + r1 = mw.before_model(state(history), runtime=None) + assert r1 is not None + after1 = r1["messages"][1:] + # Add more turns and run again. + n_so_far = max( + ( + (m.additional_kwargs or {}).get("dimos_turn", 0) + for m in after1 + if isinstance((m.additional_kwargs or {}).get("dimos_turn"), int) + ), + default=0, + ) + extended = list(after1) + for i in range(n_so_far + 1, n_so_far + 11): + extended.append(make_human(f"q{i} " + "x" * 200, i)) + extended.append(make_ai(f"a{i} " + "x" * 200, i)) + + r2 = mw.before_model(state(extended), runtime=None) + assert r2 is not None + after2 = r2["messages"][1:] + compacted = [ + m + for m in after2 + if isinstance(m, SystemMessage) and (m.additional_kwargs or {}).get("dimos_compacted") + ] + # Still exactly one compacted summary (the new one rolled the old in). + assert len(compacted) == 1 + # And [s1] was visible to the second summarizer call. + assert any("[s1]" in p for p in received[1:]) + + +class RecordingFakeAgent(FakeMessagesListChatModel): + """Fake agent chat model that records each `.invoke()`'s input messages. + + Subclassing FakeMessagesListChatModel keeps tool_call response support + while letting us inspect what the agent node saw at every step. The base + class raises on `bind_tools()`; we no-op it because the fake doesn't + actually need tool schemas — it returns predetermined responses. + """ + + received_inputs: list[list[BaseMessage]] = Field(default_factory=list) + + def bind_tools(self, tools: Any, **kwargs: Any) -> Any: # type: ignore[override] + return self + + def invoke(self, input: Any, *args: Any, **kwargs: Any) -> Any: # type: ignore[override] + if isinstance(input, list): + self.received_inputs.append(list(input)) + return super().invoke(input, *args, **kwargs) + + +@tool +def echo(text: str) -> str: + """Echo back the given text.""" + return text + + +@tool +def get_big_result() -> str: + """Return a chunk of text big enough to push history over a small threshold.""" + return "BIG_RESULT_LINE " + ("x" * 800) # ~270 tokens with the chars/3 placeholder + + +def test_full_loop_compaction_fires_inside_create_agent() -> None: + """Real `create_agent` loop: middleware fires, langgraph reducer honors + `RemoveMessage(REMOVE_ALL_MESSAGES)`, and the final state contains the + summary plus the agent's appended response. + """ + agent_model = RecordingFakeAgent(responses=[AIMessage(content="Acknowledged.")]) + summarizer = CountingFake(responses=["FAKE_SUMMARY"]) + + history: list[BaseMessage] = [SystemMessage(content="You are a test agent.")] + for i in range(1, 9): + history.append(make_human(f"q{i} " + "x" * 150, i)) + history.append(make_ai(f"a{i} " + "x" * 150, i)) + history.append(make_human("now please respond", 9)) + + mw = DimosCompactionMiddleware( + summarizer=summarizer, + threshold_tokens=400, + target_tokens=200, + summary_size_tokens=40, + system_prompt="test agent", + ) + graph: Any = create_agent( + model=agent_model, + tools=[echo], + middleware=[mw], + ) + + result = graph.invoke(cast("Any", {"messages": history})) + final_messages = result["messages"] + + # 1. The agent's response was appended (loop ran to completion). + assert any(isinstance(m, AIMessage) and m.content == "Acknowledged." for m in final_messages) + + # 2. A compaction summary message exists in the final state. + summaries = [ + m + for m in final_messages + if isinstance(m, SystemMessage) and (m.additional_kwargs or {}).get("dimos_compacted") + ] + assert len(summaries) == 1 + assert "FAKE_SUMMARY" in summaries[0].content + + # 3. The summarizer was invoked exactly once (proves compaction actually fired). + assert len(summarizer.received) == 1 + + # 4. The agent node received a *compacted* prompt — early turns are gone + # from what the model saw, summary is present, current turn is intact. + assert len(agent_model.received_inputs) == 1 + prompt_seen = agent_model.received_inputs[0] + contents = " | ".join( + m.content if isinstance(m.content, str) else str(m.content) for m in prompt_seen + ) + assert "q1 " not in contents, "earliest turn should have been summarized away" + assert "FAKE_SUMMARY" in contents, "summary should be in the agent's input" + assert "now please respond" in contents, "current turn must reach the model" + + # 5. Old turns are gone from the final state too (reducer wiped them via + # the REMOVE_ALL_MESSAGES sentinel). + final_contents = " | ".join( + m.content if isinstance(m.content, str) else str(m.content) for m in final_messages + ) + assert "q1 " not in final_contents + assert "q2 " not in final_contents + + +def test_compaction_fires_between_tool_call_and_final_answer() -> None: + """Multi-step turn: model → tool_call → tool result → model again. + + The pre-tool state is under threshold (no compaction on first `before_model`). + The tool returns a chunk big enough that the SECOND `before_model` is over + threshold and must compact. Proves the "fires before every model call" + invariant — the property that motivates doing this as middleware at all. + """ + agent_model = RecordingFakeAgent( + responses=[ + AIMessage( + content="", + tool_calls=[{"name": "get_big_result", "args": {}, "id": "call_x"}], + ), + AIMessage(content="Tool reply received."), + ] + ) + summarizer = CountingFake(responses=["MID_TURN_SUMMARY"]) + + # Pre-load enough older history that we're CLOSE to threshold but under it. + # Adding the tool result will push us over and force compaction on the 2nd call. + history: list[BaseMessage] = [SystemMessage(content="You are a test agent.")] + for i in range(1, 4): + history.append(make_human(f"q{i} " + "x" * 100, i)) + history.append(make_ai(f"a{i} " + "x" * 100, i)) + history.append(make_human("call the tool", 4)) + + mw = DimosCompactionMiddleware( + summarizer=summarizer, + threshold_tokens=350, + target_tokens=180, + summary_size_tokens=40, + system_prompt="test agent", + ) + graph: Any = create_agent( + model=agent_model, + tools=[get_big_result], + middleware=[mw], + ) + + result = graph.invoke(cast("Any", {"messages": history})) + final_messages = result["messages"] + + # 1. Model was called twice (tool_call round-trip + final answer). + assert len(agent_model.received_inputs) == 2 + + # 2. First call: NO compaction yet. The agent's first prompt should still + # contain "q1" because we haven't crossed threshold yet. + first_contents = " | ".join( + m.content if isinstance(m.content, str) else str(m.content) + for m in agent_model.received_inputs[0] + ) + assert "q1 " in first_contents, "first model call should see uncompacted history" + + # 3. Second call: compaction DID fire (after the tool result inflated state). + second_contents = " | ".join( + m.content if isinstance(m.content, str) else str(m.content) + for m in agent_model.received_inputs[1] + ) + assert "MID_TURN_SUMMARY" in second_contents, ( + "second model call should see the summary — compaction must fire between " + "tool result and next model call" + ) + assert "q1 " not in second_contents, "second call should not see compacted-away turn 1" + + # 4. Summarizer was invoked exactly once (the second before_model triggered it). + assert len(summarizer.received) == 1 + + # 5. Final answer was appended. + assert any( + isinstance(m, AIMessage) and m.content == "Tool reply received." for m in final_messages + )