diff --git a/nemo_retriever/src/nemo_retriever/llm/clients/judge.py b/nemo_retriever/src/nemo_retriever/llm/clients/judge.py index 97b01c5dc..61198a0e2 100644 --- a/nemo_retriever/src/nemo_retriever/llm/clients/judge.py +++ b/nemo_retriever/src/nemo_retriever/llm/clients/judge.py @@ -153,6 +153,8 @@ def judge(self, query: str, reference: str, candidate: str) -> JudgeResult: def _parse_judge_response(raw: str) -> JudgeResult: """Parse the judge's JSON response into a JudgeResult.""" text = raw.strip() + # Reasoning models can emit a ... block before the final JSON. + text = re.sub(r".*?", "", text, flags=re.DOTALL).strip() text = re.sub(r"^```(?:json)?\s*", "", text, flags=re.MULTILINE) text = re.sub(r"\s*```$", "", text, flags=re.MULTILINE) diff --git a/nemo_retriever/src/nemo_retriever/retriever_graph_utils.py b/nemo_retriever/src/nemo_retriever/retriever_graph_utils.py index a5b6857d7..89ae103fc 100644 --- a/nemo_retriever/src/nemo_retriever/retriever_graph_utils.py +++ b/nemo_retriever/src/nemo_retriever/retriever_graph_utils.py @@ -22,11 +22,20 @@ def hits_lists_to_rerank_dataframe( query_texts: list[str], hits_per_query: list[list[dict[str, Any]]], ) -> pd.DataFrame: - """One row per (query, hit) with payload to rebuild hits after reranking.""" + """One row per (query, hit) with payload to rebuild hits after reranking. + + Returns a DataFrame with the columns ``query``, ``text``, ``_hit`` even when + there are no hits — ``pd.DataFrame([])`` yields a column-less DataFrame, + which crashes the downstream rerank actor with ``KeyError: 'query'`` on a + legitimate empty-results path (empty/unmatched corpus, freshly-ingested + table, etc.). + """ rows: list[dict[str, Any]] = [] for q, hits in zip(query_texts, hits_per_query): for h in hits: rows.append({"query": q, "text": str(h.get("text", "")), "_hit": dict(h)}) + if not rows: + return pd.DataFrame(columns=["query", "text", "_hit"]) return pd.DataFrame(rows) diff --git a/nemo_retriever/src/nemo_retriever/skill_eval/cli.py b/nemo_retriever/src/nemo_retriever/skill_eval/cli.py index 8db8f323e..e86d3ebd4 100644 --- a/nemo_retriever/src/nemo_retriever/skill_eval/cli.py +++ b/nemo_retriever/src/nemo_retriever/skill_eval/cli.py @@ -2,14 +2,16 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""`retriever skill-eval run` benchmark.""" +"""`retriever skill-eval` benchmark.""" from __future__ import annotations +import json import logging import os import shutil from collections import defaultdict +from dataclasses import asdict, fields from pathlib import Path from typing import Any, Optional @@ -22,14 +24,21 @@ from nemo_retriever.skill_eval.report import overall_recall, write_summary from nemo_retriever.skill_eval.runner import ( CONDITIONS, + DEFAULT_AGENT_MODELS, + SUPPORTED_AGENTS, + UNSCORABLE_JUDGE_ERRORS, + TrialResult, + _apply_judge, + archive_session_log, cleanup_condition_workdir, + extract_compact_trace, run_condition, save_trial, ) DEFAULT_ORDER = ("c1_base", "c2_retriever", "c3_retriever_skill") -app = typer.Typer(help="Benchmark Claude with vs. without the /nemo-retriever skill on a folder of PDFs.") +app = typer.Typer(help="Benchmark coding agents with vs. without the /nemo-retriever skill on a folder of PDFs.") logger = logging.getLogger(__name__) @@ -85,6 +94,61 @@ def _build_judge(cfg: dict) -> Optional[Any]: return judge +def _build_trace_summarizer(cfg: dict) -> Optional[Any]: + """Construct a ``TraceSummarizer`` from ``cfg['summarizer']`` or return ``None``.""" + sum_cfg = cfg.get("summarizer") or {} + if not sum_cfg.get("enabled", True): + typer.echo("Trace summarizer disabled by config (summarizer.enabled=false).") + return None + if shutil.which("claude") is None: + typer.echo("Trace summarizer disabled: `claude` CLI is not on PATH.") + return None + from nemo_retriever.skill_eval.trace_summarizer import TraceSummarizer + + summarizer = TraceSummarizer.from_kwargs( + model=str(sum_cfg.get("model", "claude-opus-4-7")), + ) + typer.echo(f"Trace summarizer enabled: model={summarizer.model}") + return summarizer + + +def _resolve_agent(value: str) -> str: + agent = value.strip().lower() + if agent not in SUPPORTED_AGENTS: + raise typer.BadParameter(f"agent must be one of {', '.join(SUPPORTED_AGENTS)}") + return agent + + +def _resolve_agent_model(cfg: dict, agent: str, override: Optional[str]) -> str: + if override: + return override + models = cfg.get("agent_models") + if isinstance(models, dict) and models.get(agent): + return str(models[agent]) + if cfg.get("agent_model"): + return str(cfg["agent_model"]) + return DEFAULT_AGENT_MODELS[agent] + + +def _resolve_conditions(value: Optional[str], cfg: dict) -> list[str]: + if value is not None: + selected = [c.strip() for c in value.split(",") if c.strip()] + else: + raw = cfg.get("conditions") or list(DEFAULT_ORDER) + if isinstance(raw, str): + selected = [c.strip() for c in raw.split(",") if c.strip()] + elif isinstance(raw, list): + selected = [str(c).strip() for c in raw if str(c).strip()] + else: + raise typer.BadParameter("config 'conditions' must be a list or comma-separated string") + if not selected: + raise typer.BadParameter("at least one condition must be selected") + for c in selected: + if c not in CONDITIONS: + raise typer.BadParameter(f"unknown condition '{c}'. Choose from {CONDITIONS}.") + return selected + + def _resolve_domain_label(entries: list[DatasetEntry], cfg: dict, domain: str) -> str: """Pick a human-readable label for the setup prompt. @@ -112,12 +176,12 @@ def run_command( "--eval-manifest", help="Path to an agent-eval manifest (JSON list). Overrides config.eval_manifest_path.", ), - conditions: str = typer.Option( - ",".join(DEFAULT_ORDER), + conditions: Optional[str] = typer.Option( + None, "--conditions", help=( - "Comma-separated conditions in execution order. Each (condition, domain) workdir is deleted after it runs, " - "so only one LanceDB is on disk at a time." + "Comma-separated conditions in execution order. Defaults to config.conditions, then " + f"{','.join(DEFAULT_ORDER)}. Each (agent, condition, domain) workdir is deleted after it runs." ), ), domains: Optional[str] = typer.Option( @@ -128,19 +192,26 @@ def run_command( artifacts_root: Optional[Path] = typer.Option( None, "--artifacts-root", help="Override the artifact root; defaults to /nemo_retriever/artifacts/" ), + agent_name: Optional[str] = typer.Option( + None, + "--agent", + help="Agent CLI to evaluate: claude or codex. Overrides config.agent.", + ), + model_override: Optional[str] = typer.Option( + None, + "--model", + help="Agent model override for this run.", + ), ) -> None: - """Run the benchmark across the dataset's domains × selected conditions, sequentially.""" + """Run the benchmark across the dataset's domains x selected conditions.""" logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") - if shutil.which("claude") is None: - typer.echo("Error: `claude` CLI is not on PATH; install Claude Code first.", err=True) - raise typer.Exit(code=2) cfg = load_config(config) - selected = [c.strip() for c in conditions.split(",") if c.strip()] - for c in selected: - if c not in CONDITIONS: - typer.echo(f"Error: unknown condition '{c}'. Choose from {CONDITIONS}.", err=True) - raise typer.Exit(code=2) + agent = _resolve_agent(str(agent_name or cfg.get("agent") or "claude")) + if shutil.which(agent) is None: + typer.echo(f"Error: `{agent}` CLI is not on PATH.", err=True) + raise typer.Exit(code=2) + selected = _resolve_conditions(conditions, cfg) manifest_path = eval_manifest or cfg.get("eval_manifest_path") if not manifest_path: @@ -170,9 +241,13 @@ def run_command( skill_source = Path( str(cfg.get("skill_source_dir") or REPO_ROOT / ".claude" / "skills" / "nemo-retriever") ).expanduser() + if any(c in ("c2_retriever", "c3_retriever_skill") for c in selected) and not (skill_source / "SKILL.md").is_file(): + typer.echo(f"Error: skill source '{skill_source}' does not contain SKILL.md.", err=True) + raise typer.Exit(code=2) + workdir_root = Path(str(cfg.get("per_trial_workdir_root", "/tmp/skill_eval"))).expanduser() workdir_root.mkdir(parents=True, exist_ok=True) - model = str(cfg.get("agent_model", "claude-opus-4-7")) + model = _resolve_agent_model(cfg, agent, model_override) budget = float(cfg.get("per_trial_budget_usd", 5.0)) timeout = int(cfg.get("per_trial_timeout_s", 600)) testdata_prefixes_raw = cfg.get("testdata_prefixes") or [] @@ -182,15 +257,21 @@ def run_command( testdata_prefixes = tuple(str(p) for p in testdata_prefixes_raw) judge = _build_judge(cfg) + summarizer = _build_trace_summarizer(cfg) base_dir = str(artifacts_root) if artifacts_root else None session_dir = create_session_dir("skilleval", base_dir=base_dir) typer.echo(f"Session dir: {session_dir}") + typer.echo(f"Agent: {agent} model={model} conditions={selected}") - (session_dir / "config.yaml").write_text(yaml.safe_dump(cfg, default_flow_style=False), encoding="utf-8") + resolved_cfg = dict(cfg) + resolved_cfg["agent"] = agent + resolved_cfg["agent_model"] = model + resolved_cfg["conditions"] = selected + (session_dir / "config.yaml").write_text(yaml.safe_dump(resolved_cfg, default_flow_style=False), encoding="utf-8") - # Results are keyed (condition, domain) so the report can break out per-domain numbers. - results_by_key: dict[tuple[str, str], list] = {} + # Results are keyed (agent, condition, domain) so reports can compare agent runs. + results_by_key: dict[tuple[str, str, str], list[TrialResult]] = {} for cond in selected: for domain in domain_order: domain_entries = by_domain[domain] @@ -204,10 +285,11 @@ def run_command( raise typer.Exit(code=2) domain_label = _resolve_domain_label(domain_entries, cfg, domain) typer.echo( - f"Starting session for {cond}/{domain} — setup + {len(domain_entries)} query turns " + f"Starting {agent} session for {cond}/{domain} - setup + {len(domain_entries)} query turns " f"(pdfs={pdf_source})" ) workdir, results = run_condition( + agent=agent, condition=cond, entries=domain_entries, workdir_root=workdir_root, @@ -221,28 +303,57 @@ def run_command( judge=judge, testdata_prefixes=testdata_prefixes, ) + if summarizer is not None and results: + trace = extract_compact_trace(agent, workdir, results[0].session_id) + if trace: + narrative = summarizer.summarize(condition=f"{agent}/{cond}", domain=domain, trace=trace) + if narrative: + for r in results: + if r.is_setup: + r.tool_use_summary = narrative + break + typer.echo(f" tool-use summary: {len(narrative)} chars") + else: + typer.echo(" tool-use summary: (summarizer returned empty)") + else: + typer.echo(" tool-use summary skipped: session JSONL unavailable") for r in results: save_trial(r, session_dir) kind = "setup" if r.is_setup else f"entry_id={r.entry_id} query_id={r.query_id}" judge_str = "" if r.is_setup or r.judge_score is None else f" judge={r.judge_score}" + cost_str = f"${r.total_cost_usd:.3f}" if r.cost_available else "n/a" typer.echo( - f" turn {r.num_turns} [{domain}] {kind}: status={r.status} " + f" turn {r.num_turns} [{agent}/{domain}] {kind}: status={r.status} " f"tokens(in/out/cache_r)={r.input_tokens}/{r.output_tokens}/{r.cache_read_input_tokens} " - f"cost=${r.total_cost_usd:.3f} retrieved={len(r.ranked_retrieved)}{judge_str}" + f"cost={cost_str} retrieved={len(r.ranked_retrieved)}{judge_str}" ) - results_by_key[(cond, domain)] = results + results_by_key[(agent, cond, domain)] = results entries_by_id = {e.entry_id: e for e in domain_entries} scores = overall_recall(results, entries_by_id) typer.echo( - f"\nRecall for {cond}/{domain}: " + f"\nRecall for {agent}/{cond}/{domain}: " f"recall@1={scores['recall_1']:.3f} " f"recall@5={scores['recall_5']:.3f} " f"recall@10={scores['recall_10']:.3f}" ) + if results: + archived = archive_session_log( + session_dir=session_dir, + agent=agent, + condition=cond, + domain=domain, + session_uuid=results[0].session_id, + workdir=workdir, + ) + if archived is not None: + typer.echo(f" archived session log: {archived.relative_to(session_dir)}") + else: + typer.echo(f" session log not found for archiving ({agent}/{cond}/{domain})") + cleanup_condition_workdir(workdir) - typer.echo(f"Cleaned up workdir for {cond}/{domain}\n") + typer.echo(f"Cleaned up workdir for {agent}/{cond}/{domain}\n") if judge is not None: typer.echo("\nLLM-as-judge scores (mean over query turns, 0-5 scale):") @@ -250,7 +361,7 @@ def run_command( scored: list[int] = [] errored = 0 for domain in domain_order: - for r in results_by_key.get((cond, domain), []): + for r in results_by_key.get((agent, cond, domain), []): if r.is_setup: continue if r.judge_score is not None: @@ -259,17 +370,193 @@ def run_command( errored += 1 if scored: mean_score = sum(scored) / len(scored) - typer.echo(f" {cond}: mean={mean_score:.2f} n={len(scored)} errors={errored}") + typer.echo(f" {agent}/{cond}: mean={mean_score:.2f} n={len(scored)} errors={errored}") else: - typer.echo(f" {cond}: no scores errors={errored} (check judge config / litellm install)") + typer.echo(f" {agent}/{cond}: no scores errors={errored} (check judge config / litellm install)") json_path, md_path = write_summary( session_dir=session_dir, results_by_key=results_by_key, entries=entries, - config=cfg, + config=resolved_cfg, + agent=agent, + model=model, config_path=str(config) if config else "", ) typer.echo(f"\nWrote {json_path}") typer.echo(f"Wrote {md_path}") typer.echo("\nDone.") + + +def _needs_rescore(trial: dict[str, Any]) -> bool: + """Return whether a query-turn trial needs fresh judge scoring.""" + if trial.get("is_setup"): + return False + judge_error = trial.get("judge_error") or "" + if judge_error in UNSCORABLE_JUDGE_ERRORS: + return False + score = trial.get("judge_score") + if score is None: + return True + if judge_error: + return True + return False + + +def _load_trial(path: Path) -> tuple[dict[str, Any], TrialResult] | None: + """Load a trial JSON and reconstruct a ``TrialResult``. + + Returns ``None`` (and logs a warning) if the file is missing, truncated, + or otherwise unparseable, so callers can skip individual corrupt trials + without aborting the whole run. + """ + try: + data = json.loads(path.read_text(encoding="utf-8")) + known = {f.name for f in fields(TrialResult)} + ctor_kwargs = {k: v for k, v in data.items() if k in known} + return data, TrialResult(**ctor_kwargs) + except (OSError, ValueError, TypeError) as exc: + typer.echo(f" {path.name}: skip (corrupt trial: {exc})", err=True) + return None + + +def _iter_trial_files(session_dir: Path) -> list[Path]: + return sorted((session_dir / "trials").rglob("*.json")) + + +@app.command("rescore") +def rescore_command( + session_dir: Path = typer.Argument( + ..., + exists=True, + file_okay=False, + dir_okay=True, + help="Artifact session directory from a previous `retriever skill-eval run`.", + ), + config: Optional[Path] = typer.Option( + None, + "--config", + help="Judge/manifest config to use. Defaults to the session's own config.yaml.", + ), + eval_manifest: Optional[Path] = typer.Option( + None, + "--eval-manifest", + help="Manifest path. Overrides eval_manifest_path from --config / session config.", + ), + force: bool = typer.Option( + False, + "--force", + help="Rescore every query-turn trial, not just the empty/failed ones.", + ), +) -> None: + """Re-judge query-turn trials with missing or failed judge scores.""" + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") + + session_dir = session_dir.resolve() + trials_dir = session_dir / "trials" + if not trials_dir.is_dir(): + typer.echo(f"Error: {trials_dir} does not exist - is this a skill_eval session dir?", err=True) + raise typer.Exit(code=2) + + session_cfg_path = session_dir / "config.yaml" + if config is not None: + cfg = load_config(config) + config_path_str = str(config) + elif session_cfg_path.is_file(): + cfg = load_config(session_cfg_path) + config_path_str = str(session_cfg_path) + else: + typer.echo( + f"Error: no --config given and {session_cfg_path} is missing; cannot resolve judge settings.", + err=True, + ) + raise typer.Exit(code=2) + + manifest_path = eval_manifest or cfg.get("eval_manifest_path") + if not manifest_path: + typer.echo("Error: config is missing 'eval_manifest_path' and --eval-manifest was not provided.", err=True) + raise typer.Exit(code=2) + entries = load_eval_manifest(Path(str(manifest_path)).expanduser().resolve()) + entries_by_id = {e.entry_id: e for e in entries} + + judge = _build_judge(cfg) + if judge is None: + typer.echo("Error: judge is not configured (see messages above). Cannot rescore.", err=True) + raise typer.Exit(code=2) + + trial_files = _iter_trial_files(session_dir) + candidates = [] + for path in trial_files: + loaded = _load_trial(path) + if loaded is None: + continue + data, _ = loaded + if data.get("is_setup"): + continue + if force or _needs_rescore(data): + candidates.append(path) + + typer.echo( + f"Rescoring {len(candidates)} trial(s) out of {len(trial_files)} on disk " + f"(force={'on' if force else 'off'})." + ) + + rescored = 0 + unscorable = 0 + still_failed = 0 + for path in candidates: + loaded = _load_trial(path) + if loaded is None: + continue + raw, result = loaded + entry = entries_by_id.get(result.entry_id) + if entry is None: + typer.echo(f" {path.name}: skip (entry_id={result.entry_id} not in manifest)") + continue + + result.judge_score = None + result.judge_reasoning = "" + result.judge_error = "" + + _apply_judge(judge, entry, result) + + raw.update(asdict(result)) + path.write_text(json.dumps(raw, indent=2) + "\n", encoding="utf-8") + + if result.judge_score is not None: + rescored += 1 + typer.echo(f" {path.name}: entry_id={result.entry_id} judge={result.judge_score}") + elif result.judge_error in UNSCORABLE_JUDGE_ERRORS: + unscorable += 1 + typer.echo(f" {path.name}: entry_id={result.entry_id} unscorable ({result.judge_error})") + else: + still_failed += 1 + typer.echo( + f" {path.name}: entry_id={result.entry_id} still failed " f"(error={result.judge_error or 'unknown'})" + ) + + typer.echo(f"\nRescored {rescored}; unscorable {unscorable}; still failed {still_failed}.") + + results_by_key: dict[tuple[str, str, str], list[TrialResult]] = defaultdict(list) + for path in trial_files: + loaded = _load_trial(path) + if loaded is None: + continue + _, result = loaded + results_by_key[(result.agent, result.condition, result.domain)].append(result) + + agent = str(cfg.get("agent") or next((r.agent for rows in results_by_key.values() for r in rows), "claude")) + model = _resolve_agent_model(cfg, agent, None) + + json_path, md_path = write_summary( + session_dir=session_dir, + results_by_key=dict(results_by_key), + entries=entries, + config=cfg, + agent=agent, + model=model, + config_path=config_path_str, + ) + typer.echo(f"Wrote {json_path}") + typer.echo(f"Wrote {md_path}") + typer.echo("\nDone.") diff --git a/nemo_retriever/src/nemo_retriever/skill_eval/configs/skill_eval.yaml b/nemo_retriever/src/nemo_retriever/skill_eval/configs/skill_eval.yaml index 6a17f9820..df5e68364 100644 --- a/nemo_retriever/src/nemo_retriever/skill_eval/configs/skill_eval.yaml +++ b/nemo_retriever/src/nemo_retriever/skill_eval/configs/skill_eval.yaml @@ -54,9 +54,16 @@ pdf_dirs: {} testdata_prefixes: [] # --------------------------------------------------------------------------- -# Agent model and per-trial limits +# Agent selection, model, and per-trial limits # --------------------------------------------------------------------------- -agent_model: claude-opus-4-7 +# `agent` must be `claude` or `codex`. Override per run with --agent. +agent: claude +agent_models: + claude: claude-opus-4-7 + codex: gpt-5.5 +# Back-compat fallback: if set, this overrides agent_models. unless +# --model is passed. Leave unset for per-agent defaults. +# agent_model: claude-opus-4-7 per_trial_budget_usd: 5.0 per_trial_timeout_s: 600 per_trial_workdir_root: /tmp/skill_eval @@ -92,3 +99,15 @@ judge: api_key_env: NVIDIA_API_KEY temperature: 0.1 max_tokens: 4096 + +# --------------------------------------------------------------------------- +# Tool-use summarizer +# --------------------------------------------------------------------------- +# After each agent/domain/condition session, the harness reads the agent session +# JSONL and asks Claude to narrate what tools were called and what strategy was +# used. The result is stamped onto the setup-turn TrialResult and rendered in +# session_summary.md. Shells out to `claude --print`, so it reuses Claude Code +# auth and does not require a separate API key. Set enabled: false to skip it. +summarizer: + enabled: true + model: claude-opus-4-7 diff --git a/nemo_retriever/src/nemo_retriever/skill_eval/report.py b/nemo_retriever/src/nemo_retriever/skill_eval/report.py index 5efa60d1d..2c6e00023 100644 --- a/nemo_retriever/src/nemo_retriever/skill_eval/report.py +++ b/nemo_retriever/src/nemo_retriever/skill_eval/report.py @@ -2,7 +2,7 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Aggregate per-trial results into a per-condition / per-domain session summary.""" +"""Aggregate per-trial results into a per-agent / per-condition / per-domain summary.""" from __future__ import annotations @@ -77,16 +77,18 @@ def _aggregate( metrics["output_tokens"] = mean(r.output_tokens for r in query_results) metrics["cache_read_input_tokens"] = mean(r.cache_read_input_tokens for r in query_results) metrics["cache_creation_input_tokens"] = mean(r.cache_creation_input_tokens for r in query_results) - metrics["total_cost_usd"] = mean(r.total_cost_usd for r in query_results) + costed = [r.total_cost_usd for r in query_results if r.cost_available] + metrics["total_cost_usd"] = mean(costed) if costed else None metrics["duration_ms"] = mean(r.duration_ms for r in query_results) # When aggregating across multiple sessions there may be more than one setup - # turn (one per domain); sum them so the "one-time cost" reflects the full run. + # turn (one per domain); sum them so the one-time cost reflects the full run. if setup_results: metrics["setup_input_tokens"] = sum(r.input_tokens for r in setup_results) metrics["setup_output_tokens"] = sum(r.output_tokens for r in setup_results) metrics["setup_cache_read_input_tokens"] = sum(r.cache_read_input_tokens for r in setup_results) metrics["setup_cache_creation_input_tokens"] = sum(r.cache_creation_input_tokens for r in setup_results) - metrics["setup_cost_usd"] = sum(r.total_cost_usd for r in setup_results) + setup_costed = [r.total_cost_usd for r in setup_results if r.cost_available] + metrics["setup_cost_usd"] = sum(setup_costed) if setup_costed else None metrics["setup_duration_ms"] = sum(r.duration_ms for r in setup_results) metrics["setup_status"] = ( "ok" if all(r.status == "ok" for r in setup_results) else ",".join(r.status for r in setup_results) @@ -95,7 +97,8 @@ def _aggregate( metrics["session_output_tokens"] = sum(r.output_tokens for r in results) metrics["session_cache_read_input_tokens"] = sum(r.cache_read_input_tokens for r in results) metrics["session_cache_creation_input_tokens"] = sum(r.cache_creation_input_tokens for r in results) - metrics["session_total_cost_usd"] = sum(r.total_cost_usd for r in results) + session_costed = [r.total_cost_usd for r in results if r.cost_available] + metrics["session_total_cost_usd"] = sum(session_costed) if session_costed else None metrics["num_query_turns"] = len(query_results) metrics["success_rate"] = sum(1 for r in results if r.status == "ok") / len(results) metrics["retriever_used_rate"] = sum(1 for r in results if r.retriever_used_ever) / len(results) @@ -107,12 +110,15 @@ def _aggregate( metrics["judge_score_mean"] = sum(judge_scores) / len(judge_scores) metrics["judge_score_n"] = len(judge_scores) + tool_use_summary = next((r.tool_use_summary for r in setup_results if r.tool_use_summary), "") + return { "run_name": run_name, "success": all(r.status == "ok" for r in results), "metrics": metrics, - "tags": [results[0].condition, *extra_tags, f"n_queries={len(query_results)}"], + "tags": [results[0].agent, results[0].condition, *extra_tags, f"n_queries={len(query_results)}"], "artifact_dir": artifact_dir, + "tool_use_summary": tool_use_summary, } @@ -121,22 +127,28 @@ def aggregate_condition(results: Iterable[TrialResult], entries_by_id: dict[int, results_list = list(results) if not results_list: return {} + agent = getattr(results_list[0], "agent", "claude") + condition = results_list[0].condition return _aggregate( results_list, entries_by_id, - run_name=results_list[0].condition, - artifact_dir=str(Path("trials") / results_list[0].condition), + run_name=f"{agent}/{condition}", + artifact_dir=str(Path("trials") / agent / condition), ) +def _fmt_cost(value: Any) -> str: + return "n/a" if value is None else f"${float(value):.3f}" + + def _md_row(row: dict[str, Any]) -> str: m = row.get("metrics", {}) - judge_cell = f"{m['judge_score_mean']:.2f} (n={m.get('judge_score_n', 0)})" if "judge_score_mean" in m else "—" + judge_cell = f"{m['judge_score_mean']:.2f} (n={m.get('judge_score_n', 0)})" if "judge_score_mean" in m else "-" return ( - "| {cond} | {sr:.2f} | {retr:.2f} | {r1:.3f} | {r5:.3f} | {r10:.3f} | {judge} " - "| {ipt:.0f} | {opt:.0f} | {cr:.0f} | {cc:.0f} | ${cost:.3f} |" + "| {run} | {sr:.2f} | {retr:.2f} | {r1:.3f} | {r5:.3f} | {r10:.3f} | {judge} " + "| {ipt:.0f} | {opt:.0f} | {cr:.0f} | {cc:.0f} | {cost} |" ).format( - cond=row.get("run_name", "?"), + run=row.get("run_name", "?"), sr=m.get("success_rate", 0.0), retr=m.get("retriever_used_rate", 0.0), r1=m.get("recall_1", 0.0), @@ -147,12 +159,12 @@ def _md_row(row: dict[str, Any]) -> str: opt=m.get("output_tokens", 0.0), cr=m.get("cache_read_input_tokens", 0.0), cc=m.get("cache_creation_input_tokens", 0.0), - cost=m.get("total_cost_usd", 0.0), + cost=_fmt_cost(m.get("total_cost_usd")), ) _MAIN_TABLE_HEADER = ( - "| condition | success_rate | retr_used | recall@1 | recall@5 | recall@10 | judge | q_input | q_output " + "| run | success_rate | retr_used | recall@1 | recall@5 | recall@10 | judge | q_input | q_output " "| q_cache_read | q_cache_create | q_cost |" ) _MAIN_TABLE_DIVIDER = "|---|---|---|---|---|---|---|---|---|---|---|---|" @@ -163,16 +175,19 @@ def write_summary_md( rows_by_domain: dict[str, list[dict[str, Any]]], overall_rows: list[dict[str, Any]], config: dict[str, Any], + agent: str, + model: str, ) -> Path: lines = [ - f"# skill_eval session summary — `{session_dir.name}`", + f"# skill_eval session summary - `{session_dir.name}`", "", - f"- Agent model: `{config.get('agent_model', '?')}`", + f"- Agent: `{agent}`", + f"- Agent model: `{model}`", f"- Per-trial budget: ${config.get('per_trial_budget_usd', '?')}", f"- Per-trial timeout: {config.get('per_trial_timeout_s', '?')}s", "", "_Agent-session tokens only. Pipeline-side LLM calls (embeddings, VLM, etc.) are not instrumented._", - "_Each (condition, domain) is one Claude session: turn 1 = setup, turns 2..N = query turns._", + "_Each (agent, condition, domain) is one agent session: turn 1 = setup, turns 2..N = query turns._", "", "## Overall (averaged across all queries in this run)", "", @@ -198,21 +213,21 @@ def write_summary_md( lines += [ "", - "## Setup turns (one-time cost per condition, summed across domains)", + "## Setup turns (one-time cost per run, summed across domains)", "", - "| condition | status | setup_input | setup_output | setup_cache_read | setup_cost | setup_ms |", + "| run | status | setup_input | setup_output | setup_cache_read | setup_cost | setup_ms |", "|---|---|---|---|---|---|---|", ] for row in overall_rows: m = row.get("metrics", {}) lines.append( - "| {cond} | {st} | {ipt:.0f} | {opt:.0f} | {cr:.0f} | ${cost:.3f} | {ms:.0f} |".format( - cond=row.get("run_name", "?"), + "| {run} | {st} | {ipt:.0f} | {opt:.0f} | {cr:.0f} | {cost} | {ms:.0f} |".format( + run=row.get("run_name", "?"), st=m.get("setup_status", "?"), ipt=m.get("setup_input_tokens", 0), opt=m.get("setup_output_tokens", 0), cr=m.get("setup_cache_read_input_tokens", 0), - cost=m.get("setup_cost_usd", 0.0), + cost=_fmt_cost(m.get("setup_cost_usd")), ms=m.get("setup_duration_ms", 0), ) ) @@ -221,20 +236,20 @@ def write_summary_md( "", "## Session totals (setup + all query turns)", "", - "| condition | query_turns | total_input | total_output | total_cache_read | total_cache_create | total_cost |", + "| run | query_turns | total_input | total_output | total_cache_read | total_cache_create | total_cost |", "|---|---|---|---|---|---|---|", ] for row in overall_rows: m = row.get("metrics", {}) lines.append( - "| {cond} | {n} | {ipt} | {opt} | {cr} | {cc} | ${cost:.3f} |".format( - cond=row.get("run_name", "?"), + "| {run} | {n} | {ipt} | {opt} | {cr} | {cc} | {cost} |".format( + run=row.get("run_name", "?"), n=m.get("num_query_turns", 0), ipt=m.get("session_input_tokens", 0), opt=m.get("session_output_tokens", 0), cr=m.get("session_cache_read_input_tokens", 0), cc=m.get("session_cache_creation_input_tokens", 0), - cost=m.get("session_total_cost_usd", 0.0), + cost=_fmt_cost(m.get("session_total_cost_usd")), ) ) @@ -248,53 +263,84 @@ def write_summary_md( lines.append("## Diagnostics") lines.extend(diag_lines) + summary_blocks: list[tuple[str, str]] = [] + for domain in sorted(rows_by_domain): + for row in rows_by_domain[domain]: + text = row.get("tool_use_summary") or "" + if text: + summary_blocks.append((str(row.get("run_name", "?")), text)) + if summary_blocks: + lines += ["", "## Tool-use summaries", ""] + for run_name, text in summary_blocks: + lines.append(f"### {run_name}") + lines.append("") + lines.append(text) + lines.append("") + out = session_dir / "session_summary.md" out.write_text("\n".join(lines) + "\n", encoding="utf-8") return out +def _condition_order(condition: str) -> int: + try: + return CONDITIONS.index(condition) + except ValueError: + return len(CONDITIONS) + + def write_summary( session_dir: Path, - results_by_key: dict[tuple[str, str], list[TrialResult]], + results_by_key: dict[tuple[str, str, str], list[TrialResult]], entries: list[DatasetEntry], config: dict[str, Any], + agent: str, + model: str, config_path: str, ) -> tuple[Path, Path]: entries_by_id = {e.entry_id: e for e in entries} - # Per-(condition, domain) rows. + # Per-(agent, condition, domain) rows. domain_rows: dict[str, list[dict[str, Any]]] = defaultdict(list) - # Roll-up per condition across all domains. - by_condition: dict[str, list[TrialResult]] = defaultdict(list) + # Roll-up per agent/condition across all domains. + by_run: dict[tuple[str, str], list[TrialResult]] = defaultdict(list) - for (cond, domain), results in results_by_key.items(): + for (agent_name, cond, domain), results in results_by_key.items(): if not results: continue + if domain: + artifact_dir = str(Path("trials") / agent_name / cond / domain) + else: + artifact_dir = str(Path("trials") / agent_name / cond) domain_rows[domain].append( _aggregate( results, entries_by_id, - run_name=f"{cond}/{domain}", - artifact_dir=str(Path("trials") / cond / domain) if domain else str(Path("trials") / cond), - extra_tags=(f"domain={domain}",) if domain else (), + run_name=f"{agent_name}/{cond}/{domain}" if domain else f"{agent_name}/{cond}", + artifact_dir=artifact_dir, + extra_tags=(f"agent={agent_name}", f"domain={domain}") if domain else (f"agent={agent_name}",), ) ) - by_condition[cond].extend(results) + by_run[(agent_name, cond)].extend(results) overall_rows: list[dict[str, Any]] = [] - for cond in CONDITIONS: - results = by_condition.get(cond, []) + for agent_name, cond in sorted(by_run, key=lambda x: (x[0], _condition_order(x[1]), x[1])): + results = by_run[(agent_name, cond)] if not results: continue overall_rows.append( _aggregate( results, entries_by_id, - run_name=cond, - artifact_dir=str(Path("trials") / cond), + run_name=f"{agent_name}/{cond}", + artifact_dir=str(Path("trials") / agent_name / cond), + extra_tags=(f"agent={agent_name}",), ) ) + for rows in domain_rows.values(): + rows.sort(key=lambda row: tuple(str(row.get("run_name", "")).split("/", 2)[:2])) + flat_rows = overall_rows + [r for rows in domain_rows.values() for r in rows] json_path = write_session_summary( session_dir=session_dir, @@ -302,5 +348,5 @@ def write_summary( session_type="skill_eval", config_path=config_path, ) - md_path = write_summary_md(session_dir, dict(domain_rows), overall_rows, config) + md_path = write_summary_md(session_dir, dict(domain_rows), overall_rows, config, agent=agent, model=model) return json_path, md_path diff --git a/nemo_retriever/src/nemo_retriever/skill_eval/runner.py b/nemo_retriever/src/nemo_retriever/skill_eval/runner.py index 8c9a74bcf..49cfa5ef8 100644 --- a/nemo_retriever/src/nemo_retriever/skill_eval/runner.py +++ b/nemo_retriever/src/nemo_retriever/skill_eval/runner.py @@ -2,7 +2,7 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Per-trial runner: build sandboxed workdir, spawn `claude -p`, parse outputs.""" +"""Per-trial runner: build sandboxed workdirs, spawn an agent CLI, parse outputs.""" from __future__ import annotations @@ -25,7 +25,13 @@ logger = logging.getLogger(__name__) +BASE_CONDITION = "c1_base" CONDITIONS = ("c1_base", "c2_retriever", "c3_retriever_skill") +SUPPORTED_AGENTS = ("claude", "codex") +DEFAULT_AGENT_MODELS = { + "claude": "claude-opus-4-7", + "codex": "gpt-5.5", +} @functools.lru_cache(maxsize=8) @@ -47,6 +53,7 @@ class TrialResult: total_cost_usd: float model_id: str session_id: str + agent: str = "claude" input_tokens: int = 0 output_tokens: int = 0 cache_read_input_tokens: int = 0 @@ -64,18 +71,16 @@ class TrialResult: judge_score: int | None = None judge_reasoning: str = "" judge_error: str = "" + tool_use_summary: str = "" + cost_available: bool = True def _remap_pdf_paths(text: str, prefixes: tuple[str, ...]) -> str: """Rewrite caller-supplied path prefixes in *text* to ``./pdfs/``. - Some agent-eval manifests' paraphrased prompts hard-code paths from the - dataset source tree. Each trial workdir symlinks the domain's PDFs to - ``./pdfs/``, so the agent only needs the basename — rewriting the prefix - lets the natural-language reference resolve to a real file. - - Prefixes are configured per-run via the ``testdata_prefixes`` config key - (no dataset paths are hardcoded in this module). + Some agent-eval manifests' paraphrased prompts hard-code dataset-source + paths in the user-facing text. Each trial workdir symlinks the domain's + PDFs to ``./pdfs/``, so the agent only needs the basename. """ for prefix in prefixes: text = text.replace(prefix, "./pdfs") @@ -147,45 +152,47 @@ def _copy_skill(skill_source: Path, dest: Path) -> None: def _c1_settings_json() -> str: - """Project-level settings for the c1_base trial. + """Project-level settings for the c1_base Claude trial. - `--permission-mode bypassPermissions` auto-approves tool calls that aren't - explicitly denied; the deny patterns below catch every reasonable path - into the nemo_retriever library so the agent has to fall back on CPU-only - primitives (Read, Grep, pdftotext, etc.). + ``--permission-mode bypassPermissions`` auto-approves tool calls that aren't + explicitly denied; these deny patterns catch every reasonable path into the + nemo_retriever library so Claude has to fall back on CPU-only primitives. """ return json.dumps({"permissions": {"deny": list(_C1_BASH_DENY_PATTERNS)}}, indent=2) + "\n" def _build_condition_workdir( + agent: str, condition: str, root: Path, pdf_source: Path, skill_source: Path, domain: str = "", ) -> Path: - """Build one workdir per condition. Shared across all turns in the session. + """Build one workdir per agent/condition/domain session. Workdir contents: - pdfs/ symlink farm into the source PDF folder - - .claude/ sandbox (settings + per-condition skill copy) - - .bin/retriever shim (c1 only) so retriever is unavailable on PATH - - The agent itself creates any retrieval artifacts (e.g., ./lancedb/) inside the - workdir on the setup turn. + - .claude/ sandbox (settings + per-condition skill copy for Claude) + - .codex/ skill copy for Codex skill-aware installations + - .bin/retriever shim (c1 only) so the retriever CLI is unavailable on PATH """ domain_seg = f"_{domain}" if domain else "" - workdir = root / f"{condition}{domain_seg}_{uuid.uuid4().hex[:8]}" + workdir = root / f"{agent}_{condition}{domain_seg}_{uuid.uuid4().hex[:8]}" workdir.mkdir(parents=True, exist_ok=True) _build_pdf_symlinks(pdf_source, workdir / "pdfs") - (workdir / ".claude").mkdir(parents=True, exist_ok=True) - # c1 gets explicit Bash deny rules; c2/c3 keep the empty settings.json. - settings_text = _c1_settings_json() if condition == "c1_base" else "{}\n" - (workdir / ".claude" / "settings.json").write_text(settings_text, encoding="utf-8") - # c2 and c3 both have retriever installed AND the nemo-retriever skill loaded. - # The c2/c3 distinction is purely the prompt style (NL vs explicit slash command). + + if agent == "claude": + (workdir / ".claude").mkdir(parents=True, exist_ok=True) + settings_text = _c1_settings_json() if condition == "c1_base" else "{}\n" + (workdir / ".claude" / "settings.json").write_text(settings_text, encoding="utf-8") + if condition in ("c2_retriever", "c3_retriever_skill"): - _copy_skill(skill_source, workdir / ".claude" / "skills" / "nemo-retriever") + if agent == "claude": + _copy_skill(skill_source, workdir / ".claude" / "skills" / "nemo-retriever") + elif agent == "codex": + _copy_skill(skill_source, workdir / ".codex" / "skills" / "nemo-retriever") + if condition == "c1_base": _write_shim(workdir / ".bin", "retriever") # Empty HuggingFace cache redirect; env vars are wired up in _env_for. @@ -194,10 +201,7 @@ def _build_condition_workdir( def cleanup_condition_workdir(workdir: Path) -> None: - """Remove a condition's scratch workdir (PDFs symlinks, .claude/, agent-built - artifacts like .venv/, lancedb/, scratch scripts). Called after a session - completes and its results have been persisted to the artifact dir. - """ + """Remove a condition's scratch workdir after results have been persisted.""" if not workdir.exists(): return shutil.rmtree(workdir, ignore_errors=True) @@ -210,8 +214,6 @@ def _env_for(condition: str, workdir: Path) -> dict[str, str]: env["PATH"] = f"{workdir / '.bin'}{os.pathsep}{env.get('PATH', '')}" # Point HuggingFace cache env vars at an empty workdir-local dir so # any HF Python tooling the agent invokes sees no cached models. - # Direct filesystem reads (e.g. `ls ~/.cache/huggingface/`) are - # blocked separately by the Bash deny rules in settings.json. hf_empty = str(workdir / ".hf_empty") env["HF_HOME"] = hf_empty env["HF_HUB_CACHE"] = hf_empty @@ -219,7 +221,7 @@ def _env_for(condition: str, workdir: Path) -> dict[str, str]: return env -def _build_command( +def _build_claude_command( condition: str, model: str, budget_usd: float, @@ -228,10 +230,11 @@ def _build_command( *, resume: bool = False, ) -> list[str]: - """Build the `claude -p` command. First turn uses --session-id; subsequent turns use --resume. + """Build the ``claude --print`` command. - We deliberately do NOT pass --no-session-persistence because multi-turn requires - the session to persist between subprocess invocations. + First turn uses ``--session-id``; subsequent turns use ``--resume``. We + deliberately keep session persistence enabled because this benchmark is + multi-turn. """ cmd = [ "claude", @@ -249,22 +252,73 @@ def _build_command( "--setting-sources", "project", ] - # c2/c3 run fully un-gated. c1 omits --allow-dangerously-skip-permissions - # so the project-level settings.json deny rules are actually consulted by - # Claude Code instead of being short-circuited. + # c2/c3 run fully ungated. c1 omits the dangerous skip flag so the + # project-level deny rules are consulted. if condition != "c1_base": cmd.append("--allow-dangerously-skip-permissions") if resume: cmd.extend(["--resume", session_uuid]) else: cmd.extend(["--session-id", session_uuid]) - # Only c1 disables skills entirely. c2 has the skill loaded but uses NL prompt - # (relying on description-based auto-discovery); c3 explicitly invokes via slash. + # Only c1 disables skills entirely. c2 has the skill loaded but uses an NL + # prompt; c3 explicitly invokes via slash. if condition == "c1_base": cmd.append("--disable-slash-commands") return cmd +def _build_codex_command( + model: str, + session_uuid: str, + workdir: Path, + *, + resume: bool = False, +) -> list[str]: + """Build a non-interactive Codex command. + + Codex assigns the first session id itself; subsequent turns resume the id + parsed from the setup turn's JSONL events. + """ + common = [ + "--json", + "--model", + model, + "--skip-git-repo-check", + "--ignore-user-config", + "--ignore-rules", + "--dangerously-bypass-approvals-and-sandbox", + ] + if resume: + return ["codex", "exec", "resume", *common, session_uuid, "-"] + return [ + "codex", + "exec", + *common, + "--cd", + str(workdir), + "--add-dir", + str(workdir), + "-", + ] + + +def _build_command( + *, + agent: str, + condition: str, + model: str, + budget_usd: float, + session_uuid: str, + workdir: Path, + resume: bool = False, +) -> list[str]: + if agent == "claude": + return _build_claude_command(condition, model, budget_usd, session_uuid, workdir, resume=resume) + if agent == "codex": + return _build_codex_command(model, session_uuid, workdir, resume=resume) + raise ValueError(f"unsupported agent: {agent}") + + def _parse_envelope(raw: str) -> dict[str, Any]: raw = raw.strip() if not raw: @@ -284,7 +338,42 @@ def _parse_envelope(raw: str) -> dict[str, Any]: return {} -def _populate_tokens(result: TrialResult, envelope: dict[str, Any]) -> None: +def _parse_jsonl_events(raw: str) -> list[dict[str, Any]]: + events: list[dict[str, Any]] = [] + for line in raw.splitlines(): + line = line.strip() + if not line: + continue + try: + ev = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(ev, dict): + events.append(ev) + return events + + +def _codex_session_id(events: list[dict[str, Any]], fallback: str) -> str: + for ev in events: + if ev.get("type") != "session_meta": + continue + payload = ev.get("payload") or {} + if isinstance(payload, dict) and payload.get("id"): + return str(payload["id"]) + return fallback + + +def _codex_has_error(events: list[dict[str, Any]]) -> bool: + for ev in events: + payload = ev.get("payload") or {} + if not isinstance(payload, dict): + continue + if payload.get("type") in {"error", "task_failed", "turn_aborted"}: + return True + return False + + +def _populate_claude_tokens(result: TrialResult, envelope: dict[str, Any]) -> None: usage = envelope.get("usage") or {} result.input_tokens = int(usage.get("input_tokens") or 0) result.output_tokens = int(usage.get("output_tokens") or 0) @@ -295,6 +384,59 @@ def _populate_tokens(result: TrialResult, envelope: dict[str, Any]) -> None: result.ephemeral_1h_input_tokens = int(cache_detail.get("ephemeral_1h_input_tokens") or 0) +_CODEX_USAGE_FIELDS = ( + "input_tokens", + "output_tokens", + "cached_input_tokens", + "reasoning_output_tokens", +) + + +def _extract_codex_total_usage(events: list[dict[str, Any]]) -> dict[str, int]: + """Return the most recent cumulative ``total_token_usage`` from codex events. + + Each ``token_count`` event carries running session-wide counters; we want the + last one so deltas between two snapshots equal one turn's true work. + """ + for ev in reversed(events): + if ev.get("type") != "event_msg": + continue + payload = ev.get("payload") or {} + if not isinstance(payload, dict) or payload.get("type") != "token_count": + continue + info = payload.get("info") or {} + if not isinstance(info, dict): + continue + usage = info.get("total_token_usage") or {} + if not isinstance(usage, dict): + continue + return {k: int(usage.get(k) or 0) for k in _CODEX_USAGE_FIELDS} + return {k: 0 for k in _CODEX_USAGE_FIELDS} + + +def _populate_codex_tokens( + result: TrialResult, + current_totals: dict[str, int], + prior_totals: dict[str, int], +) -> None: + """Set per-turn token fields as the delta of cumulative ``total_token_usage``. + + Codex's resumed-session log is append-only across all turns, and each + ``token_count`` event reports cumulative counters, so per-turn cost is the + difference between snapshots taken before and after the subprocess call. + ``output_tokens`` here folds in ``reasoning_output_tokens`` so the column + reflects everything the model emitted, matching Claude's accounting. + """ + + def d(key: str) -> int: + return max(0, current_totals.get(key, 0) - prior_totals.get(key, 0)) + + result.input_tokens = d("input_tokens") + result.output_tokens = d("output_tokens") + d("reasoning_output_tokens") + result.cache_read_input_tokens = d("cached_input_tokens") + result.cache_creation_input_tokens = 0 + + def _parse_output_json(workdir: Path) -> tuple[str, list[dict[str, Any]], str, list[str]]: out_path = workdir / "output.json" errors: list[str] = [] @@ -332,26 +474,43 @@ def _extract_model_id(envelope: dict[str, Any], fallback: str) -> str: return str(envelope.get("model") or fallback) +def _extract_claude_error_detail(envelope: dict[str, Any]) -> str: + for key in ("error", "message", "result"): + value = envelope.get(key) + if value: + return str(value) + + content = envelope.get("content") + if isinstance(content, str) and content: + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict) and item.get("text"): + parts.append(str(item["text"])) + if parts: + return " ".join(parts) + return "" + + _PIPELINE_SEP = re.compile(r"(?:;|&&|\|\||\||\n|\$\(|`)") _ENV_ASSIGN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*=") _WRAPPER_CMDS = {"sudo", "time", "nice", "nohup", "exec", "env", "command", "builtin"} def _retriever_in_command(cmd: str) -> bool: - """Does this shell command line invoke the retriever CLI as a command? + """Return whether this shell command invokes the retriever CLI as a command. - Matches when the **executable** in any pipeline segment is the retriever - CLI — ``retriever``, ``./retriever``, ``/abs/path/retriever``, ``uv run - retriever``, or ``python -m nemo_retriever``. Deliberately does *not* - match cases where ``retriever`` appears only as a path argument (e.g. - ``cat .bin/retriever``, ``ls /path/retriever/``, ``echo "use retriever"``). + Matches when the executable in any pipeline segment is the retriever CLI: + ``retriever``, ``./retriever``, ``/abs/path/retriever``, ``uv run + retriever``, or ``python -m nemo_retriever``. Deliberately does not match + cases where ``retriever`` appears only as a path argument or prose. """ if not cmd: return False for segment in _PIPELINE_SEP.split(cmd): seg = segment.strip() - # Strip leading env-var assignments and command wrappers (sudo, time, ...). while seg: first = seg.split(None, 1) if not first: @@ -371,17 +530,11 @@ def _retriever_in_command(cmd: str) -> bool: if head == "retriever" or head == "./retriever": return True if head.endswith("/retriever") and "/" in head[: -len("/retriever") + 1]: - # An absolute or relative path whose final component is `retriever`, - # e.g. /home/.../venv/bin/retriever. Reject pure ``/retriever`` which - # is implausible as a real binary path. Also reject ``.bin/retriever`` - # paths: c1_base's workdir setup installs a deny-shim with that exact - # name (see ``_write_shim``); invoking the shim is the *opposite* of - # using the real retriever CLI. + # Reject c1_base's deny shim; invoking it is the opposite of using + # the real retriever CLI. if "/.bin/retriever" in head: continue return True - # ``uv run retriever ...`` and ``python -m nemo_retriever ...`` — - # check the first two tokens of the segment. tokens = seg.split() if len(tokens) >= 3 and tokens[0] == "uv" and tokens[1] == "run" and tokens[2] == "retriever": return True @@ -396,33 +549,271 @@ def _retriever_in_command(cmd: str) -> bool: def _claude_session_log_path(workdir: Path, session_uuid: str) -> Path: - """Claude Code persists per-session transcripts at - ``~/.claude/projects//.jsonl`` where ```` is the - project dir with ``/`` and ``_`` both replaced by ``-`` (and a leading ``-`` - preserved for the filesystem root). - """ + """Return Claude Code's per-session JSONL transcript path.""" slug = str(workdir).replace("/", "-").replace("_", "-") if not slug.startswith("-"): slug = "-" + slug return Path.home() / ".claude" / "projects" / slug / f"{session_uuid}.jsonl" -def _scan_transcript_for_signals( - envelope: dict[str, Any], - workdir: Path | None = None, - session_uuid: str | None = None, -) -> tuple[int | None, bool]: - """Detect whether the agent invoked the ``retriever`` CLI. +def _codex_session_log_path(session_uuid: str) -> Path | None: + sessions_root = Path.home() / ".codex" / "sessions" + if not sessions_root.exists(): + return None + matches = sorted( + sessions_root.glob(f"**/*{session_uuid}.jsonl"), + key=lambda p: p.stat().st_mtime if p.exists() else 0, + reverse=True, + ) + return matches[0] if matches else None - Primary signal: scan the Claude Code session jsonl for tool-use entries that - spawn a shell command containing ``retriever``. This catches every actual - invocation, regardless of whether the agent quoted it in its final reply. - Fallback signal: if the session log isn't accessible (older runs, missing - file), look for ``retriever`` in the envelope's ``result`` text — the legacy - proxy. This undercounts but never overcounts. - """ - # Primary: tool-call trace. +def _codex_session_meta_from_log(path: Path) -> dict[str, Any]: + try: + with path.open(encoding="utf-8") as f: + for raw_line in f: + raw_line = raw_line.strip() + if not raw_line: + continue + try: + ev = json.loads(raw_line) + except json.JSONDecodeError: + continue + if ev.get("type") != "session_meta": + continue + payload = ev.get("payload") or {} + return payload if isinstance(payload, dict) else {} + except OSError: + return {} + return {} + + +def _codex_session_log_for_workdir(workdir: Path) -> Path | None: + sessions_root = Path.home() / ".codex" / "sessions" + if not sessions_root.exists(): + return None + workdir_str = str(workdir) + matches = sorted( + sessions_root.glob("**/rollout-*.jsonl"), + key=lambda p: p.stat().st_mtime if p.exists() else 0, + reverse=True, + ) + for path in matches: + meta = _codex_session_meta_from_log(path) + if str(meta.get("cwd") or "") == workdir_str: + return path + return None + + +def _read_jsonl_events(path: Path | None) -> list[dict[str, Any]]: + if path is None: + return [] + try: + return _parse_jsonl_events(path.read_text(encoding="utf-8")) + except OSError: + return [] + + +_TRACE_TOOL_INPUT_CAP = 200 +_TRACE_FINAL_TEXT_CAP = 400 + + +def _truncate(s: str, cap: int) -> str: + s = " ".join(s.split()) + return s if len(s) <= cap else s[: cap - 1] + "..." + + +def _format_tool_input(name: str, inp: dict[str, Any]) -> str: + """Render a Claude tool_use input dict to a single short line.""" + if name == "Bash": + cmd = str(inp.get("command", "")) + return f"Bash: {_truncate(cmd, _TRACE_TOOL_INPUT_CAP)}" + if name == "Read": + path = str(inp.get("file_path", "")) + offset = inp.get("offset") + limit = inp.get("limit") + tail = f" offset={offset} limit={limit}" if offset is not None or limit is not None else "" + return f"Read: {path}{tail}" + if name == "Grep": + pat = str(inp.get("pattern", "")) + path = str(inp.get("path", "")) + return f"Grep: pattern={_truncate(pat, 80)} path={path}" + if name == "Glob": + return f"Glob: {inp.get('pattern', '')}" + if name in ("Edit", "Write"): + return f"{name}: {inp.get('file_path', '')}" + parts = [f"{k}={_truncate(str(v), 80)}" for k, v in inp.items()] + return f"{name}: " + " ".join(parts) if parts else name + + +def _extract_claude_compact_trace(workdir: Path, session_uuid: str) -> str | None: + """Walk a Claude Code JSONL transcript and emit a turn-organized trace.""" + log_path = _claude_session_log_path(workdir, session_uuid) + if not log_path.exists(): + return None + + turn_idx = 0 + lines_out: list[str] = [] + current_assistant_text: list[str] = [] + try: + with log_path.open(encoding="utf-8") as f: + for raw_line in f: + raw_line = raw_line.strip() + if not raw_line: + continue + try: + ev = json.loads(raw_line) + except json.JSONDecodeError: + continue + msg = ev.get("message") or {} + role = msg.get("role") or ev.get("type") + content = msg.get("content") + + if role == "user": + if current_assistant_text: + joined = " ".join(current_assistant_text).strip() + if joined: + lines_out.append(f" assistant: {_truncate(joined, _TRACE_FINAL_TEXT_CAP)}") + current_assistant_text = [] + turn_idx += 1 + user_text = "" + if isinstance(content, str): + user_text = content + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + user_text = str(item.get("text", "")) + break + label = "setup" if turn_idx == 1 else f"query {turn_idx - 1}" + lines_out.append("") + lines_out.append(f"[Turn {turn_idx} - {label}]") + if user_text: + lines_out.append(f" user: {_truncate(user_text, _TRACE_FINAL_TEXT_CAP)}") + elif role == "assistant" and isinstance(content, list): + for item in content: + if not isinstance(item, dict): + continue + itype = item.get("type") + if itype == "tool_use": + name = str(item.get("name", "?")) + inp = item.get("input") or {} + if isinstance(inp, dict): + lines_out.append(f" tool_use {_format_tool_input(name, inp)}") + else: + lines_out.append(f" tool_use {name}") + elif itype == "text": + text = str(item.get("text", "")).strip() + if text: + current_assistant_text.append(text) + except OSError: + return None + + if current_assistant_text: + joined = " ".join(current_assistant_text).strip() + if joined: + lines_out.append(f" assistant: {_truncate(joined, _TRACE_FINAL_TEXT_CAP)}") + + trace = "\n".join(lines_out).strip() + return trace or None + + +def _string_from_content_items(content: Any, *, input_text: bool = True) -> str: + if isinstance(content, str): + return content + if not isinstance(content, list): + return "" + out: list[str] = [] + wanted = "input_text" if input_text else "output_text" + fallback = "text" + for item in content: + if not isinstance(item, dict): + continue + if item.get("type") in {wanted, fallback}: + out.append(str(item.get("text") or "")) + return " ".join(x for x in out if x).strip() + + +def _codex_tool_arguments(payload: dict[str, Any]) -> Any: + args = payload.get("arguments") or "" + if isinstance(args, str): + try: + return json.loads(args) + except json.JSONDecodeError: + return args + return args + + +def _codex_tool_command(payload: dict[str, Any]) -> str: + args = _codex_tool_arguments(payload) + if isinstance(args, dict): + for key in ("cmd", "command"): + value = args.get(key) + if isinstance(value, str): + return value + return json.dumps(args, sort_keys=False) + return str(args) + + +def _format_codex_tool_input(payload: dict[str, Any]) -> str: + name = str(payload.get("name") or "?") + args = _codex_tool_arguments(payload) + if not isinstance(args, str): + args = json.dumps(args, sort_keys=False) + return f"{name}: {_truncate(args, _TRACE_TOOL_INPUT_CAP)}" + + +def _extract_codex_compact_trace(session_uuid: str) -> str | None: + log_path = _codex_session_log_path(session_uuid) + events = _read_jsonl_events(log_path) + if not events: + return None + + turn_idx = 0 + lines_out: list[str] = [] + for ev in events: + etype = ev.get("type") + payload = ev.get("payload") or {} + if not isinstance(payload, dict): + continue + + if etype == "event_msg" and payload.get("type") == "user_message": + turn_idx += 1 + label = "setup" if turn_idx == 1 else f"query {turn_idx - 1}" + lines_out.append("") + lines_out.append(f"[Turn {turn_idx} - {label}]") + text = str(payload.get("message") or "") + if text: + lines_out.append(f" user: {_truncate(text, _TRACE_FINAL_TEXT_CAP)}") + elif etype == "event_msg" and payload.get("type") == "agent_message": + text = str(payload.get("message") or "") + if text: + lines_out.append(f" assistant: {_truncate(text, _TRACE_FINAL_TEXT_CAP)}") + elif etype == "response_item": + ptype = payload.get("type") + if ptype == "function_call": + lines_out.append(f" tool_use {_format_codex_tool_input(payload)}") + elif ptype == "message" and payload.get("role") == "assistant": + text = _string_from_content_items(payload.get("content"), input_text=False) + if text: + lines_out.append(f" assistant: {_truncate(text, _TRACE_FINAL_TEXT_CAP)}") + + trace = "\n".join(lines_out).strip() + return trace or None + + +def extract_compact_trace(agent: str, workdir: Path, session_uuid: str) -> str | None: + if agent == "claude": + return _extract_claude_compact_trace(workdir, session_uuid) + if agent == "codex": + return _extract_codex_compact_trace(session_uuid) + return None + + +def _scan_claude_transcript_for_signals( + envelope: dict[str, Any], + workdir: Path | None, + session_uuid: str | None, +) -> tuple[int | None, bool]: if workdir is not None and session_uuid: log_path = _claude_session_log_path(workdir, session_uuid) if log_path.exists(): @@ -443,25 +834,68 @@ def _scan_transcript_for_signals( for item in content: if not isinstance(item, dict): continue - if item.get("type") != "tool_use": - continue - if item.get("name") != "Bash": + if item.get("type") != "tool_use" or item.get("name") != "Bash": continue cmd = (item.get("input") or {}).get("command") or "" if _retriever_in_command(cmd): return 1, True return None, False except OSError: - pass # fall through to fallback + pass - # Fallback: scan the assistant's final text. text = str(envelope.get("result") or "") used = "retriever " in text or "\nretriever\n" in text return (1 if used else None), used +def _scan_codex_transcript_for_signals( + session_uuid: str, + fallback_events: list[dict[str, Any]], +) -> tuple[int | None, bool]: + log_events = _read_jsonl_events(_codex_session_log_path(session_uuid)) + events = log_events or fallback_events + for ev in events: + if ev.get("type") != "response_item": + continue + payload = ev.get("payload") or {} + if not isinstance(payload, dict) or payload.get("type") != "function_call": + continue + if _retriever_in_command(_codex_tool_command(payload)): + return 1, True + + text_parts: list[str] = [] + for ev in events: + payload = ev.get("payload") or {} + if not isinstance(payload, dict): + continue + if ev.get("type") == "event_msg" and payload.get("type") == "agent_message": + text_parts.append(str(payload.get("message") or "")) + elif ev.get("type") == "response_item" and payload.get("type") == "message": + text_parts.append(_string_from_content_items(payload.get("content"), input_text=False)) + text = "\n".join(text_parts) + used = "retriever " in text or "\nretriever\n" in text + return (1 if used else None), used + + +def _scan_transcript_for_signals( + *, + agent: str, + envelope: dict[str, Any], + codex_events: list[dict[str, Any]], + workdir: Path | None = None, + session_uuid: str | None = None, +) -> tuple[int | None, bool]: + """Detect whether the agent invoked the ``retriever`` CLI.""" + if agent == "claude": + return _scan_claude_transcript_for_signals(envelope, workdir, session_uuid) + if agent == "codex" and session_uuid: + return _scan_codex_transcript_for_signals(session_uuid, codex_events) + return None, False + + def _run_one_turn( *, + agent: str, condition: str, prompt: str, trial_id: str, @@ -477,15 +911,21 @@ def _run_one_turn( timeout_s: int, model: str, ) -> TrialResult: - """Execute one turn. Query turns (is_setup=False) expect the agent to write - ./output.json; the setup turn does not.""" + """Execute one turn. Query turns expect the agent to write ``./output.json``.""" out_path = workdir / "output.json" if out_path.exists(): out_path.unlink() domain_tag = f"[{domain}] " if domain else "" label = "setup" if is_setup else f"entry_id={entry_id}, query_id={query_id}" - logger.info("turn %d for %s %s(%s)", turn_idx + 1, condition, domain_tag, label) + logger.info("turn %d for %s/%s %s(%s)", turn_idx + 1, agent, condition, domain_tag, label) + + prior_codex_usage: dict[str, int] = {k: 0 for k in _CODEX_USAGE_FIELDS} + if agent == "codex": + prior_log = _codex_session_log_path(session_uuid) + if prior_log is not None: + prior_codex_usage = _extract_codex_total_usage(_read_jsonl_events(prior_log)) + t0 = time.monotonic() try: proc = subprocess.run( @@ -512,34 +952,74 @@ def _run_one_turn( total_cost_usd=0.0, model_id=model, session_id=session_uuid, + agent=agent, errors=[f"turn exceeded {timeout_s}s wall timeout"], is_setup=is_setup, domain=domain, + cost_available=(agent == "claude"), + ) + + elapsed_ms = int((time.monotonic() - t0) * 1000) + envelope: dict[str, Any] = {} + codex_events: list[dict[str, Any]] = [] + token_events: list[dict[str, Any]] = [] + if agent == "claude": + envelope = _parse_envelope(proc.stdout) + agent_error = bool(envelope.get("is_error", False)) + duration_ms = int(envelope.get("duration_ms") or elapsed_ms) + duration_api_ms = int(envelope.get("duration_api_ms") or 0) + total_cost_usd = float(envelope.get("total_cost_usd") or 0.0) + model_id = _extract_model_id(envelope, fallback=model) + actual_session_id = str(envelope.get("session_id") or session_uuid) + else: + codex_events = _parse_jsonl_events(proc.stdout) + agent_error = _codex_has_error(codex_events) + duration_ms = elapsed_ms + duration_api_ms = 0 + total_cost_usd = 0.0 + model_id = model + log_path = _codex_session_log_path(session_uuid) + if log_path is None and is_setup: + log_path = _codex_session_log_for_workdir(workdir) + token_events = _read_jsonl_events(log_path) if log_path is not None else codex_events + actual_session_id = _codex_session_id( + token_events, + fallback=_codex_session_id(codex_events, fallback=session_uuid), ) - envelope = _parse_envelope(proc.stdout) stderr = proc.stderr.strip() result = TrialResult( trial_id=trial_id, condition=condition, entry_id=entry_id, query_id=query_id, - status="ok" if proc.returncode == 0 and not envelope.get("is_error", False) else "error", + status="ok" if proc.returncode == 0 and not agent_error else "error", extraction_method="n/a" if is_setup else "output_json", - duration_ms=int(envelope.get("duration_ms") or (time.monotonic() - t0) * 1000), - duration_api_ms=int(envelope.get("duration_api_ms") or 0), + duration_ms=duration_ms, + duration_api_ms=duration_api_ms, num_turns=turn_idx + 1, - total_cost_usd=float(envelope.get("total_cost_usd") or 0.0), - model_id=_extract_model_id(envelope, fallback=model), - session_id=str(envelope.get("session_id") or session_uuid), + total_cost_usd=total_cost_usd, + model_id=model_id, + session_id=actual_session_id, + agent=agent, is_setup=is_setup, domain=domain, + cost_available=(agent == "claude"), ) - _populate_tokens(result, envelope) + if agent == "claude": + _populate_claude_tokens(result, envelope) + else: + current_codex_usage = _extract_codex_total_usage(token_events or codex_events) + _populate_codex_tokens(result, current_codex_usage, prior_codex_usage) if proc.returncode != 0: result.errors.append(f"non-zero exit {proc.returncode}") - if envelope.get("is_error"): + if agent == "claude" and envelope.get("is_error"): result.errors.append(f"envelope is_error: {envelope.get('subtype') or '?'}") + detail = _extract_claude_error_detail(envelope) + if detail: + result.errors.append(f"claude error: {detail[:500]}") + if agent == "codex" and agent_error: + result.errors.append("codex event stream reported an error") if stderr: result.errors.append(f"stderr: {stderr[:500]}") @@ -557,24 +1037,39 @@ def _run_one_turn( if out_path.exists(): out_path.rename(workdir / f"output_e{entry_id}.json") - first_use, used = _scan_transcript_for_signals(envelope, workdir=workdir, session_uuid=session_uuid) + first_use, used = _scan_transcript_for_signals( + agent=agent, + envelope=envelope, + codex_events=codex_events, + workdir=workdir, + session_uuid=actual_session_id, + ) result.retriever_first_use_turn = first_use result.retriever_used_ever = used - # c1 has the skill unavailable; leave skill_fired=None to distinguish from "loaded but didn't fire". + # c1 has the skill unavailable; leave skill_fired=None to distinguish from + # "loaded but didn't fire". if condition in ("c2_retriever", "c3_retriever_skill"): result.skill_fired = used and (first_use is not None) and first_use <= 2 return result +UNSCORABLE_JUDGE_ERRORS: frozenset[str] = frozenset({"no_ground_truth", "empty_candidate"}) + + def _apply_judge(judge: Any, entry: DatasetEntry, result: TrialResult) -> None: """Score ``result.final_answer`` against ``entry.ground_truth_answer``. - Mutates the result in place. Skips silently when the judge is unset, the - ground-truth answer is empty, or the trial didn't produce a final answer. - Errors are recorded on the result rather than raised so a flaky judge - endpoint never breaks an in-flight session. + Missing ground truth and empty candidates are recorded as terminal + ``judge_error`` values so ``rescore`` can skip intrinsically unscorable + trials instead of retrying them forever. """ - if judge is None or not entry.ground_truth_answer or not result.final_answer: + if judge is None: + return + if not entry.ground_truth_answer: + result.judge_error = "no_ground_truth" + return + if not result.final_answer: + result.judge_error = "empty_candidate" return try: verdict = judge.judge( @@ -582,7 +1077,7 @@ def _apply_judge(judge: Any, entry: DatasetEntry, result: TrialResult) -> None: reference=entry.ground_truth_answer, candidate=result.final_answer, ) - except Exception as exc: # defensive — LLMJudge already catches, but be safe. + except Exception as exc: result.judge_error = f"judge_invocation_error: {exc}" logger.warning("LLMJudge raised for entry_id=%s: %s", result.entry_id, exc, exc_info=True) return @@ -594,6 +1089,7 @@ def _apply_judge(judge: Any, entry: DatasetEntry, result: TrialResult) -> None: def run_condition( *, + agent: str, condition: str, entries: list[DatasetEntry], workdir_root: Path, @@ -607,20 +1103,17 @@ def run_condition( judge: Any = None, testdata_prefixes: tuple[str, ...] = (), ) -> tuple[Path, list[TrialResult]]: - """Run one Claude Code session covering setup + all `entries` for `condition`. - - Turn 1 creates the session via --session-id; subsequent turns resume it. The - first TrialResult has is_setup=True; the rest are query results, one per entry. - All ``entries`` are expected to share the same ``domain`` (the caller groups - by domain so each session sees a single PDF corpus). - """ + """Run one agent session covering setup + all entries for one condition.""" + if agent not in SUPPORTED_AGENTS: + raise ValueError(f"unsupported agent: {agent}") if condition not in CONDITIONS: raise ValueError(f"unknown condition: {condition}") - workdir = _build_condition_workdir(condition, workdir_root, pdf_source, skill_source, domain=domain) + workdir = _build_condition_workdir(agent, condition, workdir_root, pdf_source, skill_source, domain=domain) session_uuid = str(uuid.uuid4()) env = _env_for(condition, workdir) logger.info( - "starting session for %s/%s: workdir=%s session_id=%s", + "starting session for %s/%s/%s: workdir=%s session_id=%s", + agent, condition, domain or "default", workdir, @@ -629,9 +1122,18 @@ def run_condition( results: list[TrialResult] = [] - setup_trial_id = f"{condition}_{domain or 'default'}_setup_t1" - setup_cmd = _build_command(condition, model, budget_usd, session_uuid, workdir, resume=False) + setup_trial_id = f"{agent}_{condition}_{domain or 'default'}_setup_t1" + setup_cmd = _build_command( + agent=agent, + condition=condition, + model=model, + budget_usd=budget_usd, + session_uuid=session_uuid, + workdir=workdir, + resume=False, + ) setup_result = _run_one_turn( + agent=agent, condition=condition, prompt=_render_setup_prompt(condition, domain_label), trial_id=setup_trial_id, @@ -649,13 +1151,33 @@ def run_condition( ) results.append(setup_result) - resume_cmd = _build_command(condition, model, budget_usd, session_uuid, workdir, resume=True) + if setup_result.status != "ok": + logger.warning( + "setup turn failed for %s/%s/%s; skipping %d query turns", + agent, + condition, + domain or "default", + len(entries), + ) + return workdir, results + + session_uuid = setup_result.session_id or session_uuid + resume_cmd = _build_command( + agent=agent, + condition=condition, + model=model, + budget_usd=budget_usd, + session_uuid=session_uuid, + workdir=workdir, + resume=True, + ) for i, entry in enumerate(entries): turn_idx = i + 1 result = _run_one_turn( + agent=agent, condition=condition, prompt=_render_prompt(entry, condition, testdata_prefixes), - trial_id=f"{condition}_{domain or 'default'}_e{entry.entry_id}_t{turn_idx + 1}", + trial_id=f"{agent}_{condition}_{domain or 'default'}_e{entry.entry_id}_t{turn_idx + 1}", entry_id=entry.entry_id, query_id=entry.query_id, domain=domain, @@ -674,10 +1196,43 @@ def run_condition( def save_trial(result: TrialResult, session_dir: Path) -> Path: - parts = [session_dir, "trials", result.condition] + parts = [session_dir, "trials", result.agent, result.condition] if result.domain: parts.append(result.domain) out = Path(*[str(p) for p in parts]) / f"{result.trial_id}.json" out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(asdict(result), indent=2) + "\n", encoding="utf-8") return out + + +def archive_session_log( + *, + session_dir: Path, + agent: str, + condition: str, + domain: str, + session_uuid: str, + workdir: Path, +) -> Path | None: + """Copy the agent's rollout log into the artifact dir so it survives ``cleanup_condition_workdir``. + + Without this, the per-trial JSONs are the only persistent record of the run — + you cannot retroactively recompute token deltas, tool-use signals, or anything + else that requires the raw event stream. + """ + if agent == "claude": + src = _claude_session_log_path(workdir, session_uuid) + elif agent == "codex": + src = _codex_session_log_path(session_uuid) + else: + return None + if src is None or not src.exists(): + return None + parts = [session_dir, "trials", agent, condition] + if domain: + parts.append(domain) + logs_dir = Path(*[str(p) for p in parts]) / "logs" + logs_dir.mkdir(parents=True, exist_ok=True) + dest = logs_dir / src.name + shutil.copy2(src, dest) + return dest diff --git a/nemo_retriever/src/nemo_retriever/skill_eval/trace_summarizer.py b/nemo_retriever/src/nemo_retriever/skill_eval/trace_summarizer.py new file mode 100644 index 000000000..1a6b087eb --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/skill_eval/trace_summarizer.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""LLM-generated tool-use summaries via the ``claude`` CLI. + +Reads a compact trace of one agent session (setup turn + N query turns) and +asks a strong Anthropic model to narrate what the agent did: which tools it +called, in what order, what strategy it took, and where it improvised. + +Shells out to ``claude --print`` so it reuses Claude Code's existing auth. Each +call runs in a neutral temp cwd with ``--setting-sources user`` so project-level +skills or settings do not leak into the summarization session. +""" + +from __future__ import annotations + +import logging +import subprocess +import tempfile + +logger = logging.getLogger(__name__) + +_SUMMARIZER_PROMPT_TEMPLATE = """\ +You are summarizing the tool-use trace of a coding agent that just ran an +information-retrieval benchmark over a corpus of PDFs. + +Produce a concise markdown narrative with these sections: + +**Overall strategy** - one or two sentences. What approach did the agent take? +Did it build an index, fall back to grep/pdftotext, use a skill? + +**Tool-use breakdown** - bulleted list of tool names with counts and one or two +representative invocations each. Keep inputs short. + +**Notable patterns** - retries, dead ends, fallback chains, suspicious behavior. +Skip this section if nothing stands out. + +**Per-question variation** - only include if the agent's approach changed +between query turns. Otherwise omit. + +Be terse. Aim for under 250 words total. Do not editorialize about whether the +strategy was good or bad; just describe what happened. + +--- + +Condition: {condition} +Domain: {domain} + +Trace: +{trace} +""" + +_DEFAULT_MODEL = "claude-opus-4-7" + + +class TraceSummarizer: + """Per-session tool-use narrator backed by the ``claude`` CLI.""" + + def __init__( + self, + *, + model: str = _DEFAULT_MODEL, + timeout: float = 120.0, + ): + self.model = model + self._timeout = timeout + + @classmethod + def from_kwargs(cls, **kwargs) -> "TraceSummarizer": + return cls(**kwargs) + + def summarize(self, condition: str, domain: str, trace: str) -> str: + """Return a markdown narrative of ``trace``. Empty string on failure.""" + if not trace.strip(): + return "" + + prompt = _SUMMARIZER_PROMPT_TEMPLATE.format(condition=condition, domain=domain, trace=trace) + cmd = [ + "claude", + "--print", + "--model", + self.model, + "--setting-sources", + "user", + ] + with tempfile.TemporaryDirectory(prefix="skill_eval_summarize_") as tmpdir: + try: + proc = subprocess.run( + cmd, + input=prompt, + capture_output=True, + text=True, + timeout=self._timeout, + cwd=tmpdir, + check=False, + ) + except subprocess.TimeoutExpired: + logger.warning("trace summarizer timed out after %ss", self._timeout) + return "" + except FileNotFoundError: + logger.warning("trace summarizer: `claude` CLI not on PATH") + return "" + + if proc.returncode != 0: + logger.warning( + "trace summarizer exited %d: %s", + proc.returncode, + (proc.stderr or "")[:300], + ) + return "" + return (proc.stdout or "").strip()