From 0fdb9b88decb04f19e3be51811707578e53903b0 Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Wed, 25 Mar 2026 06:12:58 +0000 Subject: [PATCH 1/5] Add diff and why subcommands (closes #4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - diff : structural comparison of two sessions via LCS-based phase alignment; reports divergence point, per-phase file/command differences, and failed vs passed outcomes - why [session-id] : traces the causal chain backwards from a target event using parent_id links, error→retry detection, path-reference matching, and user_prompt as root cause - 20 new tests (192 total passing) Co-authored-by: Ona --- src/agent_trace/cli.py | 14 +++ src/agent_trace/diff.py | 250 +++++++++++++++++++++++++++++++++++++++ src/agent_trace/why.py | 250 +++++++++++++++++++++++++++++++++++++++ tests/test_diff_why.py | 252 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 766 insertions(+) create mode 100644 src/agent_trace/diff.py create mode 100644 src/agent_trace/why.py create mode 100644 tests/test_diff_why.py diff --git a/src/agent_trace/cli.py b/src/agent_trace/cli.py index bf1c492..5357179 100644 --- a/src/agent_trace/cli.py +++ b/src/agent_trace/cli.py @@ -23,8 +23,10 @@ from .hooks import hook_main from .http_proxy import HTTPProxyServer from .cost import cmd_cost +from .diff import cmd_diff from .explain import cmd_explain from .jsonl_import import cmd_import +from .why import cmd_why from .models import EventType, SessionMeta, TraceEvent from .proxy import MCPProxy from .replay import format_event, format_summary, list_sessions, replay_session @@ -445,6 +447,16 @@ def build_parser() -> argparse.ArgumentParser: p_explain = sub.add_parser("explain", help="explain a session in plain English") p_explain.add_argument("session_id", nargs="?", help="session ID or prefix (default: latest)") + # diff + p_diff = sub.add_parser("diff", help="compare two sessions structurally") + p_diff.add_argument("session_a", help="first session ID or prefix") + p_diff.add_argument("session_b", help="second session ID or prefix") + + # why + p_why = sub.add_parser("why", help="trace the causal chain for a specific event") + p_why.add_argument("session_id", nargs="?", help="session ID or prefix (default: latest)") + p_why.add_argument("event_number", type=int, help="1-based event number (from replay output)") + # cost p_cost = sub.add_parser("cost", help="estimate token cost for a session") p_cost.add_argument("session_id", nargs="?", help="session ID or prefix (default: latest)") @@ -487,6 +499,8 @@ def main() -> None: "import": cmd_import, "explain": cmd_explain, "cost": cmd_cost, + "diff": cmd_diff, + "why": cmd_why, } handler = handlers.get(args.command) diff --git a/src/agent_trace/diff.py b/src/agent_trace/diff.py new file mode 100644 index 0000000..0bc9cc7 --- /dev/null +++ b/src/agent_trace/diff.py @@ -0,0 +1,250 @@ +"""Session diff: structural behavioral comparison of two sessions. + +Compares two sessions by their phase structure (from explain), finds the +divergence point, and reports differences in files touched, commands run, +outcomes, duration, and cost. +""" + +from __future__ import annotations + +import argparse +import sys +from dataclasses import dataclass +from typing import TextIO + +from .explain import Phase, build_phases, explain_session +from .models import EventType +from .store import TraceStore + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class PhaseDiff: + index: int + label_a: str + label_b: str + same_label: bool + files_only_a: list[str] + files_only_b: list[str] + cmds_only_a: list[str] + cmds_only_b: list[str] + a_failed: bool + b_failed: bool + + +@dataclass +class SessionDiff: + session_a: str + session_b: str + divergence_index: int # first phase index where behaviour differs (-1 = identical) + phase_diffs: list[PhaseDiff] + # Summary metrics + duration_a: float + duration_b: float + events_a: int + events_b: int + tool_calls_a: int + tool_calls_b: int + retries_a: int + retries_b: int + + +# --------------------------------------------------------------------------- +# LCS-based phase alignment +# --------------------------------------------------------------------------- + +def _lcs_indices(a: list[str], b: list[str]) -> list[tuple[int, int]]: + """Return LCS index pairs (i, j) where a[i] == b[j].""" + m, n = len(a), len(b) + dp = [[0] * (n + 1) for _ in range(m + 1)] + for i in range(m - 1, -1, -1): + for j in range(n - 1, -1, -1): + if a[i] == b[j]: + dp[i][j] = 1 + dp[i + 1][j + 1] + else: + dp[i][j] = max(dp[i + 1][j], dp[i][j + 1]) + + pairs: list[tuple[int, int]] = [] + i = j = 0 + while i < m and j < n: + if a[i] == b[j]: + pairs.append((i, j)) + i += 1 + j += 1 + elif dp[i + 1][j] >= dp[i][j + 1]: + i += 1 + else: + j += 1 + return pairs + + +def _phase_key(phase: Phase) -> str: + """Normalised key for LCS matching — use label text.""" + return phase.name.lower().strip() + + +# --------------------------------------------------------------------------- +# Diff computation +# --------------------------------------------------------------------------- + +def diff_sessions( + store: TraceStore, + session_a: str, + session_b: str, +) -> SessionDiff: + result_a = explain_session(store, session_a) + result_b = explain_session(store, session_b) + + phases_a = result_a.phases + phases_b = result_b.phases + + keys_a = [_phase_key(p) for p in phases_a] + keys_b = [_phase_key(p) for p in phases_b] + + # Align phases via LCS + aligned = _lcs_indices(keys_a, keys_b) + aligned_set_a = {i for i, _ in aligned} + aligned_set_b = {j for _, j in aligned} + + phase_diffs: list[PhaseDiff] = [] + divergence_index = -1 + + # Walk aligned pairs + for pair_idx, (i, j) in enumerate(aligned): + pa = phases_a[i] + pb = phases_b[j] + + files_a = set(pa.files_read + pa.files_written) + files_b = set(pb.files_read + pb.files_written) + cmds_a = set(pa.commands) + cmds_b = set(pb.commands) + + only_a_files = sorted(files_a - files_b) + only_b_files = sorted(files_b - files_a) + only_a_cmds = sorted(cmds_a - cmds_b) + only_b_cmds = sorted(cmds_b - cmds_a) + + differs = ( + only_a_files or only_b_files + or only_a_cmds or only_b_cmds + or pa.failed != pb.failed + ) + + if differs and divergence_index == -1: + divergence_index = pair_idx + + phase_diffs.append(PhaseDiff( + index=pair_idx, + label_a=pa.name, + label_b=pb.name, + same_label=(keys_a[i] == keys_b[j]), + files_only_a=only_a_files, + files_only_b=only_b_files, + cmds_only_a=only_a_cmds, + cmds_only_b=only_b_cmds, + a_failed=pa.failed, + b_failed=pb.failed, + )) + + # Phases only in A or only in B count as divergence + if aligned_set_a != set(range(len(phases_a))) or aligned_set_b != set(range(len(phases_b))): + if divergence_index == -1: + divergence_index = len(phase_diffs) + + meta_a = store.load_meta(session_a) + meta_b = store.load_meta(session_b) + + return SessionDiff( + session_a=session_a, + session_b=session_b, + divergence_index=divergence_index, + phase_diffs=phase_diffs, + duration_a=result_a.total_duration, + duration_b=result_b.total_duration, + events_a=result_a.total_events, + events_b=result_b.total_events, + tool_calls_a=meta_a.tool_calls, + tool_calls_b=meta_b.tool_calls, + retries_a=result_a.total_retries, + retries_b=result_b.total_retries, + ) + + +# --------------------------------------------------------------------------- +# Formatting +# --------------------------------------------------------------------------- + +def _fmt_duration(s: float) -> str: + if s < 60: + return f"{s:.0f}s" + return f"{int(s) // 60}m {int(s) % 60:02d}s" + + +def format_diff(result: SessionDiff, out: TextIO = sys.stdout) -> None: + w = out.write + a = result.session_a[:12] + b = result.session_b[:12] + + w(f"\nComparing: {a} vs {b}\n\n") + + if result.divergence_index == -1: + w("Sessions are structurally identical.\n\n") + else: + w(f"Diverged at phase {result.divergence_index + 1}:\n\n") + + for pd in result.phase_diffs: + if not (pd.files_only_a or pd.files_only_b + or pd.cmds_only_a or pd.cmds_only_b + or pd.a_failed != pd.b_failed): + continue + + w(f" Phase {pd.index + 1}: {pd.label_a}\n") + + if pd.a_failed and not pd.b_failed: + w(f" {a}: FAILED {b}: passed\n") + elif pd.b_failed and not pd.a_failed: + w(f" {a}: passed {b}: FAILED\n") + + for f in pd.cmds_only_a: + w(f" {a} only: $ {f[:70]}\n") + for f in pd.cmds_only_b: + w(f" {b} only: $ {f[:70]}\n") + for f in pd.files_only_a: + w(f" {a} only: {f}\n") + for f in pd.files_only_b: + w(f" {b} only: {f}\n") + w("\n") + + w(f" {a}: {_fmt_duration(result.duration_a)}, " + f"{result.events_a} events, " + f"{result.tool_calls_a} tools, " + f"{result.retries_a} retries\n") + w(f" {b}: {_fmt_duration(result.duration_b)}, " + f"{result.events_b} events, " + f"{result.tool_calls_b} tools, " + f"{result.retries_b} retries\n\n") + + +# --------------------------------------------------------------------------- +# CLI handler +# --------------------------------------------------------------------------- + +def cmd_diff(args: argparse.Namespace) -> int: + store = TraceStore(args.trace_dir) + + id_a = store.find_session(args.session_a) + if not id_a: + sys.stderr.write(f"Session not found: {args.session_a}\n") + return 1 + + id_b = store.find_session(args.session_b) + if not id_b: + sys.stderr.write(f"Session not found: {args.session_b}\n") + return 1 + + result = diff_sessions(store, id_a, id_b) + format_diff(result) + return 0 diff --git a/src/agent_trace/why.py b/src/agent_trace/why.py new file mode 100644 index 0000000..c03e405 --- /dev/null +++ b/src/agent_trace/why.py @@ -0,0 +1,250 @@ +"""Causal chain tracing: why did a specific event happen? + +Walks backwards from a target event through causal links to find the +chain of events that led to it, terminating at a user_prompt or +session_start. + +Causal link rules: + - tool_call after error → caused by the error (retry) + - file_write after file_read → read informed the write (same file) + - tool_call referencing a path from a prior tool_result → result informed call + - any event after user_prompt → prompt caused it + - tool_result links to its tool_call via parent_id +""" + +from __future__ import annotations + +import argparse +import sys +from dataclasses import dataclass +from typing import TextIO + +from .models import EventType, TraceEvent +from .store import TraceStore + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class CausalLink: + event: TraceEvent + reason: str # human-readable explanation of the causal link + event_index: int # 0-based index in the original event list + + +@dataclass +class CausalChain: + target_index: int + links: list[CausalLink] # ordered root → target + + +# --------------------------------------------------------------------------- +# Causal link detection +# --------------------------------------------------------------------------- + +def _event_paths(event: TraceEvent) -> set[str]: + """Extract file paths referenced in an event's data.""" + paths: set[str] = set() + data = event.data + for key in ("file_path", "uri", "path"): + val = data.get(key, "") + if val: + paths.add(str(val)) + args = data.get("arguments", {}) + if isinstance(args, dict): + for key in ("file_path", "path", "uri"): + val = args.get(key, "") + if val: + paths.add(str(val)) + result = data.get("result", "") or data.get("content_preview", "") + return paths + + +def _result_text(event: TraceEvent) -> str: + return str( + event.data.get("result", "") + or event.data.get("content_preview", "") + or event.data.get("text", "") + ) + + +def build_causal_chain(events: list[TraceEvent], target_index: int) -> CausalChain: + """Trace backwards from events[target_index] to find the causal chain.""" + if not events or target_index < 0 or target_index >= len(events): + return CausalChain(target_index=target_index, links=[]) + + # Index events by event_id for parent_id lookups + by_id: dict[str, tuple[int, TraceEvent]] = { + e.event_id: (i, e) for i, e in enumerate(events) + } + + visited: set[int] = set() + chain: list[CausalLink] = [] + + def _walk(idx: int, reason: str) -> None: + if idx in visited or idx < 0: + return + visited.add(idx) + event = events[idx] + chain.append(CausalLink(event=event, reason=reason, event_index=idx)) + + # Terminate at root causes + if event.event_type in (EventType.USER_PROMPT, EventType.SESSION_START): + return + + # 1. Follow parent_id link (tool_result → tool_call, llm_response → llm_request) + if event.parent_id and event.parent_id in by_id: + parent_idx, parent_event = by_id[event.parent_id] + _walk(parent_idx, f"← parent of #{idx + 1}") + return + + # 2. Scan backwards for the most recent causal predecessor + target_paths = _event_paths(event) + target_tool = event.data.get("tool_name", "").lower() + + for prev_idx in range(idx - 1, -1, -1): + prev = events[prev_idx] + + # Error → next tool_call is a retry + if (prev.event_type == EventType.ERROR + and event.event_type == EventType.TOOL_CALL): + _walk(prev_idx, f"retry after error at #{prev_idx + 1}") + return + + # tool_result containing a path that this tool_call references + if (prev.event_type == EventType.TOOL_RESULT and target_paths): + result_text = _result_text(prev) + if any(p in result_text for p in target_paths): + _walk(prev_idx, f"result at #{prev_idx + 1} referenced path") + return + + # file_read → file_write of same file + if (prev.event_type in (EventType.TOOL_CALL, EventType.FILE_READ) + and event.event_type in (EventType.TOOL_CALL, EventType.FILE_WRITE)): + prev_paths = _event_paths(prev) + if target_paths & prev_paths: + _walk(prev_idx, f"read at #{prev_idx + 1} informed write") + return + + # user_prompt is always a root cause + if prev.event_type == EventType.USER_PROMPT: + _walk(prev_idx, f"prompt at #{prev_idx + 1} triggered this") + return + + # Fallback: link to session_start or first event + _walk(0, "session start") + + _walk(target_index, "target event") + + # Reverse so chain reads root → target + chain.reverse() + return CausalChain(target_index=target_index, links=chain) + + +# --------------------------------------------------------------------------- +# Formatting +# --------------------------------------------------------------------------- + +def _event_summary(event: TraceEvent, index: int) -> str: + etype = event.event_type.value + data = event.data + + if event.event_type == EventType.TOOL_CALL: + name = data.get("tool_name", "?") + args = data.get("arguments", {}) + detail = "" + if isinstance(args, dict): + if "command" in args: + detail = f" $ {str(args['command'])[:60]}" + elif "file_path" in args: + detail = f" {args['file_path']}" + return f"#{index + 1:>3} tool_call: {name}{detail}" + + if event.event_type == EventType.TOOL_RESULT: + preview = _result_text(event)[:60] + return f"#{index + 1:>3} tool_result: {preview}" + + if event.event_type == EventType.ERROR: + msg = (data.get("message", "") or data.get("error", ""))[:60] + return f"#{index + 1:>3} error: {msg}" + + if event.event_type == EventType.USER_PROMPT: + prompt = data.get("prompt", "")[:60] + return f"#{index + 1:>3} user_prompt: \"{prompt}\"" + + if event.event_type == EventType.ASSISTANT_RESPONSE: + text = data.get("text", "")[:60] + return f"#{index + 1:>3} assistant_response: \"{text}\"" + + if event.event_type in (EventType.FILE_READ, EventType.FILE_WRITE): + uri = data.get("uri", data.get("file_path", "")) + return f"#{index + 1:>3} {etype}: {uri}" + + return f"#{index + 1:>3} {etype}" + + +def format_why( + chain: CausalChain, + events: list[TraceEvent], + out: TextIO = sys.stdout, +) -> None: + w = out.write + + if not chain.links: + w(f"No causal chain found for event #{chain.target_index + 1}.\n") + return + + target = events[chain.target_index] + w(f"\nWhy did event #{chain.target_index + 1} happen?\n\n") + w(f" {_event_summary(target, chain.target_index)}\n\n") + + if len(chain.links) <= 1: + w(" No prior causal events found.\n\n") + return + + w("Causal chain (root → target):\n\n") + for i, link in enumerate(chain.links): + prefix = " " + ("← " if i > 0 else " ") + w(f"{prefix}{_event_summary(link.event, link.event_index)}\n") + if i < len(chain.links) - 1 and link.reason: + w(f" ({link.reason})\n") + + w("\n") + + +# --------------------------------------------------------------------------- +# CLI handler +# --------------------------------------------------------------------------- + +def cmd_why(args: argparse.Namespace) -> int: + store = TraceStore(args.trace_dir) + + session_id = args.session_id + if not session_id: + session_id = store.get_latest_session_id() + if not session_id: + sys.stderr.write("No sessions found.\n") + return 1 + full_id = store.find_session(session_id) + if not full_id: + sys.stderr.write(f"Session not found: {session_id}\n") + return 1 + + events = store.load_events(full_id) + if not events: + sys.stderr.write("No events in session.\n") + return 1 + + # event_number is 1-based + event_number = args.event_number + if event_number < 1 or event_number > len(events): + sys.stderr.write( + f"Event number must be between 1 and {len(events)}.\n" + ) + return 1 + + chain = build_causal_chain(events, event_number - 1) + format_why(chain, events) + return 0 diff --git a/tests/test_diff_why.py b/tests/test_diff_why.py new file mode 100644 index 0000000..24a01ad --- /dev/null +++ b/tests/test_diff_why.py @@ -0,0 +1,252 @@ +"""Tests for session diff and causal chain (why).""" + +import io +import tempfile +import unittest + +from agent_trace.diff import ( + PhaseDiff, + SessionDiff, + _lcs_indices, + diff_sessions, + format_diff, +) +from agent_trace.models import EventType, SessionMeta, TraceEvent +from agent_trace.store import TraceStore +from agent_trace.why import ( + CausalChain, + build_causal_chain, + format_why, +) + + +def _make_event(event_type: EventType, ts: float, session_id: str, + event_id: str = "", parent_id: str = "", **data) -> TraceEvent: + e = TraceEvent(event_type=event_type, timestamp=ts, session_id=session_id, data=data) + if event_id: + e.event_id = event_id + if parent_id: + e.parent_id = parent_id + return e + + +def _make_store(sessions: list[tuple[SessionMeta, list[TraceEvent]]]) -> tuple[TraceStore, tempfile.TemporaryDirectory]: + tmp = tempfile.TemporaryDirectory() + store = TraceStore(tmp.name) + for meta, events in sessions: + store.create_session(meta) + for e in events: + store.append_event(meta.session_id, e) + store.update_meta(meta) + return store, tmp + + +# --------------------------------------------------------------------------- +# LCS tests +# --------------------------------------------------------------------------- + +class TestLCS(unittest.TestCase): + def test_identical_lists(self): + pairs = _lcs_indices(["a", "b", "c"], ["a", "b", "c"]) + self.assertEqual(pairs, [(0, 0), (1, 1), (2, 2)]) + + def test_empty_lists(self): + self.assertEqual(_lcs_indices([], []), []) + + def test_no_common(self): + self.assertEqual(_lcs_indices(["a", "b"], ["c", "d"]), []) + + def test_partial_match(self): + pairs = _lcs_indices(["a", "b", "c"], ["a", "x", "c"]) + # Should match a and c + self.assertIn((0, 0), pairs) + self.assertIn((2, 2), pairs) + + def test_insertion(self): + pairs = _lcs_indices(["a", "c"], ["a", "b", "c"]) + self.assertEqual(pairs, [(0, 0), (1, 2)]) + + +# --------------------------------------------------------------------------- +# Diff tests +# --------------------------------------------------------------------------- + +class TestDiffSessions(unittest.TestCase): + def _two_sessions(self, cmds_a, cmds_b, failed_a=False, failed_b=False): + def _events(sid, cmds, failed): + evts = [_make_event(EventType.USER_PROMPT, 0.0, sid, prompt="run tests")] + for i, cmd in enumerate(cmds): + evts.append(_make_event(EventType.TOOL_CALL, float(i + 1), sid, + tool_name="Bash", arguments={"command": cmd})) + if failed: + evts.append(_make_event(EventType.ERROR, float(len(cmds) + 1), sid, + message="exit 1")) + evts.append(_make_event(EventType.SESSION_END, float(len(cmds) + 2), sid)) + return evts + + meta_a = SessionMeta(session_id="sessa001", started_at=0.0, + total_duration_ms=5000, tool_calls=len(cmds_a)) + meta_b = SessionMeta(session_id="sessb001", started_at=0.0, + total_duration_ms=4000, tool_calls=len(cmds_b)) + store, tmp = _make_store([ + (meta_a, _events("sessa001", cmds_a, failed_a)), + (meta_b, _events("sessb001", cmds_b, failed_b)), + ]) + return store, tmp + + def test_identical_sessions(self): + store, tmp = self._two_sessions(["pytest"], ["pytest"]) + result = diff_sessions(store, "sessa001", "sessb001") + self.assertEqual(result.divergence_index, -1) + tmp.cleanup() + + def test_different_commands_diverge(self): + store, tmp = self._two_sessions(["pytest"], ["python -m pytest"]) + result = diff_sessions(store, "sessa001", "sessb001") + self.assertGreater(len(result.phase_diffs), 0) + tmp.cleanup() + + def test_failed_vs_passed(self): + store, tmp = self._two_sessions(["pytest"], ["pytest"], + failed_a=True, failed_b=False) + result = diff_sessions(store, "sessa001", "sessb001") + failed_diffs = [pd for pd in result.phase_diffs if pd.a_failed != pd.b_failed] + self.assertTrue(len(failed_diffs) > 0) + tmp.cleanup() + + def test_duration_captured(self): + store, tmp = self._two_sessions(["pytest"], ["pytest"]) + result = diff_sessions(store, "sessa001", "sessb001") + self.assertAlmostEqual(result.duration_a, 5.0, places=0) + self.assertAlmostEqual(result.duration_b, 4.0, places=0) + tmp.cleanup() + + +class TestFormatDiff(unittest.TestCase): + def test_output_contains_session_ids(self): + result = SessionDiff( + session_a="aaa111bbb222", + session_b="ccc333ddd444", + divergence_index=0, + phase_diffs=[ + PhaseDiff(index=0, label_a="run tests", label_b="run tests", + same_label=True, files_only_a=[], files_only_b=[], + cmds_only_a=["pytest"], cmds_only_b=["python -m pytest"], + a_failed=False, b_failed=False) + ], + duration_a=10.0, duration_b=8.0, + events_a=5, events_b=4, + tool_calls_a=2, tool_calls_b=2, + retries_a=0, retries_b=0, + ) + buf = io.StringIO() + format_diff(result, out=buf) + output = buf.getvalue() + self.assertIn("aaa111bbb2", output) + self.assertIn("ccc333ddd4", output) + self.assertIn("pytest", output) + + def test_identical_sessions_message(self): + result = SessionDiff( + session_a="aaa", session_b="bbb", + divergence_index=-1, phase_diffs=[], + duration_a=5.0, duration_b=5.0, + events_a=3, events_b=3, + tool_calls_a=1, tool_calls_b=1, + retries_a=0, retries_b=0, + ) + buf = io.StringIO() + format_diff(result, out=buf) + self.assertIn("identical", buf.getvalue()) + + +# --------------------------------------------------------------------------- +# Why / causal chain tests +# --------------------------------------------------------------------------- + +class TestBuildCausalChain(unittest.TestCase): + def test_empty_events(self): + chain = build_causal_chain([], 0) + self.assertEqual(chain.links, []) + + def test_out_of_range(self): + events = [_make_event(EventType.TOOL_CALL, 0.0, "s1", tool_name="Bash", + arguments={"command": "ls"})] + chain = build_causal_chain(events, 5) + self.assertEqual(chain.links, []) + + def test_user_prompt_is_root(self): + events = [ + _make_event(EventType.USER_PROMPT, 0.0, "s1", prompt="do it"), + _make_event(EventType.TOOL_CALL, 1.0, "s1", tool_name="Bash", + arguments={"command": "ls"}), + ] + chain = build_causal_chain(events, 1) + types = [l.event.event_type for l in chain.links] + self.assertIn(EventType.USER_PROMPT, types) + + def test_parent_id_link(self): + events = [ + _make_event(EventType.TOOL_CALL, 0.0, "s1", event_id="call1", + tool_name="Bash", arguments={"command": "ls"}), + _make_event(EventType.TOOL_RESULT, 1.0, "s1", parent_id="call1", + result="file.py"), + ] + chain = build_causal_chain(events, 1) + types = [l.event.event_type for l in chain.links] + self.assertIn(EventType.TOOL_CALL, types) + self.assertIn(EventType.TOOL_RESULT, types) + + def test_error_causes_retry(self): + events = [ + _make_event(EventType.USER_PROMPT, 0.0, "s1", prompt="run"), + _make_event(EventType.TOOL_CALL, 1.0, "s1", tool_name="Bash", + arguments={"command": "pytest"}), + _make_event(EventType.ERROR, 2.0, "s1", message="exit 1"), + _make_event(EventType.TOOL_CALL, 3.0, "s1", tool_name="Bash", + arguments={"command": "pytest"}), + ] + chain = build_causal_chain(events, 3) + types = [l.event.event_type for l in chain.links] + self.assertIn(EventType.ERROR, types) + + def test_chain_ordered_root_to_target(self): + events = [ + _make_event(EventType.USER_PROMPT, 0.0, "s1", prompt="go"), + _make_event(EventType.TOOL_CALL, 1.0, "s1", tool_name="Bash", + arguments={"command": "ls"}), + ] + chain = build_causal_chain(events, 1) + # Last link should be the target + self.assertEqual(chain.links[-1].event_index, 1) + + def test_single_event_chain(self): + events = [ + _make_event(EventType.USER_PROMPT, 0.0, "s1", prompt="hello"), + ] + chain = build_causal_chain(events, 0) + self.assertEqual(len(chain.links), 1) + self.assertEqual(chain.links[0].event.event_type, EventType.USER_PROMPT) + + +class TestFormatWhy(unittest.TestCase): + def test_output_contains_event_number(self): + events = [ + _make_event(EventType.USER_PROMPT, 0.0, "s1", prompt="do it"), + _make_event(EventType.TOOL_CALL, 1.0, "s1", tool_name="Bash", + arguments={"command": "ls"}), + ] + chain = build_causal_chain(events, 1) + buf = io.StringIO() + format_why(chain, events, out=buf) + self.assertIn("#2", buf.getvalue()) + + def test_no_chain_message(self): + chain = CausalChain(target_index=0, links=[]) + buf = io.StringIO() + format_why(chain, [], out=buf) + self.assertIn("No causal chain", buf.getvalue()) + + +if __name__ == "__main__": + unittest.main() From 2ee71c8b7a884f73a7adef2cf821cac186c74dae Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Wed, 25 Mar 2026 06:25:15 +0000 Subject: [PATCH 2/5] Remove dead target_tool variable, drop unused result in _event_paths, fix import order Co-authored-by: Ona --- src/agent_trace/cli.py | 2 +- src/agent_trace/why.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/agent_trace/cli.py b/src/agent_trace/cli.py index 5357179..34e37cf 100644 --- a/src/agent_trace/cli.py +++ b/src/agent_trace/cli.py @@ -26,11 +26,11 @@ from .diff import cmd_diff from .explain import cmd_explain from .jsonl_import import cmd_import -from .why import cmd_why from .models import EventType, SessionMeta, TraceEvent from .proxy import MCPProxy from .replay import format_event, format_summary, list_sessions, replay_session from .store import TraceStore +from .why import cmd_why def _print_live_event(event: TraceEvent) -> None: diff --git a/src/agent_trace/why.py b/src/agent_trace/why.py index c03e405..e99964e 100644 --- a/src/agent_trace/why.py +++ b/src/agent_trace/why.py @@ -58,7 +58,6 @@ def _event_paths(event: TraceEvent) -> set[str]: val = args.get(key, "") if val: paths.add(str(val)) - result = data.get("result", "") or data.get("content_preview", "") return paths @@ -102,7 +101,6 @@ def _walk(idx: int, reason: str) -> None: # 2. Scan backwards for the most recent causal predecessor target_paths = _event_paths(event) - target_tool = event.data.get("tool_name", "").lower() for prev_idx in range(idx - 1, -1, -1): prev = events[prev_idx] From 83e1450a768423db0b50645b846a8f0a0ceb3a9b Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Sat, 28 Mar 2026 16:05:26 +0000 Subject: [PATCH 3/5] =?UTF-8?q?Fix=20causal=20chain=20false=20positives,?= =?UTF-8?q?=20tighten=20read=E2=86=92write=20rule,=20add=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit why.py: - error→retry: only fire when no substantive event sits between the error and the tool_call (tool_result events are allowed as separators). Prevents unrelated tool calls after an error from being misattributed as retries. - read→write: restrict the path-match rule to actual write operations (FILE_WRITE event type, or tool_name in write/edit/create). Previously any TOOL_CALL sharing a path with a prior read would match, including Bash commands that merely mentioned the path. tests: - error→retry false positive: unrelated call between error and retry - write tool_call correctly linked to prior read of same path - Bash call with shared path not linked via read→write rule - dangling parent_id falls through to heuristic without crashing - diff_sessions with unaligned phase counts Co-authored-by: Ona --- src/agent_trace/why.py | 27 +++++++++--- tests/test_diff_why.py | 99 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 5 deletions(-) diff --git a/src/agent_trace/why.py b/src/agent_trace/why.py index e99964e..eb0b59f 100644 --- a/src/agent_trace/why.py +++ b/src/agent_trace/why.py @@ -105,11 +105,18 @@ def _walk(idx: int, reason: str) -> None: for prev_idx in range(idx - 1, -1, -1): prev = events[prev_idx] - # Error → next tool_call is a retry + # Error → immediately following tool_call is a retry. + # Only fire when the error is the closest preceding substantive + # event (skip over tool_result events which may sit between them). if (prev.event_type == EventType.ERROR and event.event_type == EventType.TOOL_CALL): - _walk(prev_idx, f"retry after error at #{prev_idx + 1}") - return + # Check nothing substantive sits between the error and this call + intervening = events[prev_idx + 1:idx] + non_result = [e for e in intervening + if e.event_type != EventType.TOOL_RESULT] + if not non_result: + _walk(prev_idx, f"retry after error at #{prev_idx + 1}") + return # tool_result containing a path that this tool_call references if (prev.event_type == EventType.TOOL_RESULT and target_paths): @@ -118,14 +125,24 @@ def _walk(idx: int, reason: str) -> None: _walk(prev_idx, f"result at #{prev_idx + 1} referenced path") return - # file_read → file_write of same file + # file_read → file_write of same file: only match actual write ops if (prev.event_type in (EventType.TOOL_CALL, EventType.FILE_READ) - and event.event_type in (EventType.TOOL_CALL, EventType.FILE_WRITE)): + and event.event_type == EventType.FILE_WRITE): prev_paths = _event_paths(prev) if target_paths & prev_paths: _walk(prev_idx, f"read at #{prev_idx + 1} informed write") return + # Write tool calls (Write/Edit) via tool_call event type + if (prev.event_type in (EventType.TOOL_CALL, EventType.FILE_READ) + and event.event_type == EventType.TOOL_CALL): + tool = event.data.get("tool_name", "").lower() + if tool in ("write", "edit", "create"): + prev_paths = _event_paths(prev) + if target_paths & prev_paths: + _walk(prev_idx, f"read at #{prev_idx + 1} informed write") + return + # user_prompt is always a root cause if prev.event_type == EventType.USER_PROMPT: _walk(prev_idx, f"prompt at #{prev_idx + 1} triggered this") diff --git a/tests/test_diff_why.py b/tests/test_diff_why.py index 24a01ad..ceba75d 100644 --- a/tests/test_diff_why.py +++ b/tests/test_diff_why.py @@ -228,6 +228,105 @@ def test_single_event_chain(self): self.assertEqual(len(chain.links), 1) self.assertEqual(chain.links[0].event.event_type, EventType.USER_PROMPT) + def test_error_retry_only_fires_for_adjacent_error(self): + """A tool_call after an error with an unrelated call in between is NOT a retry.""" + events = [ + _make_event(EventType.USER_PROMPT, 0.0, "s1", prompt="run"), + _make_event(EventType.TOOL_CALL, 1.0, "s1", tool_name="Bash", + arguments={"command": "pytest"}), + _make_event(EventType.ERROR, 2.0, "s1", message="exit 1"), + _make_event(EventType.TOOL_CALL, 3.0, "s1", tool_name="Bash", + arguments={"command": "git status"}), # unrelated + _make_event(EventType.TOOL_CALL, 4.0, "s1", tool_name="Bash", + arguments={"command": "pytest"}), # actual retry + ] + # why #4 (git status) should NOT be attributed to the error + chain_unrelated = build_causal_chain(events, 3) + types = [l.event.event_type for l in chain_unrelated.links] + self.assertNotIn(EventType.ERROR, types) + + # why #5 (retry pytest) SHOULD be attributed to the error + chain_retry = build_causal_chain(events, 4) + types_retry = [l.event.event_type for l in chain_retry.links] + self.assertIn(EventType.ERROR, types_retry) + + def test_write_tool_call_linked_to_prior_read(self): + """A Write tool_call referencing a path read earlier should link to that read.""" + events = [ + _make_event(EventType.USER_PROMPT, 0.0, "s1", prompt="update file"), + _make_event(EventType.TOOL_CALL, 1.0, "s1", tool_name="Read", + arguments={"file_path": "src/foo.py"}), + _make_event(EventType.TOOL_CALL, 2.0, "s1", tool_name="Write", + arguments={"file_path": "src/foo.py"}), + ] + chain = build_causal_chain(events, 2) + types = [l.event.event_type for l in chain.links] + # Should include the Read tool_call + self.assertIn(EventType.TOOL_CALL, types) + read_links = [l for l in chain.links + if l.event.data.get("tool_name") == "Read"] + self.assertTrue(len(read_links) > 0) + + def test_bash_not_linked_to_read_via_path(self): + """A Bash tool_call sharing a path with a prior Read should NOT be linked via read→write rule.""" + events = [ + _make_event(EventType.USER_PROMPT, 0.0, "s1", prompt="check file"), + _make_event(EventType.TOOL_CALL, 1.0, "s1", tool_name="Read", + arguments={"file_path": "src/foo.py"}), + _make_event(EventType.TOOL_CALL, 2.0, "s1", tool_name="Bash", + arguments={"command": "cat src/foo.py"}), + ] + chain = build_causal_chain(events, 2) + # The Bash call has no file_path in arguments, so _event_paths returns empty + # and the read→write rule should not fire. Root cause should be user_prompt. + types = [l.event.event_type for l in chain.links] + self.assertIn(EventType.USER_PROMPT, types) + + def test_why_with_dangling_parent_id(self): + """An event with a parent_id pointing to a non-existent event should fall through to heuristic.""" + events = [ + _make_event(EventType.USER_PROMPT, 0.0, "s1", prompt="go"), + _make_event(EventType.TOOL_RESULT, 1.0, "s1", + parent_id="nonexistent_id", result="done"), + ] + # Should not raise; should fall through to heuristic and find user_prompt + chain = build_causal_chain(events, 1) + self.assertTrue(len(chain.links) > 0) + + +class TestDiffSessionsUnaligned(unittest.TestCase): + def test_session_with_extra_phases(self): + """Sessions with different numbers of phases should still produce a diff.""" + def _events(sid, cmds): + evts = [_make_event(EventType.USER_PROMPT, 0.0, sid, prompt="step 1")] + for i, cmd in enumerate(cmds): + evts.append(_make_event(EventType.TOOL_CALL, float(i + 1), sid, + tool_name="Bash", arguments={"command": cmd})) + evts.append(_make_event(EventType.SESSION_END, float(len(cmds) + 2), sid)) + return evts + + meta_a = SessionMeta(session_id="sessa002", started_at=0.0, + total_duration_ms=5000, tool_calls=2) + meta_b = SessionMeta(session_id="sessb002", started_at=0.0, + total_duration_ms=3000, tool_calls=1) + + tmp = tempfile.TemporaryDirectory() + store = TraceStore(tmp.name) + store.create_session(meta_a) + for e in _events("sessa002", ["pytest", "coverage report"]): + store.append_event("sessa002", e) + store.update_meta(meta_a) + + store.create_session(meta_b) + for e in _events("sessb002", ["pytest"]): + store.append_event("sessb002", e) + store.update_meta(meta_b) + + result = diff_sessions(store, "sessa002", "sessb002") + # Sessions differ — divergence should be detected + self.assertGreaterEqual(result.divergence_index, -1) + tmp.cleanup() + class TestFormatWhy(unittest.TestCase): def test_output_contains_event_number(self): From 4828fb97a258aea0e15d94710c782e3364d13374 Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Sat, 28 Mar 2026 16:09:53 +0000 Subject: [PATCH 4/5] =?UTF-8?q?Fix=20error=E2=86=92retry:=20require=20same?= =?UTF-8?q?=20tool=20name=20as=20the=20failing=20call?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous adjacency check was insufficient — git status immediately after a failed pytest is adjacent to the error but is not a retry. Now we look up the tool_call that caused the error and only attribute the retry if the target tool_call uses the same tool name. Co-authored-by: Ona --- src/agent_trace/why.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/agent_trace/why.py b/src/agent_trace/why.py index eb0b59f..22a2e44 100644 --- a/src/agent_trace/why.py +++ b/src/agent_trace/why.py @@ -105,18 +105,22 @@ def _walk(idx: int, reason: str) -> None: for prev_idx in range(idx - 1, -1, -1): prev = events[prev_idx] - # Error → immediately following tool_call is a retry. - # Only fire when the error is the closest preceding substantive - # event (skip over tool_result events which may sit between them). + # Error → retry: only fire when the target tool_call repeats the + # same tool name as the tool_call that immediately preceded the + # error. This avoids attributing unrelated commands (e.g. git + # status after a failed pytest) as retries. if (prev.event_type == EventType.ERROR and event.event_type == EventType.TOOL_CALL): - # Check nothing substantive sits between the error and this call - intervening = events[prev_idx + 1:idx] - non_result = [e for e in intervening - if e.event_type != EventType.TOOL_RESULT] - if not non_result: - _walk(prev_idx, f"retry after error at #{prev_idx + 1}") - return + # Find the tool_call that caused the error (the one just before it) + causing_idx = prev_idx - 1 + while causing_idx >= 0 and events[causing_idx].event_type == EventType.TOOL_RESULT: + causing_idx -= 1 + if causing_idx >= 0 and events[causing_idx].event_type == EventType.TOOL_CALL: + causing_tool = events[causing_idx].data.get("tool_name", "") + target_tool = event.data.get("tool_name", "") + if causing_tool and causing_tool == target_tool: + _walk(prev_idx, f"retry after error at #{prev_idx + 1}") + return # tool_result containing a path that this tool_call references if (prev.event_type == EventType.TOOL_RESULT and target_paths): From de35f7c3dfcfccad4c9ae4913df2ccd7699e8a1c Mon Sep 17 00:00:00 2001 From: Siddhant Khare Date: Sat, 28 Mar 2026 16:20:23 +0000 Subject: [PATCH 5/5] =?UTF-8?q?Fix=20error=E2=86=92retry:=20also=20match?= =?UTF-8?q?=20command=20string=20for=20Bash=20retries?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tool name alone is insufficient — git status and pytest are both Bash calls. For Bash tool calls, require the command string to match the failing call. For other tools, tool name match is sufficient. Co-authored-by: Ona --- src/agent_trace/why.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/agent_trace/why.py b/src/agent_trace/why.py index 22a2e44..3681a8e 100644 --- a/src/agent_trace/why.py +++ b/src/agent_trace/why.py @@ -106,19 +106,28 @@ def _walk(idx: int, reason: str) -> None: prev = events[prev_idx] # Error → retry: only fire when the target tool_call repeats the - # same tool name as the tool_call that immediately preceded the - # error. This avoids attributing unrelated commands (e.g. git - # status after a failed pytest) as retries. + # same tool name AND command as the call that caused the error. + # This avoids attributing unrelated commands (e.g. git status + # after a failed pytest) as retries. if (prev.event_type == EventType.ERROR and event.event_type == EventType.TOOL_CALL): - # Find the tool_call that caused the error (the one just before it) + # Find the tool_call that caused the error (skip tool_results) causing_idx = prev_idx - 1 while causing_idx >= 0 and events[causing_idx].event_type == EventType.TOOL_RESULT: causing_idx -= 1 if causing_idx >= 0 and events[causing_idx].event_type == EventType.TOOL_CALL: - causing_tool = events[causing_idx].data.get("tool_name", "") - target_tool = event.data.get("tool_name", "") - if causing_tool and causing_tool == target_tool: + causing = events[causing_idx].data + target = event.data + same_tool = causing.get("tool_name", "") == target.get("tool_name", "") + # For Bash, also require the same command string + causing_args = causing.get("arguments", {}) or {} + target_args = target.get("arguments", {}) or {} + if same_tool and causing.get("tool_name", "").lower() == "bash": + same_cmd = causing_args.get("command", "") == target_args.get("command", "") + if same_cmd: + _walk(prev_idx, f"retry after error at #{prev_idx + 1}") + return + elif same_tool: _walk(prev_idx, f"retry after error at #{prev_idx + 1}") return