diff --git a/effectful/handlers/llm/completions.py b/effectful/handlers/llm/completions.py index fc6ca47a..e1108772 100644 --- a/effectful/handlers/llm/completions.py +++ b/effectful/handlers/llm/completions.py @@ -6,6 +6,7 @@ import inspect import string import textwrap +import threading import traceback import typing import uuid @@ -25,7 +26,7 @@ ) from effectful.handlers.llm.encoding import DecodedToolCall, Encodable -from effectful.handlers.llm.template import Template, Tool +from effectful.handlers.llm.template import Template, Tool, get_bound_agent from effectful.internals.unification import nested_type from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements @@ -71,6 +72,30 @@ def append_message(message: Message): pass +@Operation.define +def get_agent_history(agent_id: str) -> collections.OrderedDict[str, Message]: + """Get the message history for an agent. Returns empty OrderedDict by default.""" + return collections.OrderedDict() + + +class AgentHistoryHandler(ObjectInterpretation): + """Handler that stores per-agent message histories in memory. + + Install this handler to give :class:`Agent` instances persistent + in-memory histories across template calls:: + + with handler(AgentHistoryHandler()), handler(LiteLLMProvider()): + bot.ask("question") # history accumulates across calls + """ + + def __init__(self) -> None: + self._histories: dict[str, collections.OrderedDict[str, Message]] = {} + + @implements(get_agent_history) + def _get(self, agent_id: str) -> collections.OrderedDict[str, Message]: + return self._histories.setdefault(agent_id, collections.OrderedDict()) + + def _make_message(content: dict) -> Message: m_id = content.get("id") or str(uuid.uuid1()) message = typing.cast(Message, {**content, "id": m_id}) @@ -442,7 +467,11 @@ def _call_tool(self, tool_call: DecodedToolCall) -> Message: class LiteLLMProvider(ObjectInterpretation): - """Implements templates using the LiteLLM API.""" + """Implements templates using the LiteLLM API. + + Also provides per-agent message history storage via + :func:`get_agent_history`. + """ config: collections.abc.Mapping[str, typing.Any] @@ -451,6 +480,19 @@ def __init__(self, model="gpt-4o", **config): "model": model, **inspect.signature(litellm.completion).bind_partial(**config).kwargs, } + self._histories: dict[str, collections.OrderedDict[str, Message]] = {} + self._tls = threading.local() + + def _get_depths(self) -> dict[str, int]: + if not hasattr(self._tls, "depths"): + self._tls.depths = {} + return self._tls.depths + + @implements(get_agent_history) + def _get_agent_history( + self, agent_id: str + ) -> collections.OrderedDict[str, Message]: + return self._histories.setdefault(agent_id, collections.OrderedDict()) @implements(Template.__apply__) def _call[**P, T]( @@ -464,29 +506,50 @@ def _call[**P, T]( # Create response_model with env so tools passed as arguments are available response_model = Encodable.define(template.__signature__.return_annotation, env) - history: collections.OrderedDict[str, Message] = getattr( - template, "__history__", collections.OrderedDict() - ) # type: ignore - history_copy = history.copy() + # Get history: from agent history handler if bound to an agent, else fresh + agent = get_bound_agent(template) + if agent is not None: + agent_id = agent.__agent_id__ + history = get_agent_history(agent_id) + else: + agent_id = None + history = collections.OrderedDict() + + # Track nesting depth per agent so only the outermost call writes back. + # Inner calls work on their own copy but discard it on return. + # See: TestNestedTemplateCalling.test_only_outermost_writes_to_history + depths = self._get_depths() + if agent_id is not None: + depth = depths.get(agent_id, 0) + depths[agent_id] = depth + 1 + is_outermost = depth == 0 + else: + depth = 0 + is_outermost = False - with handler({_get_history: lambda: history_copy}): - call_system(template) - - message: Message = call_user(template.__prompt_template__, env) - - # loop based on: https://cookbook.openai.com/examples/reasoning_function_calls - tool_calls: list[DecodedToolCall] = [] - result: T | None = None - while message["role"] != "assistant" or tool_calls: - message, tool_calls, result = call_assistant( - template.tools, response_model, **self.config - ) - for tool_call in tool_calls: - message = call_tool(tool_call) + history_copy = history.copy() try: - _get_history() - except NotImplementedError: - history.clear() - history.update(history_copy) - return typing.cast(T, result) + with handler({_get_history: lambda: history_copy}): + call_system(template) + + message: Message = call_user(template.__prompt_template__, env) + + # loop based on: https://cookbook.openai.com/examples/reasoning_function_calls + tool_calls: list[DecodedToolCall] = [] + result: T | None = None + while message["role"] != "assistant" or tool_calls: + message, tool_calls, result = call_assistant( + template.tools, response_model, **self.config + ) + for tool_call in tool_calls: + message = call_tool(tool_call) + + # Only outermost call writes back to canonical history + if is_outermost: + history.clear() + history.update(history_copy) + return typing.cast(T, result) + finally: + if agent_id is not None: + depths[agent_id] = depth diff --git a/effectful/handlers/llm/persistence.py b/effectful/handlers/llm/persistence.py new file mode 100644 index 00000000..1d93ae9a --- /dev/null +++ b/effectful/handlers/llm/persistence.py @@ -0,0 +1,444 @@ +import dataclasses +import json +import sqlite3 +import threading +from collections import OrderedDict +from pathlib import Path +from typing import Any + +from effectful.handlers.llm.completions import get_agent_history +from effectful.handlers.llm.template import Agent, Template, get_bound_agent +from effectful.ops.semantics import fwd +from effectful.ops.syntax import ObjectInterpretation, implements +from effectful.ops.types import NotHandled + + +@Template.define +def summarize_context(transcript: str) -> str: + """Summarise the following conversation transcript into a concise + context summary. Preserve key facts, decisions, and any + information the agent would need to continue working. + + Transcript: + {transcript}""" + raise NotHandled + + +class PersistentAgent(Agent): + """An :class:`Agent` whose history can be persisted by :class:`PersistenceHandler`. + + This is a lightweight marker class. All persistence *behaviour* + (checkpointing, handoff, DB I/O) lives in :class:`PersistenceHandler`, + a composable handler following the same pattern as + :class:`~effectful.handlers.llm.completions.RetryLLMHandler`. + + Unlike plain :class:`Agent` (which uses ``id(self)`` by default), + ``PersistentAgent`` **requires** a stable ``agent_id`` so that + checkpoints can be matched across process restarts. + + Override :meth:`checkpoint_state` and :meth:`restore_state` to persist + custom subclass state alongside the message history. + + **Usage**:: + + from pathlib import Path + from effectful.handlers.llm.persistence import PersistentAgent, PersistenceHandler + from effectful.handlers.llm import Template + from effectful.handlers.llm.completions import LiteLLMProvider + from effectful.ops.semantics import handler + from effectful.ops.types import NotHandled + + class ResearchBot(PersistentAgent): + \"""You are a research assistant that remembers prior sessions.\""" + + @Template.define + def ask(self, question: str) -> str: + \"""Answer: {question}\""" + raise NotHandled + + bot = ResearchBot(agent_id="research-bot") + + with handler(LiteLLMProvider()), handler(PersistenceHandler(Path("./state/checkpoints.db"))): + bot.ask("What is the capital of France?") + # Kill process here, restart, and the bot resumes with context. + """ + + def __init__(self, *, agent_id: str): + self.__agent_id__ = agent_id + + def checkpoint_state(self) -> dict[str, Any]: + """Return a JSON-serialisable dict of subclass state to persist. + + The default implementation serialises all + :func:`dataclasses.dataclass` fields. Override this (and + :meth:`restore_state`) for custom serialisation. + """ + if not dataclasses.is_dataclass(self): + return {} + state: dict[str, Any] = {} + for f in dataclasses.fields(self): + val = getattr(self, f.name) + try: + json.dumps(val) + state[f.name] = val + except (TypeError, ValueError): + pass # skip non-serialisable fields + return state + + def restore_state(self, state: dict[str, Any]) -> None: + """Restore subclass state from *state* dict. + + The default implementation sets each key as an attribute. + Override this (and :meth:`checkpoint_state`) for custom + deserialisation. + """ + for key, value in state.items(): + setattr(self, key, value) + + +def _init_db(conn: sqlite3.Connection) -> None: + """Create the checkpoints table and configure WAL mode for crash tolerance.""" + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute( + """ + CREATE TABLE IF NOT EXISTS checkpoints ( + agent_id TEXT PRIMARY KEY, + handoff TEXT NOT NULL DEFAULT '', + state TEXT NOT NULL DEFAULT '{}', + history TEXT NOT NULL DEFAULT '[]' + ) + """ + ) + conn.commit() + + +class PersistenceHandler(ObjectInterpretation): + """Handler that persists :class:`PersistentAgent` history to a SQLite database. + + Install alongside + :class:`~effectful.handlers.llm.completions.LiteLLMProvider`:: + + with handler(LiteLLMProvider()), handler(PersistenceHandler(Path("./state/checkpoints.db"))): + bot.ask("question") + + Uses SQLite WAL mode for crash tolerance. If the process is killed + mid-write, SQLite's journal-based recovery ensures the database + remains consistent. + + All state is read from and written to the database directly — no + in-memory caching. This makes the handler stateless (aside from + nesting depth tracking) and easy to reason about. + + **Automatic checkpointing**: + + - **Before** each top-level template call: saves a checkpoint with a + handoff note describing the in-progress work. + - **After** each successful call: clears the handoff and saves again. + - **On failure**: saves the checkpoint (with handoff) so the next + session can resume. + + **Crash recovery**: on the next run, the handoff note from the prior + crash is injected into the system prompt so the LLM can resume. + + **Nested calls** (e.g. a tool calling another template on the same + agent) are passed through without additional checkpointing. + + Composes with :class:`~effectful.handlers.llm.completions.RetryLLMHandler` + and :class:`CompactionHandler`:: + + with ( + handler(LiteLLMProvider()), + handler(RetryLLMHandler()), + handler(CompactionHandler()), + handler(PersistenceHandler(Path("./state/checkpoints.db"))), + ): + bot.ask("question") + + **Crash recovery example**:: + + from pathlib import Path + from effectful.handlers.llm.persistence import PersistentAgent, PersistenceHandler + from effectful.handlers.llm import Template + from effectful.handlers.llm.completions import LiteLLMProvider + from effectful.ops.semantics import handler + from effectful.ops.types import NotHandled + + class Bot(PersistentAgent): + \"""You are a helpful assistant.\""" + + @Template.define + def work(self, task: str) -> str: + \"""Do: {task}\""" + raise NotHandled + + bot = Bot(agent_id="worker") + persist = PersistenceHandler(Path("./state/checkpoints.db")) + + # Session 1 — process crashes mid-call + with handler(LiteLLMProvider(model="gpt-4o-mini")), handler(persist): + bot.work("step 1") # completes, checkpointed + bot.work("step 2") # process killed here + + # Session 2 — restart with the same db_path + bot2 = Bot(agent_id="worker") + persist2 = PersistenceHandler(Path("./state/checkpoints.db")) + with handler(LiteLLMProvider(model="gpt-4o-mini")), handler(persist2): + # History from session 1 is restored automatically. + # The handoff note "Executing work ..." tells the LLM what + # was in progress when the crash occurred. + bot2.work("step 2") # resumes with full context + + Use :meth:`save` for manual checkpointing outside the automatic flow + (e.g. after initialising agent state in a choreography). + + Args: + db_path: Path to the SQLite database file. + """ + + def __init__(self, db_path: Path) -> None: + self._db_path = Path(db_path) + self._tls = threading.local() + self._db_lock = threading.Lock() + self._db_initialized = False + + def _connect(self) -> sqlite3.Connection: + """Open a new SQLite connection to the checkpoint database. + + Each call returns a fresh connection, making it safe to use from + any thread. WAL mode and table creation are applied once on the + first call (guarded by ``_db_lock``). + """ + conn = sqlite3.connect(str(self._db_path)) + conn.execute("PRAGMA busy_timeout=5000") + if not self._db_initialized: + with self._db_lock: + if not self._db_initialized: + _init_db(conn) + self._db_initialized = True + return conn + + @property + def db_path(self) -> Path: + """Path to the SQLite database file.""" + return self._db_path + + def _get_depths(self) -> dict[str, int]: + if not hasattr(self._tls, "depths"): + self._tls.depths = {} + return self._tls.depths + + def _load_row(self, agent_id: str) -> tuple[str, str, str] | None: + """Read a checkpoint row from the database. + + Returns ``(handoff, state_json, history_json)`` or ``None``. + """ + conn = self._connect() + try: + row = conn.execute( + "SELECT handoff, state, history FROM checkpoints WHERE agent_id = ?", + (agent_id,), + ).fetchone() + finally: + conn.close() + return row + + def _ensure_loaded(self, agent: PersistentAgent) -> bool: + """Load an agent's checkpoint from the database into the in-process history. + + Safe to call multiple times — only loads once per agent (tracked + via thread-local ``_loaded`` set to avoid re-seeding history that + is already live in memory). + + Returns ``True`` if a checkpoint was found and loaded. + """ + agent_id = agent.__agent_id__ + loaded = self._get_loaded() + if agent_id in loaded: + return False + loaded.add(agent_id) + + row = self._load_row(agent_id) + if row is None: + return False + + _handoff, state_json, history_json = row + agent.restore_state(json.loads(state_json)) + stored = get_agent_history(agent_id) + stored.clear() + stored.update({msg["id"]: msg for msg in json.loads(history_json)}) + return True + + def _get_loaded(self) -> set[str]: + if not hasattr(self._tls, "loaded"): + self._tls.loaded = set() + return self._tls.loaded + + def save(self, agent: PersistentAgent, handoff: str = "") -> Path: + """Write an agent's current state to the database and return the db path.""" + agent_id = agent.__agent_id__ + history = get_agent_history(agent_id) + state_json = json.dumps(agent.checkpoint_state(), default=str) + history_json = json.dumps(list(history.values()), default=str) + + conn = self._connect() + try: + conn.execute( + """ + INSERT INTO checkpoints (agent_id, handoff, state, history) + VALUES (?, ?, ?, ?) + ON CONFLICT(agent_id) DO UPDATE SET + handoff = excluded.handoff, + state = excluded.state, + history = excluded.history + """, + (agent_id, handoff, state_json, history_json), + ) + conn.commit() + finally: + conn.close() + + return self.db_path + + def _get_handoff(self, agent_id: str) -> str: + """Return the current handoff note for *agent_id* (reads from DB).""" + row = self._load_row(agent_id) + if row is None: + return "" + return row[0] + + @implements(Template.__apply__) + def _call[**P, T]( + self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs + ) -> T: + agent = get_bound_agent(template) + if not isinstance(agent, PersistentAgent): + return fwd(template, *args, **kwargs) + + agent_id = agent.__agent_id__ + self._ensure_loaded(agent) + + # Nesting: only checkpoint for outermost call per agent + depths = self._get_depths() + depth = depths.get(agent_id, 0) + depths[agent_id] = depth + 1 + is_outermost = depth == 0 + + try: + if is_outermost: + # Inject prior-session handoff into system prompt + prior_handoff = self._get_handoff(agent_id) + if prior_handoff: + template.__system_prompt__ = ( + f"{template.__system_prompt__}\n\n" + f"[HANDOFF FROM PRIOR SESSION] {prior_handoff}" + ) + + # Record current call as handoff for crash recovery + current_handoff = ( + f"Executing {template.__name__} with args={repr(args)[:200]}" + ) + self.save(agent, handoff=current_handoff) + + result = fwd(template, *args, **kwargs) + + if is_outermost: + self.save(agent, handoff="") + + return result + except BaseException: + if is_outermost: + # Preserve handoff so next session knows what was in progress + self.save(agent, handoff=current_handoff) + raise + finally: + depths[agent_id] = depth + + +class CompactionHandler(ObjectInterpretation): + """Handler that compacts agent history when it exceeds a threshold. + + After each top-level template call on an :class:`Agent`, if the + message history exceeds ``max_history_len``, older messages are + summarised into a single context-summary message via an LLM call.:: + + with handler(LiteLLMProvider()), handler(CompactionHandler(max_history_len=20)): + agent.ask("question") # history auto-compacted after call + """ + + def __init__(self, max_history_len: int = 50) -> None: + self._max_history_len = max_history_len + self._tls = threading.local() + + def _get_depths(self) -> dict[str, int]: + if not hasattr(self._tls, "depths"): + self._tls.depths = {} + return self._tls.depths + + @implements(Template.__apply__) + def _call[**P, T]( + self, template: Template[P, T], *args: P.args, **kwargs: P.kwargs + ) -> T: + agent = get_bound_agent(template) + if not isinstance(agent, Agent): + return fwd(template, *args, **kwargs) + + agent_id = agent.__agent_id__ + depths = self._get_depths() + depth = depths.get(agent_id, 0) + depths[agent_id] = depth + 1 + is_outermost = depth == 0 + + try: + result = fwd(template, *args, **kwargs) + + if is_outermost: + history = get_agent_history(agent_id) + if len(history) > self._max_history_len: + self._compact(agent_id, history) + + return result + finally: + depths[agent_id] = depth + + def _compact(self, agent_id: str, history: OrderedDict[str, Any]) -> None: + keep_recent = max(self._max_history_len // 2, 4) + items = list(history.items()) + if len(items) <= keep_recent: + return + + split = len(items) - keep_recent + # Never split between a tool_use and its tool_result(s). + while split > 0 and items[split][1].get("role") == "tool": + split -= 1 + if split <= 0: + return + + old_items = items[:split] + recent_items = items[split:] + + old_text_parts: list[str] = [] + for _, msg in old_items: + role = msg.get("role", "unknown") + content = msg.get("content", "") + if isinstance(content, list): + text_parts = [p.get("text", "") for p in content if isinstance(p, dict)] + content = " ".join(text_parts) + if content: + old_text_parts.append(f"[{role}]: {content}") + old_transcript = "\n".join(old_text_parts) + + if not old_transcript.strip(): + return + + summary = summarize_context(old_transcript) + + summary_msg: dict[str, Any] = { + "id": f"compaction-{agent_id}", + "role": "user", + "content": f"[CONTEXT SUMMARY FROM PRIOR CONVERSATION]\n{summary}", + } + history.clear() + history[summary_msg["id"]] = summary_msg + for key, msg in recent_items: + history[key] = msg diff --git a/effectful/handlers/llm/template.py b/effectful/handlers/llm/template.py index 93e7f085..3bb418d5 100644 --- a/effectful/handlers/llm/template.py +++ b/effectful/handlers/llm/template.py @@ -5,7 +5,7 @@ import string import types import typing -from collections import ChainMap, OrderedDict +from collections import ChainMap from collections.abc import Callable, Mapping, MutableMapping from typing import Annotated, Any @@ -257,8 +257,6 @@ def __get__[S](self, instance: S | None, owner: type[S] | None = None): self_param_name = list(self.__signature__.parameters.keys())[0] result.__context__ = self.__context__.new_child({self_param_name: instance}) if isinstance(instance, Agent): - assert isinstance(result, Template) and not hasattr(result, "__history__") - result.__history__ = instance.__history__ # type: ignore[attr-defined] result.__system_prompt__ = "\n\n".join( part for part in ( @@ -378,18 +376,31 @@ def send(self, user_input: str) -> str: """ - __history__: OrderedDict[str, Mapping[str, Any]] __system_prompt__: str + @functools.cached_property + def __agent_id__(self) -> str: + return str(id(self)) + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - if not hasattr(cls, "__history__"): - prop = functools.cached_property(lambda _: OrderedDict()) - prop.__set_name__(cls, "__history__") - cls.__history__ = prop if not hasattr(cls, "__system_prompt__"): sp = functools.cached_property( lambda self: inspect.getdoc(type(self)) or "" ) sp.__set_name__(cls, "__system_prompt__") cls.__system_prompt__ = sp + + +def get_bound_agent(template: Template) -> "Agent | None": + """Extract the bound :class:`Agent` instance from a template, if any. + + Bound method templates have a first context map with exactly one entry + (``{self_param_name: instance}``), while standalone templates have a + larger map (module globals). + """ + ctx = getattr(template, "__context__", None) + if ctx is None or not ctx.maps or len(ctx.maps[0]) != 1: + return None + val = next(iter(ctx.maps[0].values())) + return val if isinstance(val, Agent) else None diff --git a/tests/test_handlers_llm_encoding.py b/tests/test_handlers_llm_encoding.py index 9cccf93d..adadee23 100644 --- a/tests/test_handlers_llm_encoding.py +++ b/tests/test_handlers_llm_encoding.py @@ -686,7 +686,6 @@ def test_callable_encode_non_callable(): def test_callable_encode_no_source_no_docstring(): - class _NoDocCallable: __name__ = "nodoc" __doc__ = None diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index eec2fc78..f3d2af9c 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -10,6 +10,7 @@ import json import os import re +import sqlite3 from collections.abc import Callable from enum import StrEnum from pathlib import Path @@ -37,9 +38,15 @@ call_assistant, call_tool, completion, + get_agent_history, ) from effectful.handlers.llm.encoding import Encodable, SynthesizedFunction from effectful.handlers.llm.evaluation import UnsafeEvalProvider +from effectful.handlers.llm.persistence import ( + CompactionHandler, + PersistenceHandler, + PersistentAgent, +) from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import NotHandled @@ -1797,17 +1804,17 @@ def simple_task(self, instruction: str) -> str: agent = SimpleAgent() - # No outer _get_history handler: LiteLLMProvider._call detects this is the - # outermost template and writes back to the agent's __history__. + provider = LiteLLMProvider(model="test") with ( - handler(LiteLLMProvider(model="test")), + handler(provider), handler(mock_handler), ): result = agent.simple_task("go") assert result == "done" - # Agent's __history__ should have messages written back (system + user + assistant) - assert len(agent.__history__) >= 2 + # Agent's history should have messages written back (system + user + assistant) + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) >= 2 class TestAgentCrossTemplateRecovery: @@ -1862,8 +1869,9 @@ def _completion(self, model, messages=None, **kwargs): agent = TestAgent() - with handler(TwoPhaseCompletionHandler()): - with handler(LiteLLMProvider(model="test")): + provider = LiteLLMProvider(model="test") + with handler(provider): + with handler(TwoPhaseCompletionHandler()): # First call should fail with tool execution error with pytest.raises(ToolCallExecutionError): agent.step_with_tool("stage 1") @@ -1874,7 +1882,7 @@ def _completion(self, model, messages=None, **kwargs): assert result == "summary result" # Verify history doesn't contain messages from the failed call - history = agent.__history__ + history = provider._histories.get(agent.__agent_id__, {}) for msg in history.values(): tool_calls = msg.get("tool_calls") if tool_calls: @@ -1916,12 +1924,14 @@ def do_work(self, task: str) -> str: mock = MockCompletionHandler(responses) agent = CleanupAgent() + provider = LiteLLMProvider(model="test") with pytest.raises(ToolCallExecutionError): - with handler(LiteLLMProvider(model="test")), handler(mock): + with handler(provider), handler(mock): agent.do_work("go") # Agent history should be empty — all messages from failed call pruned - assert len(agent.__history__) == 0 + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) == 0 def test_agent_history_preserved_for_successful_calls(self): """Successful calls should leave messages in agent history.""" @@ -1943,12 +1953,14 @@ def greet(self, name: str) -> str: mock = MockCompletionHandler(responses) agent = SuccessAgent() - with handler(LiteLLMProvider(model="test")), handler(mock): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(mock): result = agent.greet("world") assert result == "Hello!" # History should contain messages from the successful call - assert len(agent.__history__) >= 2 # user + assistant at minimum + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) >= 2 # user + assistant at minimum def test_agent_multiple_successful_calls_accumulate_history(self): """Multiple successful calls should accumulate in agent history.""" @@ -1977,14 +1989,16 @@ def _completion(self, model, messages=None, **kwargs): agent = ChatAgent() - with handler(LiteLLMProvider(model="test")), handler(MultiResponseHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(MultiResponseHandler()): r1 = agent.chat("first") r2 = agent.chat("second") assert r1 == "reply 1" assert r2 == "reply 2" # History should have messages from both calls - assert len(agent.__history__) >= 4 # 2 * (user + assistant) + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) >= 4 # 2 * (user + assistant) def test_agent_error_then_success_accumulates_only_success(self): """After a failed call, only the subsequent successful call's messages remain.""" @@ -2025,19 +2039,21 @@ def _completion(self, model, messages=None, **kwargs): agent = RecoveryAgent() - with handler(LiteLLMProvider(model="test")), handler(PhaseHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(PhaseHandler()): with pytest.raises(ToolCallExecutionError): agent.risky("step 1") - history_after_error = len(agent.__history__) + history_after_error = len(get_agent_history(agent.__agent_id__)) assert history_after_error == 0 result = agent.safe("step 2") assert result == "safe result" # Only messages from the successful call should be in history - assert len(agent.__history__) >= 2 - assert len(agent.__history__) > history_after_error + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) >= 2 + assert len(history) > history_after_error class TestAgentSystemMessageDeduplication: @@ -2109,13 +2125,15 @@ def _completion(self, model, messages=None, **kwargs): agent = SystemMsgAgent() - with handler(LiteLLMProvider(model="test")), handler(MultiHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(MultiHandler()): agent.do("a") agent.do("b") agent.do("c") agent.do("d") - system_msgs = [m for m in agent.__history__.values() if m["role"] == "system"] + history = provider._histories.get(agent.__agent_id__, {}) + system_msgs = [m for m in history.values() if m["role"] == "system"] assert len(system_msgs) == 1, ( f"Expected exactly 1 system message, got {len(system_msgs)}" ) @@ -2151,14 +2169,16 @@ def _completion(self, model, messages=None, **kwargs): agent = MemoryAgent() - with handler(LiteLLMProvider(model="test")), handler(MemoryHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(MemoryHandler()): agent.chat("first") agent.chat("second") agent.chat("third") # History should have: 1 system + 3 user + 3 assistant = 7 - assert len(agent.__history__) == 7 - roles = [m["role"] for m in agent.__history__.values()] + history = provider._histories.get(agent.__agent_id__, {}) + assert len(history) == 7 + roles = [m["role"] for m in history.values()] assert roles.count("system") == 1 assert roles.count("user") == 3 assert roles.count("assistant") == 3 @@ -2187,12 +2207,400 @@ def _completion(self, model, messages=None, **kwargs): agent = OrderAgent() - with handler(LiteLLMProvider(model="test")), handler(OrderHandler()): + provider = LiteLLMProvider(model="test") + with handler(provider), handler(OrderHandler()): agent.step(1) agent.step(2) agent.step(3) - messages = list(agent.__history__.values()) + history = provider._histories.get(agent.__agent_id__, {}) + messages = list(history.values()) assert messages[0]["role"] == "system", ( "System message should be the first message in history" ) + + +# --------------------------------------------------------------------------- +# Integration tests: Agent & PersistentAgent with real LLM +# --------------------------------------------------------------------------- + + +class _PlainHelper(Agent): + """You are a concise helper. Reply with at most 10 words.""" + + @Template.define + def answer(self, q: str) -> str: + """Answer concisely: {q}""" + raise NotHandled + + +class _PersistentOrchestrator(Agent): + """You are an orchestrator. Use `ask_helper` to get answers. + + Reply with the helper's answer verbatim — do NOT call the tool more + than once. + """ + + def __init__(self, helper: _PlainHelper): + self._helper = helper + + @Tool.define + def ask_helper(self, question: str) -> str: + """Ask the helper agent a question. Call this exactly once.""" + return self._helper.answer(question) + + @Template.define + def orchestrate(self, task: str) -> str: + """Task: {task}. Call `ask_helper` once, then return the answer.""" + raise NotHandled + + +@requires_openai +def test_plain_agent_simple_call_integration(): + """Plain Agent makes a single LLM call and returns a string.""" + helper = _PlainHelper() + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=30)), + handler(LimitLLMCallsHandler(max_calls=1)), + ): + result = helper.answer("What is 2+2?") + + assert isinstance(result, str) + assert len(result) > 0 + + +@requires_openai +def test_agent_nested_tool_call_integration(): + """An Agent delegates to another Agent via a tool call. + + The orchestrator calls ask_helper (one tool round-trip) then returns. + LimitLLMCallsHandler caps total LLM calls at 4 to prevent runaway + recursion. + """ + helper = _PlainHelper() + orchestrator = _PersistentOrchestrator(helper) + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=60)), + handler(LimitLLMCallsHandler(max_calls=4)), + ): + result = orchestrator.orchestrate("What is the capital of France?") + + assert isinstance(result, str) + assert len(result) > 0 + + +@requires_openai +def test_persistent_agent_with_persistence_integration(tmp_path): + """PersistentAgent checkpoints to disk after a real LLM call.""" + + class Bot(PersistentAgent): + """You are a concise assistant. Reply in at most 10 words.""" + + @Template.define + def ask(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + bot = Bot(agent_id="integration-bot") + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=30)), + handler(LimitLLMCallsHandler(max_calls=1)), + handler(persist), + ): + result = bot.ask("Say hello") + + assert isinstance(result, str) + db_path = tmp_path / "checkpoints.db" + assert db_path.exists() + conn = sqlite3.connect(str(db_path)) + row = conn.execute( + "SELECT handoff, history FROM checkpoints WHERE agent_id = ?", + ("integration-bot",), + ).fetchone() + conn.close() + assert row is not None + assert len(json.loads(row[1])) > 0 + assert row[0] == "" + + +@requires_openai +def test_persistent_and_plain_agent_cooperate_integration(tmp_path): + """A plain Agent and PersistentAgent work together via tool delegation. + + The persistent orchestrator calls ask_helper (a plain Agent) via a tool. + LimitLLMCallsHandler caps calls at 5 to prevent runaway tool loops. + """ + + class Orchestrator(PersistentAgent): + """You orchestrate tasks. Use `ask_helper` exactly once, then + return the helper's answer verbatim. Do NOT call tools more than once. + """ + + def __init__(self, helper_agent: _PlainHelper, **kwargs): + super().__init__(**kwargs) + self._helper = helper_agent + + @Tool.define + def ask_helper(self, question: str) -> str: + """Ask the helper a question. Call exactly once.""" + return self._helper.answer(question) + + @Template.define + def run(self, task: str) -> str: + """Task: {task}. Use `ask_helper` once, then return the result.""" + raise NotHandled + + helper = _PlainHelper() + orch = Orchestrator(helper, agent_id="orch") + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=60)), + handler(LimitLLMCallsHandler(max_calls=4)), + handler(persist), + ): + result = orch.run("What is 3 * 7?") + + assert isinstance(result, str) + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + row = conn.execute( + "SELECT 1 FROM checkpoints WHERE agent_id = ?", ("orch",) + ).fetchone() + conn.close() + assert row is not None + + +@requires_openai +def test_compaction_after_multiple_calls_integration(): + """History is compacted after enough calls exceed the threshold.""" + helper = _PlainHelper() + provider = LiteLLMProvider(model="gpt-4o-mini", max_tokens=30) + + with ( + handler(provider), + handler(LimitLLMCallsHandler(max_calls=4)), + handler(CompactionHandler(max_history_len=4)), + ): + # Each call adds ~2 msgs (user + assistant). + # After the 2nd call history exceeds 4 msgs, triggering compaction. + for i in range(2): + helper.answer(f"What is {i} + 1?") + + history = provider._histories.get(helper.__agent_id__, {}) + first_msg = next(iter(history.values())) + assert "CONTEXT SUMMARY" in first_msg["content"] + + +@requires_openai +def test_compaction_with_persistence_integration(tmp_path): + """Compaction and persistence compose: compacted history is checkpointed.""" + + class Bot(PersistentAgent): + """You are a concise assistant. Reply in at most 10 words.""" + + @Template.define + def ask(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + bot = Bot(agent_id="compact-bot") + provider = LiteLLMProvider(model="gpt-4o-mini", max_tokens=30) + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(provider), + handler(LimitLLMCallsHandler(max_calls=4)), + handler(CompactionHandler(max_history_len=4)), + handler(persist), + ): + for i in range(2): + bot.ask(f"What is {i} + 1?") + + # Compacted history should be persisted to disk + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + row = conn.execute( + "SELECT history FROM checkpoints WHERE agent_id = ?", ("compact-bot",) + ).fetchone() + conn.close() + history = json.loads(row[0]) + first_msg = history[0] + assert "CONTEXT SUMMARY" in first_msg["content"] + + +class _ToolAgent(Agent): + """You are a concise assistant. Answer in at most 10 words.""" + + @Tool.define + def lookup(self, query: str) -> str: + """Look up factual information about a topic.""" + return f"Result: The answer to '{query}' is 42." + + @Template.define + def ask(self, question: str) -> str: + """Answer: {question}. You MUST use the lookup tool first.""" + raise NotHandled + + +@pytest.mark.parametrize( + "model", + [ + pytest.param("gpt-4o-mini", marks=requires_openai), + pytest.param("claude-haiku-4-5-20251001", marks=requires_anthropic), + ], +) +def test_compaction_with_tool_calls_does_not_break_api(model): + """After compaction of history containing tool pairs, subsequent calls succeed. + + Each tool-using call generates ~4 messages (user, assistant/tool_use, + tool/result, assistant/final). With max_history_len=4, compaction fires + after the 1st call. The 2nd call must succeed — if compaction orphaned + a tool_result the API would reject the conversation. + """ + bot = _ToolAgent() + provider = LiteLLMProvider(model=model, max_tokens=60) + + with ( + handler(provider), + handler(RetryLLMHandler(stop=tenacity.stop_after_attempt(2))), + handler(LimitLLMCallsHandler(max_calls=8)), + handler(CompactionHandler(max_history_len=4)), + ): + bot.ask("What is the meaning of life?") + # Compaction should have fired. This call must not fail. + result = bot.ask("Summarize what you told me.") + + assert isinstance(result, str) + assert len(result) > 0 + + # Verify no orphaned tool_result messages in final history. + history = provider._histories.get(bot.__agent_id__, {}) + tool_use_ids: set[str] = set() + for msg in history.values(): + # Anthropic format: tool_use blocks in content + content = msg.get("content", "") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_use": + tool_use_ids.add(block["id"]) + # OpenAI format: tool_calls field on assistant messages + for tc in msg.get("tool_calls") or []: + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) + if tc_id: + tool_use_ids.add(tc_id) + + for msg in history.values(): + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id", "") + assert tc_id in tool_use_ids, ( + f"Orphaned tool_result with tool_call_id={tc_id!r}" + ) + + +# --------------------------------------------------------------------------- +# Integration tests: SQLite persistence +# --------------------------------------------------------------------------- + + +@requires_openai +def test_sqlite_persistence_crash_recovery_integration(tmp_path): + """After a simulated crash, a new session resumes from the SQLite checkpoint.""" + + class Bot(PersistentAgent): + """You are a concise assistant. Reply in at most 10 words.""" + + @Template.define + def ask(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + # Session 1: successful call + bot = Bot(agent_id="crash-test-bot") + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=30)), + handler(LimitLLMCallsHandler(max_calls=1)), + handler(persist), + ): + result1 = bot.ask("What is 2+2?") + + assert isinstance(result1, str) + + # Verify SQLite DB exists and has data + db_path = tmp_path / "checkpoints.db" + assert db_path.exists() + conn = sqlite3.connect(str(db_path)) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + assert result == "ok" + row = conn.execute( + "SELECT handoff, history FROM checkpoints WHERE agent_id = ?", + ("crash-test-bot",), + ).fetchone() + conn.close() + assert row is not None + assert row[0] == "" # handoff cleared after success + assert len(json.loads(row[1])) > 0 + + # Session 2: new process loads from SQLite and continues + bot2 = Bot(agent_id="crash-test-bot") + + with ( + handler(LiteLLMProvider(model="gpt-4o-mini", max_tokens=30)), + handler(LimitLLMCallsHandler(max_calls=1)), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result2 = bot2.ask("What did I just ask?") + + assert isinstance(result2, str) + + # History should have grown (session 2 sees session 1 messages) + conn = sqlite3.connect(str(db_path)) + row = conn.execute( + "SELECT history FROM checkpoints WHERE agent_id = ?", + ("crash-test-bot",), + ).fetchone() + conn.close() + history = json.loads(row[0]) + assert len(history) > 3 # at least system + user + assistant from each session + + +@requires_openai +def test_sqlite_persistence_db_integrity_after_compaction_integration(tmp_path): + """Compacted history is correctly persisted to SQLite.""" + + class Bot(PersistentAgent): + """You are a concise assistant. Reply in at most 10 words.""" + + @Template.define + def ask(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + bot = Bot(agent_id="compact-sqlite-bot") + provider = LiteLLMProvider(model="gpt-4o-mini", max_tokens=30) + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + with ( + handler(provider), + handler(LimitLLMCallsHandler(max_calls=4)), + handler(CompactionHandler(max_history_len=4)), + handler(persist), + ): + for i in range(2): + bot.ask(f"What is {i} + 1?") + + # Verify SQLite DB integrity + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + assert result == "ok" + row = conn.execute( + "SELECT history FROM checkpoints WHERE agent_id = ?", + ("compact-sqlite-bot",), + ).fetchone() + conn.close() + history = json.loads(row[0]) + first_msg = history[0] + assert "CONTEXT SUMMARY" in first_msg["content"] diff --git a/tests/test_handlers_llm_template.py b/tests/test_handlers_llm_template.py index 7c3bd5bc..cff3da06 100644 --- a/tests/test_handlers_llm_template.py +++ b/tests/test_handlers_llm_template.py @@ -11,12 +11,15 @@ from effectful.handlers.llm import Agent, Template, Tool from effectful.handlers.llm.completions import ( DEFAULT_SYSTEM_PROMPT, + AgentHistoryHandler, LiteLLMProvider, RetryLLMHandler, + call_assistant, call_user, completion, + get_agent_history, ) -from effectful.ops.semantics import handler +from effectful.ops.semantics import fwd, handler from effectful.ops.syntax import ObjectInterpretation, implements from effectful.ops.types import NotHandled @@ -281,10 +284,10 @@ def test_history_contains_all_messages_after_two_calls(self): bot.send("a") bot.send("b") - # After two complete calls the history should have: - # call 1: system, user, assistant (3) - # call 2: system, user, assistant (3) - assert len(bot.__history__) >= 4 + # After two complete calls the history should have: + # call 1: system, user, assistant (3) + # call 2: system, user, assistant (3) + assert len(get_agent_history(bot.__agent_id__)) >= 4 def test_message_ids_are_unique(self): mock = MockCompletionHandler( @@ -299,8 +302,8 @@ def test_message_ids_are_unique(self): bot.send("a") bot.send("b") - ids = list(bot.__history__.keys()) - assert len(ids) == len(set(ids)), "message IDs must be unique" + ids = list(get_agent_history(bot.__agent_id__).keys()) + assert len(ids) == len(set(ids)), "message IDs must be unique" class TestAgentIsolation: @@ -314,20 +317,24 @@ def test_two_agents_have_independent_histories(self): ] ) bot1 = ChatBot() + bot1.__agent_id__ = "bot1" bot2 = ChatBot() + bot2.__agent_id__ = "bot2" with handler(LiteLLMProvider()), handler(mock): bot1.send("msg for bot1") bot2.send("msg for bot2") - # bot2's call should NOT contain bot1's messages — only system + user - assert len(mock.received_messages[1]) == len(mock.received_messages[0]) + # bot2's call should NOT contain bot1's messages — only system + user + assert len(mock.received_messages[1]) == len(mock.received_messages[0]) - # Each bot made exactly one call, so their histories should be equal in size - assert len(bot1.__history__) == len(bot2.__history__) + # Each bot made exactly one call, so their histories should be equal in size + h1 = get_agent_history(bot1.__agent_id__) + h2 = get_agent_history(bot2.__agent_id__) + assert len(h1) == len(h2) - # Histories share no message IDs - assert set(bot1.__history__.keys()).isdisjoint(set(bot2.__history__.keys())) + # Histories share no message IDs + assert set(h1.keys()).isdisjoint(set(h2.keys())) def test_non_agent_template_gets_fresh_sequence(self): @Template.define @@ -494,13 +501,13 @@ class ValidDocAgent(Agent): ) -class TestAgentCachedProperty: - """__history__ is lazily created per instance without requiring __init__.""" +class TestAgentHistoryViaHandler: + """History is managed via AgentHistoryHandler, not as a cached property.""" - def test_no_init_required(self): + def test_history_defaults_to_empty(self): class MinimalAgent(Agent): - """You are a minimal cached-property test agent. - Your goal is to expose lazily initialized Agent state. + """You are a minimal history-handler test agent. + Your goal is to expose handler-managed Agent history. """ @Template.define @@ -509,9 +516,10 @@ def greet(self, name: str) -> str: raise NotHandled agent = MinimalAgent() - # Should be an OrderedDict, created on first access - assert isinstance(agent.__history__, collections.OrderedDict) - assert len(agent.__history__) == 0 + with handler(AgentHistoryHandler()): + history = get_agent_history(agent.__agent_id__) + assert isinstance(history, collections.OrderedDict) + assert len(history) == 0 def test_subclass_with_own_init(self): class CustomAgent(Agent): @@ -529,13 +537,20 @@ def greet(self) -> str: agent = CustomAgent("Alice") assert agent.name == "Alice" - assert isinstance(agent.__history__, collections.OrderedDict) + with handler(AgentHistoryHandler()): + assert isinstance( + get_agent_history(agent.__agent_id__), collections.OrderedDict + ) def test_history_is_per_instance(self): a = ChatBot() + a.__agent_id__ = "a" b = ChatBot() - a.__history__["fake"] = {"id": "fake", "role": "user", "content": "x"} - assert "fake" not in b.__history__ + b.__agent_id__ = "b" + with handler(AgentHistoryHandler()): + hist_a = get_agent_history(a.__agent_id__) + hist_a["fake"] = {"id": "fake", "role": "user", "content": "x"} + assert "fake" not in get_agent_history(b.__agent_id__) class TestAgentWithToolCalls: @@ -570,13 +585,13 @@ def compute(self, question: str) -> str: with handler(LiteLLMProvider()), handler(mock): result = agent.compute("what is 2+3?") - assert result == "The answer is 5" + assert result == "The answer is 5" - # History should contain: system, user, assistant (tool_call), - # tool (result), assistant (final) - roles = [m["role"] for m in agent.__history__.values()] - assert "tool" in roles - assert roles.count("assistant") == 2 + # History should contain: system, user, assistant (tool_call), + # tool (result), assistant (final) + roles = [m["role"] for m in get_agent_history(agent.__agent_id__).values()] + assert "tool" in roles + assert roles.count("assistant") == 2 class TestAgentWithRetryHandler: @@ -611,13 +626,13 @@ def pick_number(self) -> int: ): result = agent.pick_number() - assert result == 42 + assert result == 42 - # The malformed assistant message and error feedback from the retry - # should NOT appear in the agent's history. Only the final successful - # assistant message should be there. - roles = {m["role"] for m in agent.__history__.values()} - assert {"user", "assistant"} == roles - {"system"} + # The malformed assistant message and error feedback from the retry + # should NOT appear in the agent's history. Only the final successful + # assistant message should be there. + roles = {m["role"] for m in get_agent_history(agent.__agent_id__).values()} + assert {"user", "assistant"} == roles - {"system"} class TestNestedTemplateCalling: @@ -626,7 +641,7 @@ class TestNestedTemplateCalling: When a Template triggers a tool call whose implementation invokes another Template on the same Agent, the inner call must: - work on a fresh copy of the agent's history - - NOT write its messages back to agent.__history__ + - NOT write its messages back to the agent's history - return its result correctly so the outer template can continue """ @@ -647,7 +662,7 @@ def test_same_agent_nested_template_via_tool(self): assert result == "all good" def test_only_outermost_writes_to_history(self): - """Inner template's messages are absent from agent.__history__.""" + """Inner template's messages are absent from agent history.""" mock = MockCompletionHandler( [ make_tool_call_response("self__nested_tool", '{"payload": "demo"}'), @@ -660,14 +675,14 @@ def test_only_outermost_writes_to_history(self): with handler(LiteLLMProvider()), handler(mock): agent.outer("demo") - roles = [m["role"] for m in agent.__history__.values()] - # Outer call produces: user, assistant(tool_call), tool, assistant(final) - # Inner call's user + assistant are NOT written back - assert set(roles) <= {"system", "user", "assistant", "tool"} - assert roles.count("system") == 1 - assert roles.count("user") == 1 - assert roles.count("assistant") == 2 # tool_call + final - assert roles.count("tool") == 1 + roles = [m["role"] for m in get_agent_history(agent.__agent_id__).values()] + # Outer call produces: user, assistant(tool_call), tool, assistant(final) + # Inner call's user + assistant are NOT written back + assert set(roles) <= {"system", "user", "assistant", "tool"} + assert roles.count("system") == 1 + assert roles.count("user") == 1 + assert roles.count("assistant") == 2 # tool_call + final + assert roles.count("tool") == 1 def test_inner_template_gets_fresh_messages(self): """The nested template's LLM call sees only its own system + user, @@ -709,7 +724,7 @@ def test_inner_template_sees_prior_completed_history(self): agent.outer("first") agent.outer("second") - # After first call, agent.__history__ has 2 messages (user + assistant). + # After first call, agent history has 2 messages (user + assistant). # Second outer call (call 1): starts from history(2) + own user = 3. # Inner call (call 2): starts from history(2) + own user = 3. # Both see the same base history. If inner saw the outer's in-flight @@ -747,6 +762,44 @@ def test_sequential_call_after_nested_sees_history(self): second_call_roles = [m["role"] for m in mock.received_messages[3]] assert second_call_roles.count("assistant") >= 2 # from first call's history + def test_inner_success_outer_failure_no_history_leak(self): + """When inner call succeeds but outer fails, canonical history must + not be left with inner call's stale messages.""" + + class LimitCallsHandler(ObjectInterpretation): + def __init__(self, max_calls): + self.max_calls = max_calls + self.count = 0 + + @implements(call_assistant) + def _call(self, *args, **kwargs): + self.count += 1 + if self.count > self.max_calls: + raise RuntimeError(f"Exceeded {self.max_calls} calls") + return fwd() + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__nested_tool", '{"payload": "demo"}'), + make_text_response("inner"), + make_text_response("outer"), # won't be reached + ] + ) + agent = _DesignerAgent() + provider = LiteLLMProvider() + # Allow 2 calls: outer's first + inner's. Outer's 2nd call (#3) fails. + limiter = LimitCallsHandler(max_calls=2) + + with pytest.raises(RuntimeError, match="Exceeded"): + with handler(provider), handler(mock), handler(limiter): + agent.outer("demo") + + with handler(provider): + history = get_agent_history(agent.__agent_id__) + # Canonical history should be empty — outer never completed. + # It must NOT contain inner call's system/user/assistant. + assert len(history) == 0 + # --------------------------------------------------------------------------- # Template method and scoping tests (moved from test_handlers_llm_template.py) @@ -1185,8 +1238,6 @@ def static_method(x: int) -> int: # static_method remains a plain Template accessible on class and instance assert isinstance(MyAgent.static_method, Template) assert isinstance(agent.static_method, Template) - # static_method should NOT have __history__ set - assert not hasattr(MyAgent.static_method, "__history__") def test_agent_skips_classmethod_template(self): """Agent.__init_subclass__ does not wrap classmethod Templates @@ -1211,8 +1262,6 @@ def class_method(cls) -> str: agent = MyAgent() assert isinstance(agent.instance_method, Template) assert isinstance(MyAgent.class_method, Template) - # class_method should NOT have __history__ set - assert not hasattr(MyAgent.class_method, "__history__") def test_template_formatting_scoped(): diff --git a/tests/test_persistent_agent.py b/tests/test_persistent_agent.py new file mode 100644 index 00000000..2c4458f4 --- /dev/null +++ b/tests/test_persistent_agent.py @@ -0,0 +1,1612 @@ +"""Tests for PersistentAgent + PersistenceHandler + CompactionHandler. + +Checkpointing, compaction, crash recovery, nested calls, subclass state +persistence, and system prompt augmentation. +""" + +import dataclasses +import json +import sqlite3 +from collections import OrderedDict +from pathlib import Path +from typing import Any + +import pytest +from litellm import ModelResponse + +from effectful.handlers.llm import Agent, Template, Tool +from effectful.handlers.llm.completions import ( + AgentHistoryHandler, + LiteLLMProvider, + RetryLLMHandler, + ToolCallExecutionError, + completion, + get_agent_history, +) +from effectful.handlers.llm.persistence import ( + CompactionHandler, + PersistenceHandler, + PersistentAgent, +) +from effectful.handlers.llm.template import get_bound_agent +from effectful.ops.semantics import handler +from effectful.ops.syntax import ObjectInterpretation, implements +from effectful.ops.types import NotHandled + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_text_response(content: str) -> ModelResponse: + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + model="test-model", + ) + + +def make_tool_call_response( + tool_name: str, tool_args: str, tool_call_id: str = "call_1" +) -> ModelResponse: + return ModelResponse( + id="test", + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": tool_call_id, + "type": "function", + "function": {"name": tool_name, "arguments": tool_args}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + model="test-model", + ) + + +class MockCompletionHandler(ObjectInterpretation): + """Returns pre-configured responses and captures messages sent to the LLM.""" + + def __init__(self, responses: list[ModelResponse]): + self.responses = responses + self.call_count = 0 + self.received_messages: list[list] = [] + + @implements(completion) + def _completion(self, model, messages=None, **kwargs): + self.received_messages.append(list(messages) if messages else []) + response = self.responses[min(self.call_count, len(self.responses) - 1)] + self.call_count += 1 + return response + + +def read_checkpoint(tmp_path: Path, agent_id: str) -> dict: + """Read a checkpoint from the SQLite database.""" + db_path = tmp_path / "checkpoints.db" + conn = sqlite3.connect(str(db_path)) + row = conn.execute( + "SELECT handoff, state, history FROM checkpoints WHERE agent_id = ?", + (agent_id,), + ).fetchone() + conn.close() + if row is None: + raise FileNotFoundError(f"No checkpoint for {agent_id}") + return { + "agent_id": agent_id, + "handoff": row[0], + "state": json.loads(row[1]), + "history": json.loads(row[2]), + } + + +def has_checkpoint(tmp_path: Path, agent_id: str) -> bool: + """Check if a checkpoint exists in the SQLite database.""" + db_path = tmp_path / "checkpoints.db" + if not db_path.exists(): + return False + conn = sqlite3.connect(str(db_path)) + row = conn.execute( + "SELECT 1 FROM checkpoints WHERE agent_id = ?", + (agent_id,), + ).fetchone() + conn.close() + return row is not None + + +# --------------------------------------------------------------------------- +# Test agents +# --------------------------------------------------------------------------- + + +class ChatBot(PersistentAgent): + """You are a persistent chat bot for testing.""" + + @Template.define + def send(self, user_input: str) -> str: + """User says: {user_input}""" + raise NotHandled + + +@dataclasses.dataclass +class StatefulBot(PersistentAgent): + """You are a stateful bot that tracks learned patterns.""" + + __agent_id__ = "StatefulBot" + + learned_patterns: list[str] = dataclasses.field(default_factory=list) + call_count: int = 0 + + @Template.define + def send(self, user_input: str) -> str: + """User says: {user_input}""" + raise NotHandled + + +class NestedBot(PersistentAgent): + """You are a nested-call test bot.""" + + @Template.define + def inner_check(self, payload: str) -> str: + """Check: {payload}. Do not use tools.""" + raise NotHandled + + @Tool.define + def check_tool(self, payload: str) -> str: + """Check payload by calling an inner template.""" + return self.inner_check(payload) + + @Template.define + def outer(self, payload: str) -> str: + """Call `check_tool` for: {payload}, then return final answer.""" + raise NotHandled + + +# --------------------------------------------------------------------------- +# Tests: Agent.__agent_id__ +# --------------------------------------------------------------------------- + + +class TestAgentId: + """All Agent subclasses get agent_id.""" + + def test_plain_agent_defaults_to_id(self): + class PlainAgent(Agent): + """Plain.""" + + @Template.define + def ask(self, q: str) -> str: + """Q: {q}""" + raise NotHandled + + agent = PlainAgent() + assert agent.__agent_id__ == str(id(agent)) + + def test_persistent_agent_requires_agent_id(self): + bot = ChatBot(agent_id="my-chatbot") + assert bot.__agent_id__ == "my-chatbot" + + def test_dataclass_class_level_id(self): + """Dataclass subclasses can set __agent_id__ as a class attribute.""" + bot = StatefulBot() + assert bot.__agent_id__ == "StatefulBot" + + def test_bound_template_has_agent_via_context(self): + bot = ChatBot(agent_id="ChatBot") + bound = bot.send + assert get_bound_agent(bound) is bot + + +# --------------------------------------------------------------------------- +# Tests: basic persistence (PersistenceHandler) +# --------------------------------------------------------------------------- + + +class TestCheckpointing: + """PersistenceHandler save/load round-trip correctly.""" + + def test_save_creates_file(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + path = persist.save(bot) + assert path.exists() + assert path.suffix == ".db" + + def test_save_round_trip_empty(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "ChatBot") + assert len(data["history"]) == 0 + assert data["handoff"] == "" + + def test_save_round_trip_with_history(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + history = get_agent_history(bot.__agent_id__) + history["msg1"] = { + "id": "msg1", + "role": "user", + "content": "hello", + } + persist.save(bot, handoff="working on X") + + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "working on X" + assert len(data["history"]) == 1 + assert data["history"][0]["content"] == "hello" + + def test_atomic_write(self, tmp_path: Path): + """Checkpoint write uses SQLite transactions for atomicity.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + path = persist.save(bot) + assert path.exists() + data = read_checkpoint(tmp_path, "ChatBot") + assert data["agent_id"] == "ChatBot" + + def test_custom_agent_id(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + + class CustomBot(PersistentAgent): + """Custom.""" + + @Template.define + def ask(self, q: str) -> str: + """Q: {q}""" + raise NotHandled + + bot = CustomBot(agent_id="custom-bot") + with handler(AgentHistoryHandler()): + persist.save(bot) + data = read_checkpoint(tmp_path, "custom-bot") + assert data["agent_id"] == "custom-bot" + + +# --------------------------------------------------------------------------- +# Tests: subclass state persistence (checkpoint_state / restore_state) +# --------------------------------------------------------------------------- + + +class TestSubclassStatePersistence: + """Dataclass fields on subclasses are automatically persisted.""" + + def test_dataclass_fields_round_trip(self, tmp_path: Path): + """Dataclass state survives a save and is visible in the DB.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = StatefulBot() + bot.learned_patterns = ["pattern A", "pattern B"] + bot.call_count = 5 + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "StatefulBot") + assert data["state"]["learned_patterns"] == ["pattern A", "pattern B"] + assert data["state"]["call_count"] == 5 + + def test_non_dataclass_has_empty_state(self): + """Non-dataclass subclass returns empty state dict.""" + bot = ChatBot(agent_id="ChatBot") + assert bot.checkpoint_state() == {} + + def test_non_serializable_fields_skipped(self, tmp_path: Path): + @dataclasses.dataclass + class WeirdBot(PersistentAgent): + """Bot with a non-serializable field.""" + + __agent_id__ = "WeirdBot" + + callback: object = dataclasses.field(default=None) + name: str = "test" + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = WeirdBot() + bot.callback = lambda x: x # not JSON serializable + bot.name = "Alice" + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "WeirdBot") + assert data["state"]["name"] == "Alice" + # callback is not JSON serializable, so it should be skipped + assert "callback" not in data["state"] + + def test_custom_checkpoint_restore(self, tmp_path: Path): + """Users can override checkpoint_state / restore_state.""" + + class CustomBot(PersistentAgent): + """Custom serialisation bot.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.data = {"counter": 0} + + def checkpoint_state(self): + return {"data": self.data} + + def restore_state(self, state): + self.data = state.get("data", {"counter": 0}) + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = CustomBot(agent_id="CustomBot") + bot.data["counter"] = 42 + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "CustomBot") + assert data["state"]["data"]["counter"] == 42 + + def test_state_saved_in_checkpoint_file(self, tmp_path: Path): + """The checkpoint DB contains state with subclass fields.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = StatefulBot() + bot.learned_patterns = ["X"] + bot.call_count = 3 + with handler(AgentHistoryHandler()): + persist.save(bot) + + data = read_checkpoint(tmp_path, "StatefulBot") + assert "state" in data + assert data["state"]["learned_patterns"] == ["X"] + assert data["state"]["call_count"] == 3 + + +# --------------------------------------------------------------------------- +# Tests: automatic checkpointing around template calls +# --------------------------------------------------------------------------- + + +class TestAutomaticCheckpointing: + """Template calls on PersistentAgent trigger auto-checkpointing.""" + + def test_checkpoint_saved_after_successful_call(self, tmp_path: Path): + mock = MockCompletionHandler([make_text_response("hello")]) + bot = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("hi") + + data = read_checkpoint(tmp_path, "ChatBot") + assert len(data["history"]) > 0 + assert data["handoff"] == "" + + def test_checkpoint_saved_on_exception(self, tmp_path: Path): + class FailingMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *args, **kwargs): + raise RuntimeError("boom") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError, match="boom"): + with ( + handler(LiteLLMProvider()), + handler(FailingMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("hi") + + data = read_checkpoint(tmp_path, "ChatBot") + assert "Executing send" in data["handoff"] + + def test_handoff_describes_current_call(self, tmp_path: Path): + """Before the template runs, handoff records what's in progress.""" + handoff_during_call = [] + + class SpyMock(ObjectInterpretation): + @implements(completion) + def _completion(self_, model, messages=None, **kwargs): + data = read_checkpoint(tmp_path, "ChatBot") + handoff_during_call.append(data["handoff"]) + return make_text_response("ok") + + bot = ChatBot(agent_id="ChatBot") + with ( + handler(LiteLLMProvider()), + handler(SpyMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("hello") + + assert len(handoff_during_call) == 1 + assert "Executing send" in handoff_during_call[0] + + def test_history_persists_across_sessions(self, tmp_path: Path): + """A 'restart' (new handler + agent) sees prior history.""" + mock = MockCompletionHandler([make_text_response("reply1")]) + bot = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("first message") + + data_after_first = read_checkpoint(tmp_path, "ChatBot") + history_len_first = len(data_after_first["history"]) + + # "Restart" — new handler + new agent instance + mock2 = MockCompletionHandler([make_text_response("reply after restart")]) + bot2 = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock2), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot2.send("second message") + + data_after_second = read_checkpoint(tmp_path, "ChatBot") + assert len(data_after_second["history"]) > history_len_first + + def test_second_call_sees_prior_history(self, tmp_path: Path): + mock = MockCompletionHandler( + [make_text_response("r1"), make_text_response("r2")] + ) + bot = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("a") + bot.send("b") + + assert len(mock.received_messages[1]) > len(mock.received_messages[0]) + + def test_dataclass_state_saved_around_template_calls(self, tmp_path: Path): + mock = MockCompletionHandler([make_text_response("ok")]) + bot = StatefulBot() + bot.call_count = 7 + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("test") + + data = read_checkpoint(tmp_path, "StatefulBot") + assert data["state"]["call_count"] == 7 + + +# --------------------------------------------------------------------------- +# Tests: crash recovery +# --------------------------------------------------------------------------- + + +class TestCrashRecovery: + """Handoff notes enable resumption after crashes.""" + + def test_handoff_survives_crash(self, tmp_path: Path): + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *args, **kwargs): + raise RuntimeError("process killed") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("important task") + + data = read_checkpoint(tmp_path, "ChatBot") + assert "Executing send" in data["handoff"] + + def test_system_prompt_includes_handoff(self, tmp_path: Path): + """After a crash, the next call's system prompt includes the handoff.""" + + # Simulate crash + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *args, **kwargs): + raise RuntimeError("crash") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("important task") + + # Next session: spy on system prompt + system_prompts = [] + + class SpyMock(ObjectInterpretation): + @implements(completion) + def _completion(self_, model, messages=None, **kwargs): + system_prompts.extend( + m.get("content", "") + for m in (messages or []) + if m.get("role") == "system" + ) + return make_text_response("resumed") + + bot2 = ChatBot(agent_id="ChatBot") + with ( + handler(LiteLLMProvider()), + handler(SpyMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot2.send("resume") + + assert any("[HANDOFF FROM PRIOR SESSION]" in p for p in system_prompts) + + def test_handoff_cleared_on_success(self, tmp_path: Path): + # Create crash checkpoint + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *a, **kw): + raise RuntimeError("crash") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("crash task") + + # Successful run clears handoff + mock = MockCompletionHandler([make_text_response("done")]) + bot2 = ChatBot(agent_id="ChatBot") + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot2.send("new task") + + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "" + + def test_dataclass_state_survives_crash(self, tmp_path: Path): + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *a, **kw): + raise RuntimeError("crash") + + bot = StatefulBot() + bot.learned_patterns = ["important insight"] + bot.call_count = 3 + + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("boom") + + data = read_checkpoint(tmp_path, "StatefulBot") + assert data["state"]["learned_patterns"] == ["important insight"] + assert data["state"]["call_count"] == 3 + + +# --------------------------------------------------------------------------- +# Tests: nested template calls +# --------------------------------------------------------------------------- + + +class TestNestedCalls: + """Only outermost template call triggers checkpointing.""" + + def test_nested_template_via_tool_completes(self, tmp_path: Path): + mock = MockCompletionHandler( + [ + make_tool_call_response("self__check_tool", '{"payload": "demo"}'), + make_text_response("inner result"), + make_text_response("outer result"), + ] + ) + bot = NestedBot(agent_id="NestedBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = bot.outer("demo") + + assert result == "outer result" + + def test_nested_call_does_not_double_checkpoint(self, tmp_path: Path): + save_count = 0 + original_save = PersistenceHandler.save + + def counting_save(self, agent, handoff=""): + nonlocal save_count + save_count += 1 + return original_save(self, agent, handoff=handoff) + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__check_tool", '{"payload": "demo"}'), + make_text_response("inner"), + make_text_response("outer"), + ] + ) + bot = NestedBot(agent_id="NestedBot") + PersistenceHandler.save = counting_save # type: ignore[method-assign] + try: + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.outer("demo") + finally: + PersistenceHandler.save = original_save # type: ignore[method-assign] + + # Should be exactly 2: one before call, one after + assert save_count == 2 + + def test_handoff_cleared_after_nested_success(self, tmp_path: Path): + mock = MockCompletionHandler( + [ + make_tool_call_response("self__check_tool", '{"payload": "demo"}'), + make_text_response("inner"), + make_text_response("outer"), + ] + ) + bot = NestedBot(agent_id="NestedBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.outer("demo") + + data = read_checkpoint(tmp_path, "NestedBot") + assert data["handoff"] == "" + + +# --------------------------------------------------------------------------- +# Tests: context compaction (CompactionHandler) +# --------------------------------------------------------------------------- + + +class TestContextCompaction: + """CompactionHandler compacts agent history after template calls.""" + + def test_compact_reduces_history(self): + history: OrderedDict[str, Any] = OrderedDict() + for i in range(10): + history[f"msg{i}"] = { + "id": f"msg{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Message {i}", + } + + compaction = CompactionHandler(max_history_len=6) + mock = MockCompletionHandler( + [make_text_response("Summary of prior conversation.")] + ) + provider = LiteLLMProvider() + with handler(provider), handler(mock): + stored = get_agent_history("PlainBot") + stored.update(history) + compaction._compact("PlainBot", stored) + + result = provider._histories["PlainBot"] + assert len(result) < 10 + first_msg = next(iter(result.values())) + assert "CONTEXT SUMMARY" in first_msg["content"] + + def test_compaction_preserves_recent_messages(self): + history: OrderedDict[str, Any] = OrderedDict() + for i in range(10): + history[f"msg{i}"] = { + "id": f"msg{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Message {i}", + } + + compaction = CompactionHandler(max_history_len=6) + keep_recent = max(6 // 2, 4) + mock = MockCompletionHandler([make_text_response("Summary.")]) + provider = LiteLLMProvider() + with handler(provider), handler(mock): + stored = get_agent_history("ChatBot") + stored.update(history) + compaction._compact("ChatBot", stored) + + result = provider._histories["ChatBot"] + remaining_ids = list(result.keys()) + for i in range(10 - keep_recent, 10): + assert f"msg{i}" in remaining_ids + + def test_compaction_triggered_by_template_call(self, tmp_path: Path): + bot = ChatBot(agent_id="ChatBot") + provider = LiteLLMProvider() + + with handler(provider): + history = get_agent_history(bot.__agent_id__) + for i in range(6): + history[f"old{i}"] = { + "id": f"old{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Old message {i}", + } + + mock = MockCompletionHandler( + [ + make_text_response("new reply"), + make_text_response("Summary of old conversation."), + ] + ) + with ( + handler(provider), + handler(mock), + handler(CompactionHandler(max_history_len=4)), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("trigger compaction") + + result: OrderedDict[str, Any] = provider._histories.get( + "ChatBot", OrderedDict() + ) + assert len(result) <= 4 + 4 + + def test_compaction_works_on_plain_agent(self): + """CompactionHandler works on any Agent, not just PersistentAgent.""" + + class PlainBot(Agent): + """Plain bot.""" + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + bot = PlainBot() + provider = LiteLLMProvider() + + with handler(provider): + history = get_agent_history(bot.__agent_id__) + for i in range(10): + history[f"msg{i}"] = { + "id": f"msg{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Message {i}", + } + + mock = MockCompletionHandler( + [make_text_response("reply"), make_text_response("Summary.")] + ) + with ( + handler(provider), + handler(mock), + handler(CompactionHandler(max_history_len=4)), + ): + bot.send("trigger") + + result = provider._histories.get(bot.__agent_id__, {}) + assert len(result) > 0, "history should not be empty after compaction" + first_msg = next(iter(result.values())) + assert "CONTEXT SUMMARY" in first_msg["content"] + assert len(result) <= 4 + 4 + + def test_compaction_triggered_naturally_on_plain_agent(self): + """CompactionHandler compacts after enough calls accumulate history. + + Makes multiple template calls on a plain Agent so that history + exceeds max_history_len, then verifies compaction fires and + produces a summary message. + """ + + class PlainBot(Agent): + """Plain bot.""" + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + # 4 calls × ~3 msgs each (system+user+assistant) = ~12 msgs + # Compaction threshold is 6, so it should trigger. + responses = [make_text_response(f"reply-{i}") for i in range(4)] + # Extra response for the summarize_context call during compaction + responses.append(make_text_response("Summary of conversation.")) + mock = MockCompletionHandler(responses) + + bot = PlainBot() + provider = LiteLLMProvider() + with ( + handler(provider), + handler(mock), + handler(CompactionHandler(max_history_len=6)), + ): + for i in range(4): + bot.send(f"message-{i}") + + history = provider._histories.get(bot.__agent_id__, {}) + # Should have been compacted: summary + recent messages + first_msg = next(iter(history.values())) + assert "CONTEXT SUMMARY" in first_msg["content"] + + def test_compaction_does_not_split_tool_use_tool_result_pairs(self): + """Compaction must not split tool_use/tool_result message pairs. + + If the cut point falls between an assistant message with tool_use + blocks and the corresponding tool_result message, the Anthropic API + rejects the conversation. This test constructs a history where the + naive positional split would do exactly that and asserts that both + messages end up on the same side of the cut. + """ + history: OrderedDict[str, Any] = OrderedDict() + + for i in range(7): + history[f"msg{i}"] = { + "id": f"msg{i}", + "role": "user" if i % 2 == 0 else "assistant", + "content": f"Message {i}", + } + + # msg7: assistant with tool_use (will be last item in old_items) + history["msg7"] = { + "id": "msg7", + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_call_xyz", + "name": "check_tool", + "input": {"payload": "test"}, + } + ], + } + + # msg8: tool_result (will be first item in recent_items) + history["msg8"] = { + "id": "msg8", + "role": "tool", + "tool_call_id": "tool_call_xyz", + "content": "tool result here", + } + + # msg9, msg10, msg11: padding so recent has 4 items + history["msg9"] = { + "id": "msg9", + "role": "assistant", + "content": "Response after tool", + } + history["msg10"] = { + "id": "msg10", + "role": "user", + "content": "Follow up question", + } + history["msg11"] = { + "id": "msg11", + "role": "assistant", + "content": "Final answer", + } + + compaction = CompactionHandler(max_history_len=8) + mock = MockCompletionHandler( + [make_text_response("Summary of prior conversation.")] + ) + provider = LiteLLMProvider() + with handler(provider), handler(mock): + stored = get_agent_history("ToolPairBot") + stored.update(history) + compaction._compact("ToolPairBot", stored) + + result = provider._histories["ToolPairBot"] + result_items = list(result.values()) + + # After compaction, there must be no orphaned tool_result messages. + tool_use_ids: set[str] = set() + for msg in result_items: + content = msg.get("content", "") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_use": + tool_use_ids.add(block["id"]) + + for msg in result_items: + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id", "") + assert tc_id in tool_use_ids, ( + f"Orphaned tool_result with tool_call_id={tc_id!r} after " + f"compaction — the matching tool_use was discarded. " + f"Remaining messages: {[m.get('id') for m in result_items]}" + ) + + def test_compaction_on_plain_agent_preserves_functionality(self): + """After compaction, the plain Agent still works for subsequent calls.""" + + class PlainBot(Agent): + """Plain bot.""" + + @Template.define + def send(self, msg: str) -> str: + """Say: {msg}""" + raise NotHandled + + responses = [make_text_response(f"reply-{i}") for i in range(4)] + # Compaction summary call fires after the 4th reply + responses.append(make_text_response("Summary.")) + # Then the 5th send() call + responses.append(make_text_response("reply-4")) + mock = MockCompletionHandler(responses) + + bot = PlainBot() + provider = LiteLLMProvider() + with ( + handler(provider), + handler(mock), + handler(CompactionHandler(max_history_len=6)), + ): + for i in range(4): + bot.send(f"msg-{i}") + # This call happens after compaction + result = bot.send("after-compaction") + + assert result == "reply-4" + + +# --------------------------------------------------------------------------- +# Tests: system prompt +# --------------------------------------------------------------------------- + + +class TestSystemPrompt: + """System prompt of PersistentAgent includes class docstring.""" + + def test_base_docstring_used(self): + bot = ChatBot(agent_id="ChatBot") + assert "persistent chat bot" in bot.__system_prompt__ + + def test_no_handoff_initially(self): + bot = ChatBot(agent_id="ChatBot") + assert "[HANDOFF" not in bot.__system_prompt__ + + +# --------------------------------------------------------------------------- +# Tests: agent isolation +# --------------------------------------------------------------------------- + + +class TestAgentIsolation: + """Multiple PersistentAgent instances are independent in the handler.""" + + def test_two_agents_independent(self, tmp_path: Path): + bot1 = ChatBot(agent_id="bot1") + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + with handler(AgentHistoryHandler()): + persist.save(bot1, handoff="bot1 work") + + # bot2 was never saved — should not exist in DB + assert not has_checkpoint(tmp_path, "bot2") + data = read_checkpoint(tmp_path, "bot1") + assert data["handoff"] == "bot1 work" + + def test_same_db_different_agent_id(self, tmp_path: Path): + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot_a = ChatBot(agent_id="alpha") + bot_b = ChatBot(agent_id="beta") + + with handler(AgentHistoryHandler()): + persist.save(bot_a, handoff="alpha work") + persist.save(bot_b, handoff="beta work") + + data_a = read_checkpoint(tmp_path, "alpha") + data_b = read_checkpoint(tmp_path, "beta") + assert data_a["handoff"] == "alpha work" + assert data_b["handoff"] == "beta work" + + +# --------------------------------------------------------------------------- +# Tests: compatibility with RetryLLMHandler +# --------------------------------------------------------------------------- + + +class TestRetryCompatibility: + """PersistentAgent works with RetryLLMHandler and PersistenceHandler.""" + + def test_retry_then_success(self, tmp_path: Path): + mock = MockCompletionHandler( + [ + make_text_response('"not_an_int"'), + make_text_response('{"value": 42}'), + ] + ) + + class NumberBot(PersistentAgent): + """You are a number bot.""" + + @Template.define + def pick(self) -> int: + """Pick a number.""" + raise NotHandled + + bot = NumberBot(agent_id="NumberBot") + with ( + handler(LiteLLMProvider()), + handler(RetryLLMHandler()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = bot.pick() + + assert result == 42 + data = read_checkpoint(tmp_path, "NumberBot") + assert data["handoff"] == "" + + +# --------------------------------------------------------------------------- +# Tests: PersistenceHandler is optional +# --------------------------------------------------------------------------- + + +class TestWithoutHandler: + """PersistentAgent works without PersistenceHandler — no auto-checkpointing.""" + + def test_agent_works_without_persistence_handler(self): + mock = MockCompletionHandler([make_text_response("hello")]) + bot = ChatBot(agent_id="ChatBot") + + with handler(LiteLLMProvider()), handler(mock): + result = bot.send("hi") + + assert result == "hello" + + +# --------------------------------------------------------------------------- +# Tests: nested calls with failures + persistence +# --------------------------------------------------------------------------- + + +class TestNestedCallFailuresWithPersistence: + """Nested tool calls that fail should not corrupt persistence state.""" + + def test_nested_tool_failure_still_checkpoints(self, tmp_path: Path): + """If a nested tool raises, the outermost handler saves a crash checkpoint.""" + + class FailingBot(PersistentAgent): + """Bot whose tool always fails.""" + + @Template.define + def inner(self, payload: str) -> str: + """Check: {payload}""" + raise NotHandled + + @Tool.define + def failing_tool(self, payload: str) -> str: + """Check payload — always raises.""" + raise RuntimeError("tool exploded") + + @Template.define + def outer(self, payload: str) -> str: + """Call `failing_tool` for: {payload}.""" + raise NotHandled + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__failing_tool", '{"payload": "boom"}'), + ] + ) + bot = FailingBot(agent_id="FailingBot") + + with pytest.raises(ToolCallExecutionError, match="tool exploded"): + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.outer("go") + + data = read_checkpoint(tmp_path, "FailingBot") + assert "Executing outer" in data["handoff"] + + def test_nested_tool_failure_then_recovery(self, tmp_path: Path): + """After a nested tool failure, next session resumes with handoff.""" + mock_crash = MockCompletionHandler( + [ + make_tool_call_response("self__check_tool", '{"payload": "crash"}'), + ] + ) + + class CrashInnerBot(PersistentAgent): + """Bot with crashing inner tool.""" + + call_count = 0 + + @Template.define + def inner_check(self, payload: str) -> str: + """Check: {payload}""" + raise NotHandled + + @Tool.define + def check_tool(self, payload: str) -> str: + """Check payload.""" + self.call_count += 1 + if self.call_count == 1: + raise RuntimeError("first call fails") + return self.inner_check(payload) + + @Template.define + def outer(self, payload: str) -> str: + """Call `check_tool` for: {payload}, then return answer.""" + raise NotHandled + + bot = CrashInnerBot(agent_id="CrashInnerBot") + + # Session 1: crash + with pytest.raises(ToolCallExecutionError, match="first call fails"): + with ( + handler(LiteLLMProvider()), + handler(mock_crash), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.outer("task") + + # Session 2: successful recovery + system_prompts: list[str] = [] + + class SpyMock(ObjectInterpretation): + @implements(completion) + def _completion(self_, model, messages=None, **kwargs): + system_prompts.extend( + m.get("content", "") + for m in (messages or []) + if m.get("role") == "system" + ) + return make_text_response("recovered") + + bot2 = CrashInnerBot(agent_id="CrashInnerBot") + with ( + handler(LiteLLMProvider()), + handler(SpyMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = bot2.outer("retry") + + assert result == "recovered" + assert any("[HANDOFF FROM PRIOR SESSION]" in p for p in system_prompts) + + +# --------------------------------------------------------------------------- +# Tests: Agent and PersistentAgent coexistence +# --------------------------------------------------------------------------- + + +class TestAgentPersistentAgentCoexistence: + """Plain Agent and PersistentAgent work side-by-side.""" + + def test_plain_and_persistent_agent_in_same_handler(self, tmp_path: Path): + """Both agent types work under the same LiteLLMProvider.""" + + class PlainBot(Agent): + """Plain bot.""" + + @Template.define + def ask(self, q: str) -> str: + """Q: {q}""" + raise NotHandled + + mock = MockCompletionHandler( + [make_text_response("plain-reply"), make_text_response("persist-reply")] + ) + plain = PlainBot() + persistent = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + r1 = plain.ask("hello") + r2 = persistent.send("hello") + + assert r1 == "plain-reply" + assert r2 == "persist-reply" + # Only the PersistentAgent gets a checkpoint entry + assert has_checkpoint(tmp_path, "ChatBot") + assert not has_checkpoint(tmp_path, plain.__agent_id__) + + def test_persistent_agent_tool_calls_plain_agent(self, tmp_path: Path): + """A PersistentAgent's tool can delegate to a plain Agent. + + Mock response sequence: + 0: outer → tool_call(self__delegate, {"q": "sub-task"}) + 1: inner plain agent → "inner-answer" + 2: outer → "final-answer" (after getting tool result) + """ + + class InnerPlainAgent(Agent): + """Inner helper agent.""" + + @Template.define + def answer(self, q: str) -> str: + """Answer: {q}""" + raise NotHandled + + inner = InnerPlainAgent() + + class OuterPersistent(PersistentAgent): + """Outer persistent agent that delegates via tool.""" + + @Tool.define + def delegate(self, q: str) -> str: + """Delegate a sub-question to an inner agent.""" + return inner.answer(q) + + @Template.define + def process(self, task: str) -> str: + """Process: {task}. Use `delegate` for sub-questions.""" + raise NotHandled + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__delegate", '{"q": "sub-task"}'), + make_text_response("inner-answer"), + make_text_response("final-answer"), + ] + ) + outer = OuterPersistent(agent_id="outer") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = outer.process("do it") + + assert result == "final-answer" + data = read_checkpoint(tmp_path, "outer") + assert data["agent_id"] == "outer" + + def test_plain_agent_tool_calls_persistent_agent(self, tmp_path: Path): + """A plain Agent's tool can delegate to a PersistentAgent. + + Mock response sequence: + 0: outer plain → tool_call(self__delegate, {"q": "sub"}) + 1: inner persistent → "persisted-answer" + 2: outer plain → "done" (after getting tool result) + """ + + inner = ChatBot(agent_id="inner-bot") + + class OuterPlain(Agent): + """Plain agent that delegates to a persistent agent.""" + + @Tool.define + def delegate(self, q: str) -> str: + """Delegate to persistent bot.""" + return inner.send(q) + + @Template.define + def run(self, task: str) -> str: + """Run: {task}. Use `delegate` for sub-tasks.""" + raise NotHandled + + mock = MockCompletionHandler( + [ + make_tool_call_response("self__delegate", '{"q": "sub"}'), + make_text_response("persisted-answer"), + make_text_response("done"), + ] + ) + outer = OuterPlain() + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = outer.run("go") + + assert result == "done" + + def test_two_persistent_agents_cooperate(self, tmp_path: Path): + """Two PersistentAgents with different IDs work independently. + + Mock response sequence: + 0: planner → "the plan" + 1: executor → "executed" + """ + mock = MockCompletionHandler( + [make_text_response("the plan"), make_text_response("executed")] + ) + + planner = ChatBot(agent_id="planner") + executor = ChatBot(agent_id="executor") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + plan = planner.send("make a plan") + result = executor.send(f"execute: {plan}") + + assert plan == "the plan" + assert result == "executed" + + # Each has independent history + planner_data = read_checkpoint(tmp_path, "planner") + executor_data = read_checkpoint(tmp_path, "executor") + assert len(planner_data["history"]) > 0 + assert len(executor_data["history"]) > 0 + + +# --------------------------------------------------------------------------- +# Tests: SQLite crash tolerance +# --------------------------------------------------------------------------- + + +class TestSQLiteCrashTolerance: + """SQLite-backed persistence is crash tolerant and restartable.""" + + def test_wal_mode_enabled(self, tmp_path: Path): + """Database uses WAL journal mode for crash tolerance.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + with handler(AgentHistoryHandler()): + persist.save(bot) + + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + conn.close() + assert mode == "wal" + + def test_database_survives_incomplete_write(self, tmp_path: Path): + """Prior committed data survives if a subsequent write is interrupted.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + + # First write: commit a valid checkpoint + with handler(AgentHistoryHandler()): + history = get_agent_history(bot.__agent_id__) + history["msg1"] = {"id": "msg1", "role": "user", "content": "hello"} + persist.save(bot) + + # Simulate an interrupted write by opening a connection, beginning + # a write, then rolling back (mimicking a crash before commit). + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + conn.execute( + "UPDATE checkpoints SET handoff = 'interrupted' WHERE agent_id = ?", + ("ChatBot",), + ) + # Do NOT commit — simulate crash by closing without commit + conn.close() + + # The prior committed data should still be intact + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "" + assert len(data["history"]) == 1 + + def test_database_integrity_after_multiple_saves(self, tmp_path: Path): + """Multiple rapid saves produce a consistent database.""" + mock = MockCompletionHandler( + [make_text_response(f"reply-{i}") for i in range(3)] + ) + bot = ChatBot(agent_id="ChatBot") + + with ( + handler(LiteLLMProvider()), + handler(mock), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + for i in range(3): + bot.send(f"msg-{i}") + + # Verify integrity + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + conn.close() + assert result == "ok" + + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "" + assert len(data["history"]) > 0 + + def test_recovery_from_crash_mid_template_call(self, tmp_path: Path): + """After a crash mid-call, the DB has a handoff and can be reloaded.""" + + class CrashMock(ObjectInterpretation): + @implements(completion) + def _completion(self, *args, **kwargs): + raise RuntimeError("process killed") + + bot = ChatBot(agent_id="ChatBot") + with pytest.raises(RuntimeError): + with ( + handler(LiteLLMProvider()), + handler(CrashMock()), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + bot.send("important task") + + # Verify the DB is consistent and has the crash handoff + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + assert result == "ok" + conn.close() + + data = read_checkpoint(tmp_path, "ChatBot") + assert "Executing send" in data["handoff"] + + # New process can load and resume + mock2 = MockCompletionHandler([make_text_response("recovered")]) + bot2 = ChatBot(agent_id="ChatBot") + with ( + handler(LiteLLMProvider()), + handler(mock2), + handler(PersistenceHandler(tmp_path / "checkpoints.db")), + ): + result = bot2.send("resume") + + assert result == "recovered" + data2 = read_checkpoint(tmp_path, "ChatBot") + assert data2["handoff"] == "" + + def test_multiple_agents_single_db(self, tmp_path: Path): + """All agents share one DB file, not separate JSON files.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bots = [ChatBot(agent_id=f"bot-{i}") for i in range(5)] + + with handler(AgentHistoryHandler()): + for bot in bots: + persist.save(bot) + + # Only one DB file, no JSON files + assert (tmp_path / "checkpoints.db").exists() + assert not list(tmp_path.glob("*.json")) + + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0] + conn.close() + assert count == 5 + + def test_save_is_idempotent(self, tmp_path: Path): + """Saving the same agent multiple times updates rather than duplicates.""" + persist = PersistenceHandler(tmp_path / "checkpoints.db") + bot = ChatBot(agent_id="ChatBot") + + with handler(AgentHistoryHandler()): + persist.save(bot) + persist.save(bot, handoff="updated handoff") + persist.save(bot, handoff="final handoff") + + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + count = conn.execute( + "SELECT COUNT(*) FROM checkpoints WHERE agent_id = ?", ("ChatBot",) + ).fetchone()[0] + conn.close() + assert count == 1 + + data = read_checkpoint(tmp_path, "ChatBot") + assert data["handoff"] == "final handoff" + + +# --------------------------------------------------------------------------- +# Tests: Thread safety +# --------------------------------------------------------------------------- + + +class TestThreadSafety: + """PersistenceHandler is safe to use from multiple threads.""" + + def test_concurrent_saves_from_threads(self, tmp_path: Path): + """Multiple threads saving different agents concurrently don't corrupt the DB.""" + import concurrent.futures + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + errors: list[Exception] = [] + + def save_agent(agent_id: str) -> None: + try: + bot = ChatBot(agent_id=agent_id) + hist = AgentHistoryHandler() + with handler(hist): + history = get_agent_history(agent_id) + for j in range(3): + history[f"{agent_id}-msg{j}"] = { + "id": f"{agent_id}-msg{j}", + "role": "user", + "content": f"msg {j} from {agent_id}", + } + persist.save(bot) + except Exception as e: + errors.append(e) + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool: + futures = [pool.submit(save_agent, f"agent-{i}") for i in range(8)] + concurrent.futures.wait(futures) + + assert errors == [], f"Thread errors: {errors}" + + # Verify DB integrity and all agents saved + conn = sqlite3.connect(str(tmp_path / "checkpoints.db")) + result = conn.execute("PRAGMA integrity_check").fetchone()[0] + assert result == "ok" + count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0] + conn.close() + assert count == 8 + + def test_concurrent_reads_and_writes(self, tmp_path: Path): + """Readers and writers can operate concurrently without errors.""" + import concurrent.futures + + persist = PersistenceHandler(tmp_path / "checkpoints.db") + errors: list[Exception] = [] + + # Seed some data + with handler(AgentHistoryHandler()): + for i in range(4): + bot = ChatBot(agent_id=f"agent-{i}") + persist.save(bot) + + def writer(agent_id: str) -> None: + try: + bot = ChatBot(agent_id=agent_id) + with handler(AgentHistoryHandler()): + history = get_agent_history(agent_id) + history["update"] = { + "id": "update", + "role": "user", + "content": "updated", + } + persist.save(bot) + except Exception as e: + errors.append(e) + + def reader(agent_id: str) -> None: + try: + # Verify checkpoint is readable via direct DB access + read_checkpoint(tmp_path, agent_id) + except Exception as e: + errors.append(e) + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool: + futures = [] + for i in range(4): + futures.append(pool.submit(writer, f"agent-{i}")) + futures.append(pool.submit(reader, f"agent-{i}")) + concurrent.futures.wait(futures) + + assert errors == [], f"Thread errors: {errors}"