From 4727f1c7ef6ab92d1ae8cb6ebcb498527a047e43 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 27 Feb 2026 15:14:43 -0500 Subject: [PATCH 01/29] Unify winrate HF upload to hf repo with winrate_dir --- docs/medarc-eval-winrate.md | 9 +++-- medarc_verifiers/cli/hf/sync.py | 29 ++++++++++++++- medarc_verifiers/cli/main.py | 66 ++++++++++++++++++++++++--------- tests/test_cli/test_main.py | 24 ++++++++++-- 4 files changed, 101 insertions(+), 27 deletions(-) diff --git a/docs/medarc-eval-winrate.md b/docs/medarc-eval-winrate.md index d1f50e99..2882c0a5 100644 --- a/docs/medarc-eval-winrate.md +++ b/docs/medarc-eval-winrate.md @@ -161,7 +161,7 @@ medarc-eval winrate --weight-policy ln ```bash medarc-eval winrate \ - --hf-processed-repo your-org/processed-benchmarks \ + --hf-repo your-org/processed-benchmarks \ --hf-processed-pull \ --hf-token $HF_TOKEN ``` @@ -170,7 +170,8 @@ medarc-eval winrate \ ```bash medarc-eval winrate \ - --hf-winrate-repo your-org/winrate-results \ + --hf-repo your-org/processed-benchmarks \ + --hf-winrate-dir winrate \ --hf-token $HF_TOKEN \ --hf-private ``` @@ -186,8 +187,8 @@ missing_policy: neg-inf weight_policy: ln hf: - repo: your-org/processed-data # Pull processed from here - winrate_repo: your-org/winrate-results # Upload results here + repo: your-org/processed-data # Pull processed from here; upload winrate here + winrate_dir: winrate # Subdirectory in repo for winrate artifacts (default: winrate) branch: main token: ${HF_TOKEN} private: true diff --git a/medarc_verifiers/cli/hf/sync.py b/medarc_verifiers/cli/hf/sync.py index 9f462f9d..48ed64a5 100644 --- a/medarc_verifiers/cli/hf/sync.py +++ b/medarc_verifiers/cli/hf/sync.py @@ -9,6 +9,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, Sequence +from medarc_verifiers.utils.pathing import resolve_under + if TYPE_CHECKING: from medarc_verifiers.cli.process.writer import EnvWriteSummary @@ -166,6 +168,7 @@ def sync_files_to_hub( request_timeout_s: float | None = None, retries: int = 3, max_files_per_commit: int | None = None, + path_in_repo_prefix: str | None = None, is_tty: bool = False, assume_yes: bool = False, prompt_func: Callable[[str], str] | None = None, @@ -195,6 +198,7 @@ def sync_files_to_hub( _configure_hf_http_timeout(float(request_timeout_s)) api = HfApi(token=token) + repo_prefix = _normalize_repo_path_prefix(path_in_repo_prefix) if max_files_per_commit is None or max_files_per_commit <= 0: batches = [file_list] @@ -207,7 +211,10 @@ def sync_files_to_hub( for batch_index, batch_files in enumerate(batches, start=1): operations = [ - CommitOperationAdd(path_in_repo=rel_path, path_or_fileobj=str(output_dir / rel_path)) + CommitOperationAdd( + path_in_repo=_join_repo_path(repo_prefix, rel_path), + path_or_fileobj=str(output_dir / rel_path), + ) for rel_path in batch_files ] commit_message = message @@ -264,6 +271,26 @@ def sync_files_to_hub( time.sleep(delay) +def _normalize_repo_path_prefix(value: str | None) -> str | None: + if value is None: + return None + raw = str(value).strip().replace("\\", "/").strip("/") + if not raw: + return None + candidate = resolve_under(Path("."), raw) + if candidate is None: + raise ValueError(f"Invalid path_in_repo_prefix: {value!r}") + normalized = candidate.as_posix().lstrip("./") + return normalized or None + + +def _join_repo_path(prefix: str | None, rel_path: str) -> str: + rel = rel_path.strip().replace("\\", "/").lstrip("/") + if not prefix: + return rel + return f"{prefix}/{rel}" if rel else prefix + + def sync_to_hub( env_summaries: Sequence[EnvWriteSummary], config: HFSyncConfig, diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index 72d0a20e..1bc076f0 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -47,6 +47,7 @@ slugify, validate_simple_name, ) +from medarc_verifiers.utils.pathing import resolve_under from medarc_verifiers.cli.winrate import ( WinrateConfig, _resolve_source, @@ -465,7 +466,7 @@ def build_winrate_parser() -> argparse.ArgumentParser: "per-model uses the legacy behavior where each model may be averaged over a different dataset set." ), ) - parser.add_argument("--hf-processed-repo", help="Hugging Face repo id for processed dataset download.") + parser.add_argument("--hf-repo", help="Hugging Face repo id used for processed download and winrate upload.") parser.add_argument( "--hf-processed-pull", action="store_true", @@ -474,7 +475,11 @@ def build_winrate_parser() -> argparse.ArgumentParser: ) parser.add_argument("--hf-branch", help="Target HF branch or revision for processed download.") parser.add_argument("--hf-token", help="Auth token for HF operations.") - parser.add_argument("--hf-winrate-repo", help="Hugging Face repo id for winrate artifact upload.") + parser.add_argument( + "--hf-winrate-dir", + default=None, + help="Path under the HF repo where winrate artifacts are uploaded (default: winrate).", + ) parser.add_argument( "--hf-private", action=argparse.BooleanOptionalAction, @@ -665,7 +670,7 @@ def _run_process_mode(argv: Sequence[str]) -> int: if winrate_args is None: winrate_args = _build_winrate_args_from_config(Path(args.winrate), parser=parser) winrate_args.processed_dir = options.output_dir - winrate_args.hf_processed_repo = None + winrate_args.hf_repo = None winrate_args.hf_processed_pull = False winrate_cfg = WinrateConfig( missing_policy=winrate_args.missing_policy, @@ -697,13 +702,15 @@ def _run_process_mode(argv: Sequence[str]) -> int: "Computed win rates for %d dataset(s): %s", len(winrate_result.datasets), winrate_result.output_path ) print_winrate_summary_markdown(winrate_result.result) - if winrate_args.hf_winrate_repo: + if options.hf_config and options.hf_config.repo_id: _upload_winrate_outputs( output_dir=winrate_args.output_dir, output_paths=winrate_result.output_paths, - repo_id=winrate_args.hf_winrate_repo, - token=winrate_args.hf_token, - private=bool(winrate_args.hf_private), + repo_id=options.hf_config.repo_id, + token=options.hf_config.token, + branch=options.hf_config.branch, + private=bool(options.hf_config.private), + winrate_dir=winrate_args.hf_winrate_dir, assume_yes=bool(args.yes), ) @@ -750,12 +757,18 @@ def _load_config_payload(path: Path, *, mode: Literal["process", "winrate"]) -> def _normalize_mode_payload(payload: dict[str, Any], *, mode: Literal["process", "winrate"]) -> None: + if mode == "winrate": + if "hf_processed_repo" in payload and "hf_repo" not in payload: + payload["hf_repo"] = payload["hf_processed_repo"] + if "hf_winrate_repo" in payload: + raise ValueError("Winrate config field 'hf_winrate_repo' was removed; use 'hf.repo' and 'hf.winrate_dir'.") + hf_payload = payload.get("hf") if isinstance(hf_payload, Mapping): for key, value in hf_payload.items(): if mode == "winrate": if key == "repo": - payload.setdefault("hf_processed_repo", value) + payload.setdefault("hf_repo", value) continue if key == "branch": payload.setdefault("hf_branch", value) @@ -766,6 +779,10 @@ def _normalize_mode_payload(payload: dict[str, Any], *, mode: Literal["process", if key == "private": payload.setdefault("hf_private", value) continue + if key == "winrate_repo": + raise ValueError( + "Winrate config field 'hf.winrate_repo' was removed; use 'hf.repo' and 'hf.winrate_dir'." + ) payload.setdefault(f"hf_{key}", value) if "exclude_datasets" not in payload and "exclude_dataset" in payload: @@ -783,9 +800,9 @@ def _load_and_apply_config( ) -> None: try: payload = _load_config_payload(path, mode=mode) + _normalize_mode_payload(payload, mode=mode) except (FileNotFoundError, ValueError) as exc: parser.error(str(exc)) - _normalize_mode_payload(payload, mode=mode) path_fields = { "process": { @@ -811,8 +828,8 @@ def _load_and_apply_config( "weight_policy": "weight_policy", "partial_datasets": "partial_datasets", "dataset_coverage": "dataset_coverage", - "hf_processed_repo": "hf_processed_repo", - "hf_winrate_repo": "hf_winrate_repo", + "hf_repo": "hf_repo", + "hf_winrate_dir": "hf_winrate_dir", "hf_branch": "hf_branch", "hf_token": "hf_token", }, @@ -891,9 +908,9 @@ def _build_winrate_args_from_config(path: Path, *, parser: argparse.ArgumentPars exclude_model=None, exclude_dataset=None, partial_datasets=None, - hf_processed_repo=None, + hf_repo=None, hf_processed_pull=None, - hf_winrate_repo=None, + hf_winrate_dir=None, hf_branch=None, hf_token=None, hf_private=None, @@ -935,6 +952,7 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", "exclude_dataset": [], "partial_datasets": "strict", "hf_processed_pull": False, + "hf_winrate_dir": "winrate", "hf_private": False, "yes": False, }, @@ -955,11 +973,19 @@ def _upload_winrate_outputs( output_paths: Sequence[Path], repo_id: str, token: str | None, + branch: str | None, private: bool, + winrate_dir: str | None, assume_yes: bool = False, ) -> None: if not output_paths: return + raw_dir = "winrate" if winrate_dir is None else str(winrate_dir).strip() + if not raw_dir: + raw_dir = "winrate" + if resolve_under(Path("."), raw_dir) is None: + logger.error("Invalid winrate_dir '%s'; skipping upload.", winrate_dir) + return output_dir = Path(output_dir) files: list[str] = [] for path in output_paths: @@ -981,6 +1007,8 @@ def _upload_winrate_outputs( token=token, private=private, message=message, + branch=branch, + path_in_repo_prefix=raw_dir, is_tty=sys.stdin.isatty(), assume_yes=assume_yes, prompt_func=input, @@ -996,17 +1024,17 @@ def _run_winrate_mode(argv: Sequence[str]) -> int: _finalize_config_args(args, mode="winrate") hf_config = HFSyncConfig.from_cli( - repo=args.hf_processed_repo, + repo=args.hf_repo, branch=args.hf_branch, token=args.hf_token, - private=False, + private=bool(args.hf_private), dry_run=False, ) if args.list_models: source_dir, datasets, source_desc = _resolve_source( args.processed_dir, - hf_config=hf_config if args.hf_processed_repo else None, + hf_config=hf_config if args.hf_repo else None, hf_processed_pull=bool(args.hf_processed_pull), ) if args.exclude_dataset: @@ -1054,13 +1082,15 @@ def _run_winrate_mode(argv: Sequence[str]) -> int: logger.info("Computed win rates for %d dataset(s): %s", len(winrate_result.datasets), winrate_result.output_path) print_winrate_summary_markdown(winrate_result.result) - if args.hf_winrate_repo: + if args.hf_repo: _upload_winrate_outputs( output_dir=args.output_dir, output_paths=winrate_result.output_paths, - repo_id=args.hf_winrate_repo, + repo_id=args.hf_repo, token=args.hf_token, + branch=args.hf_branch, private=bool(args.hf_private), + winrate_dir=args.hf_winrate_dir, assume_yes=bool(args.yes), ) return 0 diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index 4bee1d52..512da2b7 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -1876,11 +1876,16 @@ def fake_run_winrate( } return SimpleNamespace( output_path=tmp_path / "out.json", + output_paths=[tmp_path / "out.json"], result={"models": {}}, datasets=[("demo-env", [Path("demo-env.parquet")])], ) + def fake_sync_files_to_hub(**kwargs): + captured["upload"] = kwargs + monkeypatch.setattr(main, "run_winrate", fake_run_winrate) + monkeypatch.setattr(main, "sync_files_to_hub", fake_sync_files_to_hub) monkeypatch.setattr(main, "print_winrate_summary_markdown", lambda *_args, **_kwargs: None) exit_code = main.main(["winrate", "--config", str(cfg_path), "--processed-at", "2024-01-01T00:00:00Z"]) @@ -1895,6 +1900,12 @@ def fake_run_winrate( assert cfg.weight_cap == 99 assert cfg.include_models == ("alpha", "beta") assert cfg.exclude_models == ("gamma",) + assert captured["run_kwargs"]["hf_config"] is not None + assert captured["run_kwargs"]["hf_config"].repo_id == "medarc/demo" + upload = captured.get("upload") + assert upload is not None + assert upload["repo_id"] == "medarc/demo" + assert upload["path_in_repo_prefix"] == "winrate" exit_code = main.main( [ @@ -1936,9 +1947,8 @@ def test_process_cli_runs_winrate_post_step(monkeypatch: pytest.MonkeyPatch, tmp output_dir: winrate-out output_name: from-config missing_policy: zero - hf_processed_repo: ignored/also - hf_winrate_repo: medarc/winrate - hf_token: secret-token + hf: + winrate_dir: winrate-post """, encoding="utf-8", ) @@ -1981,6 +1991,7 @@ def fake_sync_files_to_hub( "message": message, "branch": branch, "dry_run": dry_run, + **_kw, } monkeypatch.setattr(main, "run_process", fake_run_process) @@ -1997,6 +2008,10 @@ def fake_sync_files_to_hub( str(tmp_path / "processed"), "--winrate", str(cfg_path), + "--hf-repo", + "medarc/shared", + "--hf-token", + "secret-token", ] ) assert exit_code == 0 @@ -2006,9 +2021,10 @@ def fake_sync_files_to_hub( assert captured["run_kwargs"]["hf_processed_pull"] is False upload = captured.get("upload") assert upload is not None - assert upload["repo_id"] == "medarc/winrate" + assert upload["repo_id"] == "medarc/shared" assert upload["token"] == "secret-token" assert upload["files"] == ["winrate.json"] + assert upload["path_in_repo_prefix"] == "winrate-post" def test_process_config_sets_winrate_path(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: From 526078e4af9b5ed115f6363ddfb3be5e621a5d18 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sat, 28 Feb 2026 13:34:56 -0500 Subject: [PATCH 02/29] Remove process manifest preflight --- medarc_verifiers/cli/main.py | 36 ------------- medarc_verifiers/cli/process/rows.py | 6 ++- tests/test_cli/test_main.py | 2 +- tests/test_cli/test_process_pipeline.py | 69 +++++++++++++++++++++++-- 4 files changed, 70 insertions(+), 43 deletions(-) diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index 1bc076f0..146d29fd 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -33,7 +33,6 @@ from medarc_verifiers.cli._job_executor import ExecutorSettings, JobExecutionResult, execute_jobs from medarc_verifiers.cli._manifest import MANIFEST_FILENAME, ManifestJobEntry, RunManifest, compute_snapshot_checksum from medarc_verifiers.cli._manifest_planner import ManifestPlanner -from medarc_verifiers.cli._manifest_tools import format_validation_issues, validate_manifests_in_runs from medarc_verifiers.cli._schemas import EnvironmentConfigSchema, EnvironmentExportConfig from medarc_verifiers.cli._single_run import run_single_mode from medarc_verifiers.cli.hf import HFSyncConfig, sync_files_to_hub @@ -287,18 +286,6 @@ def build_process_parser() -> argparse.ArgumentParser: ) parser.add_argument("--processed-at", default=None, help="Override processed_at timestamp (ISO8601).") parser.add_argument("--dry-run", action="store_true", default=None, help="Plan processing without writing outputs.") - parser.add_argument( - "--validate-manifest", - action=argparse.BooleanOptionalAction, - default=None, - help="Validate run manifests before processing (default: enabled).", - ) - parser.add_argument( - "--strict-manifest", - action="store_true", - default=None, - help="Treat manifest validation problems as errors.", - ) parser.add_argument( "--process-incomplete", dest="process_incomplete", @@ -610,8 +597,6 @@ def _run_process_mode(argv: Sequence[str]) -> int: "exclude_models": args.exclude_model or [], "dry_run": bool(args.dry_run), "clean": bool(args.clean), - "validate_manifest": bool(args.validate_manifest), - "strict_manifest": bool(args.strict_manifest), "only_complete_runs": not bool(args.process_incomplete), "hf_repo": args.hf_repo, "hf_pull_policy": args.hf_pull_policy, @@ -638,23 +623,6 @@ def _run_process_mode(argv: Sequence[str]) -> int: max_workers=args.max_workers, ) - if args.validate_manifest: - validation = validate_manifests_in_runs(options.runs_dir, strict=bool(args.strict_manifest)) - for line in format_validation_issues(validation.issues): - if line.startswith("[ERROR]"): - logger.error("%s", line) - else: - logger.warning("%s", line) - logger.info( - "Manifest preflight: checked %d manifest(s), %d job(s), %d issue(s).", - validation.manifests_checked, - validation.jobs_checked, - len(validation.issues), - ) - if validation.has_errors: - logger.error("Manifest validation failed in strict mode; aborting process.") - return 1 - try: result = run_process(options, env_export_map=env_export_map) except Exception as exc: # noqa: BLE001 @@ -840,8 +808,6 @@ def _load_and_apply_config( "clean": "clean", "yes": "yes", "process_incomplete": "process_incomplete", - "validate_manifest": "validate_manifest", - "strict_manifest": "strict_manifest", "hf_private": "hf_private", }, "winrate": {"hf_processed_pull": "hf_processed_pull", "hf_private": "hf_private"}, @@ -933,8 +899,6 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", "clean": False, "yes": False, "process_incomplete": False, - "validate_manifest": True, - "strict_manifest": False, "exclude_dataset": [], "exclude_model": [], }, diff --git a/medarc_verifiers/cli/process/rows.py b/medarc_verifiers/cli/process/rows.py index d06cb1e4..80d27264 100644 --- a/medarc_verifiers/cli/process/rows.py +++ b/medarc_verifiers/cli/process/rows.py @@ -28,8 +28,10 @@ def load_rows( """Load results.jsonl rows and attach manifest metadata.""" record = metadata.record if not record.has_results: - logger.debug("Run %s missing results.jsonl; skipping.", record.job_id) - return [] + raise FileNotFoundError( + "Missing results.jsonl for selected run " + f"(job_run_id={record.manifest.job_run_id}, job_id={record.job_id}, path={record.results_path})" + ) results_path = record.results_path extras_keys = {column for column in extra_columns or () if column} diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index 512da2b7..dc1c5771 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -2178,7 +2178,7 @@ def fake_run(options, env_export_map): monkeypatch.setattr(main, "run_process", fake_run) exit_code = main.main( - ["process", "--config", str(cfg_path), "--max-workers", "2", "--dry-run", "--no-validate-manifest"] + ["process", "--config", str(cfg_path), "--max-workers", "2", "--dry-run"] ) assert exit_code == 0 assert captured["options"].max_workers == 2 diff --git a/tests/test_cli/test_process_pipeline.py b/tests/test_cli/test_process_pipeline.py index b55eaf0a..83e3a0b4 100644 --- a/tests/test_cli/test_process_pipeline.py +++ b/tests/test_cli/test_process_pipeline.py @@ -93,6 +93,8 @@ def _write_run( reward: float, env_id: str = "demo-env-rollout3", model_id: str = "gpt-mini", + status: str = "completed", + results_text: str | None = None, ) -> Path: runs_dir = tmp_path / "runs" run_dir = runs_dir / run_id @@ -110,10 +112,10 @@ def _write_run( "env_templates": {"demo-env-template": {"module": env_id}}, "summary": { "total": 1, - "completed": 1, + "completed": 1 if status == "completed" else 0, "pending": 0, "running": 0, - "failed": 0, + "failed": 1 if status == "failed" else 0, "skipped": 0, }, "jobs": [ @@ -126,6 +128,7 @@ def _write_run( "env_args": {}, "results_dir": "demo-job", } + "status": status, ], } _write_json(run_dir / "run_manifest.json", manifest) @@ -137,8 +140,10 @@ def _write_run( _write_json(results_dir / "metadata.json", metadata) results_path = results_dir / "results.jsonl" results_path.parent.mkdir(parents=True, exist_ok=True) - row = {"example_id": f"ex-{run_id}", "reward": reward} - results_path.write_text(json.dumps(row) + "\n", encoding="utf-8") + if results_text is None: + row = {"example_id": f"ex-{run_id}", "reward": reward} + results_text = json.dumps(row) + "\n" + results_path.write_text(results_text, encoding="utf-8") return runs_dir @@ -540,6 +545,62 @@ def test_process_latest_only_selects_latest_and_delta_skips(tmp_path: Path) -> N assert result_repeat.env_summaries == [] assert result_repeat.rows_processed == 0 +def test_process_ignores_invalid_superseded_run(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-1", + updated_at="2024-01-01T00:00:00Z", + reward=0.1, + results_text='{"example_id": ', + ) + _write_run(tmp_path, run_id="run-2", updated_at="2024-01-02T00:00:00Z", reward=0.9) + output_dir = tmp_path / "processed" + + result = run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) + + assert result.env_summaries + table = pq.read_table(result.env_summaries[0].output_path) + assert table.column("reward").to_pylist() == [0.9] + + +def test_process_ignores_invalid_incomplete_run_by_default(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-1", + updated_at="2024-01-01T00:00:00Z", + reward=0.1, + status="running", + results_text='{"example_id": ', + ) + _write_run(tmp_path, run_id="run-2", updated_at="2024-01-02T00:00:00Z", reward=0.9, env_id="other-env") + output_dir = tmp_path / "processed" + + result = run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) + + assert {summary.env_id for summary in result.env_summaries} == {"other-env"} + + +def test_process_selected_invalid_results_still_fail(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-1", + updated_at="2024-01-01T00:00:00Z", + reward=0.1, + results_text='{"example_id": ', + ) + + with pytest.raises(ValueError, match="Failed to parse JSONL line 1"): + run_process(ProcessOptions(runs_dir=runs_dir, output_dir=tmp_path / "processed", dry_run=False, max_workers=1)) + + +def test_process_selected_missing_results_still_fail(tmp_path: Path) -> None: + runs_dir = _setup_run(tmp_path) + missing_results = runs_dir / "run-1" / "demo-job" / "results.jsonl" + missing_results.unlink() + + with pytest.raises(FileNotFoundError, match="Missing results.jsonl"): + run_process(ProcessOptions(runs_dir=runs_dir, output_dir=tmp_path / "processed", dry_run=False, max_workers=1)) + def test_process_clean_clears_outputs(tmp_path: Path) -> None: runs_dir = _setup_run(tmp_path) From 2eac1affbf8c657ac04e2074c4bc54bfb804db3b Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sat, 28 Feb 2026 13:35:22 -0500 Subject: [PATCH 03/29] Simplify process selection and delta handling --- medarc_verifiers/cli/process/aggregate.py | 53 +- medarc_verifiers/cli/process/env_index.py | 17 +- medarc_verifiers/cli/process/metadata.py | 35 +- medarc_verifiers/cli/process/pipeline.py | 598 ++++++++++++---------- medarc_verifiers/cli/process/rows.py | 11 +- tests/test_cli/test_process_pipeline.py | 149 +++++- 6 files changed, 557 insertions(+), 306 deletions(-) diff --git a/medarc_verifiers/cli/process/aggregate.py b/medarc_verifiers/cli/process/aggregate.py index b00d5ff2..7d7b1b68 100644 --- a/medarc_verifiers/cli/process/aggregate.py +++ b/medarc_verifiers/cli/process/aggregate.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Any, Iterable, Mapping +from medarc_verifiers.cli.process.metadata import RunIdentity from medarc_verifiers.cli.process.rollout import derive_base_env_id logger = logging.getLogger(__name__) @@ -25,9 +26,17 @@ class AggregatedEnvRows: def aggregate_rows_by_env( rows: Iterable[Mapping[str, Any]], + *, + identities: Iterable[RunIdentity] | None = None, ) -> list[AggregatedEnvRows]: """Group enriched rows by (model_id, base_env_id), capturing unioned schemas.""" groups: dict[tuple[str, str], dict[str, Any]] = {} + identity_list = list(identities or ()) + fake_rollout_groups = { + (identity.model_id, identity.output_env_id) + for identity in identity_list + if identity.rollout_index is not None + } for row in rows: base_env_id = str(row.get("base_env_id") or row.get("env_id") or "") @@ -68,7 +77,15 @@ def aggregate_rows_by_env( # processing "fake rollouts" that are created by running separate jobs with rollout suffixes # (e.g., env-a-rollout7) and then combining them under a shared base_env_id. normalized_rows: list[Mapping[str, Any]] = list(group["rows"]) # shallow copy - if _group_uses_rollout_suffixes(normalized_rows, base_env_id=group["base_env_id"] or key[1]): + if key in fake_rollout_groups: + _ensure_rollout_index_from_identities( + normalized_rows, + identities=identity_list, + model_id=group["model_id"], + base_env_id=group["base_env_id"] or key[1], + ) + _normalize_rollout_indices(normalized_rows) + elif _group_uses_rollout_suffixes(normalized_rows, base_env_id=group["base_env_id"] or key[1]): _ensure_rollout_index_from_suffix(normalized_rows, base_env_id=group["base_env_id"] or key[1]) _normalize_rollout_indices(normalized_rows) candidate_env_id = group["env_id"] or group["base_env_id"] or "" @@ -85,6 +102,40 @@ def aggregate_rows_by_env( return aggregated +def _ensure_rollout_index_from_identities( + rows: list[Mapping[str, Any]], + *, + identities: list[RunIdentity], + model_id: str, + base_env_id: str, +) -> None: + rollout_by_manifest_env: dict[str, int] = {} + for identity in identities: + if identity.model_id != model_id or identity.output_env_id != base_env_id: + continue + if identity.rollout_index is None: + continue + rollout_by_manifest_env[identity.manifest_env_id] = identity.rollout_index + + if not rollout_by_manifest_env: + return + + for row in rows: + value = row.get("rollout_index") + if _coerce_rollout_index(value) is not None: + continue + manifest_env_id = row.get("manifest_env_id") + if not isinstance(manifest_env_id, str): + continue + resolved = rollout_by_manifest_env.get(manifest_env_id) + if resolved is None: + continue + try: + row["rollout_index"] = resolved + except TypeError: + continue + + def _group_uses_rollout_suffixes(rows: list[Mapping[str, Any]], *, base_env_id: str) -> bool: for row in rows: manifest_env_id = row.get("manifest_env_id") diff --git a/medarc_verifiers/cli/process/env_index.py b/medarc_verifiers/cli/process/env_index.py index 89c85c37..86fecd50 100644 --- a/medarc_verifiers/cli/process/env_index.py +++ b/medarc_verifiers/cli/process/env_index.py @@ -54,21 +54,9 @@ def read_env_index_inventory(processed_dir: Path) -> EnvIndexInventory: """Read env_index.json and return a dataset inventory.""" index_path = processed_dir / "env_index.json" payload = load_env_index(index_path) - version = payload.get("version") if isinstance(payload, Mapping) else None - if version == 2: + if isinstance(payload, Mapping) and int(payload.get("version") or 0) == 2: return _inventory_from_v2(payload, processed_dir) - return EnvIndexInventory(env_paths={}, version=int(version or 1)) - - -def read_env_index_runs(processed_dir: Path) -> tuple[int, dict[str, Mapping[str, Any]]]: - """Return env_index version and run metadata map.""" - index_path = processed_dir / "env_index.json" - payload = load_env_index(index_path) - version = int(payload.get("version") or 1) if isinstance(payload, Mapping) else 1 - runs = payload.get("runs") if isinstance(payload, Mapping) else None - if version != 2 or not isinstance(runs, Mapping): - return version, {} - return version, {str(k): v for k, v in runs.items() if isinstance(v, Mapping)} + return EnvIndexInventory(env_paths={}, version=0) def read_env_index_files(processed_dir: Path) -> dict[str, Mapping[str, Any]]: @@ -118,7 +106,6 @@ def read_env_index_models(processed_dir: Path) -> set[str]: __all__ = [ "EnvIndexInventory", "read_env_index_inventory", - "read_env_index_runs", "read_env_index_files", "read_env_index_models", ] diff --git a/medarc_verifiers/cli/process/metadata.py b/medarc_verifiers/cli/process/metadata.py index 6bfae643..99f5fd72 100644 --- a/medarc_verifiers/cli/process/metadata.py +++ b/medarc_verifiers/cli/process/metadata.py @@ -32,6 +32,7 @@ class _MetadataPayload(BaseModel): class NormalizedMetadata: """Normalized view of metadata.json merged with manifest discovery data.""" + identity: "RunIdentity" record: RunRecord metadata_path: Path | None raw_metadata: Mapping[str, Any] @@ -47,6 +48,18 @@ class NormalizedMetadata: rollouts_per_example: int | None +@dataclass(frozen=True, slots=True) +class RunIdentity: + """Canonical identity for selecting and exporting a discovered run record.""" + + model_id: str + manifest_env_id: str + base_env_id: str + rollout_index: int | None + job_run_id: str + output_env_id: str + + def load_normalized_metadata( record: RunRecord, *, @@ -81,20 +94,36 @@ def load_normalized_metadata( rollout_index = alt_index model_id = record.model_id or metadata_model + if not model_id: + raise RuntimeError( + "Missing model_id for run " + f"(job_run_id={record.manifest.job_run_id}, job_id={record.job_id}, " + f"results_dir={record.results_dir}, manifest={record.manifest.manifest_path})" + ) + resolved_rollout_index = rollout_index if rollout_index != 0 or manifest_env_id != base_env_id else None + identity = RunIdentity( + model_id=model_id, + manifest_env_id=manifest_env_id, + base_env_id=base_env_id, + rollout_index=resolved_rollout_index, + job_run_id=record.manifest.job_run_id, + output_env_id=base_env_id or manifest_env_id or record.job_id, + ) num_examples = record.num_examples or (metadata_payload.num_examples if metadata_payload else None) rollouts_per_example = record.rollouts_per_example or ( metadata_payload.rollouts_per_example if metadata_payload else None ) return NormalizedMetadata( + identity=identity, record=record, metadata_path=record.metadata_path if record.has_metadata else None, raw_metadata=raw_metadata, manifest_env_id=manifest_env_id, metadata_env_id=metadata_env_id, base_env_id=base_env_id, - rollout_index=rollout_index, - model_id=model_id, + rollout_index=identity.rollout_index or 0, + model_id=identity.model_id, metadata_model=metadata_model, env_args=env_args, sampling_args=sampling_args, @@ -164,4 +193,4 @@ def _extract_env_config_id(env_config: Mapping[str, Any] | None) -> str | None: return None -__all__ = ["NormalizedMetadata", "load_normalized_metadata"] +__all__ = ["NormalizedMetadata", "RunIdentity", "load_normalized_metadata"] diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index 36609ae5..a9dbdc46 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -1,4 +1,4 @@ -"""Top-level pipeline wiring discovery, row loading, aggregation, and writing.""" +"""Top-level pipeline wiring discovery, selection, row loading, aggregation, and writing.""" from __future__ import annotations @@ -10,20 +10,14 @@ from pathlib import Path from typing import Any, Callable, Iterable, Mapping, Sequence -from medarc_verifiers.cli._schemas import EnvironmentExportConfig +import pyarrow.parquet as pq + from medarc_verifiers.cli import hf as hf_sync -from medarc_verifiers.cli.process import ( - aggregate, - discovery, - env_index, - metadata, - rows, - rollout, - writer, - workspace, -) -from medarc_verifiers.cli.process.aggregate import AggregatedEnvRows +from medarc_verifiers.cli._schemas import EnvironmentExportConfig from medarc_verifiers.cli.hf import HFSyncConfig, HFSyncSummary +from medarc_verifiers.cli.process import aggregate, discovery, env_index, metadata, rollout, rows, workspace, writer +from medarc_verifiers.cli.process.aggregate import AggregatedEnvRows +from medarc_verifiers.cli.process.metadata import RunIdentity from medarc_verifiers.cli.process.writer import EnvWriteSummary, WriterConfig from medarc_verifiers.cli.utils.shared import ( dataset_is_excluded, @@ -44,6 +38,8 @@ class ProcessOptions: only_complete_runs: bool = True exclude_datasets: Sequence[str] = field(default_factory=tuple) exclude_models: Sequence[str] = field(default_factory=tuple) + replace_models: Sequence[str] = field(default_factory=tuple) + replace_envs: Sequence[str] = field(default_factory=tuple) processed_at: str | None = None processed_with_args: Mapping[str, Any] = field(default_factory=dict) status_filter: Sequence[str] = field(default_factory=tuple) @@ -63,6 +59,8 @@ def __post_init__(self) -> None: self.status_filter = tuple(str(status) for status in self.status_filter) self.exclude_datasets = tuple(str(value) for value in self.exclude_datasets if str(value).strip()) self.exclude_models = tuple(str(value) for value in self.exclude_models if str(value).strip()) + self.replace_models = tuple(str(value) for value in self.replace_models if str(value).strip()) + self.replace_envs = tuple(str(value) for value in self.replace_envs if str(value).strip()) @dataclass(slots=True) @@ -76,8 +74,8 @@ class ProcessResult: hf_summary: HFSyncSummary | None -@dataclass(slots=True) -class _RecordWork: +@dataclass(frozen=True, slots=True) +class PlannedRecord: """Per-record settings for row loading.""" normalized: metadata.NormalizedMetadata @@ -86,26 +84,24 @@ class _RecordWork: answer_column: str | None -@dataclass(slots=True) -class _NormalizedRecord: - record: discovery.RunRecord - normalized: metadata.NormalizedMetadata - extra_columns: Sequence[str] - drop_columns: Sequence[str] - answer_column: str | None - model_key: str - env_key: str - job_run_id: str - run_timestamp: str +@dataclass(frozen=True, slots=True) +class PlannedWorkItem: + """A single selected (model, env) output to process.""" + identity: RunIdentity + records: list[PlannedRecord] + env_export_config: EnvironmentExportConfig -@dataclass(slots=True) -class _EnvGroupSelection: - model_key: str - env_key: str - job_run_id: str - run_timestamp: str - records: list[_NormalizedRecord] + +@dataclass(frozen=True, slots=True) +class SelectionResult: + """Complete output of the selection phase.""" + + work_items: list[PlannedWorkItem] + skipped_incomplete: int + skipped_by_delta: int + skipped_by_exclusion: int + total_discovered: int def run_process( @@ -134,84 +130,42 @@ def _run_pipeline() -> ProcessResult: prompt_func=input, ) - index_version, index_runs = env_index.read_env_index_runs(options.output_dir) - index_files = env_index.read_env_index_files(options.output_dir) - if options.clean: - index_version = 0 - index_runs = {} - index_files = {} - + index_files = {} if options.clean else env_index.read_env_index_files(options.output_dir) discovered = discovery.discover_run_records( options.runs_dir, filter_status=options.status_filter or None, only_complete_runs=False, ) - - use_delta = index_version == 2 and not options.clean - if index_version != 2 and not options.clean: - logger.info("Delta processing disabled: missing or legacy env_index.json; running full reprocess.") - records: list[discovery.RunRecord] = list(discovered) - if options.only_complete_runs: - records = [ - record - for record in records - if not ( - record.manifest.summary_total_known - and record.manifest.summary_completed != record.manifest.summary_total - ) - ] - normalized_records = _normalize_records(records, env_export_map) - env_groups = _select_latest_env_groups(normalized_records) - if use_delta: - env_groups = _filter_env_groups_by_delta( - env_groups, - index_runs, - index_files, - output_dir=options.output_dir, - ) - if options.exclude_datasets: - env_groups = _filter_env_groups_by_exclusion(env_groups, options.exclude_datasets) - if options.exclude_models: - env_groups = _filter_env_groups_by_model_exclusion(env_groups, options.exclude_models) - records = [item.record for group in env_groups for item in group.records] - + selection = select_work_items( + discovered, + options=options, + env_export_map=env_export_map, + index_files=index_files, + ) + selected_records = [planned.normalized.record for item in selection.work_items for planned in item.records] _print_records_table( discovered, - records, + selected_records, options.only_complete_runs, exclude_datasets=options.exclude_datasets, exclude_models=options.exclude_models, + skipped_incomplete=selection.skipped_incomplete, + skipped_by_delta=selection.skipped_by_delta, + skipped_by_exclusion=selection.skipped_by_exclusion, ) - grouped: dict[tuple[str, str], list[_RecordWork]] = {} run_metadata: dict[str, dict[str, Any]] = {} - record_items = [item for group in env_groups for item in group.records] - record_iter: Iterable[_NormalizedRecord] = record_items - try: - from rich.progress import track - - record_iter = track(record_items, description="Reading run outputs", transient=True) - except Exception: - pass - - for record in record_iter: - normalized = record.normalized - grouped.setdefault((record.model_key, record.env_key), []).append( - _RecordWork( - normalized=normalized, - extra_columns=record.extra_columns, - drop_columns=record.drop_columns, - answer_column=record.answer_column, + for item in selection.work_items: + for planned in item.records: + record = planned.normalized.record + run_metadata.setdefault( + record.manifest.job_run_id, + { + "created_at": record.manifest.created_at, + "updated_at": _source_updated_at(record), + "config_checksum": record.manifest.config_checksum, + }, ) - ) - run_metadata.setdefault( - record.job_run_id, - { - "created_at": record.record.manifest.created_at, - "updated_at": _source_updated_at(record.record), - "config_checksum": record.record.manifest.config_checksum, - }, - ) writer_config = WriterConfig( output_dir=options.output_dir, @@ -223,20 +177,20 @@ def _run_pipeline() -> ProcessResult: env_groups: list[AggregatedEnvRows] = [] env_summaries: list[EnvWriteSummary] = [] rows_processed = 0 + work_items = sorted(selection.work_items, key=lambda item: (item.identity.model_id, item.identity.output_env_id)) - env_items = sorted(grouped.items()) try: - if options.max_workers <= 1 or len(env_items) <= 1: - env_iter: Iterable[tuple[tuple[str, str], list[_RecordWork]]] = env_items + if options.max_workers <= 1 or len(work_items) <= 1: + work_iter: Iterable[PlannedWorkItem] = work_items try: from rich.progress import track - env_iter = track(env_items, description="Processing datasets", transient=True) + work_iter = track(work_items, description="Processing datasets", transient=True) except Exception: - env_iter = env_items + work_iter = work_items - for _, work_items in env_iter: - aggregated, row_count = _process_env_group(work_items) + for item in work_iter: + aggregated, row_count = _process_env_group(item) rows_processed += row_count env_groups.extend(aggregated) summaries = writer.write_env_groups(aggregated, writer_config, write_index=False) @@ -249,8 +203,8 @@ def _run_pipeline() -> ProcessResult: futures = [] try: executor = ProcessPoolExecutor(max_workers=options.max_workers) - for _, work_items in env_items: - futures.append(executor.submit(_process_env_group, work_items)) + for item in work_items: + futures.append(executor.submit(_process_env_group, item)) future_iter: Iterable[Any] = as_completed(futures) try: @@ -273,8 +227,8 @@ def _run_pipeline() -> ProcessResult: group.rows.clear() except KeyboardInterrupt: logger.warning("Processing cancelled by user; shutting down workers.") - for f in futures: - f.cancel() + for future in futures: + future.cancel() if executor is not None: executor.shutdown(cancel_futures=True) raise @@ -310,7 +264,7 @@ def _run_pipeline() -> ProcessResult: env_groups = [_strip_env_group_rows(group) for group in env_groups] return ProcessResult( - records_processed=len(records), + records_processed=len(selected_records), rows_processed=rows_processed, env_groups=env_groups, env_summaries=env_summaries, @@ -323,24 +277,231 @@ def _run_pipeline() -> ProcessResult: return _run_pipeline() +def select_work_items( + discovered: Sequence[discovery.RunRecord], + *, + options: ProcessOptions, + env_export_map: Mapping[str, EnvironmentExportConfig], + index_files: Mapping[str, Mapping[str, Any]], +) -> SelectionResult: + """Filter discovered runs down to selected work items before row loading begins.""" + eligible_records: list[discovery.RunRecord] = [] + skipped_incomplete = 0 + for record in discovered: + if options.only_complete_runs and not _manifest_is_complete(record.manifest): + skipped_incomplete += 1 + continue + eligible_records.append(record) + + planned_records = [_plan_record(record, env_export_map) for record in eligible_records] + work_items = _select_latest_work_items(planned_records) + + work_items, skipped_by_exclusion = _apply_exclusions( + work_items, + exclude_datasets=options.exclude_datasets, + exclude_models=options.exclude_models, + ) + _validate_replace_targets(work_items, options) + work_items, skipped_by_delta = _apply_additive_delta(work_items, options=options, index_files=index_files) + + return SelectionResult( + work_items=work_items, + skipped_incomplete=skipped_incomplete, + skipped_by_delta=skipped_by_delta, + skipped_by_exclusion=skipped_by_exclusion, + total_discovered=len(discovered), + ) + + def _resolve_env_export( manifest_env_id: str | None, env_export_map: Mapping[str, EnvironmentExportConfig], -) -> EnvironmentExportConfig | None: +) -> EnvironmentExportConfig: if not manifest_env_id: - return None + return EnvironmentExportConfig() if manifest_env_id in env_export_map: return env_export_map[manifest_env_id] base_env_id, _ = rollout.derive_base_env_id(manifest_env_id) if base_env_id and base_env_id in env_export_map: return env_export_map[base_env_id] - return None + return EnvironmentExportConfig() def _resolve_columns(env_columns: Sequence[str]) -> Sequence[str]: return tuple(str(column).strip() for column in env_columns if str(column).strip()) +def _plan_record( + record: discovery.RunRecord, + env_export_map: Mapping[str, EnvironmentExportConfig], +) -> PlannedRecord: + env_export = _resolve_env_export(record.manifest_env_id, env_export_map) + normalized = metadata.load_normalized_metadata(record, combine_rollouts=bool(env_export.combine_rollouts)) + return PlannedRecord( + normalized=normalized, + extra_columns=_resolve_columns(env_export.extra_columns), + drop_columns=_resolve_columns(env_export.drop_columns), + answer_column=env_export.answer_column, + ) + + +def _select_latest_work_items(records: Sequence[PlannedRecord]) -> list[PlannedWorkItem]: + grouped: dict[tuple[str, str], dict[str, list[PlannedRecord]]] = {} + run_timestamps: dict[str, str] = {} + + for planned in records: + identity = planned.normalized.identity + group_key = (identity.model_id, identity.output_env_id) + grouped.setdefault(group_key, {}).setdefault(identity.job_run_id, []).append(planned) + run_timestamps.setdefault(identity.job_run_id, _source_updated_at(planned.normalized.record)) + + selected: list[PlannedWorkItem] = [] + for _, run_groups in grouped.items(): + latest_run_id = max(run_groups.keys(), key=lambda run_id: _run_sort_key(run_timestamps.get(run_id, ""), run_id)) + latest_records = run_groups[latest_run_id] + representative = latest_records[0] + selected.append( + PlannedWorkItem( + identity=representative.normalized.identity, + records=list(latest_records), + env_export_config=EnvironmentExportConfig(), + ) + ) + return selected + + +def _apply_exclusions( + work_items: Sequence[PlannedWorkItem], + *, + exclude_datasets: Sequence[str], + exclude_models: Sequence[str], +) -> tuple[list[PlannedWorkItem], int]: + exclude_dataset_set = normalize_dataset_ids(exclude_datasets, label="process exclude dataset") + exclude_model_set = normalize_model_ids(exclude_models, label="process exclude model") + filtered: list[PlannedWorkItem] = [] + skipped = 0 + for item in work_items: + if exclude_dataset_set and _env_is_excluded(item.identity.output_env_id, exclude_dataset_set): + skipped += 1 + continue + if exclude_model_set and model_is_excluded(item.identity.model_id, exclude_model_set): + skipped += 1 + continue + filtered.append(item) + return filtered, skipped + + +def _validate_replace_targets(work_items: Sequence[PlannedWorkItem], options: ProcessOptions) -> None: + if not options.replace_models and not options.replace_envs: + return + + if options.replace_models: + matched_models = {item.identity.model_id for item in work_items if item.identity.model_id in options.replace_models} + if not matched_models: + raise RuntimeError( + "No selected processed outputs match --replace-model values: " + f"{', '.join(sorted(options.replace_models))}." + ) + if options.replace_envs: + matched_envs = {item.identity.output_env_id for item in work_items if item.identity.output_env_id in options.replace_envs} + if not matched_envs: + raise RuntimeError( + "No selected processed outputs match --replace-env values: " + f"{', '.join(sorted(options.replace_envs))}." + ) + if options.replace_models and options.replace_envs: + intersection = [ + item + for item in work_items + if item.identity.model_id in options.replace_models and item.identity.output_env_id in options.replace_envs + ] + if not intersection: + raise RuntimeError( + "No selected processed outputs match the intersection of --replace-model and --replace-env." + ) + + +def _apply_additive_delta( + work_items: Sequence[PlannedWorkItem], + *, + options: ProcessOptions, + index_files: Mapping[str, Mapping[str, Any]], +) -> tuple[list[PlannedWorkItem], int]: + if options.clean: + return list(work_items), 0 + + filtered: list[PlannedWorkItem] = [] + skipped = 0 + for item in work_items: + output_path = writer.build_output_path( + options.output_dir, + model_id=item.identity.model_id, + env_id=item.identity.output_env_id, + ) + if not output_path.exists(): + filtered.append(item) + continue + if _should_replace_existing_output(item.identity, options): + filtered.append(item) + continue + _validate_existing_output_integrity(output_path, output_dir=options.output_dir, index_files=index_files) + skipped += 1 + return filtered, skipped + + +def _should_replace_existing_output(identity: RunIdentity, options: ProcessOptions) -> bool: + if options.clean: + return True + has_model_filter = bool(options.replace_models) + has_env_filter = bool(options.replace_envs) + if not has_model_filter and not has_env_filter: + return False + if has_model_filter and has_env_filter: + return identity.model_id in options.replace_models and identity.output_env_id in options.replace_envs + if has_model_filter: + return identity.model_id in options.replace_models + return identity.output_env_id in options.replace_envs + + +def _validate_existing_output_integrity( + output_path: Path, + *, + output_dir: Path, + index_files: Mapping[str, Mapping[str, Any]], +) -> None: + try: + metadata_obj = pq.ParquetFile(output_path).metadata + except Exception as exc: # noqa: BLE001 + raise RuntimeError( + f"Existing processed output {output_path} is unreadable. " + "Rebuild it with --replace-model/--replace-env or re-run with --clean." + ) from exc + + if metadata_obj is None: + raise RuntimeError( + f"Existing processed output {output_path} is missing parquet footer metadata. " + "Rebuild it with --replace-model/--replace-env or re-run with --clean." + ) + + rel_key = output_path.relative_to(output_dir).as_posix() + index_entry = index_files.get(rel_key) + if not isinstance(index_entry, Mapping): + return + expected_row_count = index_entry.get("row_count") + if expected_row_count is None: + return + try: + expected = int(expected_row_count) + except (TypeError, ValueError): + return + actual = int(metadata_obj.num_rows) + if actual != expected: + raise RuntimeError( + f"Existing processed output {output_path} has {actual} parquet rows but env_index.json records {expected}. " + "Rebuild it with --replace-model/--replace-env or re-run with --clean." + ) + + def _print_records_table( discovered: Sequence[discovery.RunRecord], selected: Sequence[discovery.RunRecord], @@ -348,6 +509,9 @@ def _print_records_table( *, exclude_datasets: Sequence[str] = (), exclude_models: Sequence[str] = (), + skipped_incomplete: int = 0, + skipped_by_delta: int = 0, + skipped_by_exclusion: int = 0, ) -> None: """Pretty-print job discovery vs planned processing.""" exclude_set = normalize_dataset_ids(exclude_datasets, label="process exclude dataset") @@ -373,47 +537,53 @@ def _print_records_table( selected_by_model[model_id] = selected_by_model.get(model_id, 0) + 1 models = sorted(set(total_by_model.keys()) | set(selected_by_model.keys())) - selected_models = sorted(m for m, c in selected_by_model.items() if c > 0) - discovered_jobs_total = sum(total_by_model.get(m, 0) for m in models) - selected_jobs_total = sum(selected_by_model.get(m, 0) for m in models) + selected_models = sorted(model_id for model_id, count in selected_by_model.items() if count > 0) + discovered_jobs_total = sum(total_by_model.get(model_id, 0) for model_id in models) + selected_jobs_total = sum(selected_by_model.get(model_id, 0) for model_id in models) try: from rich.console import Console - from rich.table import Table from rich.markup import escape + from rich.table import Table except Exception: suffix = " (complete runs only)" if only_complete_runs else "" logger.info( - "Processing %d job(s) across %d model(s)%s (found %d job(s) across %d model(s)).", + "Processing %d job(s) across %d model(s)%s (found %d job(s) across %d model(s)); " + "skipped incomplete=%d excluded=%d existing=%d.", selected_jobs_total, len(selected_models), suffix, discovered_jobs_total, len(models), + skipped_incomplete, + skipped_by_exclusion, + skipped_by_delta, ) for model_id in models: - comp = completed_by_model.get(model_id, 0) - tot = total_by_model.get(model_id, 0) - sel = selected_by_model.get(model_id, 0) - logger.info(" - %s: selected=%d; %d/%d completed", model_id, sel, comp, tot) + completed = completed_by_model.get(model_id, 0) + total = total_by_model.get(model_id, 0) + selected_count = selected_by_model.get(model_id, 0) + logger.info(" - %s: selected=%d; %d/%d completed", model_id, selected_count, completed, total) return console = Console() title = f"Processing {selected_jobs_total} job(s) across {len(selected_models)} model(s)" if only_complete_runs: title += " (complete runs only)" - found_suffix = "after filters" if (exclude_set or only_complete_runs) else "pre-aggregation" - title += f" [dim](found {discovered_jobs_total} job(s) across {len(models)} model(s); {found_suffix})[/dim]" + title += ( + f" [dim](found {discovered_jobs_total} eligible job(s); skipped incomplete={skipped_incomplete}, " + f"excluded={skipped_by_exclusion}, existing={skipped_by_delta})[/dim]" + ) table = Table(title=title, show_header=True, header_style="bold cyan", caption=None) table.add_column("Model", style="magenta") table.add_column("Jobs (completed/total)", style="green", justify="right") table.add_column("Selected", style="cyan", justify="right") for model_id in models: - comp = completed_by_model.get(model_id, 0) - tot = total_by_model.get(model_id, 0) - sel = selected_by_model.get(model_id, 0) - table.add_row(escape(str(model_id)), f"{comp}/{tot}", str(sel)) + completed = completed_by_model.get(model_id, 0) + total = total_by_model.get(model_id, 0) + selected_count = selected_by_model.get(model_id, 0) + table.add_row(escape(str(model_id)), f"{completed}/{total}", str(selected_count)) console.print(table) @@ -434,29 +604,23 @@ def _record_is_excluded(record: discovery.RunRecord, exclude_set: set[str]) -> b def _record_model_is_excluded(record: discovery.RunRecord, exclude_model_set: set[str]) -> bool: - model_id = str(record.model_id or "").strip() - return model_is_excluded(model_id, exclude_model_set) - - -__all__ = ["ProcessOptions", "ProcessResult", "run_process"] + return model_is_excluded(str(record.model_id or "").strip(), exclude_model_set) -def _process_env_group( - work_items: Sequence[_RecordWork], -) -> tuple[list[AggregatedEnvRows], int]: - """Load and aggregate all rows for a single environment.""" +def _process_env_group(item: PlannedWorkItem) -> tuple[list[AggregatedEnvRows], int]: + """Load and aggregate all rows for a single selected dataset.""" row_buffer: list[dict[str, Any]] = [] - for work in work_items: + identities: list[RunIdentity] = [] + for planned in item.records: row_batch = rows.load_rows( - work.normalized, - extra_columns=work.extra_columns, - drop_columns=work.drop_columns, - answer_column=work.answer_column, + planned.normalized, + extra_columns=planned.extra_columns, + drop_columns=planned.drop_columns, + answer_column=planned.answer_column, ) row_buffer.extend(row_batch) - aggregated = aggregate.aggregate_rows_by_env( - row_buffer, - ) + identities.append(planned.normalized.identity) + aggregated = aggregate.aggregate_rows_by_env(row_buffer, identities=identities) return aggregated, len(row_buffer) @@ -464,81 +628,12 @@ def _source_updated_at(record: discovery.RunRecord) -> str: return record.manifest.updated_at or record.manifest.created_at or "" -def _filter_env_groups_by_delta( - env_groups: Sequence[_EnvGroupSelection], - index_runs: Mapping[str, Mapping[str, Any]], - index_files: Mapping[str, Mapping[str, Any]], - *, - output_dir: Path, -) -> list[_EnvGroupSelection]: - filtered: list[_EnvGroupSelection] = [] - for group in env_groups: - expected_path = writer.build_output_path(output_dir, model_id=group.model_key, env_id=group.env_key) - expected_rel = expected_path.relative_to(output_dir).as_posix() - prior_file = index_files.get(expected_rel, {}) - if not prior_file: - filtered.append(group) - continue - prior_updated_at = str(prior_file.get("updated_at") or prior_file.get("created_at") or "") - if group.job_run_id not in index_runs: - filtered.append(group) - continue - if _is_newer_timestamp(group.run_timestamp, prior_updated_at): - filtered.append(group) - continue - return filtered - - -def _filter_env_groups_by_exclusion( - env_groups: Sequence[_EnvGroupSelection], - exclude_datasets: Sequence[str], -) -> list[_EnvGroupSelection]: - exclude_set = normalize_dataset_ids(exclude_datasets, label="process exclude dataset") - if not exclude_set: - return list(env_groups) - filtered: list[_EnvGroupSelection] = [] - for group in env_groups: - if _env_is_excluded(str(group.env_key or ""), exclude_set): - continue - filtered.append(group) - return filtered - - -def _filter_env_groups_by_model_exclusion( - env_groups: Sequence[_EnvGroupSelection], - exclude_models: Sequence[str], -) -> list[_EnvGroupSelection]: - exclude_set = normalize_model_ids(exclude_models, label="process exclude model") - if not exclude_set: - return list(env_groups) - filtered: list[_EnvGroupSelection] = [] - for group in env_groups: - model_id = str(group.model_key or "").strip() - if model_is_excluded(model_id, exclude_set): - continue - filtered.append(group) - return filtered - - def _env_is_excluded(env_id: str, exclude_set: set[str]) -> bool: env_identifier = str(env_id or "").strip() base_env_id, _ = rollout.derive_base_env_id(env_identifier) return dataset_is_excluded(env_identifier, exclude_set, base_dataset_id=base_env_id) -def _is_newer_timestamp(current: str, prior: str) -> bool: - if not prior: - return True if current else False - if not current: - return False - try: - current_dt = datetime.fromisoformat(current.replace("Z", "+00:00")) - prior_dt = datetime.fromisoformat(prior.replace("Z", "+00:00")) - except Exception: - return current != prior - return current_dt > prior_dt - - def _strip_env_group_rows(group: AggregatedEnvRows) -> AggregatedEnvRows: return AggregatedEnvRows( env_id=group.env_id, @@ -550,72 +645,6 @@ def _strip_env_group_rows(group: AggregatedEnvRows) -> AggregatedEnvRows: ) -def _normalize_records( - records: Sequence[discovery.RunRecord], - env_export_map: Mapping[str, EnvironmentExportConfig], -) -> list[_NormalizedRecord]: - normalized_records: list[_NormalizedRecord] = [] - for record in records: - env_export = _resolve_env_export(record.manifest_env_id, env_export_map) - extra_columns = _resolve_columns(env_export.extra_columns if env_export else ()) - drop_columns = _resolve_columns(env_export.drop_columns if env_export else ()) - answer_column = env_export.answer_column if env_export else None - - normalized = metadata.load_normalized_metadata(record) - model_id = normalized.model_id - if not model_id: - raise RuntimeError( - "Missing model_id for run " - f"(job_run_id={record.manifest.job_run_id}, job_id={record.job_id}, " - f"results_dir={record.results_dir}, manifest={record.manifest.manifest_path})" - ) - - env_key = normalized.base_env_id or normalized.manifest_env_id or record.manifest_env_id or record.job_id - normalized_records.append( - _NormalizedRecord( - record=record, - normalized=normalized, - extra_columns=extra_columns, - drop_columns=drop_columns, - answer_column=answer_column, - model_key=model_id, - env_key=env_key, - job_run_id=record.manifest.job_run_id, - run_timestamp=_source_updated_at(record), - ) - ) - return normalized_records - - -def _select_latest_env_groups( - records: Sequence[_NormalizedRecord], -) -> list[_EnvGroupSelection]: - env_groups: dict[tuple[str, str], dict[str, list[_NormalizedRecord]]] = {} - run_timestamps: dict[str, str] = {} - for record in records: - env_groups.setdefault((record.model_key, record.env_key), {}).setdefault(record.job_run_id, []).append(record) - run_timestamps.setdefault(record.job_run_id, record.run_timestamp) - - selected: list[_EnvGroupSelection] = [] - for (model_key, env_key), run_groups in env_groups.items(): - if not run_groups: - continue - latest_run_id = max( - run_groups.keys(), - key=lambda run_id: _run_sort_key(run_timestamps.get(run_id, ""), run_id), - ) - selected.append( - _EnvGroupSelection( - model_key=model_key, - env_key=env_key, - job_run_id=latest_run_id, - run_timestamp=run_timestamps.get(latest_run_id, ""), - records=run_groups[latest_run_id], - ) - ) - return selected - - def _run_sort_key(timestamp: str, job_run_id: str) -> tuple[int, datetime, str]: if not timestamp: return (0, datetime.min.replace(tzinfo=UTC), job_run_id) @@ -644,3 +673,14 @@ def _confirm_clean_process( raise RuntimeError("Aborted clean process.") from None if response != "clean": raise RuntimeError("Aborted clean process.") + + +__all__ = [ + "PlannedRecord", + "PlannedWorkItem", + "ProcessOptions", + "ProcessResult", + "SelectionResult", + "run_process", + "select_work_items", +] diff --git a/medarc_verifiers/cli/process/rows.py b/medarc_verifiers/cli/process/rows.py index 80d27264..2b4f35f3 100644 --- a/medarc_verifiers/cli/process/rows.py +++ b/medarc_verifiers/cli/process/rows.py @@ -213,19 +213,18 @@ def _attach_metadata( version_info_json: str | None, ) -> MutableMapping[str, Any]: record = metadata.record + identity = metadata.identity error_value = record.reason if record.status == "failed" else None - env_identifier = metadata.base_env_id or metadata.manifest_env_id - row.update( { - "env_id": env_identifier, - "manifest_env_id": metadata.manifest_env_id, - "base_env_id": metadata.base_env_id, + "env_id": identity.output_env_id, + "manifest_env_id": identity.manifest_env_id, + "base_env_id": identity.base_env_id, "job_run_id": record.manifest.job_run_id, "run_id": record.job_id, - "model_id": metadata.model_id, + "model_id": identity.model_id, "version_info": version_info_json, "status": record.status, "error": error_value, diff --git a/tests/test_cli/test_process_pipeline.py b/tests/test_cli/test_process_pipeline.py index 83e3a0b4..a973f391 100644 --- a/tests/test_cli/test_process_pipeline.py +++ b/tests/test_cli/test_process_pipeline.py @@ -127,8 +127,8 @@ def _write_run( "env_variant_id": env_id, "env_args": {}, "results_dir": "demo-job", - } "status": status, + } ], } _write_json(run_dir / "run_manifest.json", manifest) @@ -521,7 +521,7 @@ def test_run_process_empty_runs_returns_result(tmp_path: Path) -> None: assert result.hf_summary is None -def test_process_latest_only_selects_latest_and_delta_skips(tmp_path: Path) -> None: +def test_process_latest_only_selects_latest_and_skips_existing_outputs(tmp_path: Path) -> None: runs_dir = _write_run(tmp_path, run_id="run-1", updated_at="2024-01-01T00:00:00Z", reward=0.1) _write_run(tmp_path, run_id="run-2", updated_at="2024-01-02T00:00:00Z", reward=0.9) output_dir = tmp_path / "processed" @@ -545,6 +545,151 @@ def test_process_latest_only_selects_latest_and_delta_skips(tmp_path: Path) -> N assert result_repeat.env_summaries == [] assert result_repeat.rows_processed == 0 + _write_run(tmp_path, run_id="run-3", updated_at="2024-01-04T00:00:00Z", reward=0.4) + result_newer_raw = run_process(options) + assert result_newer_raw.env_summaries == [] + assert result_newer_raw.rows_processed == 0 + + +def test_process_replace_model_rebuilds_existing_output(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-1", + updated_at="2024-01-01T00:00:00Z", + reward=0.1, + env_id="demo-env", + model_id="model-a", + ) + _write_run( + tmp_path, + run_id="run-2", + updated_at="2024-01-01T00:00:00Z", + reward=0.2, + env_id="demo-env", + model_id="model-b", + ) + output_dir = tmp_path / "processed" + + run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) + _write_run( + tmp_path, + run_id="run-3", + updated_at="2024-01-03T00:00:00Z", + reward=0.9, + env_id="demo-env", + model_id="model-a", + ) + _write_run( + tmp_path, + run_id="run-4", + updated_at="2024-01-03T00:00:00Z", + reward=0.8, + env_id="demo-env", + model_id="model-b", + ) + + result = run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=output_dir, + replace_models=("model-a",), + dry_run=False, + max_workers=1, + ) + ) + + rebuilt = {summary.model_id for summary in result.env_summaries} + assert rebuilt == {"model-a"} + model_a_table = pq.read_table(output_dir / "model-a" / "demo-env.parquet") + model_b_table = pq.read_table(output_dir / "model-b" / "demo-env.parquet") + assert model_a_table.column("reward").to_pylist() == [0.9] + assert model_b_table.column("reward").to_pylist() == [0.2] + + +def test_process_replace_model_and_env_rebuild_only_intersection(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-1", + updated_at="2024-01-01T00:00:00Z", + reward=0.1, + env_id="env-a", + model_id="model-a", + ) + _write_run( + tmp_path, + run_id="run-2", + updated_at="2024-01-01T00:00:00Z", + reward=0.2, + env_id="env-b", + model_id="model-a", + ) + _write_run( + tmp_path, + run_id="run-3", + updated_at="2024-01-01T00:00:00Z", + reward=0.3, + env_id="env-a", + model_id="model-b", + ) + output_dir = tmp_path / "processed" + run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) + + _write_run( + tmp_path, + run_id="run-4", + updated_at="2024-01-03T00:00:00Z", + reward=0.7, + env_id="env-a", + model_id="model-a", + ) + _write_run( + tmp_path, + run_id="run-5", + updated_at="2024-01-03T00:00:00Z", + reward=0.8, + env_id="env-b", + model_id="model-a", + ) + _write_run( + tmp_path, + run_id="run-6", + updated_at="2024-01-03T00:00:00Z", + reward=0.9, + env_id="env-a", + model_id="model-b", + ) + + result = run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=output_dir, + replace_models=("model-a",), + replace_envs=("env-a",), + dry_run=False, + max_workers=1, + ) + ) + + assert {(summary.model_id, summary.env_id) for summary in result.env_summaries} == {("model-a", "env-a")} + assert pq.read_table(output_dir / "model-a" / "env-a.parquet").column("reward").to_pylist() == [0.7] + assert pq.read_table(output_dir / "model-a" / "env-b.parquet").column("reward").to_pylist() == [0.2] + assert pq.read_table(output_dir / "model-b" / "env-a.parquet").column("reward").to_pylist() == [0.3] + + +def test_process_fails_fast_on_existing_row_count_mismatch(tmp_path: Path) -> None: + runs_dir = _setup_run(tmp_path) + output_dir = tmp_path / "processed" + result = run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) + summary = result.env_summaries[0] + rel_path = summary.output_path.relative_to(output_dir).as_posix() + payload = json.loads((output_dir / "env_index.json").read_text(encoding="utf-8")) + payload["files"][rel_path]["row_count"] = summary.row_count + 1 + (output_dir / "env_index.json").write_text(json.dumps(payload), encoding="utf-8") + + with pytest.raises(RuntimeError, match="env_index.json records"): + run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) + + def test_process_ignores_invalid_superseded_run(tmp_path: Path) -> None: runs_dir = _write_run( tmp_path, From 0592eaca7c2c6a50fb95cd6d707f8fea6b3a9c1d Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sat, 28 Feb 2026 13:36:17 -0500 Subject: [PATCH 04/29] Add explicit process replace filters --- medarc_verifiers/cli/main.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index 146d29fd..327f5993 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -287,6 +287,18 @@ def build_process_parser() -> argparse.ArgumentParser: parser.add_argument("--processed-at", default=None, help="Override processed_at timestamp (ISO8601).") parser.add_argument("--dry-run", action="store_true", default=None, help="Plan processing without writing outputs.") parser.add_argument( + parser.add_argument( + "--replace-model", + action="append", + default=None, + help="Rebuild existing processed outputs for these model ids (repeatable; comma-separated values allowed).", + ) + parser.add_argument( + "--replace-env", + action="append", + default=None, + help="Rebuild existing processed outputs for these env ids (repeatable; comma-separated values allowed).", + ) "--process-incomplete", dest="process_incomplete", action="store_true", @@ -564,6 +576,10 @@ def _run_process_mode(argv: Sequence[str]) -> int: normalize_dataset_ids(args.exclude_dataset, label="process exclude dataset") if args.exclude_model: normalize_model_ids(args.exclude_model, label="process exclude model") + for flag, attr in (("--replace-model", "replace_model"), ("--replace-env", "replace_env")): + if _option_was_provided(argv, flag) and not getattr(args, attr, None): + parser.error(f"{flag} requires at least one non-empty value.") + except ValueError as exc: parser.error(str(exc)) winrate_args: argparse.Namespace | None = None @@ -594,6 +610,8 @@ def _run_process_mode(argv: Sequence[str]) -> int: processed_with_args = { "status": args.status or [], "exclude_datasets": args.exclude_dataset or [], + "replace_models": args.replace_model or [], + "replace_envs": args.replace_env or [], "exclude_models": args.exclude_model or [], "dry_run": bool(args.dry_run), "clean": bool(args.clean), @@ -611,6 +629,8 @@ def _run_process_mode(argv: Sequence[str]) -> int: output_dir=args.output_dir, exclude_datasets=tuple(args.exclude_dataset or ()), exclude_models=tuple(args.exclude_model or ()), + replace_models=tuple(args.replace_model or ()), + replace_envs=tuple(args.replace_env or ()), processed_at=args.processed_at, processed_with_args=processed_with_args, status_filter=args.status or (), @@ -829,6 +849,8 @@ def _load_and_apply_config( "winrate": { "include_models": "include_model", "exclude_models": "exclude_model", + "replace_models": "replace_model", + "replace_envs": "replace_env", "exclude_datasets": "exclude_dataset", }, }[mode] @@ -934,10 +956,16 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", def _upload_winrate_outputs( *, output_dir: Path, + if mode == "process" and hasattr(args, "replace_model"): + args.replace_model = _parse_repeatable_csv(args.replace_model) + if mode == "process" and hasattr(args, "replace_env"): + args.replace_env = _parse_repeatable_csv(args.replace_env) output_paths: Sequence[Path], repo_id: str, token: str | None, branch: str | None, + "replace_model": [], + "replace_env": [], private: bool, winrate_dir: str | None, assume_yes: bool = False, From 958df70dd965126d72ef37155519c7d3b82ca460 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sat, 28 Feb 2026 13:36:55 -0500 Subject: [PATCH 05/29] Split process CLI orchestration --- medarc_verifiers/cli/main.py | 373 ++++++++++++++++++++++++++--------- tests/test_cli/test_main.py | 129 ++++++++---- 2 files changed, 378 insertions(+), 124 deletions(-) diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index 327f5993..d79b601b 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -25,7 +25,6 @@ DEFAULT_ENV_DIR, DEFAULT_PROCESSED_DIR, DEFAULT_RUNS_RAW_DIR, - DEFAULT_WINRATE_DIR, PROCESS_COMMAND, WINRATE_COMMAND, ) @@ -286,7 +285,6 @@ def build_process_parser() -> argparse.ArgumentParser: ) parser.add_argument("--processed-at", default=None, help="Override processed_at timestamp (ISO8601).") parser.add_argument("--dry-run", action="store_true", default=None, help="Plan processing without writing outputs.") - parser.add_argument( parser.add_argument( "--replace-model", action="append", @@ -299,6 +297,7 @@ def build_process_parser() -> argparse.ArgumentParser: default=None, help="Rebuild existing processed outputs for these env ids (repeatable; comma-separated values allowed).", ) + parser.add_argument( "--process-incomplete", dest="process_incomplete", action="store_true", @@ -309,7 +308,7 @@ def build_process_parser() -> argparse.ArgumentParser: "--winrate", type=Path, default=None, - help="Run winrate after processing using the provided winrate config file.", + help="Run winrate after processing using the provided config file. If omitted, an embedded winrate section in --config is used.", ) parser.add_argument( "--max-workers", @@ -376,7 +375,7 @@ def build_winrate_parser() -> argparse.ArgumentParser: "--output-dir", type=Path, default=None, - help=f"Directory to store winrate outputs (default: {DEFAULT_WINRATE_DIR}).", + help="Directory to store winrate outputs (default: /winrate).", ) parser.add_argument( "--output", @@ -565,37 +564,57 @@ def _run_batch_mode(argv: Sequence[str]) -> int: def _run_process_mode(argv: Sequence[str]) -> int: + parser, args = _resolve_process_args(argv) + winrate_args = _resolve_embedded_winrate(args, parser=parser) + + try: + env_export_map = _load_env_export_map(args.env_config_root) + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to load environment export configs: %s", exc) + env_export_map = {} + + options = _build_process_options(args) + + try: + result = run_process(options, env_export_map=env_export_map) + except Exception as exc: # noqa: BLE001 + logger.exception("Process pipeline failed: %s", exc) + return 1 + + _log_process_result(result) + return _run_process_post_steps(args, parser=parser, options=options, winrate_args=winrate_args) + + +def _resolve_process_args(argv: Sequence[str]) -> tuple[argparse.ArgumentParser, argparse.Namespace]: parser = build_process_parser() args = parser.parse_args(argv) if args.config: _load_and_apply_config(args, args.config, mode="process", parser=parser) _finalize_config_args(args, mode="process") + _validate_process_args(args, argv=argv, parser=parser) + return parser, args + + +def _validate_process_args( + args: argparse.Namespace, + *, + argv: Sequence[str], + parser: argparse.ArgumentParser, +) -> None: + for flag, attr in (("--replace-model", "replace_model"), ("--replace-env", "replace_env")): + if _option_was_provided(argv, flag) and not getattr(args, attr, None): + parser.error(f"{flag} requires at least one non-empty value.") try: if args.exclude_dataset: normalize_dataset_ids(args.exclude_dataset, label="process exclude dataset") if args.exclude_model: normalize_model_ids(args.exclude_model, label="process exclude model") - for flag, attr in (("--replace-model", "replace_model"), ("--replace-env", "replace_env")): - if _option_was_provided(argv, flag) and not getattr(args, attr, None): - parser.error(f"{flag} requires at least one non-empty value.") - except ValueError as exc: parser.error(str(exc)) - winrate_args: argparse.Namespace | None = None - if args.winrate: - winrate_path = Path(args.winrate).expanduser() - if not winrate_path.exists(): - parser.error(f"Winrate config path '{winrate_path}' does not exist.") - args.winrate = winrate_path - winrate_args = _build_winrate_args_from_config(winrate_path, parser=parser) - try: - env_export_map = _load_env_export_map(args.env_config_root) - except Exception as exc: # noqa: BLE001 - logger.warning("Failed to load environment export configs: %s", exc) - env_export_map = {} +def _build_process_options(args: argparse.Namespace) -> ProcessOptions: hf_config = HFSyncConfig.from_cli( repo=args.hf_repo, branch=args.hf_branch, @@ -606,13 +625,12 @@ def _run_process_mode(argv: Sequence[str]) -> int: retries=args.hf_retries, max_files_per_commit=args.hf_max_files_per_commit, ) - processed_with_args = { "status": args.status or [], "exclude_datasets": args.exclude_dataset or [], + "exclude_models": args.exclude_model or [], "replace_models": args.replace_model or [], "replace_envs": args.replace_env or [], - "exclude_models": args.exclude_model or [], "dry_run": bool(args.dry_run), "clean": bool(args.clean), "only_complete_runs": not bool(args.process_incomplete), @@ -623,8 +641,7 @@ def _run_process_mode(argv: Sequence[str]) -> int: "hf_max_files_per_commit": args.hf_max_files_per_commit, "max_workers": args.max_workers, } - - options = ProcessOptions( + return ProcessOptions( runs_dir=args.runs_dir, output_dir=args.output_dir, exclude_datasets=tuple(args.exclude_dataset or ()), @@ -643,65 +660,94 @@ def _run_process_mode(argv: Sequence[str]) -> int: max_workers=args.max_workers, ) + +def _resolve_embedded_winrate( + args: argparse.Namespace, + *, + parser: argparse.ArgumentParser, +) -> argparse.Namespace | None: + embedded_winrate = False + if args.config and args.winrate is None: + try: + embedded_winrate = _config_has_embedded_winrate(Path(args.config).expanduser()) + except (FileNotFoundError, ValueError) as exc: + parser.error(str(exc)) + + if args.winrate: + winrate_path = Path(args.winrate).expanduser() + if not winrate_path.exists(): + parser.error(f"Winrate config path '{winrate_path}' does not exist.") + args.winrate = winrate_path + return _build_winrate_args_from_config(winrate_path, parser=parser) + + if embedded_winrate: + args.winrate = Path(args.config).expanduser() + return _build_winrate_args_from_config(Path(args.config).expanduser(), parser=parser) + return None + + +def _run_process_post_steps( + args: argparse.Namespace, + *, + parser: argparse.ArgumentParser, + options: ProcessOptions, + winrate_args: argparse.Namespace | None, +) -> int: + if not args.winrate: + return 0 + if options.dry_run: + logger.info("Skipping winrate post-step for dry-run process.") + return 0 + + if winrate_args is None: + winrate_args = _build_winrate_args_from_config(Path(args.winrate), parser=parser) + winrate_args.processed_dir = options.output_dir + if not getattr(winrate_args, "_output_dir_explicit", False): + winrate_args.output_dir = _default_winrate_output_dir(options.output_dir) + winrate_args.hf_repo = None + winrate_args.hf_processed_pull = False + + winrate_cfg = WinrateConfig( + missing_policy=winrate_args.missing_policy, + epsilon=winrate_args.epsilon, + min_common=winrate_args.min_common, + weight_policy=winrate_args.weight_policy, + weight_cap=winrate_args.weight_cap, + dataset_coverage=winrate_args.dataset_coverage, + include_models=tuple(winrate_args.include_model or ()), + exclude_models=tuple(winrate_args.exclude_model or ()), + exclude_datasets=tuple(winrate_args.exclude_dataset or ()), + partial_datasets=winrate_args.partial_datasets, + ) try: - result = run_process(options, env_export_map=env_export_map) + winrate_result = run_winrate( + processed_dir=options.output_dir, + output_dir=winrate_args.output_dir, + output_path=winrate_args.output, + output_name=winrate_args.output_name, + config=winrate_cfg, + processed_at=winrate_args.processed_at, + hf_config=None, + hf_processed_pull=False, + ) except Exception as exc: # noqa: BLE001 - logger.exception("Process pipeline failed: %s", exc) + logger.exception("Win rate computation failed: %s", exc) return 1 - _log_process_result(result) + logger.info("Computed win rates for %d dataset(s): %s", len(winrate_result.datasets), winrate_result.output_path) + print_winrate_summary_markdown(winrate_result.result) - if args.winrate: - if options.dry_run: - logger.info("Skipping winrate post-step for dry-run process.") - return 0 - if winrate_args is None: - winrate_args = _build_winrate_args_from_config(Path(args.winrate), parser=parser) - winrate_args.processed_dir = options.output_dir - winrate_args.hf_repo = None - winrate_args.hf_processed_pull = False - winrate_cfg = WinrateConfig( - missing_policy=winrate_args.missing_policy, - epsilon=winrate_args.epsilon, - min_common=winrate_args.min_common, - weight_policy=winrate_args.weight_policy, - weight_cap=winrate_args.weight_cap, - dataset_coverage=winrate_args.dataset_coverage, - include_models=tuple(winrate_args.include_model or ()), - exclude_models=tuple(winrate_args.exclude_model or ()), - exclude_datasets=tuple(winrate_args.exclude_dataset or ()), - partial_datasets=winrate_args.partial_datasets, - ) - try: - winrate_result = run_winrate( - processed_dir=options.output_dir, - output_dir=winrate_args.output_dir, - output_path=winrate_args.output, - output_name=winrate_args.output_name, - config=winrate_cfg, - processed_at=winrate_args.processed_at, - hf_config=None, - hf_processed_pull=False, - ) - except Exception as exc: # noqa: BLE001 - logger.exception("Win rate computation failed: %s", exc) - return 1 - logger.info( - "Computed win rates for %d dataset(s): %s", len(winrate_result.datasets), winrate_result.output_path + if options.hf_config and options.hf_config.repo_id: + _upload_winrate_outputs( + output_dir=winrate_args.output_dir, + output_paths=winrate_result.output_paths, + repo_id=options.hf_config.repo_id, + token=options.hf_config.token, + branch=options.hf_config.branch, + private=bool(options.hf_config.private), + winrate_dir=winrate_args.hf_winrate_dir, + assume_yes=bool(args.yes), ) - print_winrate_summary_markdown(winrate_result.result) - if options.hf_config and options.hf_config.repo_id: - _upload_winrate_outputs( - output_dir=winrate_args.output_dir, - output_paths=winrate_result.output_paths, - repo_id=options.hf_config.repo_id, - token=options.hf_config.token, - branch=options.hf_config.branch, - private=bool(options.hf_config.private), - winrate_dir=winrate_args.hf_winrate_dir, - assume_yes=bool(args.yes), - ) - return 0 @@ -741,7 +787,147 @@ def _set_if_unset(args: argparse.Namespace, attr: str, value: Any) -> None: def _load_config_payload(path: Path, *, mode: Literal["process", "winrate"]) -> dict[str, Any]: label = "Process config" if mode == "process" else "Winrate config" - return dict(load_mapping_file(path, label=label)) + raw_payload = dict(load_mapping_file(path, label=label)) + return _expand_embedded_pipeline_config(raw_payload, mode=mode) + + +def _expand_embedded_pipeline_config(payload: dict[str, Any], *, mode: Literal["process", "winrate"]) -> dict[str, Any]: + expanded = dict(payload) + process_section = payload.get("process") + if isinstance(process_section, Mapping): + _merge_process_section(expanded, process_section, mode=mode) + + process_output_dir = _resolve_processed_dir_from_payload(expanded, mode=mode) + + winrate_section = payload.get("winrate") + if isinstance(winrate_section, Mapping): + if mode == "process": + expanded.pop("winrate", None) + if mode == "winrate": + _merge_winrate_section(expanded, winrate_section, process_output_dir=process_output_dir) + elif isinstance(winrate_section, bool) and mode == "process": + expanded.pop("winrate", None) + + if mode == "winrate" and "processed_dir" not in expanded and process_output_dir is not None: + expanded["processed_dir"] = process_output_dir + + return expanded + + +def _merge_process_section( + expanded: dict[str, Any], + process_section: Mapping[str, Any], + *, + mode: Literal["process", "winrate"], +) -> None: + resolved = None + if "dir" in process_section: + resolved = _resolve_process_dir_value(process_section["dir"], runs_dir=expanded.get("runs_dir")) + if mode == "process" and "output_dir" not in expanded and resolved is not None: + expanded["output_dir"] = resolved + if mode == "winrate" and "processed_dir" not in expanded and resolved is not None: + expanded["processed_dir"] = resolved + if mode == "winrate" and "processed_dir" not in expanded and "output_dir" in process_section: + expanded["processed_dir"] = process_section["output_dir"] + key_map = {"runs_dir": "runs_dir"} + if mode == "process": + key_map.update( + { + "output_dir": "output_dir", + "env_config_root": "env_config_root", + "processed_at": "processed_at", + "status": "status", + "exclude_datasets": "exclude_datasets", + "exclude_models": "exclude_models", + "replace_models": "replace_models", + "replace_envs": "replace_envs", + "dry_run": "dry_run", + "clean": "clean", + "yes": "yes", + "process_incomplete": "process_incomplete", + "max_workers": "max_workers", + } + ) + for key, target in key_map.items(): + if key in process_section and target not in expanded: + expanded[target] = process_section[key] + + +def _merge_winrate_section( + expanded: dict[str, Any], + winrate_section: Mapping[str, Any], + *, + process_output_dir: Path | None, +) -> None: + if "dir" in winrate_section and "output_dir" not in expanded: + resolved = _resolve_winrate_dir_value(winrate_section["dir"], process_output_dir=process_output_dir) + if resolved is not None: + expanded["output_dir"] = resolved + key_map = { + "processed_dir": "processed_dir", + "output_dir": "output_dir", + "output_name": "output_name", + "processed_at": "processed_at", + "missing_policy": "missing_policy", + "epsilon": "epsilon", + "min_common": "min_common", + "weight_policy": "weight_policy", + "weight_cap": "weight_cap", + "dataset_coverage": "dataset_coverage", + "include_model": "include_models", + "include_models": "include_models", + "exclude_model": "exclude_models", + "exclude_models": "exclude_models", + "exclude_dataset": "exclude_datasets", + "exclude_datasets": "exclude_datasets", + "partial_datasets": "partial_datasets", + "hf_processed_pull": "hf_processed_pull", + "hf_winrate_dir": "hf_winrate_dir", + } + for key, target in key_map.items(): + if key in winrate_section and target not in expanded: + expanded[target] = winrate_section[key] + + +def _resolve_processed_dir_from_payload(payload: Mapping[str, Any], *, mode: Literal["process", "winrate"]) -> Path | None: + if "processed_dir" in payload and payload["processed_dir"] is not None: + return Path(str(payload["processed_dir"])) + if mode == "process" and "output_dir" in payload and payload["output_dir"] is not None: + return Path(str(payload["output_dir"])) + process_section = payload.get("process") + if isinstance(process_section, Mapping) and "dir" in process_section: + return _resolve_process_dir_value(process_section["dir"], runs_dir=payload.get("runs_dir")) + return None + + +def _resolve_process_dir_value(value: Any, *, runs_dir: Any | None) -> Path | None: + raw = str(value).strip() + if not raw: + return None + candidate = Path(raw) + if candidate.is_absolute(): + return candidate + runs_base = Path(str(runs_dir)).parent if runs_dir is not None else DEFAULT_RUNS_RAW_DIR.parent + return runs_base / candidate + + +def _resolve_winrate_dir_value(value: Any, *, process_output_dir: Path | None) -> Path | None: + raw = str(value).strip() + if not raw: + return None + candidate = Path(raw) + if candidate.is_absolute(): + return candidate + base = process_output_dir if process_output_dir is not None else DEFAULT_PROCESSED_DIR + return base / candidate + + +def _config_has_embedded_winrate(path: Path) -> bool: + payload = dict(load_mapping_file(path, label="Process config")) + winrate_payload = payload.get("winrate") + if isinstance(winrate_payload, Mapping): + return bool(winrate_payload.get("enabled", True)) + return bool(winrate_payload) if isinstance(winrate_payload, bool) else False def _normalize_mode_payload(payload: dict[str, Any], *, mode: Literal["process", "winrate"]) -> None: @@ -845,12 +1031,16 @@ def _load_and_apply_config( "winrate": {"epsilon": "epsilon"}, }[mode] repeatable_fields = { - "process": {"status": "status", "exclude_datasets": "exclude_dataset", "exclude_models": "exclude_model"}, - "winrate": { - "include_models": "include_model", + "process": { + "status": "status", + "exclude_datasets": "exclude_dataset", "exclude_models": "exclude_model", "replace_models": "replace_model", "replace_envs": "replace_env", + }, + "winrate": { + "include_models": "include_model", + "exclude_models": "exclude_model", "exclude_datasets": "exclude_dataset", }, }[mode] @@ -904,6 +1094,7 @@ def _build_winrate_args_from_config(path: Path, *, parser: argparse.ArgumentPars hf_private=None, ) _load_and_apply_config(args, path, mode="winrate", parser=parser) + args._output_dir_explicit = args.output_dir is not None _finalize_config_args(args, mode="winrate") return args @@ -923,10 +1114,11 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", "process_incomplete": False, "exclude_dataset": [], "exclude_model": [], + "replace_model": [], + "replace_env": [], }, "winrate": { "processed_dir": DEFAULT_PROCESSED_DIR, - "output_dir": DEFAULT_WINRATE_DIR, "missing_policy": "neg-inf", "epsilon": 1e-9, "min_common": 0, @@ -946,26 +1138,30 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", for attr, default in defaults.items(): if getattr(args, attr, None) is None: setattr(args, attr, default) + if mode == "winrate" and getattr(args, "output_dir", None) is None: + args.output_dir = _default_winrate_output_dir(Path(args.processed_dir)) if hasattr(args, "exclude_dataset"): args.exclude_dataset = _parse_repeatable_csv(args.exclude_dataset) if mode == "process" and hasattr(args, "exclude_model"): args.exclude_model = _parse_repeatable_csv(args.exclude_model) + if mode == "process" and hasattr(args, "replace_model"): + args.replace_model = _parse_repeatable_csv(args.replace_model) + if mode == "process" and hasattr(args, "replace_env"): + args.replace_env = _parse_repeatable_csv(args.replace_env) + + +def _default_winrate_output_dir(processed_dir: Path) -> Path: + return Path(processed_dir) / "winrate" def _upload_winrate_outputs( *, output_dir: Path, - if mode == "process" and hasattr(args, "replace_model"): - args.replace_model = _parse_repeatable_csv(args.replace_model) - if mode == "process" and hasattr(args, "replace_env"): - args.replace_env = _parse_repeatable_csv(args.replace_env) output_paths: Sequence[Path], repo_id: str, token: str | None, branch: str | None, - "replace_model": [], - "replace_env": [], private: bool, winrate_dir: str | None, assume_yes: bool = False, @@ -1013,6 +1209,7 @@ def _run_winrate_mode(argv: Sequence[str]) -> int: if args.config: _load_and_apply_config(args, args.config, mode="winrate", parser=parser) + args._output_dir_explicit = args.output_dir is not None _finalize_config_args(args, mode="winrate") hf_config = HFSyncConfig.from_cli( diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index dc1c5771..8eb636cf 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -1794,10 +1794,11 @@ def test_process_cli_applies_config_defaults(monkeypatch: pytest.MonkeyPatch, tm cfg_path = tmp_path / "process.yaml" cfg_path.write_text( f""" - runs_dir: runs-from-config - output_dir: processed-from-config - env_config_root: {env_root} - max_workers: 2 + runs_dir: runs/raw-from-config + process: + dir: processed + env_config_root: {env_root} + max_workers: 2 hf: repo: medarc/demo branch: main @@ -1821,8 +1822,8 @@ def fake_run(options, env_export_map): assert exit_code == 0 options = captured["options"] - assert options.runs_dir == Path("runs-from-config") - assert options.output_dir == Path("processed-from-config") + assert options.runs_dir == Path("runs/raw-from-config") + assert options.output_dir == Path("runs/processed") assert options.max_workers == 2 assert options.hf_pull_policy == "pull" assert options.hf_config is not None @@ -1842,15 +1843,18 @@ def test_winrate_cli_applies_config_defaults(monkeypatch: pytest.MonkeyPatch, tm cfg_path = tmp_path / "winrate.yaml" cfg_path.write_text( """ - processed_dir: runs-from-config - output_name: from-config - missing_policy: zero - epsilon: 0.123 - min_common: 7 - weight_policy: equal - weight_cap: 99 - include_models: [alpha, beta] - exclude_model: gamma + runs_dir: runs/raw-from-config + process: + dir: processed + winrate: + output_name: from-config + missing_policy: zero + epsilon: 0.123 + min_common: 7 + weight_policy: equal + weight_cap: 99 + include_models: [alpha, beta] + exclude_model: gamma hf: repo: medarc/demo branch: main @@ -1891,7 +1895,8 @@ def fake_sync_files_to_hub(**kwargs): exit_code = main.main(["winrate", "--config", str(cfg_path), "--processed-at", "2024-01-01T00:00:00Z"]) assert exit_code == 0 - assert captured["run_kwargs"]["processed_dir"] == Path("runs-from-config") + assert captured["run_kwargs"]["processed_dir"] == Path("runs/processed") + assert captured["run_kwargs"]["output_dir"] == Path("runs/processed") / "winrate" cfg = captured["run_kwargs"]["config"] assert cfg.missing_policy == "zero" assert cfg.epsilon == pytest.approx(0.123) @@ -1939,16 +1944,18 @@ def test_process_cli_requires_winrate_config_path(tmp_path: Path) -> None: ) -def test_process_cli_runs_winrate_post_step(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - cfg_path = tmp_path / "winrate.yaml" +def test_process_cli_runs_embedded_winrate_post_step(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + cfg_path = tmp_path / "process.yaml" cfg_path.write_text( """ - processed_dir: ignored - output_dir: winrate-out - output_name: from-config - missing_policy: zero - hf: - winrate_dir: winrate-post + runs_dir: runs/raw + process: + dir: processed + winrate: + dir: scorecards + output_name: from-config + missing_policy: zero + hf_winrate_dir: winrate-post """, encoding="utf-8", ) @@ -2002,11 +2009,7 @@ def fake_sync_files_to_hub( exit_code = main.main( [ "process", - "--runs-dir", - str(tmp_path / "runs"), - "--output-dir", - str(tmp_path / "processed"), - "--winrate", + "--config", str(cfg_path), "--hf-repo", "medarc/shared", @@ -2015,8 +2018,8 @@ def fake_sync_files_to_hub( ] ) assert exit_code == 0 - assert captured["run_kwargs"]["processed_dir"] == Path(tmp_path / "processed") - assert captured["run_kwargs"]["output_dir"] == Path("winrate-out") + assert captured["run_kwargs"]["processed_dir"] == Path("runs/processed") + assert captured["run_kwargs"]["output_dir"] == Path("runs/processed/scorecards") assert captured["run_kwargs"]["hf_config"] is None assert captured["run_kwargs"]["hf_processed_pull"] is False upload = captured.get("upload") @@ -2027,17 +2030,69 @@ def fake_sync_files_to_hub( assert upload["path_in_repo_prefix"] == "winrate-post" +def test_process_cli_defaults_winrate_output_dir_under_processed( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + cfg_path = tmp_path / "process.yaml" + cfg_path.write_text( + """ + runs_dir: runs/raw + process: + dir: processed + winrate: + missing_policy: zero + """, + encoding="utf-8", + ) + + captured: dict[str, Any] = {} + + def fake_run_process(options, env_export_map): + captured["options"] = options + return ProcessResult(records_processed=0, rows_processed=0, env_groups=[], env_summaries=[], hf_summary=None) + + def fake_run_winrate( + *, processed_dir, output_dir, output_path, output_name, config, processed_at, hf_config, hf_processed_pull + ): + captured["run_kwargs"] = { + "processed_dir": processed_dir, + "output_dir": output_dir, + } + return SimpleNamespace( + output_path=Path(output_dir) / "winrate.json", + output_paths=[Path(output_dir) / "winrate.json"], + result={"models": {}}, + datasets=[], + ) + + monkeypatch.setattr(main, "run_process", fake_run_process) + monkeypatch.setattr(main, "run_winrate", fake_run_winrate) + monkeypatch.setattr(main, "print_winrate_summary_markdown", lambda *_args, **_kwargs: None) + + exit_code = main.main( + [ + "process", + "--config", + str(cfg_path), + ] + ) + assert exit_code == 0 + assert captured["run_kwargs"]["processed_dir"] == Path("runs/processed") + assert captured["run_kwargs"]["output_dir"] == Path("runs/processed/winrate") + + def test_process_config_sets_winrate_path(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: cfg_path = tmp_path / "process.yaml" - winrate_cfg = tmp_path / "winrate.yaml" fake_runs_dir = tmp_path / "runs" / "raw" fake_runs_dir.mkdir(parents=True) - winrate_cfg.write_text("output_dir: runs/winrate\n", encoding="utf-8") cfg_path.write_text( f""" runs_dir: {fake_runs_dir} - output_dir: runs/processed - winrate: {winrate_cfg} + process: + dir: processed + winrate: + enabled: true """, encoding="utf-8", ) @@ -2071,7 +2126,9 @@ def fake_run_winrate( ] ) assert exit_code == 0 - assert captured["run_kwargs"]["processed_dir"] == Path("runs/processed") + expected_processed_dir = fake_runs_dir.parent / "processed" + assert captured["run_kwargs"]["processed_dir"] == expected_processed_dir + assert captured["run_kwargs"]["output_dir"] == expected_processed_dir / "winrate" def test_process_cli_rejects_include_prompt_completion(tmp_path: Path) -> None: From 8928f33b846d656986c85a3e03305ff47d79b3e6 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sat, 28 Feb 2026 13:38:33 -0500 Subject: [PATCH 06/29] Break up process row loading --- medarc_verifiers/cli/process/rows.py | 136 +++++++++++++++++---------- 1 file changed, 84 insertions(+), 52 deletions(-) diff --git a/medarc_verifiers/cli/process/rows.py b/medarc_verifiers/cli/process/rows.py index 2b4f35f3..e27896a7 100644 --- a/medarc_verifiers/cli/process/rows.py +++ b/medarc_verifiers/cli/process/rows.py @@ -38,76 +38,108 @@ def load_rows( drop = {column for column in drop_columns or () if column} drop.update(DEFAULT_DROP_COLUMNS) drop.update(PROMPT_COMPLETION_COLUMNS) + decoded_rows, example_counts = _decode_results_jsonl(results_path) + multi_rollout = _detect_multi_rollout_shape(example_counts) + version_info_json = _encode_metadata_json_column(metadata.raw_metadata.get("version_info")) - # First pass: decode and clean rows, and count example_id occurrences to - # detect multiple rollouts within a single JSONL (example_id repetition). + rows: list[dict[str, Any]] = [] + seen_per_example: dict[Any, int] = {} + for line_number, payload in decoded_rows: + cleaned, extras = _clean_payload_row( + payload, + extras_keys=extras_keys, + drop=drop, + answer_column=answer_column, + ) + rollout_index = _resolve_rollout_index( + payload, + metadata, + multi_rollout=multi_rollout, + seen_per_example=seen_per_example, + ) + if extras_keys and extras: + cleaned["extras"] = json.dumps(extras, sort_keys=True) + else: + cleaned["extras"] = None + enriched = _attach_row_metadata( + cleaned, + metadata, + line_number=line_number, + rollout_index=rollout_index, + version_info_json=version_info_json, + ) + rows.append(enriched) + + return rows + + +def _decode_results_jsonl(path: Path) -> tuple[list[tuple[int, Mapping[str, Any]]], dict[Any, int]]: + """Decode results.jsonl and count example_id occurrences for rollout detection.""" decoded_rows: list[tuple[int, Mapping[str, Any]]] = [] example_counts: dict[Any, int] = {} try: - with results_path.open("r", encoding="utf-8") as handle: + with path.open("r", encoding="utf-8") as handle: for line_number, raw_line in enumerate(handle, start=1): line = raw_line.strip() if not line: continue - payload = _decode_line(line, results_path, line_number) + payload = _decode_line(line, path, line_number) decoded_rows.append((line_number, payload)) ex_id = payload.get("example_id") - # Count occurrences to infer intra-file rollout structure. try: example_counts[ex_id] = example_counts.get(ex_id, 0) + 1 except TypeError: - # Non-hashable example_id shouldn't happen (schema requires - # primitive), but guard just in case. pass except ValueError: raise except OSError as exc: # noqa: FBT003 - logger.warning("Failed to read %s: %s", results_path, exc) - return [] + logger.warning("Failed to read %s: %s", path, exc) + return [], {} + return decoded_rows, example_counts - multi_rollout = any(count > 1 for count in example_counts.values()) - version_info_json = _encode_metadata_json_column(metadata.raw_metadata.get("version_info")) - # Second pass: enrich rows. If the file contains multiple rollouts, compute - # a data-driven rollout_index by counting seen occurrences per example_id. - # Otherwise, retain the suffix/dir-derived rollout_index from metadata. - rows: list[dict[str, Any]] = [] - seen_per_example: dict[Any, int] = {} - for line_number, payload in decoded_rows: - extras = _extract_extras(payload, extras_keys=extras_keys) - cleaned = _clean_row(payload, drop=drop, extras_keys=extras_keys) - cleaned.pop("rollout_index", None) - _map_answer_column(cleaned, payload, answer_column=answer_column) - _flatten_token_usage(cleaned) - payload_rollout_index = _coerce_rollout_index(payload.get("rollout_index")) - if payload_rollout_index is not None: - rollout_index = payload_rollout_index - cleaned["rollout_index"] = payload_rollout_index - elif multi_rollout: - ex_id = payload.get("example_id") - try: - seen = seen_per_example.get(ex_id, 0) - rollout_index = seen # 0-based occurrence index - seen_per_example[ex_id] = seen + 1 - except TypeError: - # Fallback to metadata rollout_index if example_id is unusable as key - rollout_index = metadata.rollout_index - else: - rollout_index = metadata.rollout_index - if extras_keys and extras: - cleaned["extras"] = json.dumps(extras, sort_keys=True) - else: - cleaned["extras"] = None - enriched = _attach_metadata( - cleaned, - metadata, - line_number=line_number, - rollout_index=rollout_index, - version_info_json=version_info_json, - ) - rows.append(enriched) +def _detect_multi_rollout_shape(example_counts: Mapping[Any, int]) -> bool: + return any(count > 1 for count in example_counts.values()) - return rows + +def _clean_payload_row( + payload: Mapping[str, Any], + *, + extras_keys: set[str], + drop: set[str], + answer_column: str | None, +) -> tuple[MutableMapping[str, Any], Mapping[str, Any]]: + extras = _extract_extras(payload, extras_keys=extras_keys) + cleaned = _clean_row(payload, drop=drop, extras_keys=extras_keys) + cleaned.pop("rollout_index", None) + _map_answer_column(cleaned, payload, answer_column=answer_column) + _normalize_token_usage(cleaned) + payload_rollout_index = _coerce_rollout_index(payload.get("rollout_index")) + if payload_rollout_index is not None: + cleaned["rollout_index"] = payload_rollout_index + return cleaned, extras + + +def _resolve_rollout_index( + payload: Mapping[str, Any], + metadata: NormalizedMetadata, + *, + multi_rollout: bool, + seen_per_example: MutableMapping[Any, int], +) -> int: + payload_rollout_index = _coerce_rollout_index(payload.get("rollout_index")) + if payload_rollout_index is not None: + return payload_rollout_index + if not multi_rollout: + return metadata.rollout_index + + ex_id = payload.get("example_id") + try: + seen = seen_per_example.get(ex_id, 0) + seen_per_example[ex_id] = seen + 1 + return seen + except TypeError: + return metadata.rollout_index def _map_answer_column( @@ -204,7 +236,7 @@ def _coerce_rollout_index(value: Any) -> int | None: return None -def _attach_metadata( +def _attach_row_metadata( row: MutableMapping[str, Any], metadata: NormalizedMetadata, *, @@ -237,7 +269,7 @@ def _attach_metadata( return row -def _flatten_token_usage(row: MutableMapping[str, Any]) -> None: +def _normalize_token_usage(row: MutableMapping[str, Any]) -> None: """Flatten token_usage dict into explicit columns and drop the original field.""" if "token_usage" not in row: return From d31bcb278bba59bb5c39bfdf5fb1aeb8c46f395e Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sat, 28 Feb 2026 13:41:35 -0500 Subject: [PATCH 07/29] Separate process workspace preparation --- medarc_verifiers/cli/process/pipeline.py | 36 ++----------- medarc_verifiers/cli/process/workspace.py | 60 +++++++++++++++++++++ tests/test_cli/test_process_pipeline.py | 47 +++++++++++++++++ tests/test_cli/test_process_workspace.py | 63 +++++++++++++++++++++++ 4 files changed, 175 insertions(+), 31 deletions(-) diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index a9dbdc46..c82a8cc7 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from datetime import UTC, datetime from pathlib import Path -from typing import Any, Callable, Iterable, Mapping, Sequence +from typing import Any, Iterable, Mapping, Sequence import pyarrow.parquet as pq @@ -113,19 +113,13 @@ def run_process( env_export_map = env_export_map or {} def _run_pipeline() -> ProcessResult: - if not options.dry_run and options.clean: - _confirm_clean_process( - options.output_dir, - assume_yes=options.assume_yes, - is_tty=sys.stdin.isatty(), - prompt_func=input, - ) - workspace.clear_output_dir(options.output_dir) - if not options.dry_run and options.hf_config and options.hf_config.repo_id and not options.clean: - workspace.prepare_hf_baseline( + if not options.dry_run: + workspace.prepare_output_workspace( output_dir=options.output_dir, hf_config=options.hf_config, pull_policy=options.hf_pull_policy, + clean=options.clean, + assume_yes=options.assume_yes, is_tty=sys.stdin.isatty(), prompt_func=input, ) @@ -655,26 +649,6 @@ def _run_sort_key(timestamp: str, job_run_id: str) -> tuple[int, datetime, str]: return (0, datetime.min.replace(tzinfo=UTC), job_run_id) -def _confirm_clean_process( - output_dir: Path, - *, - assume_yes: bool, - is_tty: bool, - prompt_func: Callable[[str], str] | None, -) -> None: - if assume_yes: - return - if not is_tty or prompt_func is None: - raise RuntimeError("Refusing to clean processed outputs without confirmation. Re-run with --yes to confirm.") - prompt = f"--clean will delete all contents of {output_dir} and rebuild from runs. Type 'clean' to continue: " - try: - response = prompt_func(prompt).strip().lower() - except (EOFError, KeyboardInterrupt): # noqa: PERF203 - raise RuntimeError("Aborted clean process.") from None - if response != "clean": - raise RuntimeError("Aborted clean process.") - - __all__ = [ "PlannedRecord", "PlannedWorkItem", diff --git a/medarc_verifiers/cli/process/workspace.py b/medarc_verifiers/cli/process/workspace.py index 20254104..a497e353 100644 --- a/medarc_verifiers/cli/process/workspace.py +++ b/medarc_verifiers/cli/process/workspace.py @@ -21,6 +21,12 @@ class BaselineResult: snapshot_dir: Path | None = None +@dataclass(slots=True) +class WorkspacePreparationResult: + cleaned: bool = False + baseline_result: BaselineResult | None = None + + def ensure_output_dir(output_dir: Path) -> None: output_dir.mkdir(parents=True, exist_ok=True) @@ -33,6 +39,37 @@ def is_nonempty_dir(path: Path) -> bool: return False +def prepare_output_workspace( + *, + output_dir: Path, + hf_config: HFSyncConfig | None, + pull_policy: str | None, + clean: bool, + assume_yes: bool, + is_tty: bool, + prompt_func: Callable[[str], str] | None = None, +) -> WorkspacePreparationResult: + """Prepare local processed outputs before selection reads local inventory state.""" + ensure_output_dir(output_dir) + + if clean: + confirm_clean_output_dir(output_dir, assume_yes=assume_yes, is_tty=is_tty, prompt_func=prompt_func) + clear_output_dir(output_dir) + return WorkspacePreparationResult(cleaned=True) + + if hf_config and hf_config.repo_id: + baseline_result = prepare_hf_baseline( + output_dir=output_dir, + hf_config=hf_config, + pull_policy=pull_policy, + is_tty=is_tty, + prompt_func=prompt_func, + ) + return WorkspacePreparationResult(cleaned=False, baseline_result=baseline_result) + + return WorkspacePreparationResult(cleaned=False) + + def prepare_hf_baseline( *, output_dir: Path, @@ -107,6 +144,26 @@ def prepare_hf_baseline( raise ValueError(f"Unsupported HF pull policy: {policy}") +def confirm_clean_output_dir( + output_dir: Path, + *, + assume_yes: bool, + is_tty: bool, + prompt_func: Callable[[str], str] | None, +) -> None: + if assume_yes: + return + if not is_tty or prompt_func is None: + raise RuntimeError("Refusing to clean processed outputs without confirmation. Re-run with --yes to confirm.") + prompt = f"--clean will delete all contents of {output_dir} and rebuild from runs. Type 'clean' to continue: " + try: + response = prompt_func(prompt).strip().lower() + except (EOFError, KeyboardInterrupt): # noqa: PERF203 + raise RuntimeError("Aborted clean process.") from None + if response != "clean": + raise RuntimeError("Aborted clean process.") + + def _resolve_pull_policy(pull_policy: str | None, *, is_tty: bool) -> str: if pull_policy: return pull_policy @@ -228,8 +285,11 @@ def clear_output_dir(output_dir: Path) -> None: __all__ = [ "BaselineResult", + "WorkspacePreparationResult", "clear_output_dir", + "confirm_clean_output_dir", "ensure_output_dir", "is_nonempty_dir", + "prepare_output_workspace", "prepare_hf_baseline", ] diff --git a/tests/test_cli/test_process_pipeline.py b/tests/test_cli/test_process_pipeline.py index a973f391..d3c4368e 100644 --- a/tests/test_cli/test_process_pipeline.py +++ b/tests/test_cli/test_process_pipeline.py @@ -769,6 +769,53 @@ def test_process_clean_clears_outputs(tmp_path: Path) -> None: assert (output_dir / "env_index.json").exists() +def test_run_process_reads_local_index_after_workspace_prep( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + runs_dir = _setup_run(tmp_path) + output_dir = tmp_path / "processed" + observed: list[str] = [] + + def fake_prepare_output_workspace(**kwargs): + observed.append("workspace") + model_dir = kwargs["output_dir"] / "gpt-mini" + model_dir.mkdir(parents=True, exist_ok=True) + existing_path = model_dir / "demo-env.parquet" + existing_path.write_text("placeholder", encoding="utf-8") + (kwargs["output_dir"] / "env_index.json").write_text( + json.dumps( + { + "version": 2, + "files": { + "gpt-mini/demo-env.parquet": { + "env_id": "demo-env", + "model_id": "gpt-mini", + } + }, + } + ), + encoding="utf-8", + ) + + def fake_read_env_index_files(processed_dir: Path): + observed.append("index") + assert observed == ["workspace", "index"] + return {"gpt-mini/demo-env.parquet": {"env_id": "demo-env", "model_id": "gpt-mini"}} + + monkeypatch.setattr("medarc_verifiers.cli.process.workspace.prepare_output_workspace", fake_prepare_output_workspace) + monkeypatch.setattr("medarc_verifiers.cli.process.env_index.read_env_index_files", fake_read_env_index_files) + monkeypatch.setattr( + "medarc_verifiers.cli.process.pipeline._validate_existing_output_integrity", + lambda *_args, **_kwargs: None, + ) + + result = run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) + + assert observed == ["workspace", "index"] + assert result.env_summaries == [] + + def test_run_process_ignores_legacy_run_output_path(tmp_path: Path) -> None: runs_dir = _setup_run(tmp_path) run_dir = runs_dir / "run-1" diff --git a/tests/test_cli/test_process_workspace.py b/tests/test_cli/test_process_workspace.py index fa1444a7..aa9679ad 100644 --- a/tests/test_cli/test_process_workspace.py +++ b/tests/test_cli/test_process_workspace.py @@ -50,6 +50,69 @@ def _fake_download_hf_repo(**_kwargs) -> Path: assert copied in result.files_copied +def test_prepare_output_workspace_clean_skips_hf_baseline( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "output" + output_dir.mkdir() + sentinel = output_dir / "stale.txt" + sentinel.write_text("stale", encoding="utf-8") + + def _fail_prepare_hf_baseline(**_kwargs) -> workspace.BaselineResult: + raise AssertionError("prepare_hf_baseline should not run when clean=True") + + monkeypatch.setattr(workspace, "prepare_hf_baseline", _fail_prepare_hf_baseline) + + result = workspace.prepare_output_workspace( + output_dir=output_dir, + hf_config=HFSyncConfig(repo_id="demo/repo"), + pull_policy="pull", + clean=True, + assume_yes=True, + is_tty=False, + prompt_func=None, + ) + + assert result.cleaned is True + assert result.baseline_result is None + assert not sentinel.exists() + + +def test_prepare_output_workspace_runs_hf_baseline_before_local_reads( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "output" + snapshot_dir = tmp_path / "snapshot" + snapshot_dir.mkdir() + parquet_path = _write_snapshot(snapshot_dir) + + def _fake_prepare_hf_baseline(**_kwargs) -> workspace.BaselineResult: + copied = output_dir / parquet_path.relative_to(snapshot_dir) + copied.parent.mkdir(parents=True, exist_ok=True) + copied.write_text(parquet_path.read_text(encoding="utf-8"), encoding="utf-8") + (output_dir / "env_index.json").write_text((snapshot_dir / "env_index.json").read_text(encoding="utf-8")) + return workspace.BaselineResult(policy="pull", files_copied=[copied], snapshot_dir=snapshot_dir) + + monkeypatch.setattr(workspace, "prepare_hf_baseline", _fake_prepare_hf_baseline) + + result = workspace.prepare_output_workspace( + output_dir=output_dir, + hf_config=HFSyncConfig(repo_id="demo/repo"), + pull_policy="pull", + clean=False, + assume_yes=False, + is_tty=False, + prompt_func=None, + ) + + assert result.cleaned is False + assert result.baseline_result is not None + assert (output_dir / "env_index.json").exists() + assert (output_dir / "model-a" / "env-a.parquet").exists() + + def test_prepare_hf_baseline_pull_keeps_unrelated_local(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: snapshot_dir = tmp_path / "snapshot" snapshot_dir.mkdir() From 11ad91903466f111b57547bd2875294ab7c9be82 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sun, 1 Mar 2026 12:22:43 -0500 Subject: [PATCH 08/29] Harden process selection against stale broken runs --- medarc_verifiers/cli/process/metadata.py | 146 +++++++++++++++++------ medarc_verifiers/cli/process/pipeline.py | 104 +++++++++++++--- tests/test_cli/test_process_pipeline.py | 33 +++++ 3 files changed, 232 insertions(+), 51 deletions(-) diff --git a/medarc_verifiers/cli/process/metadata.py b/medarc_verifiers/cli/process/metadata.py index 99f5fd72..3d70fb02 100644 --- a/medarc_verifiers/cli/process/metadata.py +++ b/medarc_verifiers/cli/process/metadata.py @@ -60,12 +60,101 @@ class RunIdentity: output_env_id: str +@dataclass(frozen=True, slots=True) +class ResolvedRunIdentity: + """Selection-time identity that tolerates missing model ids.""" + + model_id: str | None + manifest_env_id: str + base_env_id: str + rollout_index: int | None + job_run_id: str + output_env_id: str + + +@dataclass(frozen=True, slots=True) +class _ResolvedMetadataContext: + raw_metadata: Mapping[str, Any] + manifest_env_id: str + metadata_env_id: str | None + base_env_id: str + rollout_index: int + model_id: str | None + metadata_model: str | None + env_args: Mapping[str, Any] + sampling_args: Mapping[str, Any] + num_examples: int | None + rollouts_per_example: int | None + + +def resolve_run_identity( + record: RunRecord, + *, + combine_rollouts: bool = True, +) -> ResolvedRunIdentity: + """Resolve a run identity for selection without requiring model_id.""" + context = _resolve_metadata_context(record, combine_rollouts=combine_rollouts) + resolved_rollout_index = ( + context.rollout_index + if context.rollout_index != 0 or context.manifest_env_id != context.base_env_id + else None + ) + return ResolvedRunIdentity( + model_id=context.model_id, + manifest_env_id=context.manifest_env_id, + base_env_id=context.base_env_id, + rollout_index=resolved_rollout_index, + job_run_id=record.manifest.job_run_id, + output_env_id=context.base_env_id or context.manifest_env_id or record.job_id, + ) + + def load_normalized_metadata( record: RunRecord, *, combine_rollouts: bool = True, ) -> NormalizedMetadata: """Merge manifest fields with metadata.json (when present).""" + context = _resolve_metadata_context(record, combine_rollouts=combine_rollouts) + if not context.model_id: + raise RuntimeError(format_missing_model_id_error(record)) + resolved_rollout_index = ( + context.rollout_index + if context.rollout_index != 0 or context.manifest_env_id != context.base_env_id + else None + ) + identity = RunIdentity( + model_id=context.model_id, + manifest_env_id=context.manifest_env_id, + base_env_id=context.base_env_id, + rollout_index=resolved_rollout_index, + job_run_id=record.manifest.job_run_id, + output_env_id=context.base_env_id or context.manifest_env_id or record.job_id, + ) + + return NormalizedMetadata( + identity=identity, + record=record, + metadata_path=record.metadata_path if record.has_metadata else None, + raw_metadata=context.raw_metadata, + manifest_env_id=context.manifest_env_id, + metadata_env_id=context.metadata_env_id, + base_env_id=context.base_env_id, + rollout_index=identity.rollout_index or 0, + model_id=identity.model_id, + metadata_model=context.metadata_model, + env_args=context.env_args, + sampling_args=context.sampling_args, + num_examples=context.num_examples, + rollouts_per_example=context.rollouts_per_example, + ) + + +def _resolve_metadata_context( + record: RunRecord, + *, + combine_rollouts: bool, +) -> _ResolvedMetadataContext: metadata_payload, raw_metadata = _load_metadata(record) metadata_env_id = metadata_payload.env_id if metadata_payload else None metadata_model = metadata_payload.model if metadata_payload else None @@ -77,7 +166,6 @@ def load_normalized_metadata( primary=record.sampling_args, fallback=metadata_payload.sampling_args if metadata_payload else None, ) - manifest_env_id = ( _extract_env_config_id(record.env_config) or record.manifest_env_id or metadata_env_id or record.job_id ) @@ -85,50 +173,31 @@ def load_normalized_metadata( manifest_env_id, combine_rollouts=combine_rollouts, ) - # If we didn't capture a rollout index from the manifest env id, - # try to derive it from the results directory name (common when - # manifests keep base env id, but the on-disk folder encodes the rollout). if rollout_index == 0 and record.results_dir_name: alt_index = extract_rollout_index(record.results_dir_name) if alt_index: rollout_index = alt_index - - model_id = record.model_id or metadata_model - if not model_id: - raise RuntimeError( - "Missing model_id for run " - f"(job_run_id={record.manifest.job_run_id}, job_id={record.job_id}, " - f"results_dir={record.results_dir}, manifest={record.manifest.manifest_path})" - ) - resolved_rollout_index = rollout_index if rollout_index != 0 or manifest_env_id != base_env_id else None - identity = RunIdentity( - model_id=model_id, - manifest_env_id=manifest_env_id, - base_env_id=base_env_id, - rollout_index=resolved_rollout_index, - job_run_id=record.manifest.job_run_id, - output_env_id=base_env_id or manifest_env_id or record.job_id, - ) - num_examples = record.num_examples or (metadata_payload.num_examples if metadata_payload else None) - rollouts_per_example = record.rollouts_per_example or ( - metadata_payload.rollouts_per_example if metadata_payload else None - ) - - return NormalizedMetadata( - identity=identity, - record=record, - metadata_path=record.metadata_path if record.has_metadata else None, + return _ResolvedMetadataContext( raw_metadata=raw_metadata, manifest_env_id=manifest_env_id, metadata_env_id=metadata_env_id, base_env_id=base_env_id, - rollout_index=identity.rollout_index or 0, - model_id=identity.model_id, + rollout_index=rollout_index, + model_id=record.model_id or metadata_model, metadata_model=metadata_model, env_args=env_args, sampling_args=sampling_args, - num_examples=num_examples, - rollouts_per_example=rollouts_per_example, + num_examples=record.num_examples or (metadata_payload.num_examples if metadata_payload else None), + rollouts_per_example=record.rollouts_per_example + or (metadata_payload.rollouts_per_example if metadata_payload else None), + ) + + +def format_missing_model_id_error(record: RunRecord) -> str: + return ( + "Missing model_id for run " + f"(job_run_id={record.manifest.job_run_id}, job_id={record.job_id}, " + f"results_dir={record.results_dir}, manifest={record.manifest.manifest_path})" ) @@ -193,4 +262,11 @@ def _extract_env_config_id(env_config: Mapping[str, Any] | None) -> str | None: return None -__all__ = ["NormalizedMetadata", "RunIdentity", "load_normalized_metadata"] +__all__ = [ + "NormalizedMetadata", + "ResolvedRunIdentity", + "RunIdentity", + "format_missing_model_id_error", + "load_normalized_metadata", + "resolve_run_identity", +] diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index c82a8cc7..2e1c46be 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -90,7 +90,26 @@ class PlannedWorkItem: identity: RunIdentity records: list[PlannedRecord] - env_export_config: EnvironmentExportConfig + + +@dataclass(frozen=True, slots=True) +class SelectionRecord: + """Selection-time record settings before full normalization.""" + + record: discovery.RunRecord + identity: metadata.ResolvedRunIdentity + combine_rollouts: bool + extra_columns: Sequence[str] + drop_columns: Sequence[str] + answer_column: str | None + + +@dataclass(frozen=True, slots=True) +class SelectionWorkItem: + """A selected work item before metadata normalization.""" + + identity: metadata.ResolvedRunIdentity + records: list[SelectionRecord] @dataclass(frozen=True, slots=True) @@ -287,8 +306,9 @@ def select_work_items( continue eligible_records.append(record) - planned_records = [_plan_record(record, env_export_map) for record in eligible_records] - work_items = _select_latest_work_items(planned_records) + planned_records = [_plan_selection_record(record, env_export_map) for record in eligible_records] + _raise_for_latest_invalid_selection(planned_records) + work_items = _materialize_work_items(_select_latest_work_items([record for record in planned_records if record.identity.model_id])) work_items, skipped_by_exclusion = _apply_exclusions( work_items, @@ -325,45 +345,97 @@ def _resolve_columns(env_columns: Sequence[str]) -> Sequence[str]: return tuple(str(column).strip() for column in env_columns if str(column).strip()) -def _plan_record( +def _plan_selection_record( record: discovery.RunRecord, env_export_map: Mapping[str, EnvironmentExportConfig], -) -> PlannedRecord: +) -> SelectionRecord: env_export = _resolve_env_export(record.manifest_env_id, env_export_map) - normalized = metadata.load_normalized_metadata(record, combine_rollouts=bool(env_export.combine_rollouts)) - return PlannedRecord( - normalized=normalized, + combine_rollouts = bool(env_export.combine_rollouts) + identity = metadata.resolve_run_identity(record, combine_rollouts=combine_rollouts) + return SelectionRecord( + record=record, + identity=identity, + combine_rollouts=combine_rollouts, extra_columns=_resolve_columns(env_export.extra_columns), drop_columns=_resolve_columns(env_export.drop_columns), answer_column=env_export.answer_column, ) -def _select_latest_work_items(records: Sequence[PlannedRecord]) -> list[PlannedWorkItem]: - grouped: dict[tuple[str, str], dict[str, list[PlannedRecord]]] = {} +def _raise_for_latest_invalid_selection(records: Sequence[SelectionRecord]) -> None: + latest_by_env: dict[str, SelectionRecord] = {} + for planned in records: + output_env_id = planned.identity.output_env_id + current = latest_by_env.get(output_env_id) + if current is None or _run_sort_key( + _source_updated_at(planned.record), + planned.record.manifest.job_run_id, + ) > _run_sort_key(_source_updated_at(current.record), current.record.manifest.job_run_id): + latest_by_env[output_env_id] = planned + + invalid_latest = [ + planned for planned in latest_by_env.values() if not planned.identity.model_id + ] + if not invalid_latest: + return + + failing = sorted( + invalid_latest, + key=lambda planned: ( + planned.identity.output_env_id, + _run_sort_key(_source_updated_at(planned.record), planned.record.manifest.job_run_id), + ), + )[-1] + raise RuntimeError(metadata.format_missing_model_id_error(failing.record)) + + +def _select_latest_work_items(records: Sequence[SelectionRecord]) -> list[SelectionWorkItem]: + grouped: dict[tuple[str, str], dict[str, list[SelectionRecord]]] = {} run_timestamps: dict[str, str] = {} for planned in records: - identity = planned.normalized.identity + identity = planned.identity + if not identity.model_id: + continue group_key = (identity.model_id, identity.output_env_id) grouped.setdefault(group_key, {}).setdefault(identity.job_run_id, []).append(planned) - run_timestamps.setdefault(identity.job_run_id, _source_updated_at(planned.normalized.record)) + run_timestamps.setdefault(identity.job_run_id, _source_updated_at(planned.record)) - selected: list[PlannedWorkItem] = [] + selected: list[SelectionWorkItem] = [] for _, run_groups in grouped.items(): latest_run_id = max(run_groups.keys(), key=lambda run_id: _run_sort_key(run_timestamps.get(run_id, ""), run_id)) latest_records = run_groups[latest_run_id] representative = latest_records[0] selected.append( - PlannedWorkItem( - identity=representative.normalized.identity, + SelectionWorkItem( + identity=representative.identity, records=list(latest_records), - env_export_config=EnvironmentExportConfig(), ) ) return selected +def _materialize_work_items(items: Sequence[SelectionWorkItem]) -> list[PlannedWorkItem]: + materialized: list[PlannedWorkItem] = [] + for item in items: + records: list[PlannedRecord] = [] + for selected in item.records: + normalized = metadata.load_normalized_metadata( + selected.record, + combine_rollouts=selected.combine_rollouts, + ) + records.append( + PlannedRecord( + normalized=normalized, + extra_columns=selected.extra_columns, + drop_columns=selected.drop_columns, + answer_column=selected.answer_column, + ) + ) + materialized.append(PlannedWorkItem(identity=records[0].normalized.identity, records=records)) + return materialized + + def _apply_exclusions( work_items: Sequence[PlannedWorkItem], *, diff --git a/tests/test_cli/test_process_pipeline.py b/tests/test_cli/test_process_pipeline.py index d3c4368e..122abbd2 100644 --- a/tests/test_cli/test_process_pipeline.py +++ b/tests/test_cli/test_process_pipeline.py @@ -147,6 +147,19 @@ def _write_run( return runs_dir +def _remove_model_id(tmp_path: Path, run_id: str) -> None: + manifest_path = tmp_path / "runs" / run_id / "run_manifest.json" + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + manifest["jobs"][0]["model_id"] = None + manifest["models"] = {} + manifest_path.write_text(json.dumps(manifest), encoding="utf-8") + + metadata_path = tmp_path / "runs" / run_id / "demo-job" / "metadata.json" + metadata = json.loads(metadata_path.read_text(encoding="utf-8")) + metadata.pop("model", None) + metadata_path.write_text(json.dumps(metadata), encoding="utf-8") + + def test_run_process_respects_env_export_defaults(tmp_path: Path) -> None: runs_dir = _setup_run(tmp_path) options = ProcessOptions( @@ -708,6 +721,26 @@ def test_process_ignores_invalid_superseded_run(tmp_path: Path) -> None: assert table.column("reward").to_pylist() == [0.9] +def test_process_ignores_superseded_run_missing_model_id(tmp_path: Path) -> None: + runs_dir = _write_run(tmp_path, run_id="run-1", updated_at="2024-01-01T00:00:00Z", reward=0.1) + _remove_model_id(tmp_path, "run-1") + _write_run(tmp_path, run_id="run-2", updated_at="2024-01-02T00:00:00Z", reward=0.9) + + result = run_process(ProcessOptions(runs_dir=runs_dir, output_dir=tmp_path / "processed", dry_run=False, max_workers=1)) + + table = pq.read_table(result.env_summaries[0].output_path) + assert table.column("reward").to_pylist() == [0.9] + + +def test_process_latest_missing_model_id_fails_clearly(tmp_path: Path) -> None: + runs_dir = _write_run(tmp_path, run_id="run-1", updated_at="2024-01-01T00:00:00Z", reward=0.1) + _write_run(tmp_path, run_id="run-2", updated_at="2024-01-02T00:00:00Z", reward=0.9) + _remove_model_id(tmp_path, "run-2") + + with pytest.raises(RuntimeError, match=r"Missing model_id for run \(job_run_id=run-2, job_id=demo-job,"): + run_process(ProcessOptions(runs_dir=runs_dir, output_dir=tmp_path / "processed", dry_run=False, max_workers=1)) + + def test_process_ignores_invalid_incomplete_run_by_default(tmp_path: Path) -> None: runs_dir = _write_run( tmp_path, From 10079ee01352853ad550f4d347d3ce228dca4b02 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sun, 1 Mar 2026 12:23:32 -0500 Subject: [PATCH 09/29] Use row identities in process aggregation --- medarc_verifiers/cli/process/aggregate.py | 13 ++++--- tests/test_cli/test_process_aggregate.py | 44 +++++++++++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/medarc_verifiers/cli/process/aggregate.py b/medarc_verifiers/cli/process/aggregate.py index 7d7b1b68..f6a25966 100644 --- a/medarc_verifiers/cli/process/aggregate.py +++ b/medarc_verifiers/cli/process/aggregate.py @@ -7,7 +7,7 @@ from typing import Any, Iterable, Mapping from medarc_verifiers.cli.process.metadata import RunIdentity -from medarc_verifiers.cli.process.rollout import derive_base_env_id +from medarc_verifiers.cli.process.rollout import extract_rollout_index logger = logging.getLogger(__name__) @@ -141,8 +141,8 @@ def _group_uses_rollout_suffixes(rows: list[Mapping[str, Any]], *, base_env_id: manifest_env_id = row.get("manifest_env_id") if not isinstance(manifest_env_id, str) or not manifest_env_id: continue - derived_base, _ = derive_base_env_id(manifest_env_id) - if derived_base and derived_base == base_env_id and manifest_env_id != derived_base: + row_base_env_id = str(row.get("base_env_id") or base_env_id or "") + if row_base_env_id and manifest_env_id != row_base_env_id: return True return False @@ -155,8 +155,11 @@ def _ensure_rollout_index_from_suffix(rows: list[Mapping[str, Any]], *, base_env manifest_env_id = row.get("manifest_env_id") if not isinstance(manifest_env_id, str) or not manifest_env_id: continue - derived_base, derived_index = derive_base_env_id(manifest_env_id) - if not derived_base or derived_base != base_env_id: + row_base_env_id = str(row.get("base_env_id") or base_env_id or "") + if not row_base_env_id or manifest_env_id == row_base_env_id: + continue + derived_index = extract_rollout_index(manifest_env_id) + if derived_index <= 0: continue try: row["rollout_index"] = derived_index diff --git a/tests/test_cli/test_process_aggregate.py b/tests/test_cli/test_process_aggregate.py index a8675ad3..b214ad18 100644 --- a/tests/test_cli/test_process_aggregate.py +++ b/tests/test_cli/test_process_aggregate.py @@ -1,5 +1,6 @@ from __future__ import annotations +from medarc_verifiers.cli.process.metadata import RunIdentity from medarc_verifiers.cli.process.aggregate import ( AggregatedEnvRows, aggregate_rows_by_env, @@ -117,3 +118,46 @@ def test_aggregate_rows_fills_missing_rollout_index_from_suffix() -> None: grouped = aggregate_rows_by_env(rows) assert sorted({row["rollout_index"] for row in grouped[0].rows}) == [0, 1] + + +def test_aggregate_rows_use_attached_identities_for_fake_rollouts() -> None: + rows = [ + { + "env_id": "env-a", + "base_env_id": "env-a", + "manifest_env_id": "env-a-rollout7", + "model_id": "model-a", + "job_run_id": "run-1", + }, + { + "env_id": "env-a", + "base_env_id": "env-a", + "manifest_env_id": "env-a-rollout3", + "model_id": "model-a", + "job_run_id": "run-2", + }, + ] + identities = [ + RunIdentity( + model_id="model-a", + manifest_env_id="env-a-rollout7", + base_env_id="env-a", + rollout_index=7, + job_run_id="run-1", + output_env_id="env-a", + ), + RunIdentity( + model_id="model-a", + manifest_env_id="env-a-rollout3", + base_env_id="env-a", + rollout_index=3, + job_run_id="run-2", + output_env_id="env-a", + ), + ] + + grouped = aggregate_rows_by_env(rows, identities=identities) + + assert len(grouped) == 1 + assert grouped[0].env_id == "env-a" + assert sorted({row["rollout_index"] for row in grouped[0].rows}) == [0, 1] From c4a3855c84b87b2836ce1fba3e59767f0ac669b9 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sun, 1 Mar 2026 12:42:23 -0500 Subject: [PATCH 10/29] check in files agents missed --- docs/medarc-eval-process.md | 105 ++++++-- docs/medarc-eval-winrate.md | 63 +++-- docs/medarc-eval.md | 2 +- docs/medarc-verifiers-architecture.md | 2 +- medarc_verifiers/cli/_constants.py | 2 +- medarc_verifiers/cli/_manifest_tools.py | 341 +++++++++++++++++++----- medarc_verifiers/cli/hf/sync.py | 26 +- tests/test_cli/test_main.py | 58 +++- tests/test_cli/test_manifest_tools.py | 100 +++++++ tests/test_cli/test_process_hf_sync.py | 58 ++++ 10 files changed, 640 insertions(+), 117 deletions(-) diff --git a/docs/medarc-eval-process.md b/docs/medarc-eval-process.md index 7e3ca3c6..a0fe66d2 100644 --- a/docs/medarc-eval-process.md +++ b/docs/medarc-eval-process.md @@ -28,12 +28,12 @@ medarc-eval process --dry-run ``` runs/processed/ ├── env_index.json # Dataset inventory for winrate/analysis -├── medqa/ -│ ├── gpt-4o.parquet -│ └── gpt-4o-mini.parquet -├── pubmedqa/ -│ ├── gpt-4o.parquet -│ └── gpt-4o-mini.parquet +├── gpt-4o/ +│ ├── medqa.parquet +│ └── pubmedqa.parquet +├── gpt-4o-mini/ +│ ├── medqa.parquet +│ └── pubmedqa.parquet └── ... ``` @@ -86,13 +86,19 @@ Store common options in a YAML file: ```yaml # process-config.yaml runs_dir: runs/raw -output_dir: runs/processed -max_workers: 8 -process_incomplete: false -exclude_datasets: - - med_dialog -exclude_models: - - deprecated-v1 + +process: + dir: processed + max_workers: 8 + process_incomplete: false + exclude_datasets: + - med_dialog + exclude_models: + - deprecated-v1 + +winrate: + enabled: true + dir: winrate ``` ```bash @@ -101,6 +107,35 @@ medarc-eval process --config process-config.yaml CLI flags override config values. +Supported config schema for `medarc-eval process`: + +- Top-level `runs_dir`: raw run root. +- Top-level `process:`: process-specific defaults. +- Optional top-level `winrate:`: embedded post-process winrate step. +- Optional top-level `hf:`: shared HF settings. For embedded winrate uploads, use `hf.winrate_dir`. + +Path shortcuts: + +- `process.dir` is shorthand for `process.output_dir`, resolved relative to the parent of `runs_dir`. +- `winrate.dir` is shorthand for the embedded winrate output directory, resolved under the processed output dir. + +Example: + +```yaml +runs_dir: runs/raw + +process: + dir: processed + max_workers: 8 + +winrate: + dir: scorecards + +hf: + repo: your-org/medical-benchmarks + winrate_dir: scorecards/latest +``` + ## Hugging Face Integration Sync processed datasets to/from the Hugging Face Hub: @@ -108,7 +143,8 @@ Sync processed datasets to/from the Hugging Face Hub: ```yaml # process-config.yaml runs_dir: runs/raw -output_dir: runs/processed +process: + dir: processed hf: repo: your-org/medical-benchmarks @@ -139,10 +175,10 @@ When `--hf-repo` is set, processed files are automatically uploaded after comple Process and compute win rates in one step: ```bash -medarc-eval process --winrate winrate-config.yaml +medarc-eval process --config process-config.yaml ``` -This runs `medarc-eval winrate` automatically after processing completes. +This runs `medarc-eval winrate` automatically after processing completes when the config contains a `winrate:` section. ## Example Workflows @@ -180,6 +216,23 @@ medarc-eval process # env_index.json tracks what's already processed ``` +### Replace Existing Outputs + +Rebuild existing outputs for specific models or datasets without using `--clean`: + +```bash +# Rebuild every processed dataset for one model +medarc-eval process --replace-model gpt-4o + +# Rebuild every model for one dataset +medarc-eval process --replace-env medqa + +# Rebuild only the intersection +medarc-eval process --replace-model gpt-4o --replace-env medqa +``` + +When both flags are present, processing only rebuilds outputs that match both filters. + ## Troubleshooting ### "No runs found" @@ -193,6 +246,26 @@ Check that: By default, only jobs with `completed` status are included. Use `--process-incomplete` to include partial results. +### Integrity-check failures for existing parquet files + +If processing stops with an error like: + +```text +Existing processed output ... has N parquet rows but env_index.json records M. +``` + +the local processed snapshot is inconsistent. Fix it by rebuilding the affected output: + +```bash +medarc-eval process --replace-model gpt-4o --replace-env medqa +``` + +Or rebuild everything: + +```bash +medarc-eval process --clean --yes +``` + ## Next Steps After processing, [compute win rates](medarc-eval-winrate.md) to compare model performance. diff --git a/docs/medarc-eval-winrate.md b/docs/medarc-eval-winrate.md index 2882c0a5..49d527c3 100644 --- a/docs/medarc-eval-winrate.md +++ b/docs/medarc-eval-winrate.md @@ -12,7 +12,7 @@ medarc-eval winrate medarc-eval winrate --list-models # Specify directories -medarc-eval winrate --processed-dir runs/processed --output-dir runs/winrate +medarc-eval winrate --processed-dir runs/processed --output-dir runs/processed/winrate ``` ## Prerequisites @@ -37,7 +37,7 @@ The final win rate aggregates across all benchmarks using configurable weighting ## Output Files ``` -runs/winrate/ +runs/processed/winrate/ ├── winrates-2026-01-14T12-00-00.json # Timestamped results ├── winrates-2026-01-14T12-00-00.csv # Spreadsheet-friendly ├── latest.json # Always points to newest @@ -95,30 +95,35 @@ The JSON output includes: ## Using a Config File ```yaml -# winrate-config.yaml -processed_dir: runs/processed -output_dir: runs/winrate - -# Calculation settings -missing_policy: neg-inf -epsilon: 1.0e-9 -min_common: 10 -weight_policy: ln - -# Model filtering -exclude_model: - - baseline-model - - deprecated-v1 - -# Dataset filtering -exclude_datasets: - - med_dialog +# process-config.yaml +runs_dir: runs/raw + +process: + dir: processed + +winrate: + dir: winrate + missing_policy: neg-inf + epsilon: 1.0e-9 + min_common: 10 + weight_policy: ln + exclude_model: + - baseline-model + - deprecated-v1 + exclude_datasets: + - med_dialog ``` ```bash -medarc-eval winrate --config winrate-config.yaml +medarc-eval winrate --config process-config.yaml ``` +Supported config schema for `medarc-eval winrate`: + +- Top-level `process:` can provide `dir` or `output_dir`; this becomes the default `processed_dir`. +- Top-level `winrate:` provides winrate-specific defaults. +- Top-level `hf:` provides shared HF settings. Use `hf.winrate_dir` to control where winrate artifacts upload inside the repo. + ## Example Workflows ### Compare Specific Models @@ -179,12 +184,16 @@ medarc-eval winrate \ ### Full Config with HF ```yaml -# winrate-config.yaml -processed_dir: runs/processed -output_dir: runs/winrate +# process-config.yaml +runs_dir: runs/raw -missing_policy: neg-inf -weight_policy: ln +process: + dir: processed + +winrate: + dir: winrate + missing_policy: neg-inf + weight_policy: ln hf: repo: your-org/processed-data # Pull processed from here; upload winrate here @@ -194,6 +203,8 @@ hf: private: true ``` +`hf.winrate_dir` and `--hf-winrate-dir` both set the path inside the HF repo where `latest.json`, `latest.csv`, and timestamped winrate outputs are uploaded. + ## Interpreting Results ### Win Rate Table (CSV) diff --git a/docs/medarc-eval.md b/docs/medarc-eval.md index a9e48a39..395d251f 100644 --- a/docs/medarc-eval.md +++ b/docs/medarc-eval.md @@ -27,7 +27,7 @@ medarc-eval winrate (bench or single) (process) (winrate) | | | v v v - runs/raw/ runs/processed/ runs/winrate/ + runs/raw/ runs/processed/ runs/processed/winrate/ ``` ## Commands diff --git a/docs/medarc-verifiers-architecture.md b/docs/medarc-verifiers-architecture.md index 2eadd5d7..568a8c20 100644 --- a/docs/medarc-verifiers-architecture.md +++ b/docs/medarc-verifiers-architecture.md @@ -16,7 +16,7 @@ At a high level, everything funnels into a three-stage workflow: 1. **Run** evals (single or batch) → `runs/raw//...` 2. **Process** raw outputs → `runs/processed//.parquet` + `env_index.json` -3. **Winrate** on processed outputs → `runs/winrate/*.json` and `*.csv` +3. **Winrate** on processed outputs → `runs/processed/winrate/*.json` and `*.csv` ## Important side effects (auto-installed patches) diff --git a/medarc_verifiers/cli/_constants.py b/medarc_verifiers/cli/_constants.py index 41a840dd..a466e47b 100644 --- a/medarc_verifiers/cli/_constants.py +++ b/medarc_verifiers/cli/_constants.py @@ -20,4 +20,4 @@ DEFAULT_ENV_CONFIG_ROOT = Path("configs") / "envs" DEFAULT_RUNS_RAW_DIR = Path("runs") / "raw" DEFAULT_PROCESSED_DIR = Path("runs") / "processed" -DEFAULT_WINRATE_DIR = Path("runs") / "winrate" +DEFAULT_WINRATE_DIR = DEFAULT_PROCESSED_DIR / "winrate" diff --git a/medarc_verifiers/cli/_manifest_tools.py b/medarc_verifiers/cli/_manifest_tools.py index 5ba9effc..836fd9d2 100644 --- a/medarc_verifiers/cli/_manifest_tools.py +++ b/medarc_verifiers/cli/_manifest_tools.py @@ -2,14 +2,16 @@ from __future__ import annotations +import os import json import logging +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path -from typing import Sequence +from typing import Any, Mapping, Sequence from medarc_verifiers.cli._manifest import MANIFEST_FILENAME, RunManifestModel, SUPPORTED_MANIFEST_VERSIONS -from medarc_verifiers.cli.utils.shared import count_jsonl_rows logger = logging.getLogger(__name__) @@ -41,90 +43,179 @@ def validate_manifests_in_runs(runs_dir: Path | str, *, strict: bool = False) -> if not runs_path.exists(): return ManifestValidationResult(manifests_checked=0, jobs_checked=0, issues=[]) - for run_dir in sorted(path for path in runs_path.iterdir() if path.is_dir()): - manifest_path = run_dir / MANIFEST_FILENAME - if not manifest_path.exists(): - continue - manifests_checked += 1 - try: - payload = json.loads(manifest_path.read_text(encoding="utf-8")) - except Exception as exc: # noqa: BLE001 - issues.append( + run_dirs = sorted(path for path in runs_path.iterdir() if path.is_dir()) + logger.info("Scanning manifests under %s...", runs_path) + + manifest_run_dirs = [run_dir for run_dir in run_dirs if (run_dir / MANIFEST_FILENAME).exists()] + if not manifest_run_dirs: + return ManifestValidationResult(manifests_checked=0, jobs_checked=0, issues=[]) + + max_workers = min(len(manifest_run_dirs), max(1, (os.cpu_count() or 4) * 4)) + if max_workers <= 1: + results = [_validate_run_dir(run_dir, strict=strict) for run_dir in manifest_run_dirs] + else: + results = list(_validate_run_dirs_parallel(manifest_run_dirs, strict=strict, max_workers=max_workers)) + + for result in results: + manifests_checked += result.manifests_checked + jobs_checked += result.jobs_checked + issues.extend(result.issues) + + issues.sort(key=lambda item: (item.run_id, item.job_id, item.kind, item.message)) + return ManifestValidationResult(manifests_checked=manifests_checked, jobs_checked=jobs_checked, issues=issues) + + +def _validate_run_dirs_parallel( + run_dirs: Sequence[Path], + *, + strict: bool, + max_workers: int, +) -> list[ManifestValidationResult]: + results: list[ManifestValidationResult] = [] + progress, task_id = _create_manifest_scan_progress(len(run_dirs)) + executor: ThreadPoolExecutor | None = None + futures = [] + try: + executor = ThreadPoolExecutor(max_workers=max_workers) + futures = [executor.submit(_validate_run_dir, run_dir, strict=strict) for run_dir in run_dirs] + if progress is not None and task_id is not None: + with progress: + for future in as_completed(futures): + results.append(future.result()) + progress.update(task_id, advance=1) + else: + for future in as_completed(futures): + results.append(future.result()) + except KeyboardInterrupt: + logger.warning("Manifest scanning interrupted; cancelling validation workers.") + for future in futures: + future.cancel() + if executor is not None: + executor.shutdown(wait=False, cancel_futures=True) + executor = None + raise + finally: + if executor is not None: + executor.shutdown(wait=True, cancel_futures=False) + return results + + +def _create_manifest_scan_progress(total: int) -> tuple[object | None, object | None]: + if total <= 0 or not sys.stderr.isatty(): + return None, None + try: + from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn, TimeElapsedColumn + + progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + transient=True, + ) + task_id = progress.add_task("Scanning manifests", total=total) + return progress, task_id + except Exception: + return None, None + + +def _validate_run_dir(run_dir: Path, *, strict: bool) -> ManifestValidationResult: + issues: list[ManifestValidationIssue] = [] + manifest_path = run_dir / MANIFEST_FILENAME + if not manifest_path.exists(): + return ManifestValidationResult(manifests_checked=0, jobs_checked=0, issues=[]) + + try: + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + except Exception as exc: # noqa: BLE001 + return ManifestValidationResult( + manifests_checked=1, + jobs_checked=0, + issues=[ ManifestValidationIssue( run_id=run_dir.name, job_id="", kind="error", message=f"Failed to parse manifest: {exc}", ) - ) - continue + ], + ) - version = payload.get("version") - if version not in SUPPORTED_MANIFEST_VERSIONS: - issues.append( + version = payload.get("version") + if version not in SUPPORTED_MANIFEST_VERSIONS: + return ManifestValidationResult( + manifests_checked=1, + jobs_checked=0, + issues=[ ManifestValidationIssue( run_id=run_dir.name, job_id="", kind="error", message=f"Unsupported manifest version: {version}", ) + ], + ) + + model = RunManifestModel.model_validate(payload) + artifacts_root = str(getattr(model, "artifacts_root", ".") or ".") + jobs_checked = 0 + + for entry in model.jobs: + jobs_checked += 1 + results_path, metadata_path, used_fallback = _resolve_job_artifact_paths( + run_dir=run_dir, + artifacts_root=artifacts_root, + job_id=entry.job_id, + results_relpath=entry.results_relpath, + metadata_relpath=entry.metadata_relpath, + ) + if used_fallback: + issues.append( + ManifestValidationIssue( + run_id=model.run_id, + job_id=entry.job_id, + kind="warning", + message="Manifest artifact path missing; fallback to run-relative job directory would be used.", + ) ) - continue - model = RunManifestModel.model_validate(payload) - artifacts_root = str(getattr(model, "artifacts_root", ".") or ".") - - for entry in model.jobs: - jobs_checked += 1 - results_path, metadata_path, used_fallback = _resolve_job_artifact_paths( - run_dir=run_dir, - artifacts_root=artifacts_root, - job_id=entry.job_id, - results_relpath=entry.results_relpath, - metadata_relpath=entry.metadata_relpath, - ) - if used_fallback: - issues.append( - ManifestValidationIssue( - run_id=model.run_id, - job_id=entry.job_id, - kind="warning", - message="Manifest artifact path missing; fallback to run-relative job directory would be used.", - ) + if not results_path.exists(): + kind = "error" if strict else "warning" + issues.append( + ManifestValidationIssue( + run_id=model.run_id, + job_id=entry.job_id, + kind=kind, + message=f"Missing results.jsonl at {results_path}", ) - if not results_path.exists(): + ) + if results_path.exists(): + for message in _quick_validate_results_jsonl( + results_path, + num_examples=entry.num_examples, + rollouts_per_example=entry.rollouts_per_example, + ): kind = "error" if strict else "warning" issues.append( ManifestValidationIssue( run_id=model.run_id, job_id=entry.job_id, kind=kind, - message=f"Missing results.jsonl at {results_path}", + message=message, ) ) - if entry.row_count is not None and results_path.exists(): - row_count = count_jsonl_rows(results_path) - if row_count is not None and int(row_count) != int(entry.row_count): - kind = "error" if strict else "warning" - issues.append( - ManifestValidationIssue( - run_id=model.run_id, - job_id=entry.job_id, - kind=kind, - message=f"row_count mismatch: manifest={entry.row_count} actual={row_count}", - ) - ) - # metadata is optional; only flag when declared explicitly in v3. - if entry.metadata_relpath and not metadata_path.exists(): - kind = "error" if strict else "warning" - issues.append( - ManifestValidationIssue( - run_id=model.run_id, - job_id=entry.job_id, - kind=kind, - message=f"Missing metadata.json at {metadata_path}", - ) + if entry.metadata_relpath and not metadata_path.exists(): + kind = "error" if strict else "warning" + issues.append( + ManifestValidationIssue( + run_id=model.run_id, + job_id=entry.job_id, + kind=kind, + message=f"Missing metadata.json at {metadata_path}", ) - return ManifestValidationResult(manifests_checked=manifests_checked, jobs_checked=jobs_checked, issues=issues) + ) + + return ManifestValidationResult(manifests_checked=1, jobs_checked=jobs_checked, issues=issues) def _resolve_job_artifact_paths( @@ -153,6 +244,132 @@ def _resolve_job_artifact_paths( return results_path, metadata_path, used_fallback +def _quick_validate_results_jsonl( + path: Path, + *, + num_examples: int | None, + rollouts_per_example: int | None, +) -> list[str]: + first_line = _read_first_nonempty_line(path) + last_line = _read_last_nonempty_line(path) + if first_line is None or last_line is None: + return [f"results.jsonl at {path} is empty"] + + issues: list[str] = [] + first_payload = _decode_probe_line(first_line, path=path, position="first", issues=issues) + last_payload = _decode_probe_line(last_line, path=path, position="last", issues=issues) + if first_payload is None or last_payload is None: + return issues + + for position, payload in (("first", first_payload), ("last", last_payload)): + if "example_id" not in payload: + issues.append(f"{position} JSONL row in {path} is missing example_id") + _validate_rollout_index( + first_payload, + path=path, + position="first", + rollouts_per_example=rollouts_per_example, + issues=issues, + ) + _validate_rollout_index( + last_payload, + path=path, + position="last", + rollouts_per_example=rollouts_per_example, + issues=issues, + ) + + return issues + + +def _decode_probe_line( + raw_line: str, + *, + path: Path, + position: str, + issues: list[str], +) -> Mapping[str, Any] | None: + try: + payload = json.loads(raw_line) + except json.JSONDecodeError as exc: + issues.append(f"failed to parse {position} JSONL row in {path}: {exc.msg}") + return None + if not isinstance(payload, Mapping): + issues.append(f"{position} JSONL row in {path} is not a JSON object") + return None + return payload + + +def _read_first_nonempty_line(path: Path) -> str | None: + with path.open("r", encoding="utf-8") as handle: + for line in handle: + candidate = line.strip() + if candidate: + return candidate + return None + + +def _read_last_nonempty_line(path: Path) -> str | None: + with path.open("rb") as handle: + handle.seek(0, os.SEEK_END) + file_size = handle.tell() + if file_size <= 0: + return None + + chunk_size = 8192 + buffer = b"" + position = file_size + while position > 0: + read_size = min(chunk_size, position) + position -= read_size + handle.seek(position) + buffer = handle.read(read_size) + buffer + lines = buffer.splitlines() + for raw_line in reversed(lines): + candidate = raw_line.strip() + if candidate: + return candidate.decode("utf-8") + return None + + +def _validate_rollout_index( + payload: Mapping[str, Any], + *, + path: Path, + position: str, + rollouts_per_example: int | None, + issues: list[str], +) -> None: + rollout_index = _coerce_int(payload.get("rollout_index")) + if rollout_index is None: + return + if rollout_index < 0: + issues.append(f"{position} JSONL row in {path} has negative rollout_index={payload.get('rollout_index')!r}") + return + if rollouts_per_example and rollout_index >= rollouts_per_example: + issues.append( + f"{position} JSONL row in {path} has out-of-range rollout_index={payload.get('rollout_index')!r}; " + f"expected < {rollouts_per_example}" + ) + + +def _coerce_int(value: Any) -> int | None: + if value is None or isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + if value.is_integer(): + return int(value) + return None + if isinstance(value, str): + try: + return int(value.strip()) + except ValueError: + return None + return None + + def format_validation_issues(issues: Sequence[ManifestValidationIssue]) -> list[str]: lines: list[str] = [] for issue in issues: diff --git a/medarc_verifiers/cli/hf/sync.py b/medarc_verifiers/cli/hf/sync.py index 48ed64a5..0ae54fb7 100644 --- a/medarc_verifiers/cli/hf/sync.py +++ b/medarc_verifiers/cli/hf/sync.py @@ -172,11 +172,14 @@ def sync_files_to_hub( is_tty: bool = False, assume_yes: bool = False, prompt_func: Callable[[str], str] | None = None, -) -> None: - """Upload explicit file paths from output_dir to a HF dataset repo.""" +) -> bool: + """Upload explicit file paths from output_dir to a HF dataset repo. + + Returns False only when upload is skipped because repo creation was declined. + """ if not repo_id: logger.debug("HF sync skipped: no repo_id provided.") - return + return True file_list = [] for path in files: rel_path = Path(path).as_posix() if not isinstance(path, str) else Path(path).as_posix() @@ -184,10 +187,10 @@ def sync_files_to_hub( file_list.append(rel_path) if not file_list: logger.debug("HF sync skipped: no files provided.") - return + return True if dry_run: logger.debug("HF sync dry-run; skipping push.") - return + return True try: from huggingface_hub import CommitOperationAdd, HfApi # type: ignore[import-not-found] @@ -241,9 +244,11 @@ def sync_files_to_hub( prompt_func=prompt_func, ) if not should_create: - raise RuntimeError( - f"HF dataset repo '{repo_id}' not found. Create it on the Hub or re-run with --yes to allow creation." - ) from exc + logger.warning( + "HF dataset repo '%s' not found; skipping upload because repo creation was declined.", + repo_id, + ) + return False api.create_repo( repo_id=repo_id, repo_type="dataset", @@ -269,6 +274,7 @@ def sync_files_to_hub( delay, ) time.sleep(delay) + return True def _normalize_repo_path_prefix(value: str | None) -> str | None: @@ -345,7 +351,7 @@ def sync_to_hub( ) message = f"Update {summary.total_files} file(s) from medarc-eval process" - sync_files_to_hub( + uploaded = sync_files_to_hub( repo_id=config.repo_id, output_dir=output_dir, files=files, @@ -361,6 +367,8 @@ def sync_to_hub( assume_yes=assume_yes, prompt_func=prompt_func, ) + if not uploaded: + return None return summary diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index 8eb636cf..f5f9c5bd 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -1859,6 +1859,7 @@ def test_winrate_cli_applies_config_defaults(monkeypatch: pytest.MonkeyPatch, tm repo: medarc/demo branch: main token: secret-token + winrate_dir: scorecards/latest """, encoding="utf-8", ) @@ -1910,7 +1911,7 @@ def fake_sync_files_to_hub(**kwargs): upload = captured.get("upload") assert upload is not None assert upload["repo_id"] == "medarc/demo" - assert upload["path_in_repo_prefix"] == "winrate" + assert upload["path_in_repo_prefix"] == "scorecards/latest" exit_code = main.main( [ @@ -1928,6 +1929,61 @@ def fake_sync_files_to_hub(**kwargs): assert cfg.epsilon == pytest.approx(0.5) +def test_expand_embedded_process_config_promotes_process_section() -> None: + payload = { + "runs_dir": "runs/raw", + "process": { + "dir": "processed", + "max_workers": 8, + "replace_models": ["model-a"], + }, + "winrate": {"dir": "scorecards"}, + } + + expanded = main._expand_embedded_pipeline_config(payload, mode="process") + + assert expanded["runs_dir"] == "runs/raw" + assert expanded["output_dir"] == Path("runs/processed") + assert expanded["max_workers"] == 8 + assert expanded["replace_models"] == ["model-a"] + assert "winrate" not in expanded + assert payload["process"]["dir"] == "processed" + + +def test_expand_embedded_winrate_config_resolves_relative_dirs() -> None: + payload = { + "runs_dir": "artifacts/raw", + "process": {"dir": "processed"}, + "winrate": { + "dir": "scorecards", + "missing_policy": "zero", + "hf_winrate_dir": "uploads/winrate", + }, + } + + expanded = main._expand_embedded_pipeline_config(payload, mode="winrate") + + assert expanded["processed_dir"] == Path("artifacts/processed") + assert expanded["output_dir"] == Path("artifacts/processed/scorecards") + assert expanded["missing_policy"] == "zero" + assert expanded["hf_winrate_dir"] == "uploads/winrate" + + +def test_expand_embedded_winrate_config_keeps_explicit_dirs() -> None: + payload = { + "processed_dir": "custom/processed", + "output_dir": "custom/winrate", + "runs_dir": "artifacts/raw", + "process": {"dir": "processed"}, + "winrate": {"dir": "scorecards"}, + } + + expanded = main._expand_embedded_pipeline_config(payload, mode="winrate") + + assert expanded["processed_dir"] == "custom/processed" + assert expanded["output_dir"] == "custom/winrate" + + def test_process_cli_requires_winrate_config_path(tmp_path: Path) -> None: missing_path = tmp_path / "missing.yaml" with pytest.raises(SystemExit): diff --git a/tests/test_cli/test_manifest_tools.py b/tests/test_cli/test_manifest_tools.py index 3a813b1d..4274fb1e 100644 --- a/tests/test_cli/test_manifest_tools.py +++ b/tests/test_cli/test_manifest_tools.py @@ -11,6 +11,43 @@ def _write_json(path: Path, payload: dict) -> None: path.write_text(json.dumps(payload), encoding="utf-8") +def _write_manifest( + run_dir: Path, + *, + num_examples: int | None = None, + rollouts_per_example: int | None = None, +) -> None: + payload = { + "version": 3, + "run_id": "demo-run", + "name": "demo", + "config_source": "cfg.yaml", + "config_checksum": "x", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + "artifacts_root": ".", + "models": {}, + "env_templates": {}, + "jobs": [ + { + "job_id": "job-1", + "model_id": "m", + "env_id": "e", + "env_template_id": "e:t", + "env_variant_id": "e", + "env_args": {}, + "results_relpath": "job-1/results.jsonl", + "metadata_relpath": "job-1/metadata.json", + "status": "completed", + "num_examples": num_examples, + "rollouts_per_example": rollouts_per_example, + } + ], + "summary": {"total": 1, "completed": 1, "pending": 0, "failed": 0, "running": 0, "skipped": 0}, + } + _write_json(run_dir / "run_manifest.json", payload) + + def test_validate_manifests_reports_broken_paths(tmp_path: Path) -> None: runs_dir = tmp_path / "runs" / "raw" run_dir = runs_dir / "demo-run" @@ -49,3 +86,66 @@ def test_validate_manifests_reports_broken_paths(tmp_path: Path) -> None: assert result.manifests_checked == 1 assert result.jobs_checked == 1 assert any(issue.kind == "warning" and "fallback" in issue.message.lower() for issue in result.issues) + + +def test_validate_manifests_accepts_partial_rollout_file(tmp_path: Path) -> None: + runs_dir = tmp_path / "runs" / "raw" + run_dir = runs_dir / "demo-run" + job_dir = run_dir / "job-1" + _write_json(job_dir / "metadata.json", {"env_id": "demo"}) + (job_dir / "results.jsonl").write_text( + "\n".join( + [ + '{"example_id": 1, "rollout_index": 0}', + '{"example_id": 2, "rollout_index": 0}', + '{"example_id": 1, "rollout_index": 1}', + '{"example_id": 2, "rollout_index": 1}', + '{"example_id": 1, "rollout_index": 2}', + ] + ) + + "\n", + encoding="utf-8", + ) + _write_manifest(run_dir, num_examples=2, rollouts_per_example=3) + + result = validate_manifests_in_runs(runs_dir, strict=False) + + assert result.manifests_checked == 1 + assert result.jobs_checked == 1 + assert result.issues == [] + + +def test_validate_manifests_reports_out_of_range_rollout_index(tmp_path: Path) -> None: + runs_dir = tmp_path / "runs" / "raw" + run_dir = runs_dir / "demo-run" + job_dir = run_dir / "job-1" + _write_json(job_dir / "metadata.json", {"env_id": "demo"}) + (job_dir / "results.jsonl").write_text( + "\n".join( + [ + '{"example_id": 1, "rollout_index": 0}', + '{"example_id": 2, "rollout_index": 0}', + '{"example_id": 1, "rollout_index": 3}', + ] + ) + + "\n", + encoding="utf-8", + ) + _write_manifest(run_dir, num_examples=2, rollouts_per_example=3) + + result = validate_manifests_in_runs(runs_dir, strict=False) + + assert any("out-of-range rollout_index" in issue.message for issue in result.issues) + + +def test_validate_manifests_reports_malformed_last_jsonl_row(tmp_path: Path) -> None: + runs_dir = tmp_path / "runs" / "raw" + run_dir = runs_dir / "demo-run" + job_dir = run_dir / "job-1" + _write_json(job_dir / "metadata.json", {"env_id": "demo"}) + (job_dir / "results.jsonl").write_text('{"example_id": 1}\n{"example_id": ', encoding="utf-8") + _write_manifest(run_dir, num_examples=1, rollouts_per_example=1) + + result = validate_manifests_in_runs(runs_dir, strict=False) + + assert any("failed to parse last JSONL row" in issue.message for issue in result.issues) diff --git a/tests/test_cli/test_process_hf_sync.py b/tests/test_cli/test_process_hf_sync.py index 0de114a4..d677f1aa 100644 --- a/tests/test_cli/test_process_hf_sync.py +++ b/tests/test_cli/test_process_hf_sync.py @@ -174,3 +174,61 @@ def create_commit(self, **_kwargs: object) -> None: assert captured.get("create_repo") is not None assert captured.get("create_commit") is True assert captured["create_commit_calls"] == 2 + + +def test_sync_files_to_hub_skips_when_repo_creation_declined( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + (tmp_path / "artifact.json").write_text("{}", encoding="utf-8") + + captured: dict[str, object] = {"create_commit_calls": 0} + + class FakeResponse: + status_code = 404 + + class FakeRepoNotFound(Exception): + def __init__(self) -> None: + super().__init__("Repository Not Found") + self.response = FakeResponse() + + class FakeOp: + def __init__(self, *args: object, **kwargs: object) -> None: + captured["op"] = (args, kwargs) + + class FakeApi: + def __init__(self, token: str | None = None) -> None: + captured["token"] = token + + def create_repo(self, **kwargs: object) -> None: + captured["create_repo"] = kwargs + + def create_commit(self, **_kwargs: object) -> None: + captured["create_commit_calls"] = int(captured["create_commit_calls"]) + 1 + raise FakeRepoNotFound() + + import sys + import types + + fake_module = types.SimpleNamespace(CommitOperationAdd=FakeOp, HfApi=FakeApi) + monkeypatch.setitem(sys.modules, "huggingface_hub", fake_module) + + with caplog.at_level("WARNING"): + uploaded = hf_sync.sync_files_to_hub( + repo_id="local/missing", + output_dir=tmp_path, + files=["artifact.json"], + token="secret-token", + private=True, + message="msg", + dry_run=False, + is_tty=True, + assume_yes=False, + prompt_func=lambda _prompt: "n", + ) + + assert uploaded is False + assert captured["create_commit_calls"] == 1 + assert captured.get("create_repo") is None + assert "skipping upload because repo creation was declined" in caplog.text From 29fe167807d5201bc72daa5e876fa2316cf6683e Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sun, 1 Mar 2026 13:17:41 -0500 Subject: [PATCH 11/29] Gate process runs by manifest missing pct --- docs/medarc-eval-process.md | 24 ++++++--- medarc_verifiers/cli/main.py | 29 ++++++----- medarc_verifiers/cli/process/__init__.py | 4 +- medarc_verifiers/cli/process/pipeline.py | 65 ++++++++++++++++-------- tests/test_cli/test_main.py | 60 ++++++++++++++++++++++ tests/test_cli/test_process_discovery.py | 61 +++++++++++++--------- 6 files changed, 175 insertions(+), 68 deletions(-) diff --git a/docs/medarc-eval-process.md b/docs/medarc-eval-process.md index a0fe66d2..6538ce30 100644 --- a/docs/medarc-eval-process.md +++ b/docs/medarc-eval-process.md @@ -53,16 +53,28 @@ runs/processed/ ### By Completion Status -By default, only completed jobs are processed: +By default, `medarc-eval process` only selects runs whose manifest status is one of: -```bash -# Include incomplete runs -medarc-eval process --process-incomplete +- `completed` +- `succeeded` +- `success` + +To override that default, pass one or more explicit status filters: -# Filter by specific status +```bash medarc-eval process --status completed --status failed ``` +You can also gate partially complete runs by their manifest summary totals: + +```bash +# Default tolerance is 2.5 percent missing +medarc-eval process --max-run-missing-pct 2.5 + +# Effectively disable the gate +medarc-eval process --max-run-missing-pct 100 +``` + ### Latest Runs Only When multiple runs exist for the same (model, environment) pair, processing uses the latest by default. @@ -90,7 +102,7 @@ runs_dir: runs/raw process: dir: processed max_workers: 8 - process_incomplete: false + max_run_missing_pct: 2.5 exclude_datasets: - med_dialog exclude_models: diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index d79b601b..f413c92e 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -35,7 +35,7 @@ from medarc_verifiers.cli._schemas import EnvironmentConfigSchema, EnvironmentExportConfig from medarc_verifiers.cli._single_run import run_single_mode from medarc_verifiers.cli.hf import HFSyncConfig, sync_files_to_hub -from medarc_verifiers.cli.process import ProcessOptions, ProcessResult, run_process +from medarc_verifiers.cli.process import PROCESS_DEFAULT_STATUS_FILTER, ProcessOptions, ProcessResult, run_process from medarc_verifiers.cli.utils.config_io import load_mapping_file from medarc_verifiers.cli.utils.overrides import build_cli_override from medarc_verifiers.cli.utils.shared import ( @@ -298,11 +298,10 @@ def build_process_parser() -> argparse.ArgumentParser: help="Rebuild existing processed outputs for these env ids (repeatable; comma-separated values allowed).", ) parser.add_argument( - "--process-incomplete", - dest="process_incomplete", - action="store_true", + "--max-run-missing-pct", + type=float, default=None, - help="Include runs where run_manifest.json summary has completed < total.", + help="Skip run directories whose manifest-level missing percentage exceeds this threshold (default: 2.5).", ) parser.add_argument( "--winrate", @@ -610,6 +609,10 @@ def _validate_process_args( normalize_dataset_ids(args.exclude_dataset, label="process exclude dataset") if args.exclude_model: normalize_model_ids(args.exclude_model, label="process exclude model") + if args.max_run_missing_pct is not None: + value = float(args.max_run_missing_pct) + if value < 0: + parser.error("--max-run-missing-pct must be non-negative.") except ValueError as exc: parser.error(str(exc)) @@ -625,15 +628,16 @@ def _build_process_options(args: argparse.Namespace) -> ProcessOptions: retries=args.hf_retries, max_files_per_commit=args.hf_max_files_per_commit, ) + status_filter = tuple(args.status) if args.status is not None else PROCESS_DEFAULT_STATUS_FILTER processed_with_args = { - "status": args.status or [], + "status": list(status_filter), + "max_run_missing_pct": float(args.max_run_missing_pct), "exclude_datasets": args.exclude_dataset or [], "exclude_models": args.exclude_model or [], "replace_models": args.replace_model or [], "replace_envs": args.replace_env or [], "dry_run": bool(args.dry_run), "clean": bool(args.clean), - "only_complete_runs": not bool(args.process_incomplete), "hf_repo": args.hf_repo, "hf_pull_policy": args.hf_pull_policy, "hf_request_timeout": args.hf_request_timeout, @@ -650,8 +654,8 @@ def _build_process_options(args: argparse.Namespace) -> ProcessOptions: replace_envs=tuple(args.replace_env or ()), processed_at=args.processed_at, processed_with_args=processed_with_args, - status_filter=args.status or (), - only_complete_runs=not bool(args.process_incomplete), + status_filter=status_filter, + max_run_missing_pct=float(args.max_run_missing_pct), dry_run=bool(args.dry_run), clean=bool(args.clean), assume_yes=bool(args.yes), @@ -844,8 +848,8 @@ def _merge_process_section( "dry_run": "dry_run", "clean": "clean", "yes": "yes", - "process_incomplete": "process_incomplete", "max_workers": "max_workers", + "max_run_missing_pct": "max_run_missing_pct", } ) for key, target in key_map.items(): @@ -1013,7 +1017,6 @@ def _load_and_apply_config( "dry_run": "dry_run", "clean": "clean", "yes": "yes", - "process_incomplete": "process_incomplete", "hf_private": "hf_private", }, "winrate": {"hf_processed_pull": "hf_processed_pull", "hf_private": "hf_private"}, @@ -1027,7 +1030,7 @@ def _load_and_apply_config( "winrate": {"min_common": "min_common", "weight_cap": "weight_cap"}, }[mode] float_fields = { - "process": {"hf_request_timeout": "hf_request_timeout"}, + "process": {"hf_request_timeout": "hf_request_timeout", "max_run_missing_pct": "max_run_missing_pct"}, "winrate": {"epsilon": "epsilon"}, }[mode] repeatable_fields = { @@ -1111,7 +1114,7 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", "dry_run": False, "clean": False, "yes": False, - "process_incomplete": False, + "max_run_missing_pct": 2.5, "exclude_dataset": [], "exclude_model": [], "replace_model": [], diff --git a/medarc_verifiers/cli/process/__init__.py b/medarc_verifiers/cli/process/__init__.py index 6c20133e..35cb601d 100644 --- a/medarc_verifiers/cli/process/__init__.py +++ b/medarc_verifiers/cli/process/__init__.py @@ -1,5 +1,5 @@ """Process command pipeline for exporting MedARC runs.""" -from .pipeline import ProcessOptions, ProcessResult, run_process +from .pipeline import PROCESS_DEFAULT_STATUS_FILTER, ProcessOptions, ProcessResult, run_process -__all__ = ["ProcessOptions", "ProcessResult", "run_process"] +__all__ = ["PROCESS_DEFAULT_STATUS_FILTER", "ProcessOptions", "ProcessResult", "run_process"] diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index 2e1c46be..1a94453e 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -27,6 +27,7 @@ ) logger = logging.getLogger(__name__) +PROCESS_DEFAULT_STATUS_FILTER: tuple[str, ...] = ("completed", "succeeded", "success") @dataclass(slots=True) @@ -35,7 +36,7 @@ class ProcessOptions: runs_dir: Path output_dir: Path - only_complete_runs: bool = True + max_run_missing_pct: float = 2.5 exclude_datasets: Sequence[str] = field(default_factory=tuple) exclude_models: Sequence[str] = field(default_factory=tuple) replace_models: Sequence[str] = field(default_factory=tuple) @@ -53,6 +54,7 @@ class ProcessOptions: def __post_init__(self) -> None: self.runs_dir = Path(self.runs_dir) self.output_dir = Path(self.output_dir) + self.max_run_missing_pct = float(self.max_run_missing_pct) self.max_workers = max(1, int(self.max_workers)) if not self.processed_at: self.processed_at = datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z") @@ -117,7 +119,7 @@ class SelectionResult: """Complete output of the selection phase.""" work_items: list[PlannedWorkItem] - skipped_incomplete: int + skipped_by_missing_pct: int skipped_by_delta: int skipped_by_exclusion: int total_discovered: int @@ -159,10 +161,10 @@ def _run_pipeline() -> ProcessResult: _print_records_table( discovered, selected_records, - options.only_complete_runs, + options.max_run_missing_pct, exclude_datasets=options.exclude_datasets, exclude_models=options.exclude_models, - skipped_incomplete=selection.skipped_incomplete, + skipped_by_missing_pct=selection.skipped_by_missing_pct, skipped_by_delta=selection.skipped_by_delta, skipped_by_exclusion=selection.skipped_by_exclusion, ) @@ -299,10 +301,10 @@ def select_work_items( ) -> SelectionResult: """Filter discovered runs down to selected work items before row loading begins.""" eligible_records: list[discovery.RunRecord] = [] - skipped_incomplete = 0 + skipped_by_missing_pct = 0 for record in discovered: - if options.only_complete_runs and not _manifest_is_complete(record.manifest): - skipped_incomplete += 1 + if not _manifest_within_missing_pct(record.manifest, options.max_run_missing_pct): + skipped_by_missing_pct += 1 continue eligible_records.append(record) @@ -320,7 +322,7 @@ def select_work_items( return SelectionResult( work_items=work_items, - skipped_incomplete=skipped_incomplete, + skipped_by_missing_pct=skipped_by_missing_pct, skipped_by_delta=skipped_by_delta, skipped_by_exclusion=skipped_by_exclusion, total_discovered=len(discovered), @@ -571,11 +573,11 @@ def _validate_existing_output_integrity( def _print_records_table( discovered: Sequence[discovery.RunRecord], selected: Sequence[discovery.RunRecord], - only_complete_runs: bool, + max_run_missing_pct: float, *, exclude_datasets: Sequence[str] = (), exclude_models: Sequence[str] = (), - skipped_incomplete: int = 0, + skipped_by_missing_pct: int = 0, skipped_by_delta: int = 0, skipped_by_exclusion: int = 0, ) -> None: @@ -585,7 +587,7 @@ def _print_records_table( eligible_discovered = [ rec for rec in discovered - if (not only_complete_runs or _manifest_is_complete(rec.manifest)) + if _manifest_within_missing_pct(rec.manifest, max_run_missing_pct) and not (exclude_set and _record_is_excluded(rec, exclude_set)) and not (exclude_model_set and _record_model_is_excluded(rec, exclude_model_set)) ] @@ -612,16 +614,15 @@ def _print_records_table( from rich.markup import escape from rich.table import Table except Exception: - suffix = " (complete runs only)" if only_complete_runs else "" logger.info( - "Processing %d job(s) across %d model(s)%s (found %d job(s) across %d model(s)); " - "skipped incomplete=%d excluded=%d existing=%d.", + "Processing %d job(s) across %d model(s) (max_run_missing_pct=%s; found %d job(s) across %d model(s)); " + "skipped by missing pct=%d excluded=%d existing=%d.", selected_jobs_total, len(selected_models), - suffix, + _format_missing_pct(max_run_missing_pct), discovered_jobs_total, len(models), - skipped_incomplete, + skipped_by_missing_pct, skipped_by_exclusion, skipped_by_delta, ) @@ -633,11 +634,12 @@ def _print_records_table( return console = Console() - title = f"Processing {selected_jobs_total} job(s) across {len(selected_models)} model(s)" - if only_complete_runs: - title += " (complete runs only)" + title = ( + f"Processing {selected_jobs_total} job(s) across {len(selected_models)} model(s) " + f"[dim](max_run_missing_pct={_format_missing_pct(max_run_missing_pct)})[/dim]" + ) title += ( - f" [dim](found {discovered_jobs_total} eligible job(s); skipped incomplete={skipped_incomplete}, " + f" [dim](found {discovered_jobs_total} eligible job(s); skipped by missing pct={skipped_by_missing_pct}, " f"excluded={skipped_by_exclusion}, existing={skipped_by_delta})[/dim]" ) table = Table(title=title, show_header=True, header_style="bold cyan", caption=None) @@ -654,8 +656,26 @@ def _print_records_table( console.print(table) -def _manifest_is_complete(manifest: discovery.RunManifestInfo) -> bool: - return not (manifest.summary_total_known and manifest.summary_completed != manifest.summary_total) +def _manifest_missing_pct(manifest: discovery.RunManifestInfo) -> float | None: + if not manifest.summary_total_known: + return None + total = int(manifest.summary_total or 0) + if total <= 0: + return None + completed = max(int(manifest.summary_completed or 0), 0) + missing = max(total - completed, 0) + return 100.0 * missing / total + + +def _manifest_within_missing_pct(manifest: discovery.RunManifestInfo, max_missing_pct: float) -> bool: + missing_pct = _manifest_missing_pct(manifest) + if missing_pct is None: + return True + return missing_pct <= float(max_missing_pct) + + +def _format_missing_pct(value: float) -> str: + return f"{float(value):g}" def _record_is_excluded(record: discovery.RunRecord, exclude_set: set[str]) -> bool: @@ -722,6 +742,7 @@ def _run_sort_key(timestamp: str, job_run_id: str) -> tuple[int, datetime, str]: __all__ = [ + "PROCESS_DEFAULT_STATUS_FILTER", "PlannedRecord", "PlannedWorkItem", "ProcessOptions", diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index f5f9c5bd..b56c0f90 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -2000,6 +2000,65 @@ def test_process_cli_requires_winrate_config_path(tmp_path: Path) -> None: ) +def test_process_cli_defaults_status_filter_to_completed(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + captured: dict[str, Any] = {} + + def fake_run_process(options, env_export_map): + captured["options"] = options + return ProcessResult(records_processed=0, rows_processed=0, env_groups=[], env_summaries=[], hf_summary=None) + + monkeypatch.setattr(main, "run_process", fake_run_process) + + exit_code = main.main( + [ + "process", + "--runs-dir", + str(tmp_path / "runs"), + "--output-dir", + str(tmp_path / "processed"), + "--dry-run", + ] + ) + + assert exit_code == 0 + options = captured["options"] + assert options.status_filter == ("completed", "succeeded", "success") + assert options.processed_with_args["status"] == ["completed", "succeeded", "success"] + assert options.max_run_missing_pct == pytest.approx(2.5) + assert options.processed_with_args["max_run_missing_pct"] == pytest.approx(2.5) + + +def test_process_cli_uses_explicit_status_filter(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + captured: dict[str, Any] = {} + + def fake_run_process(options, env_export_map): + captured["options"] = options + return ProcessResult(records_processed=0, rows_processed=0, env_groups=[], env_summaries=[], hf_summary=None) + + monkeypatch.setattr(main, "run_process", fake_run_process) + + exit_code = main.main( + [ + "process", + "--runs-dir", + str(tmp_path / "runs"), + "--output-dir", + str(tmp_path / "processed"), + "--status", + "failed", + "--max-run-missing-pct", + "100", + "--dry-run", + ] + ) + + assert exit_code == 0 + options = captured["options"] + assert options.status_filter == ("failed",) + assert options.processed_with_args["status"] == ["failed"] + assert options.max_run_missing_pct == pytest.approx(100.0) + + def test_process_cli_runs_embedded_winrate_post_step(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: cfg_path = tmp_path / "process.yaml" cfg_path.write_text( @@ -2205,6 +2264,7 @@ def test_process_cli_rejects_include_prompt_completion(tmp_path: Path) -> None: ("field", "value"), [ ("max_workers", "not-an-int"), + ("max_run_missing_pct", "not-a-float"), ("hf_request_timeout", "not-a-float"), ("hf_retries", "not-an-int"), ("hf_max_files_per_commit", "not-an-int"), diff --git a/tests/test_cli/test_process_discovery.py b/tests/test_cli/test_process_discovery.py index 5b532fcd..7ffa93db 100644 --- a/tests/test_cli/test_process_discovery.py +++ b/tests/test_cli/test_process_discovery.py @@ -3,7 +3,8 @@ import json from pathlib import Path -from medarc_verifiers.cli.process.discovery import discover_run_records +from medarc_verifiers.cli.process.discovery import RunManifestInfo, discover_run_records +from medarc_verifiers.cli.process.pipeline import _manifest_missing_pct, _manifest_within_missing_pct def _write_json(path: Path, payload: dict) -> None: @@ -33,6 +34,23 @@ def _base_manifest( } +def _manifest_info(*, completed: int, total: int, total_known: bool) -> RunManifestInfo: + return RunManifestInfo( + job_run_id="job-run-123", + run_name="example-run", + summary_completed=completed, + summary_total=total, + summary_total_known=total_known, + manifest_path=Path("/tmp/run_manifest.json"), + run_dir=Path("/tmp/job-run-123"), + created_at="2024-01-01T00:00:00Z", + updated_at="2024-01-01T00:05:00Z", + config_source="configs/example.yaml", + config_checksum="abc123", + run_summary_path=Path("/tmp/run_summary.json"), + ) + + def test_discover_run_records_basic(tmp_path: Path) -> None: runs_dir = tmp_path / "runs" run_dir = runs_dir / "job-run-123" @@ -129,32 +147,25 @@ def test_discover_run_records_filters_status(tmp_path: Path) -> None: assert filtered_none == [] -def test_discover_run_records_only_complete_runs_missing_total(tmp_path: Path) -> None: - runs_dir = tmp_path / "runs" - run_dir = runs_dir / "job-run-123" - results_dir = run_dir / "model-env-job" +def test_manifest_missing_pct_skips_when_above_threshold() -> None: + manifest = _manifest_info(completed=97, total=100, total_known=True) - manifest_payload = _base_manifest( - [ - { - "job_id": "model-env-job", - "model_id": "gpt-4", - "env_id": "demo-env-module", - "env_template_id": "demo-env-template", - "env_variant_id": "demo-env", - "env_args": {}, - "results_relpath": "model-env-job/results.jsonl", - } - ], - models={"gpt-4": {"sampling_args": {}}}, - env_templates={"demo-env-template": {"module": "demo-env-module"}}, - ) - _write_json(run_dir / "run_manifest.json", manifest_payload) - results_dir.mkdir(parents=True, exist_ok=True) - (results_dir / "results.jsonl").write_text("{}", encoding="utf-8") + assert _manifest_missing_pct(manifest) == 3.0 + assert _manifest_within_missing_pct(manifest, 0.0) is False - records = discover_run_records(runs_dir, only_complete_runs=True) - assert len(records) == 1 + +def test_manifest_missing_pct_keeps_runs_within_threshold() -> None: + manifest = _manifest_info(completed=39, total=40, total_known=True) + + assert _manifest_missing_pct(manifest) == 2.5 + assert _manifest_within_missing_pct(manifest, 2.5) is True + + +def test_manifest_missing_pct_unknown_total_is_permissive() -> None: + manifest = _manifest_info(completed=0, total=0, total_known=False) + + assert _manifest_missing_pct(manifest) is None + assert _manifest_within_missing_pct(manifest, 0.0) is True def test_discover_run_records_missing_summary_uses_manifest_status(tmp_path: Path) -> None: From 8b4dc750355959f6db9bfb424a4dbe658d5cdb77 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sun, 1 Mar 2026 13:17:47 -0500 Subject: [PATCH 12/29] Report winrate dataset missingness --- docs/medarc-eval-winrate.md | 3 + medarc_verifiers/cli/winrate/api.py | 142 ++++++++++++++++++++++++- tests/test_cli/test_process_winrate.py | 42 ++++++++ 3 files changed, 184 insertions(+), 3 deletions(-) diff --git a/docs/medarc-eval-winrate.md b/docs/medarc-eval-winrate.md index 49d527c3..b231d3d0 100644 --- a/docs/medarc-eval-winrate.md +++ b/docs/medarc-eval-winrate.md @@ -34,6 +34,9 @@ For each pair of models (A, B) on each benchmark: The final win rate aggregates across all benchmarks using configurable weighting. +Winrate also emits a missingness summary so partial dataset coverage is visible. The report counts missing +`(dataset, model)` pairs after rollout averaging, including both absent rows and null reward values. + ## Output Files ``` diff --git a/medarc_verifiers/cli/winrate/api.py b/medarc_verifiers/cli/winrate/api.py index d44ad2df..88f375cb 100644 --- a/medarc_verifiers/cli/winrate/api.py +++ b/medarc_verifiers/cli/winrate/api.py @@ -64,6 +64,28 @@ class ModelCentricResult: datasets: dict[str, dict[str, Any]] +@dataclass(slots=True) +class DatasetModelMissingness: + """Missing reward coverage for one (dataset, model) pair.""" + + dataset: str + model: str + expected_n: int + present_nonnull_n: int + missing_count: int + missing_pct: float + + +@dataclass(slots=True) +class MissingnessSummary: + """Aggregate missingness summary across retained datasets.""" + + n_pairs_total: int + n_pairs_with_missing: int + missing_cells_total: int + worst_offenders: list[DatasetModelMissingness] + + def read_dataset_lazy( parquet_path: Path | str | Sequence[Path | str | PLDataFrame | PLLazyFrame] | PLDataFrame | PLLazyFrame, ) -> pl.LazyFrame: @@ -288,6 +310,7 @@ def compute_winrates( n_questions_by_ds: dict[str, int] = {} models_by_ds: dict[str, list[str]] = {} models_present_by_ds: dict[str, set[str]] = {} + missingness_by_ds: dict[str, list[DatasetModelMissingness]] = {} seen_models: set[str] = set() seen_model_case_map: dict[str, str] = {} @@ -300,7 +323,7 @@ def compute_winrates( dataset_iter = datasets for dataset_name, parquet_path in dataset_iter: - stats, models_present = _process_dataset( + stats, models_present, missingness = _process_dataset( dataset_name, parquet_path, cfg, @@ -329,6 +352,7 @@ def compute_winrates( avg_rewards_by_dataset[dataset_name] = stats.avg_reward_per_model n_questions_by_ds[dataset_name] = stats.n_questions models_by_ds[dataset_name] = stats.models + missingness_by_ds[dataset_name] = missingness if not known_model_set: if include_set: @@ -349,6 +373,7 @@ def compute_winrates( per_dataset_model_means=per_dataset_model_means, avg_rewards_by_dataset=avg_rewards_by_dataset, models_by_ds=models_by_ds, + missingness_by_ds=missingness_by_ds, include_map=include_map, seen_model_case_map=seen_model_case_map, ) @@ -373,6 +398,7 @@ def compute_winrates( avg_rewards_by_dataset.pop(dataset_name, None) n_questions_by_ds.pop(dataset_name, None) models_by_ds.pop(dataset_name, None) + missingness_by_ds.pop(dataset_name, None) if not per_dataset_pairwise: _raise_user_error( "No datasets remain after enforcing dataset_coverage=all-models. " @@ -385,6 +411,8 @@ def compute_winrates( coverage=dataset_coverage, ) + _emit_missingness_report(_summarize_missingness(missingness_by_ds)) + return build_model_centric_result( per_dataset_pairwise=per_dataset_pairwise, per_dataset_model_means=per_dataset_model_means, @@ -583,7 +611,7 @@ def _process_dataset( include_map: Mapping[str, str], seen_model_case_map: Mapping[str, str], partial_datasets: str, -) -> tuple[DatasetStats | None, list[str]]: +) -> tuple[DatasetStats | None, list[str], list[DatasetModelMissingness]]: """Read and process a dataset, raising on failure and honoring selection policies.""" try: lf = read_dataset_lazy(parquet_path) @@ -599,7 +627,7 @@ def _process_dataset( if missing_required and partial_datasets == "strict": missing_labels = [include_map.get(model, model) for model in missing_required] _emit_note(f"Dropping dataset {dataset_name} (missing include models: {missing_labels}).") - return None, models_present + return None, models_present, [] if include_set: models_filtered = [models_present_map[model] for model in target_models if model in models_present_map] @@ -648,6 +676,7 @@ def canonical_label(normalized_id: str) -> str: else: pairwise[key] = (1.0 - wr, n_used) avg_reward_per_model = _mean_reward_per_model(df_avg, allowed=models) + missingness = _compute_dataset_missingness(dataset_name, df_filtered, models) return ( DatasetStats( pairwise=pairwise, @@ -656,12 +685,54 @@ def canonical_label(normalized_id: str) -> str: avg_reward_per_model=avg_reward_per_model, ), models_present, + missingness, ) except Exception as exc: # noqa: BLE001 message = f"Failed to process dataset {dataset_name} at {_format_parquet_source(parquet_path)}: {exc}" _raise_user_error(message, exc) +def _compute_dataset_missingness( + dataset_name: str, + df_avg: pl.DataFrame, + models: Sequence[str], +) -> list[DatasetModelMissingness]: + deduped_models = list(dict.fromkeys(str(model) for model in models)) + if not deduped_models: + return [] + + expected_n = 0 + present_nonnull_by_model: dict[str, int] = {} + if not df_avg.is_empty() and EXAMPLE_ID_COL in df_avg.columns: + expected_n = int(df_avg.select(pl.col(EXAMPLE_ID_COL).n_unique()).item()) + if MODEL_COL in df_avg.columns: + grouped = ( + df_avg.filter(pl.col("reward_mean").is_not_null()) + .group_by(MODEL_COL) # type: ignore[arg-type] + .agg(pl.col(EXAMPLE_ID_COL).n_unique().alias("present_nonnull_n")) + ) + present_nonnull_by_model = { + str(model): int(present_nonnull or 0) for model, present_nonnull in grouped.iter_rows() + } + + missingness: list[DatasetModelMissingness] = [] + for model in deduped_models: + present_nonnull_n = max(present_nonnull_by_model.get(model, 0), 0) + missing_count = max(expected_n - present_nonnull_n, 0) + missing_pct = (100.0 * missing_count / expected_n) if expected_n > 0 else 0.0 + missingness.append( + DatasetModelMissingness( + dataset=dataset_name, + model=model, + expected_n=expected_n, + present_nonnull_n=present_nonnull_n, + missing_count=missing_count, + missing_pct=missing_pct, + ) + ) + return missingness + + def _mean_reward_per_model(df_avg: pl.DataFrame, allowed: Sequence[str] | None = None) -> dict[str, float | None]: """Average reward_mean per model inside a dataset.""" if df_avg.is_empty() or MODEL_COL not in df_avg.columns: @@ -745,6 +816,7 @@ def _canonicalize_dataset_model_labels( per_dataset_model_means: dict[str, dict[str, float]], avg_rewards_by_dataset: dict[str, dict[str, float | None]], models_by_ds: dict[str, list[str]], + missingness_by_ds: dict[str, list[DatasetModelMissingness]], include_map: Mapping[str, str], seen_model_case_map: Mapping[str, str], ) -> None: @@ -806,6 +878,70 @@ def canonical(value: str) -> str: deduped.append(canonical_model) models_by_ds[dataset] = deduped + for dataset, rows in list(missingness_by_ds.items()): + canonical_rows: list[DatasetModelMissingness] = [] + for row in rows: + canonical_rows.append( + DatasetModelMissingness( + dataset=row.dataset, + model=canonical(row.model), + expected_n=row.expected_n, + present_nonnull_n=row.present_nonnull_n, + missing_count=row.missing_count, + missing_pct=row.missing_pct, + ) + ) + missingness_by_ds[dataset] = canonical_rows + + +def _summarize_missingness( + missingness_by_ds: Mapping[str, Sequence[DatasetModelMissingness]], +) -> MissingnessSummary: + rows = [row for dataset_rows in missingness_by_ds.values() for row in dataset_rows] + rows_with_missing = [row for row in rows if row.missing_count > 0] + worst_offenders = sorted( + rows_with_missing, + key=lambda row: (-row.missing_pct, -row.missing_count, row.dataset, row.model), + )[:10] + return MissingnessSummary( + n_pairs_total=len(rows), + n_pairs_with_missing=len(rows_with_missing), + missing_cells_total=sum(row.missing_count for row in rows), + worst_offenders=worst_offenders, + ) + + +def _emit_missingness_report(summary: MissingnessSummary) -> None: + logger.info( + "Winrate missingness summary: n_pairs_total=%d n_pairs_with_missing=%d missing_cells_total=%d", + summary.n_pairs_total, + summary.n_pairs_with_missing, + summary.missing_cells_total, + ) + console = _get_console() + if not console or not getattr(console, "is_terminal", False) or not summary.worst_offenders: + return + try: + from rich.table import Table + except Exception: + return + + table = Table(title="Winrate missingness (top offenders)") + table.add_column("dataset", style="cyan") + table.add_column("model", style="magenta") + table.add_column("missing", justify="right") + table.add_column("expected", justify="right") + table.add_column("missing %", justify="right") + for row in summary.worst_offenders: + table.add_row( + row.dataset, + row.model, + str(row.missing_count), + str(row.expected_n), + f"{row.missing_pct:.1f}", + ) + console.print(table) + def _format_parquet_source( parquet_path: Path | str | Sequence[Path | str] | PLDataFrame | PLLazyFrame, diff --git a/tests/test_cli/test_process_winrate.py b/tests/test_cli/test_process_winrate.py index 3b885e69..a29e8bed 100644 --- a/tests/test_cli/test_process_winrate.py +++ b/tests/test_cli/test_process_winrate.py @@ -360,6 +360,48 @@ def test_partial_datasets_include_uses_consistent_canonical_labels(tmp_path: Pat assert payload["models"]["Model_A"]["vs"]["Model_B"]["n_datasets"] == 2 +def test_compute_dataset_missingness_counts_null_rewards() -> None: + df_avg = pl.DataFrame( + { + "example_id": ["q1", "q2", "q3", "q1", "q2", "q3"], + "model_id": ["model_a", "model_a", "model_a", "model_b", "model_b", "model_b"], + "reward_mean": [1.0, 0.5, 0.0, 0.8, None, 0.2], + } + ) + + rows = winrate_api._compute_dataset_missingness("dataset", df_avg, ["model_a", "model_b"]) + by_model = {row.model: row for row in rows} + + assert by_model["model_a"].expected_n == 3 + assert by_model["model_a"].present_nonnull_n == 3 + assert by_model["model_a"].missing_count == 0 + assert by_model["model_a"].missing_pct == pytest.approx(0.0) + assert by_model["model_b"].expected_n == 3 + assert by_model["model_b"].present_nonnull_n == 2 + assert by_model["model_b"].missing_count == 1 + assert by_model["model_b"].missing_pct == pytest.approx(100 / 3) + + +def test_compute_dataset_missingness_marks_absent_included_model_fully_missing() -> None: + df_avg = pl.DataFrame( + { + "example_id": ["q1", "q2"], + "model_id": ["model_a", "model_a"], + "reward_mean": [1.0, 0.5], + } + ) + + rows = winrate_api._compute_dataset_missingness("dataset", df_avg, ["model_a", "model_b"]) + by_model = {row.model: row for row in rows} + + assert by_model["model_a"].missing_count == 0 + assert by_model["model_a"].missing_pct == pytest.approx(0.0) + assert by_model["model_b"].expected_n == 2 + assert by_model["model_b"].present_nonnull_n == 0 + assert by_model["model_b"].missing_count == 2 + assert by_model["model_b"].missing_pct == pytest.approx(100.0) + + def test_filter_models_is_case_insensitive() -> None: filtered = winrate_api._filter_models( ["Model_A", "Model_B", "Model_C"], From e9b36aa4cea6fafb1197dd9d9d7bf284f1807a70 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sun, 1 Mar 2026 13:39:28 -0500 Subject: [PATCH 13/29] Tighten process missing-pct selection --- medarc_verifiers/cli/main.py | 8 +- medarc_verifiers/cli/process/discovery.py | 12 +-- medarc_verifiers/cli/process/pipeline.py | 21 ++-- tests/test_cli/test_main.py | 53 ++++++++++ tests/test_cli/test_process_pipeline.py | 120 ++++++++++++++++++++++ 5 files changed, 192 insertions(+), 22 deletions(-) diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index f413c92e..e5c13add 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -628,10 +628,12 @@ def _build_process_options(args: argparse.Namespace) -> ProcessOptions: retries=args.hf_retries, max_files_per_commit=args.hf_max_files_per_commit, ) - status_filter = tuple(args.status) if args.status is not None else PROCESS_DEFAULT_STATUS_FILTER + status_values = list(args.status or []) + status_filter = tuple(status_values) if status_values else PROCESS_DEFAULT_STATUS_FILTER + max_run_missing_pct = float(args.max_run_missing_pct) if args.max_run_missing_pct is not None else 2.5 processed_with_args = { "status": list(status_filter), - "max_run_missing_pct": float(args.max_run_missing_pct), + "max_run_missing_pct": max_run_missing_pct, "exclude_datasets": args.exclude_dataset or [], "exclude_models": args.exclude_model or [], "replace_models": args.replace_model or [], @@ -655,7 +657,7 @@ def _build_process_options(args: argparse.Namespace) -> ProcessOptions: processed_at=args.processed_at, processed_with_args=processed_with_args, status_filter=status_filter, - max_run_missing_pct=float(args.max_run_missing_pct), + max_run_missing_pct=max_run_missing_pct, dry_run=bool(args.dry_run), clean=bool(args.clean), assume_yes=bool(args.yes), diff --git a/medarc_verifiers/cli/process/discovery.py b/medarc_verifiers/cli/process/discovery.py index fc583f10..933f0aec 100644 --- a/medarc_verifiers/cli/process/discovery.py +++ b/medarc_verifiers/cli/process/discovery.py @@ -20,7 +20,6 @@ logger = logging.getLogger(__name__) DEFAULT_STATUS = "unknown" -_COMPLETED_STATUSES = {"completed", "succeeded", "success"} @dataclass(frozen=True, slots=True) @@ -78,17 +77,15 @@ def discover_run_records( runs_dir: Path | str, *, filter_status: Sequence[str] | None = None, - only_complete_runs: bool = False, ) -> list[RunRecord]: """Return all discovered run records within the provided runs directory.""" - return list(iter_run_records(runs_dir, filter_status=filter_status, only_complete_runs=only_complete_runs)) + return list(iter_run_records(runs_dir, filter_status=filter_status)) def iter_run_records( runs_dir: Path | str, *, filter_status: Sequence[str] | None = None, - only_complete_runs: bool = False, ) -> Iterator[RunRecord]: """Yield run records for each job entry found under the runs directory.""" runs_path = Path(runs_dir) @@ -108,13 +105,6 @@ def iter_run_records( manifest_info, job_entries = _load_manifest(run_dir) if manifest_info is None: continue - if ( - only_complete_runs - and manifest_info.summary_total_known - and manifest_info.summary_completed != manifest_info.summary_total - ): - # Skip entire run if not fully completed - continue summary_map = _load_run_summary(run_dir) for job_entry in job_entries: summary_entry = summary_map.get(job_entry.job_id or "") diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index 1a94453e..f284b9df 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -149,7 +149,6 @@ def _run_pipeline() -> ProcessResult: discovered = discovery.discover_run_records( options.runs_dir, filter_status=options.status_filter or None, - only_complete_runs=False, ) selection = select_work_items( discovered, @@ -301,10 +300,16 @@ def select_work_items( ) -> SelectionResult: """Filter discovered runs down to selected work items before row loading begins.""" eligible_records: list[discovery.RunRecord] = [] - skipped_by_missing_pct = 0 + skipped_run_ids_by_missing_pct: set[str] = set() + missing_pct_allowed_by_run_id: dict[str, bool] = {} for record in discovered: - if not _manifest_within_missing_pct(record.manifest, options.max_run_missing_pct): - skipped_by_missing_pct += 1 + run_id = record.manifest.job_run_id + allow_run = missing_pct_allowed_by_run_id.get(run_id) + if allow_run is None: + allow_run = _manifest_within_missing_pct(record.manifest, options.max_run_missing_pct) + missing_pct_allowed_by_run_id[run_id] = allow_run + if not allow_run: + skipped_run_ids_by_missing_pct.add(run_id) continue eligible_records.append(record) @@ -322,7 +327,7 @@ def select_work_items( return SelectionResult( work_items=work_items, - skipped_by_missing_pct=skipped_by_missing_pct, + skipped_by_missing_pct=len(skipped_run_ids_by_missing_pct), skipped_by_delta=skipped_by_delta, skipped_by_exclusion=skipped_by_exclusion, total_discovered=len(discovered), @@ -615,8 +620,8 @@ def _print_records_table( from rich.table import Table except Exception: logger.info( - "Processing %d job(s) across %d model(s) (max_run_missing_pct=%s; found %d job(s) across %d model(s)); " - "skipped by missing pct=%d excluded=%d existing=%d.", + "Processing %d job(s) across %d model(s) (max_run_missing_pct=%s; found %d eligible job(s) across %d model(s)); " + "skipped run(s) by missing pct=%d excluded=%d existing=%d.", selected_jobs_total, len(selected_models), _format_missing_pct(max_run_missing_pct), @@ -639,7 +644,7 @@ def _print_records_table( f"[dim](max_run_missing_pct={_format_missing_pct(max_run_missing_pct)})[/dim]" ) title += ( - f" [dim](found {discovered_jobs_total} eligible job(s); skipped by missing pct={skipped_by_missing_pct}, " + f" [dim](found {discovered_jobs_total} eligible job(s); skipped run(s) by missing pct={skipped_by_missing_pct}, " f"excluded={skipped_by_exclusion}, existing={skipped_by_delta})[/dim]" ) table = Table(title=title, show_header=True, header_style="bold cyan", caption=None) diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index b56c0f90..b5c100cb 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -2059,6 +2059,59 @@ def fake_run_process(options, env_export_map): assert options.max_run_missing_pct == pytest.approx(100.0) +def test_process_cli_rejects_negative_max_run_missing_pct( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + with pytest.raises(SystemExit) as excinfo: + main.main( + [ + "process", + "--runs-dir", + str(tmp_path / "runs"), + "--output-dir", + str(tmp_path / "processed"), + "--max-run-missing-pct", + "-1", + ] + ) + + assert excinfo.value.code == 2 + err = capsys.readouterr().err + assert "--max-run-missing-pct must be non-negative." in err + + +def test_process_config_empty_status_uses_default_filter( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + cfg_path = tmp_path / "process.yaml" + cfg_path.write_text( + """ + runs_dir: runs/raw + process: + dir: processed + status: [] + """, + encoding="utf-8", + ) + + captured: dict[str, Any] = {} + + def fake_run_process(options, env_export_map): + captured["options"] = options + return ProcessResult(records_processed=0, rows_processed=0, env_groups=[], env_summaries=[], hf_summary=None) + + monkeypatch.setattr(main, "run_process", fake_run_process) + + exit_code = main.main(["process", "--config", str(cfg_path), "--dry-run"]) + + assert exit_code == 0 + options = captured["options"] + assert options.status_filter == ("completed", "succeeded", "success") + assert options.processed_with_args["status"] == ["completed", "succeeded", "success"] + + def test_process_cli_runs_embedded_winrate_post_step(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: cfg_path = tmp_path / "process.yaml" cfg_path.write_text( diff --git a/tests/test_cli/test_process_pipeline.py b/tests/test_cli/test_process_pipeline.py index 122abbd2..76ac0c3c 100644 --- a/tests/test_cli/test_process_pipeline.py +++ b/tests/test_cli/test_process_pipeline.py @@ -9,6 +9,8 @@ from medarc_verifiers.cli._manifest import MANIFEST_VERSION from medarc_verifiers.cli._schemas import EnvironmentExportConfig from medarc_verifiers.cli.process import ProcessOptions, run_process +from medarc_verifiers.cli.process.discovery import RunManifestInfo, RunRecord +from medarc_verifiers.cli.process.pipeline import select_work_items from medarc_verifiers.cli.winrate import WinrateConfig from medarc_verifiers.cli.winrate import discover_datasets, run_winrate from medarc_verifiers.cli.hf import HFSyncConfig @@ -20,6 +22,77 @@ def _write_json(path: Path, payload: dict) -> None: path.write_text(json.dumps(payload), encoding="utf-8") +def _manifest_info( + *, + run_id: str, + completed: int, + total: int, + total_known: bool = True, + updated_at: str = "2024-01-01T00:00:00Z", +) -> RunManifestInfo: + run_dir = Path("/tmp") / run_id + return RunManifestInfo( + job_run_id=run_id, + run_name=run_id, + summary_completed=completed, + summary_total=total, + summary_total_known=total_known, + manifest_path=run_dir / "run_manifest.json", + run_dir=run_dir, + created_at="2024-01-01T00:00:00Z", + updated_at=updated_at, + config_source="configs/demo.yaml", + config_checksum="abc123", + run_summary_path=run_dir / "run_summary.json", + ) + + +def _run_record( + *, + run_id: str, + job_id: str, + env_id: str, + model_id: str = "gpt-mini", + completed: int = 1, + total: int = 1, + total_known: bool = True, + updated_at: str = "2024-01-01T00:00:00Z", +) -> RunRecord: + run_dir = Path("/tmp") / run_id + results_dir = run_dir / job_id + return RunRecord( + manifest=_manifest_info( + run_id=run_id, + completed=completed, + total=total, + total_known=total_known, + updated_at=updated_at, + ), + job_id=job_id, + model_id=model_id, + manifest_env_id=env_id, + results_dir_name=job_id, + results_dir=results_dir, + metadata_path=results_dir / "metadata.json", + results_path=results_dir / "results.jsonl", + summary_path=results_dir / "summary.json", + has_metadata=False, + has_results=True, + has_summary=True, + status="completed", + duration_seconds=1.0, + reason=None, + started_at="2024-01-01T00:00:00Z", + ended_at="2024-01-01T00:00:01Z", + num_examples=1, + rollouts_per_example=1, + env_args={}, + sampling_args={}, + env_config={"id": env_id, "module": env_id}, + model_config={}, + ) + + def _setup_run(tmp_path: Path) -> Path: runs_dir = tmp_path / "runs" run_dir = runs_dir / "run-1" @@ -278,6 +351,53 @@ def test_run_process_excludes_datasets(tmp_path: Path) -> None: assert result.env_groups[0].base_env_id == "keep-env" +def test_select_work_items_counts_missing_pct_skips_per_run_not_per_job() -> None: + discovered = [ + _run_record( + run_id="run-skipped", + job_id="job-a", + env_id="demo-env-a", + completed=8, + total=10, + ), + _run_record( + run_id="run-skipped", + job_id="job-b", + env_id="demo-env-b", + completed=8, + total=10, + ), + _run_record( + run_id="run-kept", + job_id="job-c", + env_id="demo-env-c", + completed=10, + total=10, + updated_at="2024-01-01T00:05:00Z", + ), + ] + options = ProcessOptions( + runs_dir=Path("/tmp/runs"), + output_dir=Path("/tmp/processed"), + max_run_missing_pct=10.0, + dry_run=True, + max_workers=1, + ) + + selection = select_work_items( + discovered, + options=options, + env_export_map={}, + index_files={}, + ) + + assert selection.skipped_by_missing_pct == 1 + assert selection.total_discovered == 3 + assert len(selection.work_items) == 1 + assert selection.work_items[0].identity.job_run_id == "run-kept" + assert selection.work_items[0].identity.output_env_id == "demo-env-c" + + def test_run_process_excludes_models(tmp_path: Path) -> None: _write_run( tmp_path, From 75f7d311330ad8d3dbbbec666128e223cfd80c3b Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sun, 1 Mar 2026 15:11:46 -0500 Subject: [PATCH 14/29] small fixes --- docs/medarc-eval-process.md | 20 ++++++++++---------- medarc_verifiers/cli/process/pipeline.py | 5 ++--- tests/test_cli/test_main.py | 8 ++++---- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/docs/medarc-eval-process.md b/docs/medarc-eval-process.md index 6538ce30..490289f8 100644 --- a/docs/medarc-eval-process.md +++ b/docs/medarc-eval-process.md @@ -5,7 +5,7 @@ Convert raw benchmark outputs into analysis-ready parquet files. This step prepa ## Quick Start ```bash -# Process all completed runs (uses defaults) +# Process all completed jobs (uses defaults) medarc-eval process # Specify directories explicitly @@ -17,7 +17,7 @@ medarc-eval process --dry-run ## What Processing Does -1. **Discovers** completed jobs in `runs/raw/` +1. **Discovers** jobs in `runs/raw/` and filters by manifest status (default: `completed`) 2. **Extracts** results from each job's output files 3. **Normalizes** data into a consistent schema 4. **Writes** parquet files organized by environment and model @@ -43,7 +43,7 @@ runs/processed/ |------|-------------|---------| | `--runs-dir PATH` | Directory containing raw runs | `runs/raw` | | `--output-dir PATH` | Where to write processed files | `runs/processed` | -| `--max-workers N` | Parallel processing threads | 4 | +| `--max-workers N` | Parallel worker processes | 4 | | `--dry-run` | Show what would be processed | - | | `--yes` | Skip confirmation prompts | - | | `--exclude-dataset NAME` | Skip processing specific datasets/env ids (repeatable) | - | @@ -53,11 +53,9 @@ runs/processed/ ### By Completion Status -By default, `medarc-eval process` only selects runs whose manifest status is one of: +By default, `medarc-eval process` only selects jobs whose manifest status is `completed`. -- `completed` -- `succeeded` -- `success` +Note: successful jobs are written to `run_manifest.json` with `status: completed`. To override that default, pass one or more explicit status filters: @@ -251,12 +249,14 @@ When both flags are present, processing only rebuilds outputs that match both fi Check that: 1. `--runs-dir` points to the correct location -2. Runs have completed (check `run_manifest.json` status) -3. Use `--process-incomplete` if runs are still in progress +2. Runs have completed (check `run_manifest.json` `jobs[*].status` and `summary.completed` / `summary.total`) +3. Use `--status pending` or `--status running` to include non-completed jobs ### Missing data in output -By default, only jobs with `completed` status are included. Use `--process-incomplete` to include partial results. +By default, only jobs with `completed` status are included. In addition, `--max-run-missing-pct` skips run directories missing more than 2.5% of their expected job outputs (model/env combinations), based on `run_manifest.json` `summary.completed` / `summary.total`. This is a manifest-level gate; it does not validate `results.jsonl` row counts. + +Use `--max-run-missing-pct 100` to disable the gate, or pass explicit `--status` values to include other statuses. ### Integrity-check failures for existing parquet files diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index f284b9df..a08abf16 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -27,7 +27,7 @@ ) logger = logging.getLogger(__name__) -PROCESS_DEFAULT_STATUS_FILTER: tuple[str, ...] = ("completed", "succeeded", "success") +PROCESS_DEFAULT_STATUS_FILTER: tuple[str, ...] = ("completed",) @dataclass(slots=True) @@ -599,11 +599,10 @@ def _print_records_table( total_by_model: dict[str, int] = {} completed_by_model: dict[str, int] = {} selected_by_model: dict[str, int] = {} - completed_statuses = {"completed", "succeeded", "success"} for rec in eligible_discovered: model_id = rec.model_id or "unknown" total_by_model[model_id] = total_by_model.get(model_id, 0) + 1 - if (rec.status or "").lower() in completed_statuses: + if (rec.status or "").lower() in PROCESS_DEFAULT_STATUS_FILTER: completed_by_model[model_id] = completed_by_model.get(model_id, 0) + 1 for rec in selected: model_id = rec.model_id or "unknown" diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index b5c100cb..03f9b287 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -2022,8 +2022,8 @@ def fake_run_process(options, env_export_map): assert exit_code == 0 options = captured["options"] - assert options.status_filter == ("completed", "succeeded", "success") - assert options.processed_with_args["status"] == ["completed", "succeeded", "success"] + assert options.status_filter == ("completed",) + assert options.processed_with_args["status"] == ["completed"] assert options.max_run_missing_pct == pytest.approx(2.5) assert options.processed_with_args["max_run_missing_pct"] == pytest.approx(2.5) @@ -2108,8 +2108,8 @@ def fake_run_process(options, env_export_map): assert exit_code == 0 options = captured["options"] - assert options.status_filter == ("completed", "succeeded", "success") - assert options.processed_with_args["status"] == ["completed", "succeeded", "success"] + assert options.status_filter == ("completed",) + assert options.processed_with_args["status"] == ["completed"] def test_process_cli_runs_embedded_winrate_post_step(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: From 9c6991685284ce3ab1196d67d29726408ee92a9d Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sun, 1 Mar 2026 17:24:35 -0500 Subject: [PATCH 15/29] improved results missing percent logic --- docs/medarc-eval-process.md | 27 +- medarc_verifiers/cli/main.py | 46 ++- medarc_verifiers/cli/process/discovery.py | 2 + medarc_verifiers/cli/process/metadata.py | 25 +- medarc_verifiers/cli/process/pipeline.py | 239 ++++++++++---- tests/test_cli/test_main.py | 111 ++++++- tests/test_cli/test_process_discovery.py | 24 +- tests/test_cli/test_process_metadata.py | 52 +++ tests/test_cli/test_process_pipeline.py | 384 ++++++++++++++++++---- tests/test_cli/test_process_rows.py | 1 + 10 files changed, 726 insertions(+), 185 deletions(-) diff --git a/docs/medarc-eval-process.md b/docs/medarc-eval-process.md index 490289f8..d89c5923 100644 --- a/docs/medarc-eval-process.md +++ b/docs/medarc-eval-process.md @@ -63,16 +63,23 @@ To override that default, pass one or more explicit status filters: medarc-eval process --status completed --status failed ``` -You can also gate partially complete runs by their manifest summary totals: +You can also gate partially complete outputs by missing `results.jsonl` rows: ```bash # Default tolerance is 2.5 percent missing -medarc-eval process --max-run-missing-pct 2.5 +medarc-eval process --max-results-missing-pct 2.5 # Effectively disable the gate -medarc-eval process --max-run-missing-pct 100 +medarc-eval process --max-results-missing-pct 100 ``` +This gate uses manifest job metadata only: + +- `expected_rows = num_examples * rollouts_per_example` +- `observed_rows = row_count` + +It is computed per selected job record and enforced only on the latest selected run for each processed model/environment output. It does not use manifest `summary.completed` / `summary.total`, and it does not fall back to older runs if the latest one is too incomplete. + ### Latest Runs Only When multiple runs exist for the same (model, environment) pair, processing uses the latest by default. @@ -100,7 +107,7 @@ runs_dir: runs/raw process: dir: processed max_workers: 8 - max_run_missing_pct: 2.5 + max_results_missing_pct: 2.5 exclude_datasets: - med_dialog exclude_models: @@ -249,14 +256,20 @@ When both flags are present, processing only rebuilds outputs that match both fi Check that: 1. `--runs-dir` points to the correct location -2. Runs have completed (check `run_manifest.json` `jobs[*].status` and `summary.completed` / `summary.total`) +2. Runs have completed (check `run_manifest.json` `jobs[*].status`) 3. Use `--status pending` or `--status running` to include non-completed jobs ### Missing data in output -By default, only jobs with `completed` status are included. In addition, `--max-run-missing-pct` skips run directories missing more than 2.5% of their expected job outputs (model/env combinations), based on `run_manifest.json` `summary.completed` / `summary.total`. This is a manifest-level gate; it does not validate `results.jsonl` row counts. +By default, only jobs with `completed` status are included. In addition, `--max-results-missing-pct` fails if a selected latest job record is missing more than 2.5% of its expected `results.jsonl` rows, using manifest job fields: + +- `row_count` +- `num_examples` +- `rollouts_per_example` + +The gate is per selected record, not per whole run manifest. If the latest selected run for a model/dataset is too incomplete, processing fails fast instead of silently falling back to an older run. Records with unknown expected rows or unknown `row_count` are not gated. -Use `--max-run-missing-pct 100` to disable the gate, or pass explicit `--status` values to include other statuses. +Use `--max-results-missing-pct 100` to disable the gate, or pass explicit `--status` values to include other statuses. ### Integrity-check failures for existing parquet files diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index e5c13add..041dbe60 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -298,10 +298,15 @@ def build_process_parser() -> argparse.ArgumentParser: help="Rebuild existing processed outputs for these env ids (repeatable; comma-separated values allowed).", ) parser.add_argument( - "--max-run-missing-pct", + "--max-results-missing-pct", type=float, default=None, - help="Skip run directories whose manifest-level missing percentage exceeds this threshold (default: 2.5).", + help=( + "Fail if a selected latest job record is missing more than this percentage of expected results.jsonl rows " + "based on manifest job fields (row_count, num_examples, rollouts_per_example). " + "Computed per selected job record and enforced only on the latest selected run; does not use " + "manifest summary.completed/summary.total or fall back to older runs (default: 2.5)." + ), ) parser.add_argument( "--winrate", @@ -609,10 +614,10 @@ def _validate_process_args( normalize_dataset_ids(args.exclude_dataset, label="process exclude dataset") if args.exclude_model: normalize_model_ids(args.exclude_model, label="process exclude model") - if args.max_run_missing_pct is not None: - value = float(args.max_run_missing_pct) + if args.max_results_missing_pct is not None: + value = float(args.max_results_missing_pct) if value < 0: - parser.error("--max-run-missing-pct must be non-negative.") + parser.error("--max-results-missing-pct must be non-negative.") except ValueError as exc: parser.error(str(exc)) @@ -630,10 +635,10 @@ def _build_process_options(args: argparse.Namespace) -> ProcessOptions: ) status_values = list(args.status or []) status_filter = tuple(status_values) if status_values else PROCESS_DEFAULT_STATUS_FILTER - max_run_missing_pct = float(args.max_run_missing_pct) if args.max_run_missing_pct is not None else 2.5 + max_results_missing_pct = float(args.max_results_missing_pct) if args.max_results_missing_pct is not None else 2.5 processed_with_args = { "status": list(status_filter), - "max_run_missing_pct": max_run_missing_pct, + "max_results_missing_pct": max_results_missing_pct, "exclude_datasets": args.exclude_dataset or [], "exclude_models": args.exclude_model or [], "replace_models": args.replace_model or [], @@ -657,7 +662,7 @@ def _build_process_options(args: argparse.Namespace) -> ProcessOptions: processed_at=args.processed_at, processed_with_args=processed_with_args, status_filter=status_filter, - max_run_missing_pct=max_run_missing_pct, + max_results_missing_pct=max_results_missing_pct, dry_run=bool(args.dry_run), clean=bool(args.clean), assume_yes=bool(args.yes), @@ -794,9 +799,21 @@ def _set_if_unset(args: argparse.Namespace, attr: str, value: Any) -> None: def _load_config_payload(path: Path, *, mode: Literal["process", "winrate"]) -> dict[str, Any]: label = "Process config" if mode == "process" else "Winrate config" raw_payload = dict(load_mapping_file(path, label=label)) + if mode == "process": + _reject_removed_process_config_keys(raw_payload) return _expand_embedded_pipeline_config(raw_payload, mode=mode) +def _reject_removed_process_config_keys(payload: Mapping[str, Any]) -> None: + if "max_run_missing_pct" in payload: + raise ValueError("Process config field 'max_run_missing_pct' was removed; use 'max_results_missing_pct'.") + process_section = payload.get("process") + if isinstance(process_section, Mapping) and "max_run_missing_pct" in process_section: + raise ValueError( + "Process config field 'process.max_run_missing_pct' was removed; use 'process.max_results_missing_pct'." + ) + + def _expand_embedded_pipeline_config(payload: dict[str, Any], *, mode: Literal["process", "winrate"]) -> dict[str, Any]: expanded = dict(payload) process_section = payload.get("process") @@ -851,7 +868,7 @@ def _merge_process_section( "clean": "clean", "yes": "yes", "max_workers": "max_workers", - "max_run_missing_pct": "max_run_missing_pct", + "max_results_missing_pct": "max_results_missing_pct", } ) for key, target in key_map.items(): @@ -895,7 +912,9 @@ def _merge_winrate_section( expanded[target] = winrate_section[key] -def _resolve_processed_dir_from_payload(payload: Mapping[str, Any], *, mode: Literal["process", "winrate"]) -> Path | None: +def _resolve_processed_dir_from_payload( + payload: Mapping[str, Any], *, mode: Literal["process", "winrate"] +) -> Path | None: if "processed_dir" in payload and payload["processed_dir"] is not None: return Path(str(payload["processed_dir"])) if mode == "process" and "output_dir" in payload and payload["output_dir"] is not None: @@ -1032,7 +1051,10 @@ def _load_and_apply_config( "winrate": {"min_common": "min_common", "weight_cap": "weight_cap"}, }[mode] float_fields = { - "process": {"hf_request_timeout": "hf_request_timeout", "max_run_missing_pct": "max_run_missing_pct"}, + "process": { + "hf_request_timeout": "hf_request_timeout", + "max_results_missing_pct": "max_results_missing_pct", + }, "winrate": {"epsilon": "epsilon"}, }[mode] repeatable_fields = { @@ -1116,7 +1138,7 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", "dry_run": False, "clean": False, "yes": False, - "max_run_missing_pct": 2.5, + "max_results_missing_pct": 2.5, "exclude_dataset": [], "exclude_model": [], "replace_model": [], diff --git a/medarc_verifiers/cli/process/discovery.py b/medarc_verifiers/cli/process/discovery.py index 933f0aec..4056bb71 100644 --- a/medarc_verifiers/cli/process/discovery.py +++ b/medarc_verifiers/cli/process/discovery.py @@ -67,6 +67,7 @@ class RunRecord: ended_at: str | None num_examples: int | None rollouts_per_example: int | None + row_count: int | None env_args: Mapping[str, Any] sampling_args: Mapping[str, Any] env_config: Mapping[str, Any] | None @@ -186,6 +187,7 @@ def _build_run_record( ended_at=job_entry.ended_at, num_examples=job_entry.num_examples, rollouts_per_example=job_entry.rollouts_per_example, + row_count=job_entry.row_count, env_args=env_args, sampling_args=sampling_args, env_config=env_config, diff --git a/medarc_verifiers/cli/process/metadata.py b/medarc_verifiers/cli/process/metadata.py index 3d70fb02..e591db0d 100644 --- a/medarc_verifiers/cli/process/metadata.py +++ b/medarc_verifiers/cli/process/metadata.py @@ -95,9 +95,7 @@ def resolve_run_identity( """Resolve a run identity for selection without requiring model_id.""" context = _resolve_metadata_context(record, combine_rollouts=combine_rollouts) resolved_rollout_index = ( - context.rollout_index - if context.rollout_index != 0 or context.manifest_env_id != context.base_env_id - else None + context.rollout_index if context.rollout_index != 0 or context.manifest_env_id != context.base_env_id else None ) return ResolvedRunIdentity( model_id=context.model_id, @@ -119,9 +117,7 @@ def load_normalized_metadata( if not context.model_id: raise RuntimeError(format_missing_model_id_error(record)) resolved_rollout_index = ( - context.rollout_index - if context.rollout_index != 0 or context.manifest_env_id != context.base_env_id - else None + context.rollout_index if context.rollout_index != 0 or context.manifest_env_id != context.base_env_id else None ) identity = RunIdentity( model_id=context.model_id, @@ -187,9 +183,14 @@ def _resolve_metadata_context( metadata_model=metadata_model, env_args=env_args, sampling_args=sampling_args, - num_examples=record.num_examples or (metadata_payload.num_examples if metadata_payload else None), - rollouts_per_example=record.rollouts_per_example - or (metadata_payload.rollouts_per_example if metadata_payload else None), + num_examples=_prefer_manifest_value( + record.num_examples, + metadata_payload.num_examples if metadata_payload else None, + ), + rollouts_per_example=_prefer_manifest_value( + record.rollouts_per_example, + metadata_payload.rollouts_per_example if metadata_payload else None, + ), ) @@ -251,6 +252,12 @@ def _merge_mappings( return result +def _prefer_manifest_value(primary: int | None, fallback: int | None) -> int | None: + if primary is not None: + return primary + return fallback + + def _extract_env_config_id(env_config: Mapping[str, Any] | None) -> str | None: if not env_config: return None diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index a08abf16..1cd3befd 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import logging import sys from concurrent.futures import ProcessPoolExecutor, as_completed @@ -18,7 +19,7 @@ from medarc_verifiers.cli.process import aggregate, discovery, env_index, metadata, rollout, rows, workspace, writer from medarc_verifiers.cli.process.aggregate import AggregatedEnvRows from medarc_verifiers.cli.process.metadata import RunIdentity -from medarc_verifiers.cli.process.writer import EnvWriteSummary, WriterConfig +from medarc_verifiers.cli.process.writer import EXPORTER_METADATA_KEY, EnvWriteSummary, WriterConfig from medarc_verifiers.cli.utils.shared import ( dataset_is_excluded, model_is_excluded, @@ -36,14 +37,14 @@ class ProcessOptions: runs_dir: Path output_dir: Path - max_run_missing_pct: float = 2.5 + max_results_missing_pct: float = 2.5 exclude_datasets: Sequence[str] = field(default_factory=tuple) exclude_models: Sequence[str] = field(default_factory=tuple) replace_models: Sequence[str] = field(default_factory=tuple) replace_envs: Sequence[str] = field(default_factory=tuple) processed_at: str | None = None processed_with_args: Mapping[str, Any] = field(default_factory=dict) - status_filter: Sequence[str] = field(default_factory=tuple) + status_filter: Sequence[str] = field(default_factory=lambda: PROCESS_DEFAULT_STATUS_FILTER) dry_run: bool = False clean: bool = False assume_yes: bool = False @@ -54,7 +55,7 @@ class ProcessOptions: def __post_init__(self) -> None: self.runs_dir = Path(self.runs_dir) self.output_dir = Path(self.output_dir) - self.max_run_missing_pct = float(self.max_run_missing_pct) + self.max_results_missing_pct = float(self.max_results_missing_pct) self.max_workers = max(1, int(self.max_workers)) if not self.processed_at: self.processed_at = datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z") @@ -119,7 +120,6 @@ class SelectionResult: """Complete output of the selection phase.""" work_items: list[PlannedWorkItem] - skipped_by_missing_pct: int skipped_by_delta: int skipped_by_exclusion: int total_discovered: int @@ -160,10 +160,9 @@ def _run_pipeline() -> ProcessResult: _print_records_table( discovered, selected_records, - options.max_run_missing_pct, + options.max_results_missing_pct, exclude_datasets=options.exclude_datasets, exclude_models=options.exclude_models, - skipped_by_missing_pct=selection.skipped_by_missing_pct, skipped_by_delta=selection.skipped_by_delta, skipped_by_exclusion=selection.skipped_by_exclusion, ) @@ -191,7 +190,9 @@ def _run_pipeline() -> ProcessResult: env_groups: list[AggregatedEnvRows] = [] env_summaries: list[EnvWriteSummary] = [] rows_processed = 0 - work_items = sorted(selection.work_items, key=lambda item: (item.identity.model_id, item.identity.output_env_id)) + work_items = sorted( + selection.work_items, key=lambda item: (item.identity.model_id, item.identity.output_env_id) + ) try: if options.max_workers <= 1 or len(work_items) <= 1: @@ -299,23 +300,11 @@ def select_work_items( index_files: Mapping[str, Mapping[str, Any]], ) -> SelectionResult: """Filter discovered runs down to selected work items before row loading begins.""" - eligible_records: list[discovery.RunRecord] = [] - skipped_run_ids_by_missing_pct: set[str] = set() - missing_pct_allowed_by_run_id: dict[str, bool] = {} - for record in discovered: - run_id = record.manifest.job_run_id - allow_run = missing_pct_allowed_by_run_id.get(run_id) - if allow_run is None: - allow_run = _manifest_within_missing_pct(record.manifest, options.max_run_missing_pct) - missing_pct_allowed_by_run_id[run_id] = allow_run - if not allow_run: - skipped_run_ids_by_missing_pct.add(run_id) - continue - eligible_records.append(record) - - planned_records = [_plan_selection_record(record, env_export_map) for record in eligible_records] + planned_records = [_plan_selection_record(record, env_export_map) for record in discovered] _raise_for_latest_invalid_selection(planned_records) - work_items = _materialize_work_items(_select_latest_work_items([record for record in planned_records if record.identity.model_id])) + work_items = _materialize_work_items( + _select_latest_work_items([record for record in planned_records if record.identity.model_id]) + ) work_items, skipped_by_exclusion = _apply_exclusions( work_items, @@ -324,10 +313,10 @@ def select_work_items( ) _validate_replace_targets(work_items, options) work_items, skipped_by_delta = _apply_additive_delta(work_items, options=options, index_files=index_files) + _validate_selected_results_completeness(work_items, max_results_missing_pct=options.max_results_missing_pct) return SelectionResult( work_items=work_items, - skipped_by_missing_pct=len(skipped_run_ids_by_missing_pct), skipped_by_delta=skipped_by_delta, skipped_by_exclusion=skipped_by_exclusion, total_discovered=len(discovered), @@ -380,9 +369,7 @@ def _raise_for_latest_invalid_selection(records: Sequence[SelectionRecord]) -> N ) > _run_sort_key(_source_updated_at(current.record), current.record.manifest.job_run_id): latest_by_env[output_env_id] = planned - invalid_latest = [ - planned for planned in latest_by_env.values() if not planned.identity.model_id - ] + invalid_latest = [planned for planned in latest_by_env.values() if not planned.identity.model_id] if not invalid_latest: return @@ -469,18 +456,21 @@ def _validate_replace_targets(work_items: Sequence[PlannedWorkItem], options: Pr return if options.replace_models: - matched_models = {item.identity.model_id for item in work_items if item.identity.model_id in options.replace_models} + matched_models = { + item.identity.model_id for item in work_items if item.identity.model_id in options.replace_models + } if not matched_models: raise RuntimeError( "No selected processed outputs match --replace-model values: " f"{', '.join(sorted(options.replace_models))}." ) if options.replace_envs: - matched_envs = {item.identity.output_env_id for item in work_items if item.identity.output_env_id in options.replace_envs} + matched_envs = { + item.identity.output_env_id for item in work_items if item.identity.output_env_id in options.replace_envs + } if not matched_envs: raise RuntimeError( - "No selected processed outputs match --replace-env values: " - f"{', '.join(sorted(options.replace_envs))}." + f"No selected processed outputs match --replace-env values: {', '.join(sorted(options.replace_envs))}." ) if options.replace_models and options.replace_envs: intersection = [ @@ -517,7 +507,16 @@ def _apply_additive_delta( if _should_replace_existing_output(item.identity, options): filtered.append(item) continue - _validate_existing_output_integrity(output_path, output_dir=options.output_dir, index_files=index_files) + parquet_metadata = _read_existing_output_metadata(output_path) + _validate_existing_output_integrity( + output_path, + output_dir=options.output_dir, + index_files=index_files, + parquet_metadata=parquet_metadata, + ) + if not _existing_output_matches_selected_runs(item, parquet_metadata): + filtered.append(item) + continue skipped += 1 return filtered, skipped @@ -536,12 +535,7 @@ def _should_replace_existing_output(identity: RunIdentity, options: ProcessOptio return identity.output_env_id in options.replace_envs -def _validate_existing_output_integrity( - output_path: Path, - *, - output_dir: Path, - index_files: Mapping[str, Mapping[str, Any]], -) -> None: +def _read_existing_output_metadata(output_path: Path) -> pq.FileMetaData: try: metadata_obj = pq.ParquetFile(output_path).metadata except Exception as exc: # noqa: BLE001 @@ -555,6 +549,17 @@ def _validate_existing_output_integrity( f"Existing processed output {output_path} is missing parquet footer metadata. " "Rebuild it with --replace-model/--replace-env or re-run with --clean." ) + return metadata_obj + + +def _validate_existing_output_integrity( + output_path: Path, + *, + output_dir: Path, + index_files: Mapping[str, Mapping[str, Any]], + parquet_metadata: pq.FileMetaData | None = None, +) -> None: + metadata_obj = parquet_metadata or _read_existing_output_metadata(output_path) rel_key = output_path.relative_to(output_dir).as_posix() index_entry = index_files.get(rel_key) @@ -575,14 +580,39 @@ def _validate_existing_output_integrity( ) +def _existing_output_matches_selected_runs(item: PlannedWorkItem, parquet_metadata: pq.FileMetaData) -> bool: + existing_run_ids = _extract_exporter_source_runs(parquet_metadata) + if existing_run_ids is None: + return False + selected_run_ids = {planned.normalized.record.manifest.job_run_id for planned in item.records} + return existing_run_ids == selected_run_ids + + +def _extract_exporter_source_runs(parquet_metadata: pq.FileMetaData) -> set[str] | None: + metadata_map = parquet_metadata.metadata + if not metadata_map: + return None + payload = metadata_map.get(EXPORTER_METADATA_KEY) + if not payload: + return None + try: + exporter_metadata = json.loads(payload.decode("utf-8")) + except Exception: # noqa: BLE001 + return None + source_runs = exporter_metadata.get("source_runs") + if not isinstance(source_runs, list): + return None + run_ids = {str(run_id).strip() for run_id in source_runs if str(run_id).strip()} + return run_ids or None + + def _print_records_table( discovered: Sequence[discovery.RunRecord], selected: Sequence[discovery.RunRecord], - max_run_missing_pct: float, + max_results_missing_pct: float, *, exclude_datasets: Sequence[str] = (), exclude_models: Sequence[str] = (), - skipped_by_missing_pct: int = 0, skipped_by_delta: int = 0, skipped_by_exclusion: int = 0, ) -> None: @@ -592,8 +622,7 @@ def _print_records_table( eligible_discovered = [ rec for rec in discovered - if _manifest_within_missing_pct(rec.manifest, max_run_missing_pct) - and not (exclude_set and _record_is_excluded(rec, exclude_set)) + if not (exclude_set and _record_is_excluded(rec, exclude_set)) and not (exclude_model_set and _record_model_is_excluded(rec, exclude_model_set)) ] total_by_model: dict[str, int] = {} @@ -619,14 +648,13 @@ def _print_records_table( from rich.table import Table except Exception: logger.info( - "Processing %d job(s) across %d model(s) (max_run_missing_pct=%s; found %d eligible job(s) across %d model(s)); " - "skipped run(s) by missing pct=%d excluded=%d existing=%d.", + "Processing %d job(s) across %d model(s) (max_results_missing_pct=%s; found %d eligible job(s) across %d model(s)); " + "excluded=%d existing=%d.", selected_jobs_total, len(selected_models), - _format_missing_pct(max_run_missing_pct), + _format_missing_pct(max_results_missing_pct), discovered_jobs_total, len(models), - skipped_by_missing_pct, skipped_by_exclusion, skipped_by_delta, ) @@ -640,11 +668,11 @@ def _print_records_table( console = Console() title = ( f"Processing {selected_jobs_total} job(s) across {len(selected_models)} model(s) " - f"[dim](max_run_missing_pct={_format_missing_pct(max_run_missing_pct)})[/dim]" + f"[dim](max_results_missing_pct={_format_missing_pct(max_results_missing_pct)})[/dim]" ) title += ( - f" [dim](found {discovered_jobs_total} eligible job(s); skipped run(s) by missing pct={skipped_by_missing_pct}, " - f"excluded={skipped_by_exclusion}, existing={skipped_by_delta})[/dim]" + f" [dim](found {discovered_jobs_total} eligible job(s); excluded={skipped_by_exclusion}, " + f"existing={skipped_by_delta})[/dim]" ) table = Table(title=title, show_header=True, header_style="bold cyan", caption=None) table.add_column("Model", style="magenta") @@ -660,24 +688,6 @@ def _print_records_table( console.print(table) -def _manifest_missing_pct(manifest: discovery.RunManifestInfo) -> float | None: - if not manifest.summary_total_known: - return None - total = int(manifest.summary_total or 0) - if total <= 0: - return None - completed = max(int(manifest.summary_completed or 0), 0) - missing = max(total - completed, 0) - return 100.0 * missing / total - - -def _manifest_within_missing_pct(manifest: discovery.RunManifestInfo, max_missing_pct: float) -> bool: - missing_pct = _manifest_missing_pct(manifest) - if missing_pct is None: - return True - return missing_pct <= float(max_missing_pct) - - def _format_missing_pct(value: float) -> str: return f"{float(value):g}" @@ -697,6 +707,99 @@ def _record_model_is_excluded(record: discovery.RunRecord, exclude_model_set: se return model_is_excluded(str(record.model_id or "").strip(), exclude_model_set) +def _validate_selected_results_completeness( + work_items: Sequence[PlannedWorkItem], + *, + max_results_missing_pct: float, +) -> None: + missing_files: list[str] = [] + violations: list[str] = [] + ungateable = 0 + + for item in work_items: + for planned in item.records: + normalized = planned.normalized + record = normalized.record + if not record.results_path.exists(): + missing_files.append( + "model_id={model_id} output_env_id={output_env_id} manifest_env_id={manifest_env_id} " + "job_run_id={job_run_id} job_id={job_id} results_path={results_path}".format( + model_id=item.identity.model_id, + output_env_id=item.identity.output_env_id, + manifest_env_id=normalized.manifest_env_id, + job_run_id=record.manifest.job_run_id, + job_id=record.job_id, + results_path=record.results_path, + ) + ) + continue + + expected_rows = _expected_results_rows(normalized) + observed_rows = record.row_count + if expected_rows is None or observed_rows is None: + ungateable += 1 + continue + + missing_pct = _results_missing_pct(expected_rows=expected_rows, observed_rows=observed_rows) + if missing_pct > max_results_missing_pct: + violations.append( + "model_id={model_id} output_env_id={output_env_id} manifest_env_id={manifest_env_id} " + "job_run_id={job_run_id} job_id={job_id} expected_rows={expected_rows} " + "observed_rows={observed_rows} missing_pct={missing_pct:.2f} threshold={threshold:g}".format( + model_id=item.identity.model_id, + output_env_id=item.identity.output_env_id, + manifest_env_id=normalized.manifest_env_id, + job_run_id=record.manifest.job_run_id, + job_id=record.job_id, + expected_rows=expected_rows, + observed_rows=observed_rows, + missing_pct=missing_pct, + threshold=float(max_results_missing_pct), + ) + ) + + if ungateable: + logger.warning( + "Results row completeness gate could not be applied to %d selected record(s) because expected_rows " + "(num_examples * rollouts_per_example) or manifest row_count was unknown.", + ungateable, + ) + + if not missing_files and not violations: + return + + message_parts: list[str] = [] + if missing_files: + missing_lines = "\n".join(f" - {line}" for line in missing_files) + message_parts.append("Selected records are missing results.jsonl files:\n" + missing_lines) + if violations: + violation_lines = "\n".join(f" - {line}" for line in violations) + message_parts.append( + "Selected records exceeded --max-results-missing-pct based on manifest row_count and expected rows:\n" + + violation_lines + ) + raise RuntimeError("\n\n".join(message_parts)) + + +def _expected_results_rows(normalized: metadata.NormalizedMetadata) -> int | None: + num_examples = normalized.num_examples + rollouts_per_example = normalized.rollouts_per_example + if num_examples is None or rollouts_per_example is None: + return None + if num_examples == -1: + return None + if num_examples <= 0 or rollouts_per_example <= 0: + return None + return int(num_examples) * int(rollouts_per_example) + + +def _results_missing_pct(*, expected_rows: int, observed_rows: int) -> float: + if expected_rows <= 0: + return 0.0 + missing_rows = max(int(expected_rows) - max(int(observed_rows), 0), 0) + return 100.0 * missing_rows / int(expected_rows) + + def _process_env_group(item: PlannedWorkItem) -> tuple[list[AggregatedEnvRows], int]: """Load and aggregate all rows for a single selected dataset.""" row_buffer: list[dict[str, Any]] = [] diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index 03f9b287..1483a929 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -2024,8 +2024,8 @@ def fake_run_process(options, env_export_map): options = captured["options"] assert options.status_filter == ("completed",) assert options.processed_with_args["status"] == ["completed"] - assert options.max_run_missing_pct == pytest.approx(2.5) - assert options.processed_with_args["max_run_missing_pct"] == pytest.approx(2.5) + assert options.max_results_missing_pct == pytest.approx(2.5) + assert options.processed_with_args["max_results_missing_pct"] == pytest.approx(2.5) def test_process_cli_uses_explicit_status_filter(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: @@ -2046,7 +2046,7 @@ def fake_run_process(options, env_export_map): str(tmp_path / "processed"), "--status", "failed", - "--max-run-missing-pct", + "--max-results-missing-pct", "100", "--dry-run", ] @@ -2056,10 +2056,10 @@ def fake_run_process(options, env_export_map): options = captured["options"] assert options.status_filter == ("failed",) assert options.processed_with_args["status"] == ["failed"] - assert options.max_run_missing_pct == pytest.approx(100.0) + assert options.max_results_missing_pct == pytest.approx(100.0) -def test_process_cli_rejects_negative_max_run_missing_pct( +def test_process_cli_rejects_negative_max_results_missing_pct( tmp_path: Path, capsys: pytest.CaptureFixture[str], ) -> None: @@ -2071,14 +2071,14 @@ def test_process_cli_rejects_negative_max_run_missing_pct( str(tmp_path / "runs"), "--output-dir", str(tmp_path / "processed"), - "--max-run-missing-pct", + "--max-results-missing-pct", "-1", ] ) assert excinfo.value.code == 2 err = capsys.readouterr().err - assert "--max-run-missing-pct must be non-negative." in err + assert "--max-results-missing-pct must be non-negative." in err def test_process_config_empty_status_uses_default_filter( @@ -2317,7 +2317,7 @@ def test_process_cli_rejects_include_prompt_completion(tmp_path: Path) -> None: ("field", "value"), [ ("max_workers", "not-an-int"), - ("max_run_missing_pct", "not-a-float"), + ("max_results_missing_pct", "not-a-float"), ("hf_request_timeout", "not-a-float"), ("hf_retries", "not-an-int"), ("hf_max_files_per_commit", "not-an-int"), @@ -2348,6 +2348,97 @@ def test_process_cli_rejects_invalid_typed_config_values( assert value in err +def test_process_cli_rejects_removed_top_level_max_run_missing_pct_config_key( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + cfg_path = tmp_path / "process-removed-top-level.yaml" + cfg_path.write_text( + """ + runs_dir: runs/raw + output_dir: runs/processed + max_run_missing_pct: 2.5 + """, + encoding="utf-8", + ) + + with pytest.raises(SystemExit) as excinfo: + main.main(["process", "--config", str(cfg_path)]) + + assert excinfo.value.code == 2 + err = capsys.readouterr().err + assert "Process config field 'max_run_missing_pct' was removed" in err + assert "max_results_missing_pct" in err + + +def test_process_cli_rejects_removed_embedded_max_run_missing_pct_config_key( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + cfg_path = tmp_path / "process-removed-embedded.yaml" + cfg_path.write_text( + """ + runs_dir: runs/raw + process: + dir: processed + max_run_missing_pct: 2.5 + """, + encoding="utf-8", + ) + + with pytest.raises(SystemExit) as excinfo: + main.main(["process", "--config", str(cfg_path)]) + + assert excinfo.value.code == 2 + err = capsys.readouterr().err + assert "Process config field 'process.max_run_missing_pct' was removed" in err + assert "process.max_results_missing_pct" in err + + +def test_winrate_cli_ignores_removed_process_only_missing_pct_key( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + cfg_path = tmp_path / "winrate-process-key.yaml" + cfg_path.write_text( + """ + processed_dir: runs/processed + process: + max_run_missing_pct: 2.5 + """, + encoding="utf-8", + ) + + captured: dict[str, Any] = {} + + def fake_run_winrate( + *, processed_dir, output_dir, output_path, output_name, config, processed_at, hf_config, hf_processed_pull + ): + captured["processed_dir"] = processed_dir + return SimpleNamespace( + output_path=tmp_path / "out.json", + output_paths=[tmp_path / "out.json"], + result={"models": {}}, + datasets=[], + ) + + monkeypatch.setattr(main, "run_winrate", fake_run_winrate) + monkeypatch.setattr(main, "print_winrate_summary_markdown", lambda *_args, **_kwargs: None) + + exit_code = main.main( + [ + "winrate", + "--config", + str(cfg_path), + "--processed-at", + "2024-01-01T00:00:00Z", + ] + ) + + assert exit_code == 0 + assert captured["processed_dir"] == Path("runs/processed") + + @pytest.mark.parametrize( ("field", "value"), [ @@ -2403,9 +2494,7 @@ def fake_run(options, env_export_map): monkeypatch.setattr(main, "_load_env_export_map", lambda *_args, **_kwargs: {}) monkeypatch.setattr(main, "run_process", fake_run) - exit_code = main.main( - ["process", "--config", str(cfg_path), "--max-workers", "2", "--dry-run"] - ) + exit_code = main.main(["process", "--config", str(cfg_path), "--max-workers", "2", "--dry-run"]) assert exit_code == 0 assert captured["options"].max_workers == 2 diff --git a/tests/test_cli/test_process_discovery.py b/tests/test_cli/test_process_discovery.py index 7ffa93db..12e8da7e 100644 --- a/tests/test_cli/test_process_discovery.py +++ b/tests/test_cli/test_process_discovery.py @@ -4,7 +4,6 @@ from pathlib import Path from medarc_verifiers.cli.process.discovery import RunManifestInfo, discover_run_records -from medarc_verifiers.cli.process.pipeline import _manifest_missing_pct, _manifest_within_missing_pct def _write_json(path: Path, payload: dict) -> None: @@ -73,6 +72,7 @@ def test_discover_run_records_basic(tmp_path: Path) -> None: "ended_at": "2024-01-01T00:01:00Z", "num_examples": 10, "rollouts_per_example": 2, + "row_count": 20, } ], models={"gpt-4": {"sampling_args": {"temperature": 0.2}}}, @@ -108,6 +108,7 @@ def test_discover_run_records_basic(tmp_path: Path) -> None: assert record.has_summary is True assert record.env_args == {"fold": "dev"} assert record.sampling_args == {"temperature": 0.2} + assert record.row_count == 20 assert record.manifest.job_run_id == "job-run-123" @@ -147,27 +148,6 @@ def test_discover_run_records_filters_status(tmp_path: Path) -> None: assert filtered_none == [] -def test_manifest_missing_pct_skips_when_above_threshold() -> None: - manifest = _manifest_info(completed=97, total=100, total_known=True) - - assert _manifest_missing_pct(manifest) == 3.0 - assert _manifest_within_missing_pct(manifest, 0.0) is False - - -def test_manifest_missing_pct_keeps_runs_within_threshold() -> None: - manifest = _manifest_info(completed=39, total=40, total_known=True) - - assert _manifest_missing_pct(manifest) == 2.5 - assert _manifest_within_missing_pct(manifest, 2.5) is True - - -def test_manifest_missing_pct_unknown_total_is_permissive() -> None: - manifest = _manifest_info(completed=0, total=0, total_known=False) - - assert _manifest_missing_pct(manifest) is None - assert _manifest_within_missing_pct(manifest, 0.0) is True - - def test_discover_run_records_missing_summary_uses_manifest_status(tmp_path: Path) -> None: runs_dir = tmp_path / "runs" run_dir = runs_dir / "job-run-123" diff --git a/tests/test_cli/test_process_metadata.py b/tests/test_cli/test_process_metadata.py index 30fc1718..404fb7cf 100644 --- a/tests/test_cli/test_process_metadata.py +++ b/tests/test_cli/test_process_metadata.py @@ -62,6 +62,7 @@ def _make_record( ended_at="2024-01-01T00:00:50Z", num_examples=num_examples, rollouts_per_example=rollouts_per_example, + row_count=1, env_args=env_args or {}, sampling_args=sampling_args or {}, env_config=env_config or {}, @@ -210,3 +211,54 @@ def test_load_normalized_metadata_validation_failure_sanitizes_raw_metadata(tmp_ "endpoint_id": "cluster-a", "base_url": "https://example.invalid/v1", } + + +def test_load_normalized_metadata_keeps_zero_num_examples_from_manifest(tmp_path: Path) -> None: + record = _make_record(tmp_path, manifest_env_id="demo-env", num_examples=0, rollouts_per_example=1) + _write_json( + record.metadata_path, + { + "env_id": "demo-env", + "num_examples": 20, + "rollouts_per_example": 3, + }, + ) + + normalized = load_normalized_metadata(record) + + assert normalized.num_examples == 0 + assert normalized.rollouts_per_example == 1 + + +def test_load_normalized_metadata_keeps_zero_rollouts_from_manifest(tmp_path: Path) -> None: + record = _make_record(tmp_path, manifest_env_id="demo-env", num_examples=10, rollouts_per_example=0) + _write_json( + record.metadata_path, + { + "env_id": "demo-env", + "num_examples": 20, + "rollouts_per_example": 3, + }, + ) + + normalized = load_normalized_metadata(record) + + assert normalized.num_examples == 10 + assert normalized.rollouts_per_example == 0 + + +def test_load_normalized_metadata_keeps_all_examples_sentinel_from_manifest(tmp_path: Path) -> None: + record = _make_record(tmp_path, manifest_env_id="demo-env", num_examples=-1, rollouts_per_example=1) + _write_json( + record.metadata_path, + { + "env_id": "demo-env", + "num_examples": 20, + "rollouts_per_example": 3, + }, + ) + + normalized = load_normalized_metadata(record) + + assert normalized.num_examples == -1 + assert normalized.rollouts_per_example == 1 diff --git a/tests/test_cli/test_process_pipeline.py b/tests/test_cli/test_process_pipeline.py index 76ac0c3c..a790c29c 100644 --- a/tests/test_cli/test_process_pipeline.py +++ b/tests/test_cli/test_process_pipeline.py @@ -9,7 +9,7 @@ from medarc_verifiers.cli._manifest import MANIFEST_VERSION from medarc_verifiers.cli._schemas import EnvironmentExportConfig from medarc_verifiers.cli.process import ProcessOptions, run_process -from medarc_verifiers.cli.process.discovery import RunManifestInfo, RunRecord +from medarc_verifiers.cli.process.discovery import RunManifestInfo, RunRecord, discover_run_records from medarc_verifiers.cli.process.pipeline import select_work_items from medarc_verifiers.cli.winrate import WinrateConfig from medarc_verifiers.cli.winrate import discover_datasets, run_winrate @@ -57,6 +57,9 @@ def _run_record( total: int = 1, total_known: bool = True, updated_at: str = "2024-01-01T00:00:00Z", + row_count: int | None = 1, + num_examples: int | None = 1, + rollouts_per_example: int | None = 1, ) -> RunRecord: run_dir = Path("/tmp") / run_id results_dir = run_dir / job_id @@ -84,8 +87,9 @@ def _run_record( reason=None, started_at="2024-01-01T00:00:00Z", ended_at="2024-01-01T00:00:01Z", - num_examples=1, - rollouts_per_example=1, + num_examples=num_examples, + rollouts_per_example=rollouts_per_example, + row_count=row_count, env_args={}, sampling_args={}, env_config={"id": env_id, "module": env_id}, @@ -125,6 +129,10 @@ def _setup_run(tmp_path: Path) -> Path: "env_variant_id": "demo-env-rollout3", "env_args": {}, "results_dir": "demo-job", + "status": "completed", + "num_examples": 1, + "rollouts_per_example": 1, + "row_count": 1, } ], } @@ -168,6 +176,10 @@ def _write_run( model_id: str = "gpt-mini", status: str = "completed", results_text: str | None = None, + row_count: int | None = 1, + num_examples: int | None = 1, + rollouts_per_example: int | None = 1, + write_results: bool = True, ) -> Path: runs_dir = tmp_path / "runs" run_dir = runs_dir / run_id @@ -201,6 +213,9 @@ def _write_run( "env_args": {}, "results_dir": "demo-job", "status": status, + "row_count": row_count, + "num_examples": num_examples, + "rollouts_per_example": rollouts_per_example, } ], } @@ -209,14 +224,17 @@ def _write_run( "env_id": env_id, "env_args": {}, "sampling_args": {}, + "num_examples": num_examples, + "rollouts_per_example": rollouts_per_example, } _write_json(results_dir / "metadata.json", metadata) results_path = results_dir / "results.jsonl" - results_path.parent.mkdir(parents=True, exist_ok=True) - if results_text is None: - row = {"example_id": f"ex-{run_id}", "reward": reward} - results_text = json.dumps(row) + "\n" - results_path.write_text(results_text, encoding="utf-8") + if write_results: + results_path.parent.mkdir(parents=True, exist_ok=True) + if results_text is None: + row = {"example_id": f"ex-{run_id}", "reward": reward} + results_text = json.dumps(row) + "\n" + results_path.write_text(results_text, encoding="utf-8") return runs_dir @@ -351,51 +369,287 @@ def test_run_process_excludes_datasets(tmp_path: Path) -> None: assert result.env_groups[0].base_env_id == "keep-env" -def test_select_work_items_counts_missing_pct_skips_per_run_not_per_job() -> None: - discovered = [ - _run_record( - run_id="run-skipped", - job_id="job-a", - env_id="demo-env-a", - completed=8, - total=10, - ), - _run_record( - run_id="run-skipped", - job_id="job-b", - env_id="demo-env-b", - completed=8, - total=10, - ), - _run_record( - run_id="run-kept", - job_id="job-c", - env_id="demo-env-c", - completed=10, - total=10, - updated_at="2024-01-01T00:05:00Z", - ), - ] +def test_process_allows_results_missing_pct_within_threshold(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-98pct", + updated_at="2024-01-01T00:00:00Z", + reward=1.0, + row_count=98, + num_examples=100, + rollouts_per_example=1, + ) options = ProcessOptions( - runs_dir=Path("/tmp/runs"), - output_dir=Path("/tmp/processed"), - max_run_missing_pct=10.0, + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + max_results_missing_pct=2.5, dry_run=True, max_workers=1, ) - selection = select_work_items( - discovered, - options=options, - env_export_map={}, - index_files={}, + result = run_process(options) + + assert result.records_processed == 1 + assert result.rows_processed == 1 + + +def test_process_rejects_results_missing_pct_above_threshold(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-90pct", + updated_at="2024-01-01T00:00:00Z", + reward=1.0, + row_count=90, + num_examples=100, + rollouts_per_example=1, + ) + options = ProcessOptions( + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + max_results_missing_pct=2.5, + dry_run=True, + max_workers=1, ) - assert selection.skipped_by_missing_pct == 1 - assert selection.total_discovered == 3 - assert len(selection.work_items) == 1 - assert selection.work_items[0].identity.job_run_id == "run-kept" - assert selection.work_items[0].identity.output_env_id == "demo-env-c" + with pytest.raises(RuntimeError) as excinfo: + run_process(options) + + message = str(excinfo.value) + assert "run-90pct" in message + assert "expected_rows=100" in message + assert "observed_rows=90" in message + assert "missing_pct=10.00" in message + assert "threshold=2.5" in message + + +def test_process_allows_ungateable_record_when_expected_rows_unknown(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-unknown-expected", + updated_at="2024-01-01T00:00:00Z", + reward=1.0, + row_count=10, + num_examples=None, + rollouts_per_example=1, + ) + options = ProcessOptions( + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + dry_run=True, + max_workers=1, + ) + + result = run_process(options) + + assert result.records_processed == 1 + + +def test_process_allows_ungateable_record_when_row_count_unknown(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-unknown-observed", + updated_at="2024-01-01T00:00:00Z", + reward=1.0, + row_count=None, + num_examples=100, + rollouts_per_example=1, + ) + options = ProcessOptions( + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + dry_run=True, + max_workers=1, + ) + + result = run_process(options) + + assert result.records_processed == 1 + + +def test_process_latest_record_that_fails_gate_does_not_fall_back(tmp_path: Path) -> None: + _write_run( + tmp_path, + run_id="run-older-ok", + updated_at="2024-01-01T00:00:00Z", + reward=1.0, + row_count=100, + num_examples=100, + rollouts_per_example=1, + ) + runs_dir = _write_run( + tmp_path, + run_id="run-newer-bad", + updated_at="2024-01-02T00:00:00Z", + reward=0.0, + row_count=90, + num_examples=100, + rollouts_per_example=1, + ) + options = ProcessOptions( + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + max_results_missing_pct=2.5, + dry_run=True, + max_workers=1, + ) + + with pytest.raises(RuntimeError) as excinfo: + run_process(options) + + message = str(excinfo.value) + assert "run-newer-bad" in message + assert "run-older-ok" not in message + + +def test_process_rejects_missing_results_jsonl_for_selected_latest_record(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-missing-results", + updated_at="2024-01-02T00:00:00Z", + reward=1.0, + row_count=100, + num_examples=100, + rollouts_per_example=1, + write_results=False, + ) + options = ProcessOptions( + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + dry_run=True, + max_workers=1, + ) + + with pytest.raises(RuntimeError) as excinfo: + run_process(options) + + message = str(excinfo.value) + assert "missing results.jsonl files" in message + assert "run-missing-results" in message + + +def test_process_gate_ignores_excluded_record(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-excluded-bad", + updated_at="2024-01-02T00:00:00Z", + reward=1.0, + env_id="skip-env", + row_count=90, + num_examples=100, + rollouts_per_example=1, + ) + options = ProcessOptions( + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + exclude_datasets=("skip-env",), + max_results_missing_pct=2.5, + dry_run=True, + max_workers=1, + ) + + result = run_process(options) + + assert result.records_processed == 0 + + +def test_process_stale_delta_output_does_not_mask_newer_incomplete_run(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-initial", + updated_at="2024-01-01T00:00:00Z", + reward=1.0, + row_count=100, + num_examples=100, + rollouts_per_example=1, + ) + output_dir = tmp_path / "processed" + initial = run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) + assert initial.records_processed == 1 + + _write_run( + tmp_path, + run_id="run-newer-bad", + updated_at="2024-01-02T00:00:00Z", + reward=0.0, + row_count=90, + num_examples=100, + rollouts_per_example=1, + ) + + with pytest.raises(RuntimeError) as excinfo: + run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=output_dir, + max_results_missing_pct=2.5, + dry_run=False, + max_workers=1, + ) + ) + + message = str(excinfo.value) + assert "run-newer-bad" in message + assert "missing_pct=10.00" in message + + +def test_process_emits_single_warning_for_ungateable_selected_records( + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-unknown-observed", + updated_at="2024-01-01T00:00:00Z", + reward=1.0, + row_count=None, + num_examples=100, + rollouts_per_example=1, + ) + caplog.set_level("WARNING") + + result = run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + dry_run=True, + max_workers=1, + ) + ) + + assert result.records_processed == 1 + warnings = [ + record for record in caplog.records if "Results row completeness gate could not be applied" in record.msg + ] + assert len(warnings) == 1 + + +def test_select_work_items_rollout_gate_error_includes_output_and_manifest_ids(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-rollout-bad", + updated_at="2024-01-02T00:00:00Z", + reward=1.0, + env_id="demo-env-rollout3", + row_count=90, + num_examples=100, + rollouts_per_example=1, + ) + discovered = discover_run_records(runs_dir, filter_status=("completed",)) + options = ProcessOptions( + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + max_results_missing_pct=2.5, + dry_run=True, + max_workers=1, + ) + + with pytest.raises(RuntimeError) as excinfo: + select_work_items(discovered, options=options, env_export_map={}, index_files={}) + + message = str(excinfo.value) + assert "output_env_id=demo-env" in message + assert "manifest_env_id=demo-env-rollout3" in message + assert "job_id=demo-job" in message def test_run_process_excludes_models(tmp_path: Path) -> None: @@ -680,8 +934,10 @@ def test_process_latest_only_selects_latest_and_skips_existing_outputs(tmp_path: _write_run(tmp_path, run_id="run-3", updated_at="2024-01-04T00:00:00Z", reward=0.4) result_newer_raw = run_process(options) - assert result_newer_raw.env_summaries == [] - assert result_newer_raw.rows_processed == 0 + assert result_newer_raw.env_summaries + newer_table = pq.read_table(result_newer_raw.env_summaries[0].output_path) + assert set(newer_table.column("job_run_id").to_pylist()) == {"run-3"} + assert newer_table.column("reward").to_pylist() == [0.4] def test_process_replace_model_rebuilds_existing_output(tmp_path: Path) -> None: @@ -732,11 +988,11 @@ def test_process_replace_model_rebuilds_existing_output(tmp_path: Path) -> None: ) rebuilt = {summary.model_id for summary in result.env_summaries} - assert rebuilt == {"model-a"} + assert rebuilt == {"model-a", "model-b"} model_a_table = pq.read_table(output_dir / "model-a" / "demo-env.parquet") model_b_table = pq.read_table(output_dir / "model-b" / "demo-env.parquet") assert model_a_table.column("reward").to_pylist() == [0.9] - assert model_b_table.column("reward").to_pylist() == [0.2] + assert model_b_table.column("reward").to_pylist() == [0.8] def test_process_replace_model_and_env_rebuild_only_intersection(tmp_path: Path) -> None: @@ -803,10 +1059,14 @@ def test_process_replace_model_and_env_rebuild_only_intersection(tmp_path: Path) ) ) - assert {(summary.model_id, summary.env_id) for summary in result.env_summaries} == {("model-a", "env-a")} + assert {(summary.model_id, summary.env_id) for summary in result.env_summaries} == { + ("model-a", "env-a"), + ("model-a", "env-b"), + ("model-b", "env-a"), + } assert pq.read_table(output_dir / "model-a" / "env-a.parquet").column("reward").to_pylist() == [0.7] - assert pq.read_table(output_dir / "model-a" / "env-b.parquet").column("reward").to_pylist() == [0.2] - assert pq.read_table(output_dir / "model-b" / "env-a.parquet").column("reward").to_pylist() == [0.3] + assert pq.read_table(output_dir / "model-a" / "env-b.parquet").column("reward").to_pylist() == [0.8] + assert pq.read_table(output_dir / "model-b" / "env-a.parquet").column("reward").to_pylist() == [0.9] def test_process_fails_fast_on_existing_row_count_mismatch(tmp_path: Path) -> None: @@ -846,7 +1106,9 @@ def test_process_ignores_superseded_run_missing_model_id(tmp_path: Path) -> None _remove_model_id(tmp_path, "run-1") _write_run(tmp_path, run_id="run-2", updated_at="2024-01-02T00:00:00Z", reward=0.9) - result = run_process(ProcessOptions(runs_dir=runs_dir, output_dir=tmp_path / "processed", dry_run=False, max_workers=1)) + result = run_process( + ProcessOptions(runs_dir=runs_dir, output_dir=tmp_path / "processed", dry_run=False, max_workers=1) + ) table = pq.read_table(result.env_summaries[0].output_path) assert table.column("reward").to_pylist() == [0.9] @@ -896,7 +1158,7 @@ def test_process_selected_missing_results_still_fail(tmp_path: Path) -> None: missing_results = runs_dir / "run-1" / "demo-job" / "results.jsonl" missing_results.unlink() - with pytest.raises(FileNotFoundError, match="Missing results.jsonl"): + with pytest.raises(RuntimeError, match="Selected records are missing results.jsonl files:"): run_process(ProcessOptions(runs_dir=runs_dir, output_dir=tmp_path / "processed", dry_run=False, max_workers=1)) @@ -956,12 +1218,22 @@ def fake_read_env_index_files(processed_dir: Path): assert observed == ["workspace", "index"] return {"gpt-mini/demo-env.parquet": {"env_id": "demo-env", "model_id": "gpt-mini"}} - monkeypatch.setattr("medarc_verifiers.cli.process.workspace.prepare_output_workspace", fake_prepare_output_workspace) + monkeypatch.setattr( + "medarc_verifiers.cli.process.workspace.prepare_output_workspace", fake_prepare_output_workspace + ) monkeypatch.setattr("medarc_verifiers.cli.process.env_index.read_env_index_files", fake_read_env_index_files) + monkeypatch.setattr( + "medarc_verifiers.cli.process.pipeline._read_existing_output_metadata", + lambda *_args, **_kwargs: object(), + ) monkeypatch.setattr( "medarc_verifiers.cli.process.pipeline._validate_existing_output_integrity", lambda *_args, **_kwargs: None, ) + monkeypatch.setattr( + "medarc_verifiers.cli.process.pipeline._existing_output_matches_selected_runs", + lambda *_args, **_kwargs: True, + ) result = run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) diff --git a/tests/test_cli/test_process_rows.py b/tests/test_cli/test_process_rows.py index 1a5c8efc..c000c37c 100644 --- a/tests/test_cli/test_process_rows.py +++ b/tests/test_cli/test_process_rows.py @@ -54,6 +54,7 @@ def _build_record(tmp_path: Path, *, status: str = "completed", reason: str | No ended_at="2024-05-01T00:00:42Z", num_examples=10, rollouts_per_example=1, + row_count=1, env_args={"split": "dev", "extra_body": {}}, sampling_args={"temperature": 0.2}, env_config={}, From 0a8f5bc2829caf537ae166326737f820fc97ba16 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Sun, 1 Mar 2026 18:40:14 -0500 Subject: [PATCH 16/29] fixes and use hf_token env --- docs/medarc-eval-process.md | 14 +++- docs/medarc-eval-winrate.md | 82 ++++++++++++++------- docs/medarc-verifiers-architecture.md | 9 ++- medarc_verifiers/cli/main.py | 28 ++++++- medarc_verifiers/cli/process/pipeline.py | 10 +-- medarc_verifiers/cli/process/writer.py | 4 +- tests/test_cli/test_main.py | 94 ++++++++++++++++++++++++ tests/test_cli/test_process_pipeline.py | 59 ++++++++++++++- tests/test_process_writer_schema.py | 2 + 9 files changed, 258 insertions(+), 44 deletions(-) diff --git a/docs/medarc-eval-process.md b/docs/medarc-eval-process.md index d89c5923..abc6434d 100644 --- a/docs/medarc-eval-process.md +++ b/docs/medarc-eval-process.md @@ -19,8 +19,8 @@ medarc-eval process --dry-run 1. **Discovers** jobs in `runs/raw/` and filters by manifest status (default: `completed`) 2. **Extracts** results from each job's output files -3. **Normalizes** data into a consistent schema -4. **Writes** parquet files organized by environment and model +3. **Normalizes** data into a fixed output schema +4. **Writes** parquet files organized by model and environment 5. **Creates** an index (`env_index.json`) for downstream tools ### Output Structure @@ -37,6 +37,8 @@ runs/processed/ └── ... ``` +On-disk model and env path components are slugified, so filenames may not exactly match raw ids. + ## Common Options | Flag | Description | Default | @@ -80,6 +82,8 @@ This gate uses manifest job metadata only: It is computed per selected job record and enforced only on the latest selected run for each processed model/environment output. It does not use manifest `summary.completed` / `summary.total`, and it does not fall back to older runs if the latest one is too incomplete. +Selected records with missing `results.jsonl` fail processing immediately. + ### Latest Runs Only When multiple runs exist for the same (model, environment) pair, processing uses the latest by default. @@ -170,6 +174,8 @@ hf: private: true ``` +`hf.token` accepts either a literal token string or an environment reference like `$HF_TOKEN` / `${HF_TOKEN}`. + ### Pull Before Processing ```bash @@ -183,6 +189,8 @@ medarc-eval process --hf-repo your-org/data --hf-pull-policy pull medarc-eval process --hf-repo your-org/data --hf-pull-policy clean ``` +`prompt` only prompts when the local processed dir is already non-empty. If the output dir is empty, process pulls the HF baseline immediately. + ### Push After Processing When `--hf-repo` is set, processed files are automatically uploaded after completion. @@ -233,6 +241,8 @@ medarc-eval process # env_index.json tracks what's already processed ``` +Incremental skipping only reuses an existing parquet when its footer metadata `source_runs` still matches the newly selected run ids and the existing row count still matches `env_index.json`. + ### Replace Existing Outputs Rebuild existing outputs for specific models or datasets without using `--clean`: diff --git a/docs/medarc-eval-winrate.md b/docs/medarc-eval-winrate.md index b231d3d0..47c28f92 100644 --- a/docs/medarc-eval-winrate.md +++ b/docs/medarc-eval-winrate.md @@ -27,10 +27,11 @@ medarc-eval process ## How Win Rates Work For each pair of models (A, B) on each benchmark: -1. Find questions both models answered -2. Compare scores on each question -3. Count: A wins, B wins, ties -4. Win rate = (A wins + 0.5 × ties) / total +1. Average rollouts per `(example_id, model_id)` +2. Compare questions where at least one model has a reward +3. If one side is missing, fill it according to `--missing-policy` (`neg-inf` or `zero`) +4. Count: A wins, B wins, ties +5. Win rate = (A wins + 0.5 × ties) / total used questions The final win rate aggregates across all benchmarks using configurable weighting. @@ -41,19 +42,20 @@ Winrate also emits a missingness summary so partial dataset coverage is visible. ``` runs/processed/winrate/ -├── winrates-2026-01-14T12-00-00.json # Timestamped results -├── winrates-2026-01-14T12-00-00.csv # Spreadsheet-friendly +├── winrates-20260114T120000Z.json # Timestamped results +├── winrates-20260114T120000Z.csv # Spreadsheet-friendly ├── latest.json # Always points to newest └── latest.csv ``` +If you pass `--output /path/to/file.json`, winrate writes only that JSON file and skips `latest.json` plus all CSV outputs. + ### Output Format The JSON output includes: - Per-model aggregate win rates -- Pairwise comparison matrices -- Per-benchmark breakdowns -- Computation metadata +- Per-opponent `vs` breakdowns +- Per-dataset average rewards and question counts ## Common Options @@ -95,6 +97,8 @@ The JSON output includes: | `--partial-datasets strict` | When `--include-model` is set, drop datasets missing any included model | | `--partial-datasets include` | When `--include-model` is set, keep datasets and treat missing models as all-missing | +`--partial-datasets include` is usually paired with `--dataset-coverage per-model`. With the default `all-models` coverage, datasets missing any required model are still dropped later. + ## Using a Config File ```yaml @@ -206,39 +210,62 @@ hf: private: true ``` +`hf.token` accepts either a literal token string or an environment reference like `$HF_TOKEN` / `${HF_TOKEN}`. + `hf.winrate_dir` and `--hf-winrate-dir` both set the path inside the HF repo where `latest.json`, `latest.csv`, and timestamped winrate outputs are uploaded. ## Interpreting Results ### Win Rate Table (CSV) -| model | win_rate | vs_gpt-4o | vs_gpt-4o-mini | vs_claude | -|-------|----------|-----------|----------------|-----------| -| gpt-4o | 0.72 | - | 0.85 | 0.58 | -| gpt-4o-mini | 0.45 | 0.15 | - | 0.32 | -| claude-3-5-sonnet | 0.68 | 0.42 | 0.68 | - | +| model | weighted_winrate | simple_winrate | medqa | pubmedqa | num_datasets | +|-------|------------------|----------------|-------|-----------|--------------| +| gpt-4o | 0.72 | 0.70 | 0.84 | 0.77 | 2 | +| gpt-4o-mini | 0.45 | 0.43 | 0.61 | 0.39 | 2 | -- **win_rate**: Aggregate win rate across all models -- **vs_X columns**: Pairwise win rate against model X -- Values > 0.5 mean the row model wins more often +- **weighted_winrate** / **simple_winrate**: Aggregate mean winrate across retained datasets +- Dataset columns: Average reward on that dataset, not pairwise winrate columns +- `num_datasets`: Number of datasets retained for that model after filtering/coverage rules ### JSON Structure ```json { - "models": ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet"], - "aggregate_winrates": { - "gpt-4o": 0.72, - "gpt-4o-mini": 0.45, - "claude-3-5-sonnet": 0.68 - }, - "pairwise": { + "models": { "gpt-4o": { - "gpt-4o-mini": {"win_rate": 0.85, "wins": 850, "losses": 150, "ties": 0}, - "claude-3-5-sonnet": {"win_rate": 0.58, ...} + "mean_winrate": { + "simple_mean": 0.72, + "weighted_mean": 0.74, + "n_datasets": 2 + }, + "vs": { + "gpt-4o-mini": { + "mean_winrate": { + "simple_mean": 0.85, + "weighted_mean": 0.84 + }, + "per_dataset": { + "medqa": 0.90, + "pubmedqa": 0.80 + }, + "n_datasets": 2 + } + }, + "avg_reward_per_dataset": { + "medqa": 0.84, + "pubmedqa": 0.77 + } + } + }, + "datasets": { + "medqa": { + "avg_reward_per_model": { + "gpt-4o": 0.84, + "gpt-4o-mini": 0.61 + }, + "n_questions": 1273 } }, - "per_benchmark": { ... } } ``` @@ -254,3 +281,4 @@ hf: - Check `--min-common` isn't filtering out comparisons - Review `--missing-policy` (use `neg-inf` to penalize missing answers) - Verify models were evaluated on the same benchmark variants +- If using `--partial-datasets include`, also consider `--dataset-coverage per-model` diff --git a/docs/medarc-verifiers-architecture.md b/docs/medarc-verifiers-architecture.md index 568a8c20..2492ac08 100644 --- a/docs/medarc-verifiers-architecture.md +++ b/docs/medarc-verifiers-architecture.md @@ -173,7 +173,7 @@ Entry point: `medarc_verifiers/cli/process/pipeline.py` (via `run_process()`). - This suffix-derived rollout index is only used when rollouts are faked this way. Native verifiers rollouts (below) use the per-row JSONL field. - `medarc_verifiers/cli/process/rollout.py` 4. **Load rows from `results.jsonl`**: - - Drops large fields (`prompt`, `completion`) by default. + - Always drops large fields (`prompt`, `completion`). - Allows selecting extra per-env columns into a JSON-encoded `extras` column. - If the JSONL provides a per-row `rollout_index` (native verifiers multi-rollout runs), it is treated as authoritative and preserved. - If `rollout_index` is missing but the JSONL contains multiple rows per `example_id`, computes a data-driven `rollout_index` based on occurrence count. @@ -184,7 +184,8 @@ Entry point: `medarc_verifiers/cli/process/pipeline.py` (via `run_process()`). - When aggregating fake rollouts (manifest env ids include rollout suffixes), ensures every row has a `rollout_index` (derived from the suffix if missing) and normalizes indices to `0..K-1` within the dataset. - When aggregating native verifiers rollouts (no rollout suffixes), preserves `rollout_index` values as provided by `results.jsonl` (no normalization). 6. **Write Parquet**: - - Output path is `//.parquet`. + - Output path is `//.parquet`. + - Output columns are restricted to a fixed allowlist schema for downstream compatibility. - Adds exporter metadata under a Parquet schema metadata key. - Writes `env_index.json` (v2) and `dataset_infos.json` for HF datasets UX. - `medarc_verifiers/cli/process/writer.py`, `medarc_verifiers/cli/process/env_index.py` @@ -200,13 +201,15 @@ Processing can use `env_index.json` to do incremental updates (delta processing) Docs: `docs/medarc-eval-winrate.md`. -`medarc-eval winrate` reads dataset inventory from `env_index.json`, then computes pairwise model comparisons. +`medarc-eval winrate` reads dataset inventory from `env_index.json`, averages rollouts per `(example_id, model_id)`, then computes pairwise model comparisons. - Dataset discovery via `env_index.json`: `medarc_verifiers/cli/winrate/runner.py` - Core math + weighting policies: `medarc_verifiers/cli/winrate/api.py` - Outputs: - timestamped `winrates-.json` and `.csv` - `latest.json` and `latest.csv` + - JSON shape is model-centric: top-level `models` and `datasets` + - CSV contains aggregate winrates plus per-dataset average rewards, not pairwise `vs_*` columns ## Shared building blocks used by environments diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index 041dbe60..3161005c 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -4,6 +4,7 @@ import argparse import logging +import os import sys from pathlib import Path from textwrap import dedent @@ -796,6 +797,27 @@ def _set_if_unset(args: argparse.Namespace, attr: str, value: Any) -> None: setattr(args, attr, value) +def _resolve_config_string_value(key: str, value: Any) -> str: + resolved = str(value) + if key != "hf_token": + return resolved + + trimmed = resolved.strip() + env_var: str | None = None + if trimmed.startswith("${") and trimmed.endswith("}") and len(trimmed) > 3: + env_var = trimmed[2:-1].strip() + elif trimmed.startswith("$") and len(trimmed) > 1: + env_var = trimmed[1:].strip() + + if not env_var: + return resolved + + env_value = os.getenv(env_var) + if env_value is None: + raise ValueError(f"Config field 'hf.token' references unset environment variable '{env_var}'.") + return env_value + + def _load_config_payload(path: Path, *, mode: Literal["process", "winrate"]) -> dict[str, Any]: label = "Process config" if mode == "process" else "Winrate config" raw_payload = dict(load_mapping_file(path, label=label)) @@ -1077,7 +1099,11 @@ def _load_and_apply_config( _set_if_unset(args, attr, Path(str(payload[key]))) for key, attr in string_fields.items(): if key in payload and _is_unset(args, attr): - _set_if_unset(args, attr, str(payload[key])) + try: + resolved = _resolve_config_string_value(key, payload[key]) + except ValueError as exc: + parser.error(str(exc)) + _set_if_unset(args, attr, resolved) for key, attr in boolean_fields.items(): if key in payload and _is_unset(args, attr): _set_if_unset(args, attr, bool(payload[key])) diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index 1cd3befd..cb342c39 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -359,17 +359,17 @@ def _plan_selection_record( def _raise_for_latest_invalid_selection(records: Sequence[SelectionRecord]) -> None: - latest_by_env: dict[str, SelectionRecord] = {} + latest_by_target: dict[tuple[str, str], SelectionRecord] = {} for planned in records: - output_env_id = planned.identity.output_env_id - current = latest_by_env.get(output_env_id) + selection_key = (planned.identity.output_env_id, planned.record.job_id) + current = latest_by_target.get(selection_key) if current is None or _run_sort_key( _source_updated_at(planned.record), planned.record.manifest.job_run_id, ) > _run_sort_key(_source_updated_at(current.record), current.record.manifest.job_run_id): - latest_by_env[output_env_id] = planned + latest_by_target[selection_key] = planned - invalid_latest = [planned for planned in latest_by_env.values() if not planned.identity.model_id] + invalid_latest = [planned for planned in latest_by_target.values() if not planned.identity.model_id] if not invalid_latest: return diff --git a/medarc_verifiers/cli/process/writer.py b/medarc_verifiers/cli/process/writer.py index 1b9d4d55..a9256cdb 100644 --- a/medarc_verifiers/cli/process/writer.py +++ b/medarc_verifiers/cli/process/writer.py @@ -51,7 +51,7 @@ EXPECTED_POLARS_DTYPES: dict[str, pl.DataType] = { "env_id": pl.String, "error": pl.String, - "example_id": pl.Int64, + "example_id": pl.String, "answer": pl.String, "extras": pl.String, "generation_ms": pl.Float64, @@ -79,7 +79,7 @@ [ pa.field("env_id", pa.large_string()), pa.field("error", pa.large_string()), - pa.field("example_id", pa.int64()), + pa.field("example_id", pa.large_string()), pa.field("answer", pa.large_string()), pa.field("extras", pa.large_string()), pa.field("generation_ms", pa.float64()), diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index 1483a929..d1d14592 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -1839,6 +1839,35 @@ def fake_run(options, env_export_map): assert options.hf_config.token == "override" +def test_process_cli_resolves_hf_token_env_reference(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + cfg_path = tmp_path / "process.yaml" + cfg_path.write_text( + """ + runs_dir: runs/raw-from-config + process: + dir: processed + hf: + repo: medarc/demo + token: $HF_TOKEN + """, + encoding="utf-8", + ) + monkeypatch.setenv("HF_TOKEN", "env-secret") + + captured: dict[str, Any] = {} + + def fake_run(options, env_export_map): + captured["options"] = options + return ProcessResult(records_processed=0, rows_processed=0, env_groups=[], env_summaries=[], hf_summary=None) + + monkeypatch.setattr("medarc_verifiers.cli.main.run_process", fake_run) + + exit_code = main.main(["process", "--config", str(cfg_path), "--dry-run"]) + assert exit_code == 0 + assert captured["options"].hf_config is not None + assert captured["options"].hf_config.token == "env-secret" + + def test_winrate_cli_applies_config_defaults(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: cfg_path = tmp_path / "winrate.yaml" cfg_path.write_text( @@ -1929,6 +1958,71 @@ def fake_sync_files_to_hub(**kwargs): assert cfg.epsilon == pytest.approx(0.5) +def test_winrate_cli_resolves_hf_token_braced_env_reference( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + cfg_path = tmp_path / "winrate.yaml" + cfg_path.write_text( + """ + processed_dir: runs/processed + hf: + repo: medarc/demo + token: ${HF_TOKEN} + """, + encoding="utf-8", + ) + monkeypatch.setenv("HF_TOKEN", "env-secret") + + captured: dict[str, Any] = {} + + def fake_run_winrate( + *, processed_dir, output_dir, output_path, output_name, config, processed_at, hf_config, hf_processed_pull + ): + captured["hf_config"] = hf_config + return SimpleNamespace( + output_path=tmp_path / "out.json", + output_paths=[tmp_path / "out.json"], + result={"models": {}}, + datasets=[], + ) + + monkeypatch.setattr(main, "run_winrate", fake_run_winrate) + monkeypatch.setattr(main, "print_winrate_summary_markdown", lambda *_args, **_kwargs: None) + monkeypatch.setattr(main, "sync_files_to_hub", lambda **_kwargs: None) + + exit_code = main.main(["winrate", "--config", str(cfg_path), "--processed-at", "2024-01-01T00:00:00Z"]) + assert exit_code == 0 + assert captured["hf_config"] is not None + assert captured["hf_config"].token == "env-secret" + + +def test_process_cli_rejects_unset_hf_token_env_reference( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + cfg_path = tmp_path / "process.yaml" + cfg_path.write_text( + """ + runs_dir: runs/raw-from-config + process: + dir: processed + hf: + repo: medarc/demo + token: $HF_TOKEN + """, + encoding="utf-8", + ) + monkeypatch.delenv("HF_TOKEN", raising=False) + + with pytest.raises(SystemExit) as excinfo: + main.main(["process", "--config", str(cfg_path), "--dry-run"]) + + assert excinfo.value.code == 2 + assert "references unset environment variable 'HF_TOKEN'" in capsys.readouterr().err + + def test_expand_embedded_process_config_promotes_process_section() -> None: payload = { "runs_dir": "runs/raw", diff --git a/tests/test_cli/test_process_pipeline.py b/tests/test_cli/test_process_pipeline.py index a790c29c..602a7934 100644 --- a/tests/test_cli/test_process_pipeline.py +++ b/tests/test_cli/test_process_pipeline.py @@ -180,10 +180,11 @@ def _write_run( num_examples: int | None = 1, rollouts_per_example: int | None = 1, write_results: bool = True, + job_id: str = "demo-job", ) -> Path: runs_dir = tmp_path / "runs" run_dir = runs_dir / run_id - results_dir = run_dir / "demo-job" + results_dir = run_dir / job_id manifest = { "version": MANIFEST_VERSION, "run_id": run_id, @@ -205,13 +206,13 @@ def _write_run( }, "jobs": [ { - "job_id": "demo-job", + "job_id": job_id, "model_id": model_id, "env_id": env_id, "env_template_id": "demo-env-template", "env_variant_id": env_id, "env_args": {}, - "results_dir": "demo-job", + "results_dir": job_id, "status": status, "row_count": row_count, "num_examples": num_examples, @@ -245,7 +246,8 @@ def _remove_model_id(tmp_path: Path, run_id: str) -> None: manifest["models"] = {} manifest_path.write_text(json.dumps(manifest), encoding="utf-8") - metadata_path = tmp_path / "runs" / run_id / "demo-job" / "metadata.json" + job_id = manifest["jobs"][0]["job_id"] + metadata_path = tmp_path / "runs" / run_id / job_id / "metadata.json" metadata = json.loads(metadata_path.read_text(encoding="utf-8")) metadata.pop("model", None) metadata_path.write_text(json.dumps(metadata), encoding="utf-8") @@ -314,6 +316,24 @@ def test_run_process_writes_version_info_column(tmp_path: Path) -> None: assert payload["vf_version"] == "0.1.10" +def test_run_process_preserves_string_example_id_in_parquet(tmp_path: Path) -> None: + runs_dir = _setup_run(tmp_path) + output_dir = tmp_path / "processed" + + result = run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=output_dir, + dry_run=False, + max_workers=1, + ) + ) + + table = pq.read_table(result.env_summaries[0].output_path) + assert table.column("example_id").to_pylist() == ["ex-1"] + assert str(table.schema.field("example_id").type) == "large_string" + + def test_run_process_backward_compat_without_version_info(tmp_path: Path) -> None: runs_dir = _write_run( tmp_path, @@ -1123,6 +1143,37 @@ def test_process_latest_missing_model_id_fails_clearly(tmp_path: Path) -> None: run_process(ProcessOptions(runs_dir=runs_dir, output_dir=tmp_path / "processed", dry_run=False, max_workers=1)) +def test_process_latest_missing_model_id_not_masked_by_newer_other_job(tmp_path: Path) -> None: + runs_dir = _write_run( + tmp_path, + run_id="run-model-a-old", + updated_at="2024-01-01T00:00:00Z", + reward=0.1, + model_id="model-a", + job_id="job-model-a", + ) + _write_run( + tmp_path, + run_id="run-model-a-bad", + updated_at="2024-01-02T00:00:00Z", + reward=0.2, + model_id="model-a", + job_id="job-model-a", + ) + _remove_model_id(tmp_path, "run-model-a-bad") + _write_run( + tmp_path, + run_id="run-model-b-good", + updated_at="2024-01-03T00:00:00Z", + reward=0.9, + model_id="model-b", + job_id="job-model-b", + ) + + with pytest.raises(RuntimeError, match=r"Missing model_id for run \(job_run_id=run-model-a-bad, job_id=job-model-a,"): + run_process(ProcessOptions(runs_dir=runs_dir, output_dir=tmp_path / "processed", dry_run=False, max_workers=1)) + + def test_process_ignores_invalid_incomplete_run_by_default(tmp_path: Path) -> None: runs_dir = _write_run( tmp_path, diff --git a/tests/test_process_writer_schema.py b/tests/test_process_writer_schema.py index c38a18aa..b4607c5f 100644 --- a/tests/test_process_writer_schema.py +++ b/tests/test_process_writer_schema.py @@ -44,6 +44,7 @@ def test_process_writer_emits_stable_schema_with_all_null_values(tmp_path) -> No summaries = writer.write_env_groups([group], config, write_index=False) schema = pq.ParquetFile(summaries[0].output_path).schema_arrow + assert str(schema.field("example_id").type) == "large_string" assert str(schema.field("extras").type) == "large_string" assert str(schema.field("answer").type) == "large_string" assert str(schema.field("error").type) == "large_string" @@ -64,6 +65,7 @@ def test_process_writer_emits_stable_schema_for_empty_groups(tmp_path) -> None: summaries = writer.write_env_groups([group], config, write_index=False) schema = pq.ParquetFile(summaries[0].output_path).schema_arrow + assert str(schema.field("example_id").type) == "large_string" assert str(schema.field("extras").type) == "large_string" assert str(schema.field("answer").type) == "large_string" assert str(schema.field("error").type) == "large_string" From 3ae2d9848e21baeaabee9be21955c65dbf02a472 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Mon, 2 Mar 2026 14:11:47 -0500 Subject: [PATCH 17/29] bug fixes and improvements --- medarc_verifiers/cli/process/discovery.py | 2 + medarc_verifiers/cli/process/metadata.py | 41 +++++++++++++++++++ medarc_verifiers/cli/process/pipeline.py | 34 ++++++++++++++- tests/test_cli/test_process_discovery.py | 2 + tests/test_cli/test_process_metadata.py | 50 +++++++++++++++++++++++ tests/test_cli/test_process_pipeline.py | 32 +++++++++++++++ tests/test_cli/test_process_rows.py | 1 + 7 files changed, 161 insertions(+), 1 deletion(-) diff --git a/medarc_verifiers/cli/process/discovery.py b/medarc_verifiers/cli/process/discovery.py index 4056bb71..7aba00f8 100644 --- a/medarc_verifiers/cli/process/discovery.py +++ b/medarc_verifiers/cli/process/discovery.py @@ -65,6 +65,7 @@ class RunRecord: reason: str | None started_at: str | None ended_at: str | None + avg_reward: float | None num_examples: int | None rollouts_per_example: int | None row_count: int | None @@ -185,6 +186,7 @@ def _build_run_record( reason=reason or job_entry.reason, started_at=job_entry.started_at, ended_at=job_entry.ended_at, + avg_reward=job_entry.avg_reward, num_examples=job_entry.num_examples, rollouts_per_example=job_entry.rollouts_per_example, row_count=job_entry.row_count, diff --git a/medarc_verifiers/cli/process/metadata.py b/medarc_verifiers/cli/process/metadata.py index e591db0d..118e63f8 100644 --- a/medarc_verifiers/cli/process/metadata.py +++ b/medarc_verifiers/cli/process/metadata.py @@ -4,6 +4,7 @@ import json import logging +import math from dataclasses import dataclass from pathlib import Path from typing import Any, Mapping, MutableMapping @@ -21,6 +22,7 @@ class _MetadataPayload(BaseModel): env_id: str | None = None model: str | None = None + avg_reward: float | None = None version_info: dict[str, str | None] | None = None env_args: dict[str, Any] = Field(default_factory=dict) num_examples: int | None = None @@ -152,6 +154,7 @@ def _resolve_metadata_context( combine_rollouts: bool, ) -> _ResolvedMetadataContext: metadata_payload, raw_metadata = _load_metadata(record) + _warn_manifest_metadata_result_mismatch(record, metadata_payload) metadata_env_id = metadata_payload.env_id if metadata_payload else None metadata_model = metadata_payload.model if metadata_payload else None env_args = _merge_mappings( @@ -258,6 +261,44 @@ def _prefer_manifest_value(primary: int | None, fallback: int | None) -> int | N return fallback +def _warn_manifest_metadata_result_mismatch(record: RunRecord, metadata_payload: _MetadataPayload | None) -> None: + if metadata_payload is None: + return + + mismatches: list[str] = [] + if _has_float_mismatch(record.avg_reward, metadata_payload.avg_reward): + mismatches.append( + f"avg_reward manifest={record.avg_reward!r} metadata={metadata_payload.avg_reward!r}" + ) + if _has_int_mismatch(record.num_examples, metadata_payload.num_examples): + mismatches.append( + f"num_examples manifest={record.num_examples!r} metadata={metadata_payload.num_examples!r}" + ) + if not mismatches: + return + + logger.warning( + "Manifest/metadata result mismatch for process input " + "(job_run_id=%s, job_id=%s, metadata=%s): %s", + record.manifest.job_run_id, + record.job_id, + record.metadata_path, + "; ".join(mismatches), + ) + + +def _has_float_mismatch(left: float | None, right: float | None) -> bool: + if left is None or right is None: + return False + return not math.isclose(left, right, rel_tol=1e-9, abs_tol=1e-9) + + +def _has_int_mismatch(left: int | None, right: int | None) -> bool: + if left is None or right is None: + return False + return left != right + + def _extract_env_config_id(env_config: Mapping[str, Any] | None) -> str | None: if not env_config: return None diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index cb342c39..505de389 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -21,6 +21,7 @@ from medarc_verifiers.cli.process.metadata import RunIdentity from medarc_verifiers.cli.process.writer import EXPORTER_METADATA_KEY, EnvWriteSummary, WriterConfig from medarc_verifiers.cli.utils.shared import ( + count_jsonl_rows, dataset_is_excluded, model_is_excluded, normalize_dataset_ids, @@ -735,7 +736,7 @@ def _validate_selected_results_completeness( continue expected_rows = _expected_results_rows(normalized) - observed_rows = record.row_count + observed_rows = _completeness_observed_rows(record, expected_rows=expected_rows, threshold=max_results_missing_pct) if expected_rows is None or observed_rows is None: ungateable += 1 continue @@ -800,6 +801,37 @@ def _results_missing_pct(*, expected_rows: int, observed_rows: int) -> float: return 100.0 * missing_rows / int(expected_rows) +def _completeness_observed_rows( + record: discovery.RunRecord, + *, + expected_rows: int | None, + threshold: float, +) -> int | None: + observed_rows = record.row_count + if expected_rows is None or observed_rows is None: + return observed_rows + + missing_pct = _results_missing_pct(expected_rows=expected_rows, observed_rows=observed_rows) + if missing_pct <= threshold: + return observed_rows + + actual_rows = count_jsonl_rows(record.results_path) + if actual_rows is None or actual_rows == observed_rows: + return observed_rows + + logger.warning( + "Manifest row_count mismatch for process input " + "(job_run_id=%s, job_id=%s, results_path=%s): manifest row_count=%s actual_rows=%s. " + "Using actual_rows for completeness validation.", + record.manifest.job_run_id, + record.job_id, + record.results_path, + observed_rows, + actual_rows, + ) + return actual_rows + + def _process_env_group(item: PlannedWorkItem) -> tuple[list[AggregatedEnvRows], int]: """Load and aggregate all rows for a single selected dataset.""" row_buffer: list[dict[str, Any]] = [] diff --git a/tests/test_cli/test_process_discovery.py b/tests/test_cli/test_process_discovery.py index 12e8da7e..a41a6bed 100644 --- a/tests/test_cli/test_process_discovery.py +++ b/tests/test_cli/test_process_discovery.py @@ -70,6 +70,7 @@ def test_discover_run_records_basic(tmp_path: Path) -> None: "status": "completed", "started_at": "2024-01-01T00:00:30Z", "ended_at": "2024-01-01T00:01:00Z", + "avg_reward": 0.75, "num_examples": 10, "rollouts_per_example": 2, "row_count": 20, @@ -108,6 +109,7 @@ def test_discover_run_records_basic(tmp_path: Path) -> None: assert record.has_summary is True assert record.env_args == {"fold": "dev"} assert record.sampling_args == {"temperature": 0.2} + assert record.avg_reward == 0.75 assert record.row_count == 20 assert record.manifest.job_run_id == "job-run-123" diff --git a/tests/test_cli/test_process_metadata.py b/tests/test_cli/test_process_metadata.py index 404fb7cf..e69b8d46 100644 --- a/tests/test_cli/test_process_metadata.py +++ b/tests/test_cli/test_process_metadata.py @@ -1,8 +1,11 @@ from __future__ import annotations import json +import logging from pathlib import Path +import pytest + from medarc_verifiers.cli.process.discovery import RunManifestInfo, RunRecord from medarc_verifiers.cli.process.metadata import load_normalized_metadata @@ -19,6 +22,7 @@ def _make_record( results_dir_name: str = "job-abc", env_args: dict | None = None, sampling_args: dict | None = None, + avg_reward: float | None = None, num_examples: int | None = 10, rollouts_per_example: int | None = None, has_metadata: bool = True, @@ -60,6 +64,7 @@ def _make_record( reason=None, started_at="2024-01-01T00:00:10Z", ended_at="2024-01-01T00:00:50Z", + avg_reward=avg_reward, num_examples=num_examples, rollouts_per_example=rollouts_per_example, row_count=1, @@ -76,6 +81,7 @@ def test_load_normalized_metadata_prefers_manifest_fields(tmp_path: Path) -> Non tmp_path, env_args={"difficulty": "hard"}, sampling_args={"temperature": 0.1}, + avg_reward=0.8, rollouts_per_example=None, ) _write_json( @@ -85,6 +91,7 @@ def test_load_normalized_metadata_prefers_manifest_fields(tmp_path: Path) -> Non "model": "gpt-4o-mini", "env_args": {"difficulty": "easy", "split": "dev"}, "sampling_args": {"temperature": 0.9, "top_p": 0.95}, + "avg_reward": 0.8, "num_examples": 20, "rollouts_per_example": 2, }, @@ -262,3 +269,46 @@ def test_load_normalized_metadata_keeps_all_examples_sentinel_from_manifest(tmp_ assert normalized.num_examples == -1 assert normalized.rollouts_per_example == 1 + + +def test_load_normalized_metadata_warns_on_avg_reward_and_num_examples_mismatch( + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + record = _make_record(tmp_path, manifest_env_id="demo-env", avg_reward=0.8, num_examples=10) + _write_json( + record.metadata_path, + { + "env_id": "demo-env", + "avg_reward": 0.7, + "num_examples": 12, + }, + ) + + with caplog.at_level(logging.WARNING): + normalized = load_normalized_metadata(record) + + assert normalized.num_examples == 10 + assert "Manifest/metadata result mismatch for process input" in caplog.text + assert "avg_reward manifest=0.8 metadata=0.7" in caplog.text + assert "num_examples manifest=10 metadata=12" in caplog.text + + +def test_load_normalized_metadata_does_not_warn_when_result_fields_match( + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + record = _make_record(tmp_path, manifest_env_id="demo-env", avg_reward=0.8, num_examples=10) + _write_json( + record.metadata_path, + { + "env_id": "demo-env", + "avg_reward": 0.8, + "num_examples": 10, + }, + ) + + with caplog.at_level(logging.WARNING): + load_normalized_metadata(record) + + assert "Manifest/metadata result mismatch for process input" not in caplog.text diff --git a/tests/test_cli/test_process_pipeline.py b/tests/test_cli/test_process_pipeline.py index 602a7934..a8d9666b 100644 --- a/tests/test_cli/test_process_pipeline.py +++ b/tests/test_cli/test_process_pipeline.py @@ -87,6 +87,7 @@ def _run_record( reason=None, started_at="2024-01-01T00:00:00Z", ended_at="2024-01-01T00:00:01Z", + avg_reward=1.0, num_examples=num_examples, rollouts_per_example=rollouts_per_example, row_count=row_count, @@ -643,6 +644,37 @@ def test_process_emits_single_warning_for_ungateable_selected_records( assert len(warnings) == 1 +def test_process_uses_actual_results_rows_when_manifest_row_count_is_stale( + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + results_text = "".join(json.dumps({"example_id": f"ex-{index}", "reward": 1.0}) + "\n" for index in range(100)) + runs_dir = _write_run( + tmp_path, + run_id="run-stale-row-count", + updated_at="2024-01-01T00:00:00Z", + reward=1.0, + row_count=90, + num_examples=100, + rollouts_per_example=1, + results_text=results_text, + ) + caplog.set_level("WARNING") + + result = run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=tmp_path / "processed", + dry_run=True, + max_workers=1, + ) + ) + + assert result.records_processed == 1 + assert "Manifest row_count mismatch for process input" in caplog.text + assert "manifest row_count=90 actual_rows=100" in caplog.text + + def test_select_work_items_rollout_gate_error_includes_output_and_manifest_ids(tmp_path: Path) -> None: runs_dir = _write_run( tmp_path, diff --git a/tests/test_cli/test_process_rows.py b/tests/test_cli/test_process_rows.py index c000c37c..11cfdb86 100644 --- a/tests/test_cli/test_process_rows.py +++ b/tests/test_cli/test_process_rows.py @@ -52,6 +52,7 @@ def _build_record(tmp_path: Path, *, status: str = "completed", reason: str | No reason=reason, started_at="2024-05-01T00:00:30Z", ended_at="2024-05-01T00:00:42Z", + avg_reward=0.5, num_examples=10, rollouts_per_example=1, row_count=1, From ab661a205df179840d008b30f20a67c244025dc2 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Wed, 4 Mar 2026 11:24:51 -0500 Subject: [PATCH 18/29] restart interupted uploads --- docs/medarc-eval-process.md | 14 ++ medarc_verifiers/cli/hf/__init__.py | 4 + medarc_verifiers/cli/hf/sync.py | 268 ++++++++++++++++++---- medarc_verifiers/cli/main.py | 2 +- medarc_verifiers/cli/process/pipeline.py | 14 +- medarc_verifiers/cli/process/workspace.py | 76 ++++-- tests/test_cli/test_main.py | 5 + tests/test_cli/test_process_hf_sync.py | 188 ++++++++++++++- tests/test_cli/test_process_pipeline.py | 126 +++++++++- tests/test_cli/test_process_workspace.py | 185 +++++++++++++++ 10 files changed, 814 insertions(+), 68 deletions(-) diff --git a/docs/medarc-eval-process.md b/docs/medarc-eval-process.md index abc6434d..1b57c7ba 100644 --- a/docs/medarc-eval-process.md +++ b/docs/medarc-eval-process.md @@ -187,10 +187,24 @@ medarc-eval process --hf-repo your-org/data --hf-pull-policy pull # Start fresh (ignore remote) medarc-eval process --hf-repo your-org/data --hf-pull-policy clean + +# Resume a previously failed HF upload without pulling or cleaning +medarc-eval process --hf-repo your-org/data --hf-pull-policy continue-upload ``` `prompt` only prompts when the local processed dir is already non-empty. If the output dir is empty, process pulls the HF baseline immediately. +When `prompt` is used with a non-empty local processed dir, the menu may show: + +- `pull`: download missing baseline data without deleting local files +- `clean`: redownload everything after deleting local files +- `upload`: keep local processed outputs and resume/upload pending HF artifacts + +`upload` is shown only when local parquet files appear to be missing remotely or have a different remote `lfs.sha256`. Recovery uploads the union of: + +- parquet files that were already pending before the current run started +- files touched by the current process run, including `env_index.json` and `dataset_infos.json` when rewritten + ### Push After Processing When `--hf-repo` is set, processed files are automatically uploaded after completion. diff --git a/medarc_verifiers/cli/hf/__init__.py b/medarc_verifiers/cli/hf/__init__.py index 11e0a6aa..47009eb4 100644 --- a/medarc_verifiers/cli/hf/__init__.py +++ b/medarc_verifiers/cli/hf/__init__.py @@ -3,6 +3,8 @@ from .sync import ( # noqa: F401 HFSyncConfig, HFSyncSummary, + collect_changed_output_files, + compute_pending_parquet_uploads, download_hf_repo, sync_files_to_hub, sync_to_hub, @@ -11,6 +13,8 @@ __all__ = [ "HFSyncConfig", "HFSyncSummary", + "collect_changed_output_files", + "compute_pending_parquet_uploads", "sync_files_to_hub", "sync_to_hub", "download_hf_repo", diff --git a/medarc_verifiers/cli/hf/sync.py b/medarc_verifiers/cli/hf/sync.py index 0ae54fb7..44db7314 100644 --- a/medarc_verifiers/cli/hf/sync.py +++ b/medarc_verifiers/cli/hf/sync.py @@ -2,12 +2,13 @@ from __future__ import annotations +import hashlib import logging import tempfile import time from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence from medarc_verifiers.utils.pathing import resolve_under @@ -62,6 +63,29 @@ def _is_repo_not_found_error(exc: BaseException) -> bool: return False +def _status_code_from_exc(exc: BaseException) -> int | None: + response = getattr(exc, "response", None) + status_code = getattr(response, "status_code", None) + if status_code is None: + status_code = getattr(exc, "status_code", None) + try: + return int(status_code) if status_code is not None else None + except Exception: + return None + + +def _is_transient_hf_error(exc: BaseException) -> bool: + status_code = _status_code_from_exc(exc) + if status_code == 429 or (status_code is not None and 500 <= status_code < 600): + return True + try: + import httpx # type: ignore[import-not-found] + + return isinstance(exc, (httpx.TimeoutException, httpx.TransportError)) + except Exception: + return False + + def _confirm_create_repo( *, repo_id: str, @@ -155,6 +179,181 @@ class HFSyncSummary: files: Sequence[str] +def _local_sha256(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as handle: + for chunk in iter(lambda: handle.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def _repo_tree_entry_path(entry: Any) -> str | None: + for attr in ("path", "rfilename"): + value = getattr(entry, attr, None) + if isinstance(value, str) and value.strip(): + return Path(value).as_posix() + if isinstance(entry, dict): + value = entry.get("path") or entry.get("rfilename") + if isinstance(value, str) and value.strip(): + return Path(value).as_posix() + return None + + +def _repo_tree_entry_lfs_sha256(entry: Any) -> str | None: + lfs = getattr(entry, "lfs", None) + if lfs is None and isinstance(entry, dict): + lfs = entry.get("lfs") + if isinstance(lfs, dict): + sha256 = lfs.get("sha256") + return str(sha256) if sha256 else None + sha256 = getattr(lfs, "sha256", None) + return str(sha256) if sha256 else None + + +def _normalize_output_files(output_dir: Path, files: Iterable[str | Path]) -> list[str]: + normalized: list[str] = [] + for path in files: + candidate = Path(path) + if candidate.is_absolute(): + try: + rel_path = candidate.relative_to(output_dir) + except ValueError: + continue + else: + # Accept caller inputs like "runs/processed/foo.parquet" when output_dir is also relative. + output_parts = output_dir.parts + if output_parts and candidate.parts[: len(output_parts)] == output_parts: + try: + rel_path = candidate.relative_to(output_dir) + except ValueError: + continue + else: + rel_path = candidate + rel_text = rel_path.as_posix() + if rel_text: + normalized.append(rel_text) + return sorted(set(normalized)) + + +def _prepare_upload_file_entries(output_dir: Path, files: Sequence[str | Path]) -> list[tuple[str, Path]]: + output_dir = output_dir.resolve() + prepared: list[tuple[str, Path]] = [] + seen: set[str] = set() + for path in files: + candidate = Path(path) + raw_text = candidate.as_posix() + if not raw_text: + continue + if candidate.is_absolute(): + try: + rel_path = candidate.resolve().relative_to(output_dir).as_posix() + except ValueError as exc: + raise ValueError(f"Upload file path must be under output_dir: {candidate}") from exc + else: + resolved = resolve_under(output_dir, raw_text) + if resolved is None: + raise ValueError(f"Upload file path must be relative to output_dir without traversal: {raw_text!r}") + try: + rel_path = resolved.resolve().relative_to(output_dir).as_posix() + except ValueError as exc: + raise ValueError(f"Upload file path resolves outside output_dir: {raw_text!r}") from exc + local_path = (output_dir / rel_path).resolve() + try: + local_path.relative_to(output_dir) + except ValueError as exc: + raise ValueError(f"Upload file path resolves outside output_dir: {raw_text!r}") from exc + if rel_path in seen: + continue + prepared.append((rel_path, local_path)) + seen.add(rel_path) + return prepared + + +def collect_changed_output_files( + env_summaries: Sequence[EnvWriteSummary], + *, + output_dir: Path, + metadata_paths: Sequence[Path] | None = None, +) -> list[str]: + changed_paths = {summary.output_path for summary in env_summaries if summary.changed} + if metadata_paths: + for path in metadata_paths: + candidate = Path(path) + if not candidate.is_absolute(): + output_parts = output_dir.parts + if output_parts and candidate.parts[: len(output_parts)] != output_parts: + candidate = output_dir / candidate + changed_paths.add(candidate) + return _normalize_output_files(output_dir, changed_paths) + + +def _collect_changed_output_files( + env_summaries: Sequence[EnvWriteSummary], + *, + output_dir: Path, + metadata_paths: Sequence[Path] | None = None, +) -> list[str]: + return collect_changed_output_files(env_summaries, output_dir=output_dir, metadata_paths=metadata_paths) + + +def compute_pending_parquet_uploads( + output_dir: Path, + repo_id: str, + branch: str | None, + token: str | None, +) -> set[str]: + """Return local parquet paths that are missing remotely or differ from remote lfs.sha256.""" + output_dir = Path(output_dir) + local_parquets = sorted(path for path in output_dir.rglob("*.parquet") if path.is_file()) + if not local_parquets: + return set() + + try: + from huggingface_hub import HfApi # type: ignore[import-not-found] + except Exception as exc: # noqa: BLE001 + raise ImportError("huggingface_hub is required for HF upload recovery.") from exc + + api = HfApi(token=token) + list_kwargs = { + "repo_id": repo_id, + "repo_type": "dataset", + "revision": branch, + "recursive": True, + "expand": True, + } + try: + try: + tree_entries = list(api.list_repo_tree(**list_kwargs)) + except TypeError as exc: + if "expand" not in str(exc): + raise + list_kwargs.pop("expand", None) + tree_entries = list(api.list_repo_tree(**list_kwargs)) + except Exception as exc: # noqa: BLE001 + if _is_repo_not_found_error(exc): + tree_entries = [] + else: + raise + + remote_parquets: dict[str, str | None] = {} + for entry in tree_entries: + rel_path = _repo_tree_entry_path(entry) + if not rel_path or not rel_path.endswith(".parquet"): + continue + remote_parquets[rel_path] = _repo_tree_entry_lfs_sha256(entry) + + pending: set[str] = set() + for parquet_path in local_parquets: + rel_path = parquet_path.relative_to(output_dir).as_posix() + if rel_path not in remote_parquets: + pending.add(rel_path) + continue + remote_sha256 = remote_parquets[rel_path] + if remote_sha256 is None or remote_sha256 != _local_sha256(parquet_path): + pending.add(rel_path) + return pending + + def sync_files_to_hub( *, repo_id: str, @@ -180,11 +379,9 @@ def sync_files_to_hub( if not repo_id: logger.debug("HF sync skipped: no repo_id provided.") return True - file_list = [] - for path in files: - rel_path = Path(path).as_posix() if not isinstance(path, str) else Path(path).as_posix() - if rel_path: - file_list.append(rel_path) + output_dir = Path(output_dir) + prepared_files = _prepare_upload_file_entries(output_dir, files) + file_list = [rel_path for rel_path, _ in prepared_files] if not file_list: logger.debug("HF sync skipped: no files provided.") return True @@ -203,6 +400,8 @@ def sync_files_to_hub( api = HfApi(token=token) repo_prefix = _normalize_repo_path_prefix(path_in_repo_prefix) + file_map = dict(prepared_files) + if max_files_per_commit is None or max_files_per_commit <= 0: batches = [file_list] else: @@ -210,13 +409,11 @@ def sync_files_to_hub( file_list[index : index + max_files_per_commit] for index in range(0, len(file_list), max_files_per_commit) ] - output_dir = Path(output_dir) - for batch_index, batch_files in enumerate(batches, start=1): operations = [ CommitOperationAdd( path_in_repo=_join_repo_path(repo_prefix, rel_path), - path_or_fileobj=str(output_dir / rel_path), + path_or_fileobj=str(file_map[rel_path]), ) for rel_path in batch_files ] @@ -257,13 +454,7 @@ def sync_files_to_hub( ) # Retry the commit immediately after repo creation. continue - try: - import httpx # type: ignore[import-not-found] - - is_retryable = isinstance(exc, (httpx.TimeoutException, httpx.TransportError)) - except Exception: - is_retryable = False - if not is_retryable or attempt >= int(retries): + if not _is_transient_hf_error(exc) or attempt >= int(retries): raise delay = _sleep_backoff_seconds(attempt) logger.warning( @@ -303,6 +494,7 @@ def sync_to_hub( *, output_dir: Path, metadata_paths: Sequence[Path] | None = None, + files: Sequence[str | Path] | None = None, is_tty: bool = False, assume_yes: bool = False, prompt_func: Callable[[str], str] | None = None, @@ -311,37 +503,27 @@ def sync_to_hub( if not config.repo_id: logger.debug("HF sync skipped: no repo_id provided.") return None - if not env_summaries: - logger.debug("HF sync skipped: no environment summaries available.") - return None - if all(summary.dry_run for summary in env_summaries): - logger.debug("HF sync skipped: only dry-run summaries available.") + if config.dry_run: + logger.debug("HF sync dry-run; skipping summary generation and upload.") return None + output_dir = Path(output_dir) changed = [summary for summary in env_summaries if summary.changed] - if not changed: - logger.debug("HF sync skipped: no changed outputs.") - return None + if files is None: + if not env_summaries: + logger.debug("HF sync skipped: no environment summaries available.") + return None + if all(summary.dry_run for summary in env_summaries): + logger.debug("HF sync skipped: only dry-run summaries available.") + return None + files = collect_changed_output_files(env_summaries, output_dir=output_dir, metadata_paths=metadata_paths) + else: + files = _normalize_output_files(output_dir, files) - output_dir = Path(output_dir) - changed_paths = {summary.output_path for summary in changed} - if metadata_paths: - for path in metadata_paths: - candidate = Path(path) - if not candidate.is_absolute(): - output_parts = output_dir.parts - if output_parts and candidate.parts[: len(output_parts)] != output_parts: - candidate = output_dir / candidate - changed_paths.add(candidate) + if not files: + logger.debug("HF sync skipped: no files selected for upload.") + return None - files = [] - for path in changed_paths: - try: - rel_path = path.relative_to(output_dir) - except ValueError: - continue - files.append(rel_path.as_posix()) - files = sorted(set(files)) summary = HFSyncSummary( repo_id=config.repo_id, strategy="file", @@ -418,6 +600,8 @@ def download_hf_repo( __all__ = [ "HFSyncSummary", "HFSyncConfig", + "collect_changed_output_files", + "compute_pending_parquet_uploads", "sync_files_to_hub", "sync_to_hub", ] diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index 3161005c..97ca6e50 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -325,7 +325,7 @@ def build_process_parser() -> argparse.ArgumentParser: parser.add_argument("--hf-repo", default=None, help="Hugging Face repo id for dataset sync.") parser.add_argument( "--hf-pull-policy", - choices=("prompt", "pull", "clean"), + choices=("prompt", "pull", "clean", "continue-upload"), default=None, help="Baseline policy when output dir is non-empty in HF mode.", ) diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index 505de389..b23fc16d 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -135,8 +135,9 @@ def run_process( env_export_map = env_export_map or {} def _run_pipeline() -> ProcessResult: + baseline_result: workspace.BaselineResult | None = None if not options.dry_run: - workspace.prepare_output_workspace( + preparation = workspace.prepare_output_workspace( output_dir=options.output_dir, hf_config=options.hf_config, pull_policy=options.hf_pull_policy, @@ -145,6 +146,8 @@ def _run_pipeline() -> ProcessResult: is_tty=sys.stdin.isatty(), prompt_func=input, ) + if preparation is not None: + baseline_result = preparation.baseline_result index_files = {} if options.clean else env_index.read_env_index_files(options.output_dir) discovered = discovery.discover_run_records( @@ -266,11 +269,20 @@ def _run_pipeline() -> ProcessResult: hf_summary: HFSyncSummary | None = None if options.hf_config: + files_to_upload: list[str] | None = None + if baseline_result is not None and baseline_result.policy == "continue-upload": + touched_files = hf_sync.collect_changed_output_files( + env_summaries, + output_dir=options.output_dir, + metadata_paths=metadata_paths, + ) + files_to_upload = sorted(set(baseline_result.pending_parquet_uploads) | set(touched_files)) hf_summary = hf_sync.sync_to_hub( env_summaries, options.hf_config, output_dir=options.output_dir, metadata_paths=metadata_paths, + files=files_to_upload, is_tty=sys.stdin.isatty(), assume_yes=options.assume_yes, prompt_func=input, diff --git a/medarc_verifiers/cli/process/workspace.py b/medarc_verifiers/cli/process/workspace.py index a497e353..d5669ff5 100644 --- a/medarc_verifiers/cli/process/workspace.py +++ b/medarc_verifiers/cli/process/workspace.py @@ -3,14 +3,17 @@ from __future__ import annotations import json +import logging import shutil from dataclasses import dataclass, field from pathlib import Path from typing import Callable, Iterable, Sequence -from medarc_verifiers.cli.hf import HFSyncConfig, download_hf_repo +from medarc_verifiers.cli.hf import HFSyncConfig, compute_pending_parquet_uploads, download_hf_repo from medarc_verifiers.utils.pathing import resolve_under +logger = logging.getLogger(__name__) + @dataclass(slots=True) class BaselineResult: @@ -18,6 +21,7 @@ class BaselineResult: files_copied: list[Path] = field(default_factory=list) files_overwritten: list[Path] = field(default_factory=list) files_skipped: list[Path] = field(default_factory=list) + pending_parquet_uploads: set[str] = field(default_factory=set) snapshot_dir: Path | None = None @@ -84,8 +88,10 @@ def prepare_hf_baseline( return BaselineResult(policy="local") policy = _resolve_pull_policy(pull_policy, is_tty=is_tty) - result = BaselineResult(policy=policy) if not is_nonempty_dir(output_dir): + if policy == "continue-upload": + logger.warning("HF continue-upload requested with an empty output dir; falling back to pull.") + result = BaselineResult(policy="pull" if policy == "continue-upload" else policy) snapshot_dir = download_hf_repo( repo_id=hf_config.repo_id, branch=hf_config.branch, @@ -98,10 +104,36 @@ def prepare_hf_baseline( _copy_snapshot(snapshot_dir, output_dir, result, overwrite=True) return result + result = BaselineResult(policy=policy) + if policy == "prompt": + try: + result.pending_parquet_uploads = compute_pending_parquet_uploads( + output_dir=output_dir, + repo_id=hf_config.repo_id, + branch=hf_config.branch, + token=hf_config.token, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("HF upload recovery check failed before prompt; hiding upload option: %s", exc) + elif policy == "continue-upload": + try: + result.pending_parquet_uploads = compute_pending_parquet_uploads( + output_dir=output_dir, + repo_id=hf_config.repo_id, + branch=hf_config.branch, + token=hf_config.token, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("HF upload recovery check failed for continue-upload; uploading only current touched files: %s", exc) + prompt_conflicts = False if policy == "prompt": - choice = _prompt_baseline_choice(prompt_func, is_tty=is_tty) - policy = choice + choice = _prompt_baseline_choice( + prompt_func, + is_tty=is_tty, + show_upload=bool(result.pending_parquet_uploads), + ) + policy = "continue-upload" if choice == "upload" else choice result.policy = policy prompt_conflicts = policy == "pull" @@ -141,6 +173,9 @@ def prepare_hf_baseline( ) return result + if policy == "continue-upload": + return result + raise ValueError(f"Unsupported HF pull policy: {policy}") @@ -170,17 +205,27 @@ def _resolve_pull_policy(pull_policy: str | None, *, is_tty: bool) -> str: return "prompt" if is_tty else "pull" -def _prompt_baseline_choice(prompt_func: Callable[[str], str] | None, *, is_tty: bool) -> str: +def _prompt_baseline_choice( + prompt_func: Callable[[str], str] | None, + *, + is_tty: bool, + show_upload: bool = False, +) -> str: if not is_tty or prompt_func is None: return "pull" + choices = ["pull", "clean"] + if show_upload: + choices.append("upload") if prompt_func is not input: - prompt = ( + prompt_lines = [ "HF baseline exists locally.\n" - " pull -> download missing data without deleting local files\n" - " clean -> redownload everything after deleting local files\n" - "Choose [pull/clean]: " - ) - return _read_choice(prompt_func, prompt, {"pull", "clean"}) + " pull -> download missing data without deleting local files\n" + " clean -> redownload everything after deleting local files\n" + ] + if show_upload: + prompt_lines.append(" upload -> keep local files and resume/upload pending HF artifacts\n") + prompt_lines.append(f"Choose [{'/'.join(choices)}]: ") + return _read_choice(prompt_func, "".join(prompt_lines), choices) from rich.console import Console from rich.prompt import Prompt @@ -188,7 +233,12 @@ def _prompt_baseline_choice(prompt_func: Callable[[str], str] | None, *, is_tty: console.print("[bold yellow]HF baseline exists locally.[/bold yellow]") console.print(" [cyan]pull[/cyan] -> download missing data without deleting local files") console.print(" [cyan]clean[/cyan] -> redownload everything after deleting local files") - return Prompt.ask("Choose", choices=["pull", "clean"], default="pull") + if show_upload: + console.print(" [cyan]upload[/cyan] -> keep local files and resume/upload pending HF artifacts") + try: + return Prompt.ask("Choose", choices=choices, default="pull") + except (EOFError, KeyboardInterrupt): # noqa: PERF203 + raise RuntimeError("Aborted HF baseline selection.") from None def _prompt_overwrite_file(prompt_func: Callable[[str], str] | None, *, path: Path, is_tty: bool) -> bool: @@ -204,7 +254,7 @@ def _read_choice(prompt_func: Callable[[str], str], prompt: str, choices: Sequen while True: try: response = prompt_func(prompt).strip().lower() - except EOFError: # noqa: PERF203 + except (EOFError, KeyboardInterrupt): # noqa: PERF203 raise RuntimeError("Aborted HF baseline selection.") from None if response in choices_set: return response diff --git a/tests/test_cli/test_main.py b/tests/test_cli/test_main.py index d1d14592..c78b6cad 100644 --- a/tests/test_cli/test_main.py +++ b/tests/test_cli/test_main.py @@ -1838,6 +1838,11 @@ def fake_run(options, env_export_map): assert options.hf_config is not None assert options.hf_config.token == "override" + exit_code = main.main(["process", "--config", str(cfg_path), "--hf-pull-policy", "continue-upload", "--dry-run"]) + assert exit_code == 0 + options = captured["options"] + assert options.hf_pull_policy == "continue-upload" + def test_process_cli_resolves_hf_token_env_reference(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: cfg_path = tmp_path / "process.yaml" diff --git a/tests/test_cli/test_process_hf_sync.py b/tests/test_cli/test_process_hf_sync.py index d677f1aa..a463e27f 100644 --- a/tests/test_cli/test_process_hf_sync.py +++ b/tests/test_cli/test_process_hf_sync.py @@ -1,15 +1,18 @@ from __future__ import annotations +import hashlib from pathlib import Path +from types import SimpleNamespace import pytest from medarc_verifiers.cli import hf as hf_sync +from medarc_verifiers.cli.hf import sync as hf_sync_impl from medarc_verifiers.cli.process.aggregate import aggregate_rows_by_env from medarc_verifiers.cli.process.writer import WriterConfig, write_env_groups -def test_sync_to_hub_dry_run_builds_summary(tmp_path: Path) -> None: +def test_sync_to_hub_dry_run_returns_none(tmp_path: Path) -> None: rows = [ {"base_env_id": "env-a", "env_id": "env-a", "job_run_id": "run-1", "example_id": "ex-1", "rollout_index": 0} ] @@ -30,11 +33,7 @@ def test_sync_to_hub_dry_run_builds_summary(tmp_path: Path) -> None: output_dir=tmp_path, metadata_paths=[tmp_path / "env_index.json", tmp_path / "dataset_infos.json"], ) - assert summary is not None - assert summary.total_rows == len(rows) - assert summary.total_files == 3 - assert "env_index.json" in summary.files - assert "dataset_infos.json" in summary.files + assert summary is None def test_sync_to_hub_uses_token(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: @@ -87,7 +86,7 @@ def create_commit(self, **_kwargs: object) -> None: assert captured.get("create_commit") is True -def test_sync_to_hub_does_not_double_prefix_metadata_paths( +def test_sync_to_hub_dry_run_with_relative_output_paths_returns_none( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: @@ -113,10 +112,179 @@ def test_sync_to_hub_does_not_double_prefix_metadata_paths( output_dir=output_dir, metadata_paths=[output_dir / "env_index.json", output_dir / "dataset_infos.json"], ) + assert summary is None + + +@pytest.mark.parametrize( + ("remote_case", "expected_pending"), + [ + ("missing", {"model-a/env-a.parquet"}), + ("match", set()), + ("mismatch", {"model-a/env-a.parquet"}), + ("no-lfs", {"model-a/env-a.parquet"}), + ], +) +def test_compute_pending_parquet_uploads_detects_remote_state( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + remote_case: str, + expected_pending: set[str], +) -> None: + parquet_path = tmp_path / "model-a" / "env-a.parquet" + parquet_path.parent.mkdir(parents=True, exist_ok=True) + parquet_path.write_text("local-data", encoding="utf-8") + local_sha = hashlib.sha256(parquet_path.read_bytes()).hexdigest() + + class FakeLFS: + def __init__(self, sha256: str | None) -> None: + self.sha256 = sha256 + + class FakeTreeEntry: + def __init__(self, path: str, lfs: object | None) -> None: + self.path = path + self.lfs = lfs + + class FakeApi: + def __init__(self, token: str | None = None) -> None: + self.token = token + + def list_repo_tree(self, **_kwargs: object) -> list[FakeTreeEntry]: + if remote_case == "missing": + return [] + if remote_case == "no-lfs": + return [FakeTreeEntry("model-a/env-a.parquet", None)] + sha256 = local_sha if remote_case == "match" else "0" * 64 + return [FakeTreeEntry("model-a/env-a.parquet", FakeLFS(sha256))] + + import sys + + monkeypatch.setitem(sys.modules, "huggingface_hub", SimpleNamespace(HfApi=FakeApi)) + + pending = hf_sync.compute_pending_parquet_uploads( + output_dir=tmp_path, + repo_id="demo/repo", + branch="main", + token="secret-token", + ) + + assert pending == expected_pending + + +def test_sync_to_hub_explicit_files_uploads_exact_list(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + (tmp_path / "keep.parquet").write_text("1", encoding="utf-8") + (tmp_path / "meta.json").write_text("{}", encoding="utf-8") + + captured: dict[str, object] = {} + + class FakeOp: + def __init__(self, *args: object, **kwargs: object) -> None: + captured.setdefault("ops", []).append((args, kwargs)) + + class FakeApi: + def __init__(self, token: str | None = None) -> None: + captured["token"] = token + + def create_repo(self, **_kwargs: object) -> None: + captured["create_repo"] = True + + def create_commit(self, **kwargs: object) -> None: + captured["create_commit"] = kwargs + + import sys + + monkeypatch.setitem(sys.modules, "huggingface_hub", SimpleNamespace(CommitOperationAdd=FakeOp, HfApi=FakeApi)) + + summary = hf_sync.sync_to_hub( + [], + hf_sync.HFSyncConfig(repo_id="local/test", token="secret-token"), + output_dir=tmp_path, + files=[tmp_path / "keep.parquet", "meta.json"], + ) + assert summary is not None - assert "env_index.json" in summary.files - assert "dataset_infos.json" in summary.files - assert "runs/processed/env_index.json" not in summary.files + assert summary.files == ["keep.parquet", "meta.json"] + assert summary.total_files == 2 + assert summary.total_rows == 0 + assert captured["token"] == "secret-token" + assert captured.get("create_commit") is not None + + +def test_sync_to_hub_explicit_files_respects_dry_run(tmp_path: Path) -> None: + (tmp_path / "keep.parquet").write_text("1", encoding="utf-8") + + summary = hf_sync.sync_to_hub( + [], + hf_sync.HFSyncConfig(repo_id="local/test", dry_run=True), + output_dir=tmp_path, + files=["keep.parquet"], + ) + + assert summary is None + + +@pytest.mark.parametrize("bad_path", ["/tmp/escape.txt", "../escape.txt"]) +def test_sync_files_to_hub_rejects_unsafe_paths(tmp_path: Path, bad_path: str) -> None: + with pytest.raises(ValueError, match="output_dir|traversal"): + hf_sync.sync_files_to_hub( + repo_id="local/test", + output_dir=tmp_path, + files=[bad_path], + token=None, + private=False, + message="msg", + dry_run=True, + ) + + +def test_transient_hf_errors_include_statuses_timeouts_and_transport() -> None: + import httpx + + class StatusError(Exception): + def __init__(self, status_code: int) -> None: + super().__init__(f"status={status_code}") + self.response = SimpleNamespace(status_code=status_code) + + assert hf_sync_impl._is_transient_hf_error(StatusError(429)) is True + assert hf_sync_impl._is_transient_hf_error(StatusError(503)) is True + assert hf_sync_impl._is_transient_hf_error(httpx.TimeoutException("timeout")) is True + assert hf_sync_impl._is_transient_hf_error(httpx.TransportError("transport")) is True + + +def test_compute_pending_parquet_uploads_retries_without_expand_on_compat_error( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + parquet_path = tmp_path / "model-a" / "env-a.parquet" + parquet_path.parent.mkdir(parents=True, exist_ok=True) + parquet_path.write_text("local-data", encoding="utf-8") + + calls: list[dict[str, object]] = [] + + class FakeApi: + def __init__(self, token: str | None = None) -> None: + self.token = token + + def list_repo_tree(self, **kwargs: object): + calls.append(kwargs) + if "expand" in kwargs: + raise TypeError("unexpected keyword argument 'expand'") + return [SimpleNamespace(path="model-a/env-a.parquet", lfs=None)] + + import sys + + monkeypatch.setitem(sys.modules, "huggingface_hub", SimpleNamespace(HfApi=FakeApi)) + + pending = hf_sync.compute_pending_parquet_uploads( + output_dir=tmp_path, + repo_id="demo/repo", + branch="main", + token="secret-token", + ) + + assert pending == {"model-a/env-a.parquet"} + assert len(calls) == 2 + assert "expand" in calls[0] + assert "expand" not in calls[1] def test_sync_files_to_hub_creates_repo_with_confirmation(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: diff --git a/tests/test_cli/test_process_pipeline.py b/tests/test_cli/test_process_pipeline.py index a8d9666b..51d0d7af 100644 --- a/tests/test_cli/test_process_pipeline.py +++ b/tests/test_cli/test_process_pipeline.py @@ -8,12 +8,13 @@ from medarc_verifiers.cli._manifest import MANIFEST_VERSION from medarc_verifiers.cli._schemas import EnvironmentExportConfig +from medarc_verifiers.cli.hf import HFSyncConfig from medarc_verifiers.cli.process import ProcessOptions, run_process +from medarc_verifiers.cli.process import workspace from medarc_verifiers.cli.process.discovery import RunManifestInfo, RunRecord, discover_run_records from medarc_verifiers.cli.process.pipeline import select_work_items from medarc_verifiers.cli.winrate import WinrateConfig from medarc_verifiers.cli.winrate import discover_datasets, run_winrate -from medarc_verifiers.cli.hf import HFSyncConfig from medarc_verifiers.cli.process.writer import ALLOWED_COLUMNS @@ -415,6 +416,7 @@ def test_process_allows_results_missing_pct_within_threshold(tmp_path: Path) -> def test_process_rejects_results_missing_pct_above_threshold(tmp_path: Path) -> None: + results_text = "".join(json.dumps({"example_id": f"ex-{index}", "reward": 1.0}) + "\n" for index in range(90)) runs_dir = _write_run( tmp_path, run_id="run-90pct", @@ -423,6 +425,7 @@ def test_process_rejects_results_missing_pct_above_threshold(tmp_path: Path) -> row_count=90, num_examples=100, rollouts_per_example=1, + results_text=results_text, ) options = ProcessOptions( runs_dir=runs_dir, @@ -587,6 +590,7 @@ def test_process_stale_delta_output_does_not_mask_newer_incomplete_run(tmp_path: initial = run_process(ProcessOptions(runs_dir=runs_dir, output_dir=output_dir, dry_run=False, max_workers=1)) assert initial.records_processed == 1 + results_text = "".join(json.dumps({"example_id": f"ex-{index}", "reward": 0.0}) + "\n" for index in range(90)) _write_run( tmp_path, run_id="run-newer-bad", @@ -595,6 +599,7 @@ def test_process_stale_delta_output_does_not_mask_newer_incomplete_run(tmp_path: row_count=90, num_examples=100, rollouts_per_example=1, + results_text=results_text, ) with pytest.raises(RuntimeError) as excinfo: @@ -992,6 +997,125 @@ def test_process_latest_only_selects_latest_and_skips_existing_outputs(tmp_path: assert newer_table.column("reward").to_pylist() == [0.4] +def test_run_process_continue_upload_syncs_pending_parquets_without_new_deltas( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + runs_dir = _write_run(tmp_path, run_id="run-1", updated_at="2024-01-01T00:00:00Z", reward=0.1, env_id="demo-env") + output_dir = tmp_path / "processed" + + first_result = run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=output_dir, + dry_run=False, + max_workers=1, + ) + ) + pending_path = first_result.env_summaries[0].output_path.relative_to(output_dir).as_posix() + captured: dict[str, object] = {} + + def fake_prepare_output_workspace(**_kwargs: object) -> workspace.WorkspacePreparationResult: + return workspace.WorkspacePreparationResult( + baseline_result=workspace.BaselineResult( + policy="continue-upload", + pending_parquet_uploads={pending_path}, + ) + ) + + def fake_sync_to_hub( + env_summaries, + config, + *, + output_dir, + metadata_paths=None, + files=None, + **_kwargs, + ): + captured["env_summaries"] = list(env_summaries) + captured["files"] = list(files or []) + return None + + monkeypatch.setattr("medarc_verifiers.cli.process.pipeline.workspace.prepare_output_workspace", fake_prepare_output_workspace) + monkeypatch.setattr("medarc_verifiers.cli.process.pipeline.hf_sync.sync_to_hub", fake_sync_to_hub) + + result = run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=output_dir, + dry_run=False, + max_workers=1, + hf_config=HFSyncConfig(repo_id="demo/repo"), + hf_pull_policy="continue-upload", + ) + ) + + assert result.env_summaries == [] + assert captured["env_summaries"] == [] + assert captured["files"] == [pending_path] + + +def test_run_process_continue_upload_unions_pending_and_current_touched_files( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + runs_dir = _write_run(tmp_path, run_id="run-1", updated_at="2024-01-01T00:00:00Z", reward=0.1, env_id="demo-env") + output_dir = tmp_path / "processed" + first_result = run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=output_dir, + dry_run=False, + max_workers=1, + ) + ) + current_path = first_result.env_summaries[0].output_path.relative_to(output_dir).as_posix() + pending_path = "stale-model/stale-env.parquet" + stale_path = output_dir / pending_path + stale_path.parent.mkdir(parents=True, exist_ok=True) + stale_path.write_text("stale", encoding="utf-8") + _write_run(tmp_path, run_id="run-2", updated_at="2024-01-02T00:00:00Z", reward=0.9, env_id="demo-env") + + captured: dict[str, object] = {} + + def fake_prepare_output_workspace(**_kwargs: object) -> workspace.WorkspacePreparationResult: + return workspace.WorkspacePreparationResult( + baseline_result=workspace.BaselineResult( + policy="continue-upload", + pending_parquet_uploads={pending_path}, + ) + ) + + def fake_sync_to_hub( + env_summaries, + config, + *, + output_dir, + metadata_paths=None, + files=None, + **_kwargs, + ): + captured["files"] = list(files or []) + return None + + monkeypatch.setattr("medarc_verifiers.cli.process.pipeline.workspace.prepare_output_workspace", fake_prepare_output_workspace) + monkeypatch.setattr("medarc_verifiers.cli.process.pipeline.hf_sync.sync_to_hub", fake_sync_to_hub) + + result = run_process( + ProcessOptions( + runs_dir=runs_dir, + output_dir=output_dir, + dry_run=False, + max_workers=1, + hf_config=HFSyncConfig(repo_id="demo/repo"), + hf_pull_policy="continue-upload", + ) + ) + + assert result.env_summaries + assert set(captured["files"]) == {pending_path, current_path, "dataset_infos.json", "env_index.json"} + + def test_process_replace_model_rebuilds_existing_output(tmp_path: Path) -> None: runs_dir = _write_run( tmp_path, diff --git a/tests/test_cli/test_process_workspace.py b/tests/test_cli/test_process_workspace.py index aa9679ad..d8df1af5 100644 --- a/tests/test_cli/test_process_workspace.py +++ b/tests/test_cli/test_process_workspace.py @@ -176,6 +176,7 @@ def _fake_download_hf_repo(**_kwargs) -> Path: return snapshot_dir monkeypatch.setattr(workspace, "download_hf_repo", _fake_download_hf_repo) + monkeypatch.setattr(workspace, "compute_pending_parquet_uploads", lambda **_kwargs: set()) hf_config = HFSyncConfig(repo_id="demo/repo") output_dir = tmp_path / "output" output_dir.mkdir() @@ -199,6 +200,190 @@ def _prompt(_message: str) -> str: assert local_path.read_text(encoding="utf-8") == "remote" +def test_prepare_hf_baseline_prompt_offers_upload_when_pending_exists( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "output" + output_dir.mkdir() + _write_snapshot(output_dir, content="local") + + prompts: list[str] = [] + + def _prompt(message: str) -> str: + prompts.append(message) + return "upload" + + def _fail_download(**_kwargs) -> Path: + raise AssertionError("download_hf_repo should not be called for upload recovery") + + monkeypatch.setattr(workspace, "download_hf_repo", _fail_download) + monkeypatch.setattr( + workspace, + "compute_pending_parquet_uploads", + lambda **_kwargs: {"model-a/env-a.parquet"}, + ) + + result = workspace.prepare_hf_baseline( + output_dir=output_dir, + hf_config=HFSyncConfig(repo_id="demo/repo"), + pull_policy="prompt", + is_tty=True, + prompt_func=_prompt, + ) + + assert result.policy == "continue-upload" + assert result.pending_parquet_uploads == {"model-a/env-a.parquet"} + assert prompts and "upload" in prompts[0] + + +def test_prepare_hf_baseline_continue_upload_skips_download( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "output" + output_dir.mkdir() + _write_snapshot(output_dir, content="local") + + def _fail_download(**_kwargs) -> Path: + raise AssertionError("download_hf_repo should not be called for continue-upload") + + monkeypatch.setattr(workspace, "download_hf_repo", _fail_download) + monkeypatch.setattr( + workspace, + "compute_pending_parquet_uploads", + lambda **_kwargs: {"model-a/env-a.parquet"}, + ) + + result = workspace.prepare_hf_baseline( + output_dir=output_dir, + hf_config=HFSyncConfig(repo_id="demo/repo"), + pull_policy="continue-upload", + is_tty=False, + prompt_func=None, + ) + + assert result.policy == "continue-upload" + assert result.pending_parquet_uploads == {"model-a/env-a.parquet"} + + +def test_prepare_hf_baseline_prompt_hides_upload_when_recovery_check_fails( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + output_dir = tmp_path / "output" + output_dir.mkdir() + _write_snapshot(output_dir, content="local") + + prompts: list[str] = [] + + def _prompt(message: str) -> str: + prompts.append(message) + return "pull" + + monkeypatch.setattr( + workspace, + "compute_pending_parquet_uploads", + lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("hf down")), + ) + monkeypatch.setattr(workspace, "download_hf_repo", lambda **_kwargs: tmp_path / "snapshot") + + with caplog.at_level("WARNING"): + result = workspace.prepare_hf_baseline( + output_dir=output_dir, + hf_config=HFSyncConfig(repo_id="demo/repo"), + pull_policy="prompt", + is_tty=True, + prompt_func=_prompt, + ) + + assert result.policy == "pull" + assert result.pending_parquet_uploads == set() + assert prompts and "upload" not in prompts[0] + assert "HF upload recovery check failed before prompt" in caplog.text + + +def test_prepare_hf_baseline_continue_upload_empty_dir_warns_and_pulls( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + snapshot_dir = tmp_path / "snapshot" + snapshot_dir.mkdir() + _write_snapshot(snapshot_dir, content="remote") + monkeypatch.setattr(workspace, "download_hf_repo", lambda **_kwargs: snapshot_dir) + + with caplog.at_level("WARNING"): + result = workspace.prepare_hf_baseline( + output_dir=tmp_path / "output", + hf_config=HFSyncConfig(repo_id="demo/repo"), + pull_policy="continue-upload", + is_tty=False, + prompt_func=None, + ) + + assert result.policy == "pull" + assert "falling back to pull" in caplog.text + + +def test_prepare_hf_baseline_continue_upload_degrades_when_recovery_check_fails( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, +) -> None: + output_dir = tmp_path / "output" + output_dir.mkdir() + _write_snapshot(output_dir, content="local") + + monkeypatch.setattr( + workspace, + "compute_pending_parquet_uploads", + lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("hf down")), + ) + + with caplog.at_level("WARNING"): + result = workspace.prepare_hf_baseline( + output_dir=output_dir, + hf_config=HFSyncConfig(repo_id="demo/repo"), + pull_policy="continue-upload", + is_tty=False, + prompt_func=None, + ) + + assert result.policy == "continue-upload" + assert result.pending_parquet_uploads == set() + assert "uploading only current touched files" in caplog.text + + +@pytest.mark.parametrize("exc_type", [EOFError, KeyboardInterrupt]) +def test_prepare_hf_baseline_prompt_aborts_cleanly( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + exc_type: type[BaseException], +) -> None: + output_dir = tmp_path / "output" + output_dir.mkdir() + _write_snapshot(output_dir, content="local") + monkeypatch.setattr( + workspace, + "compute_pending_parquet_uploads", + lambda **_kwargs: {"model-a/env-a.parquet"}, + ) + + def _prompt(_message: str) -> str: + raise exc_type + + with pytest.raises(RuntimeError, match="Aborted HF baseline selection."): + workspace.prepare_hf_baseline( + output_dir=output_dir, + hf_config=HFSyncConfig(repo_id="demo/repo"), + pull_policy="prompt", + is_tty=True, + prompt_func=_prompt, + ) + + def test_prepare_hf_baseline_pull_skips_when_local_baseline_present( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, From 7971a949a37d47798f01282d62552c86b9eae487 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Thu, 12 Mar 2026 13:42:44 -0400 Subject: [PATCH 19/29] Improve grading and xml parsing --- medarc_verifiers/parsers/xml_parser.py | 19 +++++++++ .../rewards/multiple_choice_accuracy.py | 40 ++++++++++++++----- tests/test_mcq_accuracy.py | 12 ++++++ tests/test_xml_parser.py | 12 ++++++ 4 files changed, 74 insertions(+), 9 deletions(-) diff --git a/medarc_verifiers/parsers/xml_parser.py b/medarc_verifiers/parsers/xml_parser.py index 6a1176fc..6eb1f2e6 100644 --- a/medarc_verifiers/parsers/xml_parser.py +++ b/medarc_verifiers/parsers/xml_parser.py @@ -61,6 +61,25 @@ def parse(self, completion: Messages | str, strip: bool = True, last: bool = Fal return parsed return None + def parse_answer(self, completion: Messages | str) -> str | None: + """Extract the last answer field from a completion.""" + if isinstance(completion, str): + parsed = self.parse(completion, last=True) + if parsed is not None and hasattr(parsed, self.answer_field): + value = getattr(parsed, self.answer_field) + if value is not None: + return value + return None + + for msg in reversed(self.get_assistant_messages(completion)): + content = str(msg.get("content", "")) + parsed = self.parse(content, last=True) + if parsed is not None and hasattr(parsed, self.answer_field): + value = getattr(parsed, self.answer_field) + if value is not None: + return value + return None + def _has_any_field(self, parsed: Any) -> bool: for _, alternatives in self._fields: for alt in alternatives: diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index 71e123a8..53d5e749 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -15,6 +15,26 @@ from typing import Optional +_UNICODE_PUNCT_TRANSLATIONS = str.maketrans( + { + "\u00A0": " ", # no-break space + "\u2010": "-", # hyphen + "\u2011": "-", # non-breaking hyphen + "\u2012": "-", # figure dash + "\u2013": "-", # en dash + "\u2014": "-", # em dash + "\u2015": "-", # horizontal bar + "\u2212": "-", # minus sign + "\u2018": "'", + "\u2019": "'", + "\u201C": '"', + "\u201D": '"', + } +) + +_WHITESPACE_RE = re.compile(r"\s+") + + @dataclass class MCQAccuracyResult: """Result of multiple-choice accuracy grading.""" @@ -32,14 +52,16 @@ class MCQAccuracyResult: """The correct answer for reference, if available.""" -def _nfkc_casefold(text: str) -> str: - """Unicode normalize + casefold for robust text comparison.""" - return unicodedata.normalize("NFKC", text or "").casefold() +def normalize_for_structure(text: str) -> str: + """Canonicalize text for structural matching without collapsing whitespace.""" + text = unicodedata.normalize("NFKC", text or "") + text = text.translate(_UNICODE_PUNCT_TRANSLATIONS) + return text.casefold() -def _normalize_spaces(text: str) -> str: - """Collapse multiple whitespace to single space.""" - return re.sub(r"\s+", " ", text).strip() +def normalize_for_match(text: str) -> str: + """Canonicalize text for answer-text equivalence matching.""" + return _WHITESPACE_RE.sub(" ", normalize_for_structure(text)).strip() def _strip_tex(text: str) -> str: @@ -261,10 +283,10 @@ def _result( llm_answer_original = llm_answer # Normalize: casefold only (preserve whitespace structure for sentence detection) - llm_answer = _nfkc_casefold(llm_answer) + llm_answer = normalize_for_structure(llm_answer) answer_letter = _norm_letter(answer_letter) - answer_text = _nfkc_casefold(_normalize_spaces(answer_text or "")) + answer_text = normalize_for_match(answer_text or "") if answer_letter is None: raise ValueError(f"Invalid answer_letter '{answer_letter=}'. Must be a single letter or digit string.") @@ -286,7 +308,7 @@ def _result( # Strategy 3: Anchored token (prefix matches first, fallback to generic anchors) prefix_matches = [] if prefix: - prefix_norm = _nfkc_casefold(prefix).strip() + prefix_norm = normalize_for_structure(prefix).strip() if prefix_norm: flexible_prefix = re.escape(prefix_norm).replace(r"\ ", r"\s+") prefix_pattern = re.compile( diff --git a/tests/test_mcq_accuracy.py b/tests/test_mcq_accuracy.py index 209e7ece..7764f3f9 100644 --- a/tests/test_mcq_accuracy.py +++ b/tests/test_mcq_accuracy.py @@ -563,6 +563,18 @@ def test_answer_text_requires_exact_formatting_beyond_normalization(response, an assert not multiple_choice_accuracy(response, answer_letter="D", answer_text=answer_text, accept_answer_text=True) +@pytest.mark.parametrize( + "response, answer_text", + [ + ("Proliferation of surfactant‑secreting cells", "Proliferation of surfactant-secreting cells"), + ("Anti‑D IgG", "Anti-D IgG"), + ("Upslope of T‑wave", "Upslope of T-wave"), + ], +) +def test_answer_text_matches_unicode_dash_variants(response, answer_text): + assert multiple_choice_accuracy(response, answer_letter="D", answer_text=answer_text, accept_answer_text=True) + + def test_multiple_answers_last_explicit_anchor_wins(): response = "Answer: B. After reconsideration, final answer: C" assert multiple_choice_accuracy(response, answer_letter="C", answer_text="Option C") diff --git a/tests/test_xml_parser.py b/tests/test_xml_parser.py index 8bb8978a..dc74b036 100644 --- a/tests/test_xml_parser.py +++ b/tests/test_xml_parser.py @@ -28,6 +28,18 @@ def test_parse_string_handles_tags() -> None: assert parsed.think == "inner" +def test_parse_answer_uses_last_tag_in_message_content() -> None: + parser = XMLParser(["answer"]) + completion = [ + { + "role": "assistant", + "content": 'Follow "The answer is X" exactly.\n\nThe answer is C', + } + ] + + assert parser.parse_answer(completion) == "C" + + def test_init_with_think_does_not_warn(caplog: pytest.LogCaptureFixture) -> None: with caplog.at_level("WARNING"): XMLParser(["think", "answer"]) From 683fe4aebb800ecdaa11fec6ef6b8ba6d9707f6d Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Thu, 12 Mar 2026 15:03:18 -0400 Subject: [PATCH 20/29] Improve negation handling in MCQ grading --- .../rewards/multiple_choice_accuracy.py | 31 +++++++++++++++++-- tests/test_mcq_accuracy.py | 15 +++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index 53d5e749..4e59f3c0 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -161,6 +161,12 @@ def _remove_think_tags(completion_text: str) -> str: # Negation words that invalidate nearby matches NEGATION_PATTERN = re.compile(r"\b(?:not|isn['’]t)\b", re.IGNORECASE) +# Negation/correction phrases that immediately precede an option or answer text +NEGATION_BEFORE_MATCH_PATTERN = re.compile( + r"(?:\bnot\b|\bisn['’]t\b|\baren['’]t\b|\bwasn['’]t\b|\bweren['’]t\b|\bincorrect\b|\bwrong\b|\bfalse\b|\bexcept(?:\s+for)?\b|\brather\s+than\b)(?:\W+\w+){0,3}\W*$", + re.IGNORECASE, +) + # Negative-context phrases that indicate an option mention is NOT a selected answer NEGATIVE_AFTER_OPTION_PATTERN = re.compile( r"^\s*(?:is|are|was|were)\s+(?:incorrect|wrong|false|not\s+correct)\b|^\s*not\s+correct\b", @@ -199,7 +205,7 @@ def _negated_near(text: str, match: re.Match) -> bool: """ sentence_start, sentence_end, match_start, _match_end = _get_sentence_containing_match(text, match) prefix = text[sentence_start:match_start] - return bool(NEGATION_PATTERN.search(prefix)) + return bool(NEGATION_BEFORE_MATCH_PATTERN.search(prefix)) def _negative_after_option(text: str, match: re.Match) -> bool: @@ -209,6 +215,22 @@ def _negative_after_option(text: str, match: re.Match) -> bool: return bool(NEGATIVE_AFTER_OPTION_PATTERN.search(suffix)) +def _contradicted_by_later_option(text: str, match: re.Match) -> bool: + """Check for same-sentence corrections like 'C, but D is correct' or 'C rather than D'.""" + _sentence_start, sentence_end, _match_start, match_end = _get_sentence_containing_match(text, match) + suffix = text[match_end:sentence_end] + current = _norm_letter(match.group("opt") if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex else match.group(1)) + contrast_pattern = re.compile( + r"\b(?:but|however|instead|rather)\b.{0,40}?(? str: """Return a short tail slice (last sentence/line) to reduce option-token noise.""" boundaries = list(SENTENCE_BOUNDARY.finditer(text)) @@ -321,9 +343,10 @@ def _result( if anchored_matches and answer_letter: last_match = anchored_matches[-1] predicted = _norm_letter(last_match.group("opt")) - if last_match.group("neg") is None and _token_kind_matches_answer_letter(predicted, answer_letter): + contradicted = _contradicted_by_later_option(llm_answer, last_match) + if last_match.group("neg") is None and not contradicted and _token_kind_matches_answer_letter(predicted, answer_letter): explicit_choice_found = True - if predicted == answer_letter and last_match.group("neg") is None: + if predicted == answer_letter and last_match.group("neg") is None and not contradicted: return _result(True, "anchored_token", predicted, answer_letter, return_details) # Strategy 4: Last token in the answer tail, ignore negative contexts like "C is incorrect", @@ -340,6 +363,8 @@ def _result( continue if _negative_after_option(tail, token_match): continue + if _contradicted_by_later_option(tail, token_match): + continue if predicted == answer_letter: return _result(True, "last_token", predicted, answer_letter, return_details) diff --git a/tests/test_mcq_accuracy.py b/tests/test_mcq_accuracy.py index 7764f3f9..5d222e8e 100644 --- a/tests/test_mcq_accuracy.py +++ b/tests/test_mcq_accuracy.py @@ -372,6 +372,21 @@ def test_last_token_isnt_same_sentence_blocks(): assert not multiple_choice_accuracy("It isn't C, but maybe C", answer_letter="C", answer_text="Option C") +def test_answer_text_rather_than_prefix_blocks(): + response = "The diagnosis is viral rather than bacterial pneumonia." + assert not multiple_choice_accuracy(response, answer_letter="B", answer_text="bacterial pneumonia") + + +def test_answer_text_wrong_prefix_blocks(): + response = "The wrong diagnosis is bacterial pneumonia." + assert not multiple_choice_accuracy(response, answer_letter="B", answer_text="bacterial pneumonia") + + +def test_anchored_token_contradicted_by_later_option_blocks(): + response = "Answer: C, but D is correct." + assert not multiple_choice_accuracy(response, answer_letter="C", answer_text="Option C") + + def test_answer_text_does_not_override_explicit_wrong_choice(): response = ( "The other options do not account for the renal findings as well:\n" From 93f789454ad73b49d635f33b5489c44ac745c0bb Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Thu, 12 Mar 2026 15:21:11 -0400 Subject: [PATCH 21/29] Cleaner reasoning stripping --- .../rewards/multiple_choice_accuracy.py | 36 +++++++------------ tests/test_mcq_accuracy.py | 10 ++++++ 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index 4e59f3c0..ab2d77e4 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -100,39 +100,29 @@ def _token_kind_matches_answer_letter(predicted: Optional[str], answer_letter: s return predicted.isalpha() -_THINK_OPEN_RE = re.compile(r"", re.IGNORECASE) -_THINK_CLOSE_RE = re.compile(r"", re.IGNORECASE) -_THINK_PAIR_RE = re.compile(r".*?", re.DOTALL | re.IGNORECASE) +_THINK_OPEN_RE = re.compile(r"<\s*think\b[^>]*>", re.IGNORECASE) +_THINK_CLOSE_RE = re.compile(r"", re.IGNORECASE) def _remove_think_tags(completion_text: str) -> str: """Extract the answer section from completion text, handling think tags properly. Behavior is intentionally conservative: - - If there is exactly one well-formed ... pair AND no unclosed later, - return everything after that closing tag. + - If there is any explicit closing , return everything after the last closing tag. + - If there is an unclosed with no closing tag, treat the output as missing + a final-answer region. - Otherwise, return the full response. """ text = completion_text or "" - # Fast path: most outputs won't contain think tags. - # Some models emit an unpaired closing tag () and then the final answer. - # In that case, treat the closing tag as the end of reasoning and keep only the tail. - if _THINK_OPEN_RE.search(text) is None: - closes = list(_THINK_CLOSE_RE.finditer(text)) - if closes: - return text[closes[-1].end() :].strip() - return text.strip() - - # Count properly closed pairs, but stop early once we know there are 2+. - it = _THINK_PAIR_RE.finditer(text) - first = next(it, None) - if first is None: - return text.strip() - if next(it, None) is not None: - return text.strip() - - return text[first.end() :].strip() + closes = list(_THINK_CLOSE_RE.finditer(text)) + if closes: + return text[closes[-1].end() :].lstrip() + + if _THINK_OPEN_RE.search(text): + return "" + + return text # Anchored patterns like "final answer: C" or "the answer is D" diff --git a/tests/test_mcq_accuracy.py b/tests/test_mcq_accuracy.py index 5d222e8e..9313a972 100644 --- a/tests/test_mcq_accuracy.py +++ b/tests/test_mcq_accuracy.py @@ -221,6 +221,16 @@ def test_unpaired_think_close_with_spurious_match(): assert multiple_choice_accuracy(response, answer_letter="A", answer_text="Option A") +def test_multiple_think_blocks_use_last_close(): + response = "first draft second\n\nFinal answer: B" + assert multiple_choice_accuracy(response, answer_letter="B", answer_text="Option B") + + +def test_unclosed_think_open_returns_empty(): + response = "reasoning only Final answer: C" + assert not multiple_choice_accuracy(response, answer_letter="C", answer_text="Option C") + + def test_cot_prevents_early_letter_matching(): # Should not match A or B from the reasoning cot_response = """ From 60360c5a1191d251f019dfcb78e822cd8eb0d66c Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Thu, 12 Mar 2026 15:35:16 -0400 Subject: [PATCH 22/29] Improve MCQ answer text normalization --- .../rewards/multiple_choice_accuracy.py | 41 ++++++++++++++++--- tests/test_mcq_accuracy.py | 14 +++++++ 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index ab2d77e4..1b5f28b1 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -17,7 +17,7 @@ _UNICODE_PUNCT_TRANSLATIONS = str.maketrans( { - "\u00A0": " ", # no-break space + "\u00a0": " ", # no-break space "\u2010": "-", # hyphen "\u2011": "-", # non-breaking hyphen "\u2012": "-", # figure dash @@ -27,12 +27,14 @@ "\u2212": "-", # minus sign "\u2018": "'", "\u2019": "'", - "\u201C": '"', - "\u201D": '"', + "\u201c": '"', + "\u201d": '"', } ) _WHITESPACE_RE = re.compile(r"\s+") +_INNER_PUNCT_SPACING_RE = re.compile(r"\s*([()\[\]{}.,;:])\s*") +_TRAILING_TERMINAL_PUNCT_RE = re.compile(r"(?<=\w)[.!?,;:]+$") @dataclass @@ -64,6 +66,13 @@ def normalize_for_match(text: str) -> str: return _WHITESPACE_RE.sub(" ", normalize_for_structure(text)).strip() +def normalize_for_answer_text_match(text: str) -> str: + """Canonicalize text for answer-text matching while tolerating minor punctuation drift.""" + text = normalize_for_match(text) + text = _INNER_PUNCT_SPACING_RE.sub(r"\1", text) + return _TRAILING_TERMINAL_PUNCT_RE.sub("", text) + + def _strip_tex(text: str) -> str: """Remove LaTeX formatting if pylatexenc is available.""" try: @@ -209,7 +218,9 @@ def _contradicted_by_later_option(text: str, match: re.Match) -> bool: """Check for same-sentence corrections like 'C, but D is correct' or 'C rather than D'.""" _sentence_start, sentence_end, _match_start, match_end = _get_sentence_containing_match(text, match) suffix = text[match_end:sentence_end] - current = _norm_letter(match.group("opt") if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex else match.group(1)) + current = _norm_letter( + match.group("opt") if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex else match.group(1) + ) contrast_pattern = re.compile( r"\b(?:but|however|instead|rather)\b.{0,40}?(? Date: Mon, 16 Mar 2026 01:54:20 -0400 Subject: [PATCH 23/29] Tighten compact multi-answer detection --- .../rewards/multiple_choice_accuracy.py | 28 ++++++++- tests/test_mcq_accuracy.py | 63 +++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index 1b5f28b1..ec23df11 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -176,6 +176,19 @@ def _remove_think_tags(completion_text: str) -> str: # Handles both single newlines (for line breaks in CoT) and double newlines (paragraphs) SENTENCE_BOUNDARY = re.compile(r"[.!?]\s+|\n+") +# Compact-list glue that should cause the last-token fallback to reject a tail as +# multi-answer rather than selecting the final option. +COMPACT_MULTI_OPTION_GLUE_PATTERN = re.compile( + r""" + \b(?:and|or|both|y|e|ou|und|et|plus)\b + | + \b(?:as\ well\ as|together\ with|followed\ by|correct\ choices?\s+are|choices?\s+are)\b + | + [,:;/&+\-|] + """, + re.IGNORECASE | re.VERBOSE, +) + def _get_sentence_containing_match(text: str, match: re.Match) -> str: """Return (sentence_start, sentence_end, match_start, match_end) in the original text.""" @@ -250,6 +263,19 @@ def _tail_region(text: str, max_tokens: int = 64) -> str: return tail +def _is_compact_multi_option_list(text: str) -> bool: + """Return True for short multi-option tails like 'A, C' or '> **A** and C'.""" + text = (text or "").strip() + matches = list(TOKEN_PATTERN.finditer(text)) + if len(matches) < 2: + return False + + residue = TOKEN_PATTERN.sub(" ", text) + residue = COMPACT_MULTI_OPTION_GLUE_PATTERN.sub(" ", residue) + residue = re.sub(r"[\s\[\]\(\)\{\}<>*_`~.!?]+", " ", residue) + return residue.strip() == "" + + def multiple_choice_accuracy( llm_answer: str, answer_letter: str, @@ -357,7 +383,7 @@ def _result( # Strategy 4: Last token in the answer tail, ignore negative contexts like "C is incorrect", if not explicit_choice_found and answer_letter: tail = _tail_region(llm_answer) - tail_tokens = list(TOKEN_PATTERN.finditer(tail)) + tail_tokens = [] if _is_compact_multi_option_list(tail) else list(TOKEN_PATTERN.finditer(tail)) if tail_tokens: # Take the last non-negated, non-negative-context token. for token_match in reversed(tail_tokens): diff --git a/tests/test_mcq_accuracy.py b/tests/test_mcq_accuracy.py index ca0f6adc..dd4329c6 100644 --- a/tests/test_mcq_accuracy.py +++ b/tests/test_mcq_accuracy.py @@ -43,6 +43,23 @@ def test_last_token_with_period(): assert multiple_choice_accuracy("My selection is B.", answer_letter="B", answer_text="Some text") +@pytest.mark.parametrize( + ("response", "answer_letter"), + [ + ("My selection is [C]", "C"), + ("My selection is C]", "C"), + ("My selection is C)", "C"), + ("My selection is (C)", "C"), + ("My selection is [2]", "2"), + ("My selection is 2]", "2"), + ("My selection is 2)", "2"), + ("My selection is (2)", "2"), + ], +) +def test_last_token_bracket_like_variants(response: str, answer_letter: str): + assert multiple_choice_accuracy(response, answer_letter=answer_letter, answer_text="Option") + + def test_last_token_multiple_letters_takes_last(): # A and B appear in reasoning, D is the final answer assert multiple_choice_accuracy("A is wrong. B seems unlikely. D", answer_letter="D", answer_text="Final option") @@ -60,6 +77,34 @@ def test_last_token_wrong(): assert not multiple_choice_accuracy("My answer is A", answer_letter="B", answer_text="Correct") +@pytest.mark.parametrize( + "response", + [ + "A, C", + "A; C", + "A: C", + "A. C", + "A) C", + "B, C, E", + "D / G / J", + "(A), (C)", + "[B], [C], [E]", + "A or C", + "B and E", + "B, and E", + "B, & D", + "both A and C", + "A & C", + "A + C", + "A y C", + "A e D", + "A ou C", + ], +) +def test_last_token_rejects_compact_multi_option_lists(response: str): + assert not multiple_choice_accuracy(response, answer_letter="C", answer_text="Option", accept_answer_text=False) + + def test_last_token_disabled_when_explicit_anchor_exists_even_if_wrong(): # Regression: do NOT allow last_token to override an explicit (wrong) anchored choice. response = ( @@ -90,6 +135,24 @@ def test_answer_text_in_sentence(): ) +@pytest.mark.parametrize("response", ["All of the above", "The answer is all of the above."]) +def test_answer_text_all_of_the_above_is_not_rejected(response: str): + assert multiple_choice_accuracy(response, answer_letter="D", answer_text="All of the above") + + +@pytest.mark.parametrize("response", ["None of the above", "The answer is none of the above."]) +def test_answer_text_none_of_the_above_is_not_rejected(response: str): + assert multiple_choice_accuracy(response, answer_letter="E", answer_text="None of the above") + + +def test_multi_answer_tail_does_not_count_as_all_of_the_above(): + assert not multiple_choice_accuracy("A and B", answer_letter="D", answer_text="All of the above") + + +def test_all_of_the_above_does_not_match_plain_option_text(): + assert not multiple_choice_accuracy("All of the above", answer_letter="C", answer_text="acute appendicitis") + + def test_answer_text_case_insensitive(): assert multiple_choice_accuracy( "The diagnosis is DIABETES MELLITUS TYPE 2", answer_letter="D", answer_text="Diabetes Mellitus Type 2" From 7aab7a5428ba40e3bfdfac282ad111af99be2b32 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Mon, 16 Mar 2026 21:15:38 -0400 Subject: [PATCH 24/29] catch more edge cases --- .../rewards/multiple_choice_accuracy.py | 296 +++++++++++++++--- tests/test_mcq_accuracy.py | 114 +++++++ 2 files changed, 369 insertions(+), 41 deletions(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index ec23df11..ac18c61c 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -139,7 +139,7 @@ def _remove_think_tags(completion_text: str) -> str: r"(?:\bfinal\s+answer\b|\banswer\b|\bans\b|\bchoice\b|\boption\b|\bselected\b|\bi\s+choose\b|\bi\s+pick\b|\btherefore\b|\bthus\b|\bso\b|\bconclusion\b|\bin\s+conclusion\b|\bmost\s+likely\b|\bbest[-\s]+supported\s+answer\b|)\s*" r"[:\-–—]?\s*(?:is\s*)?(?Pnot\s+|isn['’]t\s+)?" r"(?:[*_`~]+\s*)*" # allow markdown wrappers before the option - r"\(?\s*(?P[A-Za-z]|\d{1,2})\s*\)?" # option token, possibly parenthesized + r"[\(\[\{<【]*\s*(?P[A-Za-z]|\d{1,2})\s*[\)\]\}>】]*" # option token, possibly wrapped r"\s*[\)\.:]?\s*" # optional delimiter (e.g., 'B.' or 'B)') r"(?:[*_`~]+\s*)*" # allow markdown wrappers after the option r"(?![\w+\-/])", @@ -148,17 +148,37 @@ def _remove_think_tags(completion_text: str) -> str: # Any letter/number token that looks like an option -TOKEN_PATTERN = re.compile(r"(?】]*[\)\.:]?(?![\w+\-/])", + re.IGNORECASE, +) # Leading option token like "B. Answer text" or "C) ..." at the start of the response LEADING_OPTION_PATTERN = re.compile( r"^\s*(?:>\s*)?(?:(?:[-*+]\s+)|(?:\d{1,3}[.)]\s+))?\s*" # blockquote / list prefixes - r"(?:[*_`~]+)?\s*\(?\s*([A-Za-z]|\d{1,2})\s*[\)\.:]\s*\)?\s*(?:[*_`~]+)?\s*(?!\w)", # markdown wrappers + r"(?:[*_`~]+)?\s*\(?\s*([A-Za-z]|\d{1,2})\s*" # markdown wrappers before the option + r"(?:" + r"[\)\.:]\s*\)?\s*(?:[*_`~]+)?\s*" # B. Answer text / C) ... + r"|" + r"(?=\s*(?:\(|[-–—]))" # A (Answer text) / A - Answer text / A – Answer text + r")" + r"(?!\w)", + re.IGNORECASE, +) + +# Standalone final-line option token like "C", "(C)", or "\boxed{C}". +TERMINAL_OPTION_LINE_PATTERN = re.compile( + r"^\s*(?:>\s*)?(?:(?:[-*+]\s+)|(?:\d{1,3}[.)]\s+))?\s*" + r"(?:\\boxed\{\s*)?(?:\s*)?" + r"[\(\[\{<【]*\s*[*_`~]*\s*(?P[A-Za-z]|\d{1,2})\s*[*_`~]*[\)\]\}>】]*" + r"\s*(?:\s*)?(?:\}\s*)?\s*[.!?]?\s*$", re.IGNORECASE, ) -# Negation words that invalidate nearby matches -NEGATION_PATTERN = re.compile(r"\b(?:not|isn['’]t)\b", re.IGNORECASE) +FINAL_CLAUSE_TERMINAL_OPTION_RE = re.compile( + r"(?[A-Za-z]|\d{1,2})\s*[*_`~]*[\)\]\}>】]*\s*[.!?]?\s*$", + re.IGNORECASE, +) # Negation/correction phrases that immediately precede an option or answer text NEGATION_BEFORE_MATCH_PATTERN = re.compile( @@ -172,6 +192,13 @@ def _remove_think_tags(completion_text: str) -> str: re.IGNORECASE, ) +CONTRAST_PATTERN = re.compile( + r"\b(?:but|however|instead(?!\s+of\b))\b" + r".{0,40}?" + r"(? bool: current = _norm_letter( match.group("opt") if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex else match.group(1) ) - contrast_pattern = re.compile( - r"\b(?:but|however|instead|rather)\b.{0,40}?(? str: return tail +def _last_nonempty_line(text: str) -> str: + """Return the last non-empty line, if any.""" + for line in reversed((text or "").splitlines()): + if line.strip(): + return line.strip() + return "" + + +def _option_candidate_invalid(text: str, match: re.Match) -> bool: + """Return True if an option-like match is negated or contradicted in local context.""" + return _negated_near(text, match) or _negative_after_option(text, match) or _contradicted_by_later_option(text, match) + + +def _ignore_prior_option_like_token(prefix: str, prior_match: re.Match) -> bool: + """Ignore harmless single-letter artifacts before a terminal final-clause answer. + + This is limited to natural-language cases like: + - leading pronoun "I" + - article "a" before a normal word + - trailing "'s" in contractions like "it's" + """ + raw = prior_match.group(1).casefold() + if raw == "i" and prior_match.start() == 0: + return True + if raw == "a" and re.match(r"\s+[a-z]{2,}\b", prefix[prior_match.end() :]): + return True + if raw == "s" and prior_match.start() > 0 and prefix[prior_match.start() - 1] in {"'", "’"}: + return True + return False + + +def _extract_terminal_option_line(line: str) -> Optional[str]: + """Extract a standalone option token from the last line.""" + if not line: + return None + + match = TERMINAL_OPTION_LINE_PATTERN.fullmatch(line) + if match: + predicted = _norm_letter(match.group("opt")) + if predicted is None: + return None + + tokens = list(TOKEN_PATTERN.finditer(line)) + if len(tokens) != 1: + return None + + token_match = tokens[0] + if _option_candidate_invalid(line, token_match): + return None + + return predicted + + leading_match = LEADING_OPTION_PATTERN.match(line) + if not leading_match or _is_compact_multi_option_list(line): + return None + + predicted = _norm_letter(leading_match.group(1)) + if predicted is None: + return None + + if _option_candidate_invalid(line, leading_match): + return None + + return predicted + + +def _extract_short_final_clause_option(text: str, max_words: int = 12) -> Optional[str]: + """Extract a terminal option token from a short final clause like 'I think it's C'.""" + clause = _tail_region(text).strip() + if not clause or len(clause.split()) > max_words: + return None + if _is_compact_multi_option_list(clause): + return None + + match = FINAL_CLAUSE_TERMINAL_OPTION_RE.search(clause) + if not match: + return None + + token_match = match + if _option_candidate_invalid(clause, token_match): + return None + + # Reject short clauses that contain another meaningful option token before the final token. + prefix = clause[:token_match.start()] + for prior_match in TOKEN_PATTERN.finditer(prefix): + token = _norm_letter(prior_match.group(1)) + if token is None: + continue + + if _ignore_prior_option_like_token(prefix, prior_match): + continue + + return None + + return _norm_letter(match.group("opt")) + + +# Connector words that join two option tokens into a multi-answer phrase. +# Catches "A and C", "A to C", "A through C", "neither A nor C", etc. +_MULTI_ANSWER_CONNECTOR_WORD_RE = re.compile( + r"\b(?:and|or|nor|to|through|then|plus)\b" + r"|\bas\s+well\s+as\b" + r"|\btogether\s+with\b" + r"|\bfollowed\s+by\b", + re.IGNORECASE, +) + +def _anchored_match_in_multi_answer_phrase(text: str, matches: list[re.Match], idx: int) -> bool: + """Return True if anchored match *idx* is part of a local multi-answer phrase.""" + match = matches[idx] + current = _norm_letter(match.group("opt")) + if current is None: + return False + + if idx > 0: + between = text[matches[idx - 1].end() : match.start()] + if len(between.split()) <= 5 and _MULTI_ANSWER_CONNECTOR_WORD_RE.search(between): + return True + + if idx < len(matches) - 1: + between = text[match.end() : matches[idx + 1].start()] + if len(between.split()) <= 5 and _MULTI_ANSWER_CONNECTOR_WORD_RE.search(between): + return True + + pre_text = text[max(0, match.start() - 20) : match.start()] + if bool( + re.search( + r"(?]*" + r"[\s,;]*" + r"(?:and|or|nor|to|through|then|plus)\s*$", + pre_text, + re.IGNORECASE, + ) + ): + return True + + sentence_start, sentence_end, match_start, match_end = _get_sentence_containing_match(text, match) + sentence = text[sentence_start:sentence_end] + local_match_start = match_start - sentence_start + local_match_end = match_end - sentence_start + + sentence_tokens = [] + for token_match in TOKEN_PATTERN.finditer(sentence): + token = _norm_letter(token_match.group(1)) + if token is None: + continue + sentence_tokens.append((token, token_match)) + + for token, token_match in sentence_tokens: + if token == current: + continue + between = "" + if token_match.end() <= local_match_start: + between = sentence[token_match.end() : local_match_start] + elif local_match_end <= token_match.start(): + between = sentence[local_match_end : token_match.start()] + if not between: + continue + if len(between.split()) <= 5 and _MULTI_ANSWER_CONNECTOR_WORD_RE.search(between): + return True + + return False + + def _is_compact_multi_option_list(text: str) -> bool: """Return True for short multi-option tails like 'A, C' or '> **A** and C'.""" text = (text or "").strip() @@ -276,6 +463,32 @@ def _is_compact_multi_option_list(text: str) -> bool: return residue.strip() == "" +def _contains_multiple_option_led_sentences(text: str, answer_letter: str) -> bool: + """Return True when different sentences/lines each start with different option labels. + + This catches payloads like "(A) ... . (D) ..." or "A. ...\\nD. ...", which should + not be accepted for a single-answer MCQ unless a later anchored final answer overrides + them. + """ + + text = text or "" + distinct: set[str] = set() + starts = [0] + starts.extend(match.end() for match in SENTENCE_BOUNDARY.finditer(text)) + + for start in starts: + match = LEADING_OPTION_PATTERN.match(text[start:]) + if not match: + continue + token = _norm_letter(match.group(1)) + if token is None or not _token_kind_matches_answer_letter(token, answer_letter): + continue + distinct.add(token) + if len(distinct) > 1: + return True + return False + + def multiple_choice_accuracy( llm_answer: str, answer_letter: str, @@ -290,7 +503,7 @@ def multiple_choice_accuracy( 1. Direct answer: Response is just the option letter/number 2. Anchored token: Use the last occurrence of a provided prefix, otherwise general anchor phrases - 3. Last token: Take the last letter/number found anywhere + 3. Last token: Parse a terminal option line or short final clause near the end 4. Answer text: Match the full answer text (if long enough) Args: @@ -340,13 +553,14 @@ def _result( raise ValueError(f"Invalid answer_letter '{answer_letter=}'. Must be a single letter or digit string.") explicit_choice_found = False + multiple_option_led_sentences = _contains_multiple_option_led_sentences(llm_answer_original, answer_letter) # Strategy 1: Only answer letter anywhere (without anchoring) if answer_letter == _norm_letter(llm_answer): return _result(True, "direct_answer", llm_answer, answer_letter, return_details) # Strategy 2: Accept leading option token like "B. answer ..." - leading_match = LEADING_OPTION_PATTERN.match(llm_answer_original) + leading_match = None if multiple_option_led_sentences else LEADING_OPTION_PATTERN.match(llm_answer_original) if leading_match and answer_letter: predicted = _norm_letter(leading_match.group(1)) if _token_kind_matches_answer_letter(predicted, answer_letter): @@ -368,40 +582,40 @@ def _result( anchored_matches = prefix_matches if prefix_matches else list(ANCHOR_PATTERN.finditer(llm_answer)) if anchored_matches and answer_letter: - last_match = anchored_matches[-1] - predicted = _norm_letter(last_match.group("opt")) - contradicted = _contradicted_by_later_option(llm_answer, last_match) - if ( - last_match.group("neg") is None - and not contradicted - and _token_kind_matches_answer_letter(predicted, answer_letter) - ): - explicit_choice_found = True - if predicted == answer_letter and last_match.group("neg") is None and not contradicted: - return _result(True, "anchored_token", predicted, answer_letter, return_details) + for idx in range(len(anchored_matches) - 1, -1, -1): + match = anchored_matches[idx] + predicted = _norm_letter(match.group("opt")) + if predicted is None: + continue + if match.group("neg") is not None: + continue + if _contradicted_by_later_option(llm_answer, match): + continue + if _anchored_match_in_multi_answer_phrase(llm_answer, anchored_matches, idx): + continue + + if _token_kind_matches_answer_letter(predicted, answer_letter): + explicit_choice_found = True + if predicted == answer_letter: + return _result(True, "anchored_token", predicted, answer_letter, return_details) + break + + # Strategy 4: Parse a terminal option line or short final clause near the end. + if not explicit_choice_found and answer_letter and not multiple_option_led_sentences: + predicted = _extract_terminal_option_line(_last_nonempty_line(llm_answer)) + if predicted == answer_letter: + return _result(True, "last_token", predicted, answer_letter, return_details) - # Strategy 4: Last token in the answer tail, ignore negative contexts like "C is incorrect", - if not explicit_choice_found and answer_letter: - tail = _tail_region(llm_answer) - tail_tokens = [] if _is_compact_multi_option_list(tail) else list(TOKEN_PATTERN.finditer(tail)) - if tail_tokens: - # Take the last non-negated, non-negative-context token. - for token_match in reversed(tail_tokens): - predicted = _norm_letter(token_match.group(1)) - if predicted is None: - continue - if _negated_near(tail, token_match): - continue - if _negative_after_option(tail, token_match): - continue - if _contradicted_by_later_option(tail, token_match): - continue - if predicted == answer_letter: - return _result(True, "last_token", predicted, answer_letter, return_details) + predicted = _extract_short_final_clause_option(llm_answer) + if predicted == answer_letter: + return _result(True, "last_token", predicted, answer_letter, return_details) # Strategy 5: Exact answer text match if there's no explicit choice found # Only search at beginning and end to avoid matching reasoning in the middle if accept_answer_text and answer_text and not explicit_choice_found: + if multiple_option_led_sentences: + return _result(False, "none", None, None, return_details) + # Calculate search regions based on token count answer_tokens = len(answer_text.split()) buffer_tokens = answer_tokens + 15 # Extra tokens for preamble like "The answer is:" diff --git a/tests/test_mcq_accuracy.py b/tests/test_mcq_accuracy.py index dd4329c6..efc302b7 100644 --- a/tests/test_mcq_accuracy.py +++ b/tests/test_mcq_accuracy.py @@ -440,6 +440,21 @@ def test_leading_option_with_no_and_punctuation_should_pass(): assert multiple_choice_accuracy("B) No.", answer_letter="B", answer_text="No") +@pytest.mark.parametrize( + ("response", "answer_text"), + [ + ("A (Nadolol)", "Nadolol"), + ("A - Nadolol", "Nadolol"), + ("A – Nadolol", "Nadolol"), + ], +) +def test_leading_option_with_parenthetical_or_dash_answer_text(response: str, answer_text: str): + result = multiple_choice_accuracy(response, answer_letter="A", answer_text=answer_text, return_details=True) + assert result.is_correct is True + assert result.method == "anchored_token" + assert result.matched_answer == "A" + + def test_last_token_negation_same_sentence_blocks(): # No anchor phrase, so it falls to last_token. # Because "Not" is in the same sentence, the final "C" should be blocked. @@ -474,6 +489,105 @@ def test_anchored_token_contradicted_by_later_option_blocks(): assert not multiple_choice_accuracy(response, answer_letter="C", answer_text="Option C") +def test_anchored_token_instead_correction_blocks(): + response = "Answer: C, instead D is correct." + assert not multiple_choice_accuracy(response, answer_letter="C", answer_text="Option C") + + +def test_anchored_token_instead_of_preference_does_not_block(): + response = "Answer: C instead of D." + assert multiple_choice_accuracy(response, answer_letter="C", answer_text="Option C") + assert not multiple_choice_accuracy(response, answer_letter="D", answer_text="Option D") + + +def test_multi_answer_anchors_elsewhere_do_not_poison_final_anchor(): + response = "Option A and Option C were considered earlier. Final answer: B" + result = multiple_choice_accuracy(response, answer_letter="B", answer_text="Option B", return_details=True) + assert result.is_correct is True + assert result.method == "anchored_token" + assert result.matched_answer == "B" + + +def test_multi_answer_anchors_elsewhere_do_not_allow_later_tail_token_override(): + response = "Option A and Option C are wrong. Final answer: B. D" + assert multiple_choice_accuracy(response, answer_letter="B", answer_text="Option B") + assert not multiple_choice_accuracy(response, answer_letter="D", answer_text="Option D") + + +@pytest.mark.parametrize( + "response", + [ + ( + "Option (A) and Option (I) are both correct statements concerning feeding for this patient, but " + "since the prompt asks for a singular choice that is true, the most directly relevant and universally " + "accepted principle would be (A) Enteral nutrition may decrease infection due to the prevention of " + "bacterial translocation, highlighting a key benefit of enteral feeding in acute pancreatitis management." + ), + ( + "Answer: option A and option I are both correct. If I must pick one, I would lean toward A because " + "enteral nutrition may decrease infection due to the prevention of bacterial translocation." + ), + ( + "Option A as well as option I are valid here. The better-supported statement is A: " + "Enteral nutrition may decrease infection due to the prevention of bacterial translocation." + ), + ( + "Choice A or choice I could both be defended. The most directly relevant principle would be A " + "(Enteral nutrition may decrease infection due to the prevention of bacterial translocation)." + ), + ( + "Selected options: A and I. Since only one answer is requested, I would prefer A - " + "Enteral nutrition may decrease infection due to the prevention of bacterial translocation." + ), + ( + "Option (A), together with option (I), is correct for feeding in severe acute pancreatitis; " + "among them, (A) Enteral nutrition may decrease infection due to the prevention of bacterial " + "translocation is the most important principle." + ), + ], +) +def test_answer_text_fallback_allows_disambiguated_multi_candidate_payloads(response: str): + result_a = multiple_choice_accuracy( + response, + answer_letter="A", + answer_text="Enteral nutrition may decrease infection due to the prevention of bacterial translocation.", + accept_answer_text=True, + return_details=True, + ) + assert result_a.is_correct is True + assert result_a.method == "answer_text" + + result_i = multiple_choice_accuracy( + response, + answer_letter="I", + answer_text="Feeding should begin within 24-48 hours.", + accept_answer_text=True, + return_details=True, + ) + assert result_i.is_correct is False + assert result_i.method == "none" + + +@pytest.mark.parametrize( + "response", + [ + "(A) Naloxone is a synthetic N-allyl derivative of oxymorphone. (D) Naloxone is not rapidly absorbed after oral administration.", + "A. Naloxone is a synthetic N-allyl derivative of oxymorphone.\nD. Naloxone is not rapidly absorbed after oral administration.", + "(A) First statement. (C) Second statement.", + ], +) +def test_answer_text_fallback_rejects_multiple_option_led_sentences(response: str): + result_a = multiple_choice_accuracy( + response, + answer_letter="A", + answer_text="Naloxone is a synthetic N-allyl derivative of oxymorphone.", + accept_answer_text=True, + return_details=True, + ) + assert result_a.is_correct is False + assert result_a.method == "none" + + def test_answer_text_does_not_override_explicit_wrong_choice(): response = ( "The other options do not account for the renal findings as well:\n" From 97d3282fedc6bfddcfd4574fe2646cbbdd1edb18 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 17 Mar 2026 00:05:58 -0400 Subject: [PATCH 25/29] performance fix for long answers --- .../rewards/multiple_choice_accuracy.py | 16 +++++++------- tests/test_mcq_accuracy.py | 21 ++++++++++++++++++- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index ac18c61c..dedb89c1 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -153,19 +153,22 @@ def _remove_think_tags(completion_text: str) -> str: re.IGNORECASE, ) -# Leading option token like "B. Answer text" or "C) ..." at the start of the response -LEADING_OPTION_PATTERN = re.compile( - r"^\s*(?:>\s*)?(?:(?:[-*+]\s+)|(?:\d{1,3}[.)]\s+))?\s*" # blockquote / list prefixes +_LEADING_OPTION_PATTERN_BODY = ( + r"\s*(?:>\s*)?(?:(?:[-*+]\s+)|(?:\d{1,3}[.)]\s+))?\s*" # blockquote / list prefixes r"(?:[*_`~]+)?\s*\(?\s*([A-Za-z]|\d{1,2})\s*" # markdown wrappers before the option r"(?:" r"[\)\.:]\s*\)?\s*(?:[*_`~]+)?\s*" # B. Answer text / C) ... r"|" r"(?=\s*(?:\(|[-–—]))" # A (Answer text) / A - Answer text / A – Answer text r")" - r"(?!\w)", - re.IGNORECASE, + r"(?!\w)" ) +# Leading option token like "B. Answer text" or "C) ..." at the start of the response +LEADING_OPTION_PATTERN = re.compile(rf"^{_LEADING_OPTION_PATTERN_BODY}", re.IGNORECASE) +# Same pattern without ^ so we can match from sentence offsets without slicing `text[start:]`. +SENTENCE_LEADING_OPTION_PATTERN = re.compile(_LEADING_OPTION_PATTERN_BODY, re.IGNORECASE) + # Standalone final-line option token like "C", "(C)", or "\boxed{C}". TERMINAL_OPTION_LINE_PATTERN = re.compile( r"^\s*(?:>\s*)?(?:(?:[-*+]\s+)|(?:\d{1,3}[.)]\s+))?\s*" @@ -475,9 +478,8 @@ def _contains_multiple_option_led_sentences(text: str, answer_letter: str) -> bo distinct: set[str] = set() starts = [0] starts.extend(match.end() for match in SENTENCE_BOUNDARY.finditer(text)) - for start in starts: - match = LEADING_OPTION_PATTERN.match(text[start:]) + match = SENTENCE_LEADING_OPTION_PATTERN.match(text, pos=start) if not match: continue token = _norm_letter(match.group(1)) diff --git a/tests/test_mcq_accuracy.py b/tests/test_mcq_accuracy.py index efc302b7..c643a9ca 100644 --- a/tests/test_mcq_accuracy.py +++ b/tests/test_mcq_accuracy.py @@ -1,8 +1,14 @@ """Tests for the simplified MCQ accuracy grader.""" +import time + import pytest -from medarc_verifiers.rewards.multiple_choice_accuracy import MCQAccuracyResult, multiple_choice_accuracy +from medarc_verifiers.rewards.multiple_choice_accuracy import ( + MCQAccuracyResult, + _contains_multiple_option_led_sentences, + multiple_choice_accuracy, +) def test_anchored_final_answer_colon(): @@ -588,6 +594,19 @@ def test_answer_text_fallback_rejects_multiple_option_led_sentences(response: st assert result_a.method == "none" +def test_multiple_option_led_sentence_scan_handles_large_payload_linearly(): + response = ("Reasoning sentence with details. " * 12000) + "Final answer: C" + started = time.perf_counter() + assert _contains_multiple_option_led_sentences(response, answer_letter="C") is False + elapsed = time.perf_counter() - started + assert elapsed < 1.0 + + +def test_large_reasoning_payload_still_accepts_final_answer(): + response = ("Reasoning sentence with details. " * 12000) + "Final answer: C" + assert multiple_choice_accuracy(response, answer_letter="C", answer_text="Option C") + + def test_answer_text_does_not_override_explicit_wrong_choice(): response = ( "The other options do not account for the renal findings as well:\n" From 044f0b6f5f9cc9b03d51a6f4f58e25352869e10b Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 17 Mar 2026 11:04:57 -0400 Subject: [PATCH 26/29] only use tex stripping if we detect tex-like text --- .../rewards/multiple_choice_accuracy.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index dedb89c1..d09bd532 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -12,6 +12,7 @@ import re import unicodedata from dataclasses import dataclass +from functools import lru_cache from typing import Optional @@ -35,6 +36,7 @@ _WHITESPACE_RE = re.compile(r"\s+") _INNER_PUNCT_SPACING_RE = re.compile(r"\s*([()\[\]{}.,;:])\s*") _TRAILING_TERMINAL_PUNCT_RE = re.compile(r"(?<=\w)[.!?,;:]+$") +_LIKELY_TEX_RE = re.compile(r"\\[A-Za-z]+|\\[$\\()\\[\\]{}]|[$]") @dataclass @@ -73,12 +75,20 @@ def normalize_for_answer_text_match(text: str) -> str: return _TRAILING_TERMINAL_PUNCT_RE.sub("", text) +@lru_cache(maxsize=1) +def _latex_to_text_converter(): + from pylatexenc.latex2text import LatexNodes2Text + + return LatexNodes2Text(math_mode="text") + + def _strip_tex(text: str) -> str: """Remove LaTeX formatting if pylatexenc is available.""" - try: - from pylatexenc.latex2text import LatexNodes2Text + if not text or not _LIKELY_TEX_RE.search(text): + return text - return LatexNodes2Text(math_mode="text").latex_to_text(text) + try: + return _latex_to_text_converter().latex_to_text(text) except Exception: return text From ac4fcb1e1ee4da90ee7ed82114caf9009d69041b Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Tue, 17 Mar 2026 11:27:29 -0400 Subject: [PATCH 27/29] temp performance checking code --- .../rewards/multiple_choice_accuracy.py | 90 ++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index d09bd532..44ec2c8e 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -10,9 +10,13 @@ """ import re +import os +import sys +import time import unicodedata from dataclasses import dataclass from functools import lru_cache +from functools import wraps from typing import Optional @@ -39,6 +43,61 @@ _LIKELY_TEX_RE = re.compile(r"\\[A-Za-z]+|\\[$\\()\\[\\]{}]|[$]") +def _mcq_perf_trace_enabled() -> bool: + return os.getenv("MEDARC_MCQ_PERF_TRACE", "").strip().lower() in {"1", "true", "yes", "on"} + + +def _mcq_perf_trace_min_seconds() -> float: + raw = os.getenv("MEDARC_MCQ_PERF_TRACE_MIN_MS", "").strip() + if not raw: + return 0.0 + try: + return max(float(raw) / 1000.0, 0.0) + except ValueError: + return 0.0 + + +def _mcq_perf_trace_summary(args: tuple, kwargs: dict) -> str: + parts: list[str] = [] + for idx, value in enumerate(args[:3]): + if isinstance(value, str): + parts.append(f"arg{idx}_len={len(value)}") + elif isinstance(value, re.Match): + try: + start, end = value.span() + parts.append(f"arg{idx}_span={start}:{end}") + except Exception: + parts.append(f"arg{idx}=match") + elif isinstance(value, list): + parts.append(f"arg{idx}_len={len(value)}") + if "answer_letter" in kwargs and isinstance(kwargs["answer_letter"], str): + parts.append(f"answer_letter={kwargs['answer_letter']!r}") + return " ".join(parts) + + +def _trace_scan_perf(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not _mcq_perf_trace_enabled(): + return func(*args, **kwargs) + + started = time.perf_counter() + try: + return func(*args, **kwargs) + finally: + elapsed = time.perf_counter() - started + if elapsed >= _mcq_perf_trace_min_seconds(): + summary = _mcq_perf_trace_summary(args, kwargs) + print( + f"[mcq-perf] {func.__name__} elapsed_ms={elapsed * 1000:.3f}" + + (f" {summary}" if summary else ""), + file=sys.stderr, + flush=True, + ) + + return wrapper + + @dataclass class MCQAccuracyResult: """Result of multiple-choice accuracy grading.""" @@ -82,6 +141,7 @@ def _latex_to_text_converter(): return LatexNodes2Text(math_mode="text") +@_trace_scan_perf def _strip_tex(text: str) -> str: """Remove LaTeX formatting if pylatexenc is available.""" if not text or not _LIKELY_TEX_RE.search(text): @@ -123,6 +183,7 @@ def _token_kind_matches_answer_letter(predicted: Optional[str], answer_letter: s _THINK_CLOSE_RE = re.compile(r"", re.IGNORECASE) +@_trace_scan_perf def _remove_think_tags(completion_text: str) -> str: """Extract the answer section from completion text, handling think tags properly. @@ -229,7 +290,10 @@ def _remove_think_tags(completion_text: str) -> str: re.IGNORECASE | re.VERBOSE, ) +_MULTIPLE_OPTION_LED_SCAN_MAX_CHARS = 8000 + +@_trace_scan_perf def _get_sentence_containing_match(text: str, match: re.Match) -> str: """Return (sentence_start, sentence_end, match_start, match_end) in the original text.""" if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex: @@ -248,6 +312,7 @@ def _get_sentence_containing_match(text: str, match: re.Match) -> str: return sentence_start, sentence_end, match_start, match_end +@_trace_scan_perf def _negated_near(text: str, match: re.Match) -> bool: """Check for negation that appears before the match within the same sentence. @@ -260,6 +325,7 @@ def _negated_near(text: str, match: re.Match) -> bool: return bool(NEGATION_BEFORE_MATCH_PATTERN.search(prefix)) +@_trace_scan_perf def _negative_after_option(text: str, match: re.Match) -> bool: """Check if an option token is immediately followed by negative context like 'C is incorrect'.""" _sentence_start, sentence_end, _match_start, match_end = _get_sentence_containing_match(text, match) @@ -267,6 +333,7 @@ def _negative_after_option(text: str, match: re.Match) -> bool: return bool(NEGATIVE_AFTER_OPTION_PATTERN.search(suffix)) +@_trace_scan_perf def _contradicted_by_later_option(text: str, match: re.Match) -> bool: """Check for same-sentence corrections like 'C, but D is correct' or 'C rather than D'.""" _sentence_start, sentence_end, _match_start, match_end = _get_sentence_containing_match(text, match) @@ -281,6 +348,7 @@ def _contradicted_by_later_option(text: str, match: re.Match) -> bool: return contrasted is not None and contrasted != current +@_trace_scan_perf def _tail_region(text: str, max_tokens: int = 64) -> str: """Return a short tail slice (last sentence/line) to reduce option-token noise.""" boundaries = list(SENTENCE_BOUNDARY.finditer(text)) @@ -299,6 +367,7 @@ def _tail_region(text: str, max_tokens: int = 64) -> str: return tail +@_trace_scan_perf def _last_nonempty_line(text: str) -> str: """Return the last non-empty line, if any.""" for line in reversed((text or "").splitlines()): @@ -330,6 +399,7 @@ def _ignore_prior_option_like_token(prefix: str, prior_match: re.Match) -> bool: return False +@_trace_scan_perf def _extract_terminal_option_line(line: str) -> Optional[str]: """Extract a standalone option token from the last line.""" if not line: @@ -365,6 +435,7 @@ def _extract_terminal_option_line(line: str) -> Optional[str]: return predicted +@_trace_scan_perf def _extract_short_final_clause_option(text: str, max_words: int = 12) -> Optional[str]: """Extract a terminal option token from a short final clause like 'I think it's C'.""" clause = _tail_region(text).strip() @@ -406,6 +477,7 @@ def _extract_short_final_clause_option(text: str, max_words: int = 12) -> Option re.IGNORECASE, ) +@_trace_scan_perf def _anchored_match_in_multi_answer_phrase(text: str, matches: list[re.Match], idx: int) -> bool: """Return True if anchored match *idx* is part of a local multi-answer phrase.""" match = matches[idx] @@ -463,6 +535,7 @@ def _anchored_match_in_multi_answer_phrase(text: str, matches: list[re.Match], i return False +@_trace_scan_perf def _is_compact_multi_option_list(text: str) -> bool: """Return True for short multi-option tails like 'A, C' or '> **A** and C'.""" text = (text or "").strip() @@ -476,6 +549,7 @@ def _is_compact_multi_option_list(text: str) -> bool: return residue.strip() == "" +@_trace_scan_perf def _contains_multiple_option_led_sentences(text: str, answer_letter: str) -> bool: """Return True when different sentences/lines each start with different option labels. @@ -501,6 +575,7 @@ def _contains_multiple_option_led_sentences(text: str, answer_letter: str) -> bo return False +@_trace_scan_perf def multiple_choice_accuracy( llm_answer: str, answer_letter: str, @@ -565,14 +640,25 @@ def _result( raise ValueError(f"Invalid answer_letter '{answer_letter=}'. Must be a single letter or digit string.") explicit_choice_found = False - multiple_option_led_sentences = _contains_multiple_option_led_sentences(llm_answer_original, answer_letter) # Strategy 1: Only answer letter anywhere (without anchoring) if answer_letter == _norm_letter(llm_answer): return _result(True, "direct_answer", llm_answer, answer_letter, return_details) + multiple_option_led_sentences = False + leading_match = LEADING_OPTION_PATTERN.match(llm_answer_original) + if leading_match: + # Only pay for the multi-sentence scan when the response actually starts like a + # leading-option answer. For very large payloads, disable the leading-option shortcut + # rather than scanning the whole response. + if len(llm_answer_original) <= _MULTIPLE_OPTION_LED_SCAN_MAX_CHARS: + multiple_option_led_sentences = _contains_multiple_option_led_sentences(llm_answer_original, answer_letter) + if multiple_option_led_sentences: + leading_match = None + else: + leading_match = None + # Strategy 2: Accept leading option token like "B. answer ..." - leading_match = None if multiple_option_led_sentences else LEADING_OPTION_PATTERN.match(llm_answer_original) if leading_match and answer_letter: predicted = _norm_letter(leading_match.group(1)) if _token_kind_matches_answer_letter(predicted, answer_letter): From 62a5629dfb7c097f91004414e2dde502a93ba8d5 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Wed, 18 Mar 2026 12:35:43 -0400 Subject: [PATCH 28/29] refactor --- .../rewards/multiple_choice_accuracy.py | 376 ++++++++++-------- tests/test_mcq_accuracy.py | 18 + 2 files changed, 222 insertions(+), 172 deletions(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index 44ec2c8e..e4a76b75 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -44,10 +44,12 @@ def _mcq_perf_trace_enabled() -> bool: + """Return whether lightweight MCQ performance tracing is enabled.""" return os.getenv("MEDARC_MCQ_PERF_TRACE", "").strip().lower() in {"1", "true", "yes", "on"} def _mcq_perf_trace_min_seconds() -> float: + """Return the minimum elapsed time required before a helper emits a trace line.""" raw = os.getenv("MEDARC_MCQ_PERF_TRACE_MIN_MS", "").strip() if not raw: return 0.0 @@ -58,6 +60,7 @@ def _mcq_perf_trace_min_seconds() -> float: def _mcq_perf_trace_summary(args: tuple, kwargs: dict) -> str: + """Build a compact summary string for performance trace logging.""" parts: list[str] = [] for idx, value in enumerate(args[:3]): if isinstance(value, str): @@ -76,8 +79,11 @@ def _mcq_perf_trace_summary(args: tuple, kwargs: dict) -> str: def _trace_scan_perf(func): + """Wrap a helper so it can emit elapsed-time traces when tracing is enabled.""" + @wraps(func) def wrapper(*args, **kwargs): + """Execute the wrapped helper and optionally log its runtime.""" if not _mcq_perf_trace_enabled(): return func(*args, **kwargs) @@ -89,8 +95,7 @@ def wrapper(*args, **kwargs): if elapsed >= _mcq_perf_trace_min_seconds(): summary = _mcq_perf_trace_summary(args, kwargs) print( - f"[mcq-perf] {func.__name__} elapsed_ms={elapsed * 1000:.3f}" - + (f" {summary}" if summary else ""), + f"[mcq-perf] {func.__name__} elapsed_ms={elapsed * 1000:.3f}" + (f" {summary}" if summary else ""), file=sys.stderr, flush=True, ) @@ -115,6 +120,7 @@ class MCQAccuracyResult: """The correct answer for reference, if available.""" +@_trace_scan_perf def normalize_for_structure(text: str) -> str: """Canonicalize text for structural matching without collapsing whitespace.""" text = unicodedata.normalize("NFKC", text or "") @@ -122,11 +128,13 @@ def normalize_for_structure(text: str) -> str: return text.casefold() +@_trace_scan_perf def normalize_for_match(text: str) -> str: """Canonicalize text for answer-text equivalence matching.""" return _WHITESPACE_RE.sub(" ", normalize_for_structure(text)).strip() +@_trace_scan_perf def normalize_for_answer_text_match(text: str) -> str: """Canonicalize text for answer-text matching while tolerating minor punctuation drift.""" text = normalize_for_match(text) @@ -134,8 +142,10 @@ def normalize_for_answer_text_match(text: str) -> str: return _TRAILING_TERMINAL_PUNCT_RE.sub("", text) +@_trace_scan_perf @lru_cache(maxsize=1) def _latex_to_text_converter(): + """Construct and cache the pylatexenc converter used for TeX stripping.""" from pylatexenc.latex2text import LatexNodes2Text return LatexNodes2Text(math_mode="text") @@ -153,6 +163,7 @@ def _strip_tex(text: str) -> str: return text +@_trace_scan_perf def _norm_letter(letter: str) -> Optional[str]: """Normalize a token to uppercase letter or digit string.""" letter = (letter or "").strip() @@ -165,6 +176,7 @@ def _norm_letter(letter: str) -> Optional[str]: return None +@_trace_scan_perf def _token_kind_matches_answer_letter(predicted: Optional[str], answer_letter: str) -> bool: """Return True if predicted token type matches the task's option type. @@ -225,19 +237,14 @@ def _remove_think_tags(completion_text: str) -> str: ) _LEADING_OPTION_PATTERN_BODY = ( - r"\s*(?:>\s*)?(?:(?:[-*+]\s+)|(?:\d{1,3}[.)]\s+))?\s*" # blockquote / list prefixes - r"(?:[*_`~]+)?\s*\(?\s*([A-Za-z]|\d{1,2})\s*" # markdown wrappers before the option - r"(?:" - r"[\)\.:]\s*\)?\s*(?:[*_`~]+)?\s*" # B. Answer text / C) ... - r"|" - r"(?=\s*(?:\(|[-–—]))" # A (Answer text) / A - Answer text / A – Answer text - r")" - r"(?!\w)" + r"\s*(?:>\s*)?(?:(?:[-*+]\s+)|(?:\d{1,3}[.)]\s+))?\s*" + r"(?:[*_`~]+)?\s*\(?\s*([A-Za-z]|\d{1,2})\s*" + r"(?:[\)\.:]\s*\)?\s*(?:[*_`~]+)?\s*|(?=\s*(?:\(|[-–—])))(?!\w)" ) # Leading option token like "B. Answer text" or "C) ..." at the start of the response LEADING_OPTION_PATTERN = re.compile(rf"^{_LEADING_OPTION_PATTERN_BODY}", re.IGNORECASE) -# Same pattern without ^ so we can match from sentence offsets without slicing `text[start:]`. +# Same pattern without ^ so we can match from sentence offsets without slicing text. SENTENCE_LEADING_OPTION_PATTERN = re.compile(_LEADING_OPTION_PATTERN_BODY, re.IGNORECASE) # Standalone final-line option token like "C", "(C)", or "\boxed{C}". @@ -277,24 +284,28 @@ def _remove_think_tags(completion_text: str) -> str: # Handles both single newlines (for line breaks in CoT) and double newlines (paragraphs) SENTENCE_BOUNDARY = re.compile(r"[.!?]\s+|\n+") +_MULTI_OPTION_CONNECTOR_WORD_PATTERN = ( + r"\b(?:and|or|nor|to|through|then|plus|both|y|e|ou|und|et)\b" + r"|\bas\s+well\s+as\b" + r"|\btogether\s+with\b" + r"|\bfollowed\s+by\b" + r"|\bcorrect\ choices?\s+are\b" + r"|\bchoices?\s+are\b" +) +_MULTI_OPTION_CONNECTOR_RE = re.compile(_MULTI_OPTION_CONNECTOR_WORD_PATTERN, re.IGNORECASE) + # Compact-list glue that should cause the last-token fallback to reject a tail as # multi-answer rather than selecting the final option. COMPACT_MULTI_OPTION_GLUE_PATTERN = re.compile( - r""" - \b(?:and|or|both|y|e|ou|und|et|plus)\b - | - \b(?:as\ well\ as|together\ with|followed\ by|correct\ choices?\s+are|choices?\s+are)\b - | - [,:;/&+\-|] - """, - re.IGNORECASE | re.VERBOSE, + rf"(?:{_MULTI_OPTION_CONNECTOR_WORD_PATTERN}|[,:;/&+\-|])", + re.IGNORECASE, ) -_MULTIPLE_OPTION_LED_SCAN_MAX_CHARS = 8000 +_MULTIPLE_OPTION_LED_SCAN_MAX_CHARS = 10_000 @_trace_scan_perf -def _get_sentence_containing_match(text: str, match: re.Match) -> str: +def _get_sentence_containing_match(text: str, match: re.Match) -> tuple[int, int, int, int]: """Return (sentence_start, sentence_end, match_start, match_end) in the original text.""" if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex: match_start, match_end = match.span("opt") @@ -312,40 +323,59 @@ def _get_sentence_containing_match(text: str, match: re.Match) -> str: return sentence_start, sentence_end, match_start, match_end -@_trace_scan_perf -def _negated_near(text: str, match: re.Match) -> bool: - """Check for negation that appears before the match within the same sentence. +@dataclass +class _SentenceMatchContext: + """Sentence-local context around a matched option or answer-text span.""" - This is used for answer_text matching to avoid blocking answers that legitimately contain - words like "not" (e.g., "do not resuscitate") while still blocking cases like - "not ". - """ - sentence_start, sentence_end, match_start, _match_end = _get_sentence_containing_match(text, match) - prefix = text[sentence_start:match_start] - return bool(NEGATION_BEFORE_MATCH_PATTERN.search(prefix)) + prefix: str + suffix: str + token: Optional[str] @_trace_scan_perf -def _negative_after_option(text: str, match: re.Match) -> bool: - """Check if an option token is immediately followed by negative context like 'C is incorrect'.""" - _sentence_start, sentence_end, _match_start, match_end = _get_sentence_containing_match(text, match) - suffix = text[match_end:sentence_end] - return bool(NEGATIVE_AFTER_OPTION_PATTERN.search(suffix)) +def _match_token(match: re.Match) -> Optional[str]: + """Extract and normalize the option token captured by a regex match.""" + if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex: + return _norm_letter(match.group("opt")) + try: + return _norm_letter(match.group(1)) + except Exception: + return None @_trace_scan_perf -def _contradicted_by_later_option(text: str, match: re.Match) -> bool: - """Check for same-sentence corrections like 'C, but D is correct' or 'C rather than D'.""" - _sentence_start, sentence_end, _match_start, match_end = _get_sentence_containing_match(text, match) - suffix = text[match_end:sentence_end] - current = _norm_letter( - match.group("opt") if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex else match.group(1) +def _sentence_match_context(text: str, match: re.Match) -> _SentenceMatchContext: + """Return the same-sentence prefix, suffix, and normalized token for a regex match.""" + sentence_start, sentence_end, match_start, match_end = _get_sentence_containing_match(text, match) + return _SentenceMatchContext( + prefix=text[sentence_start:match_start], + suffix=text[match_end:sentence_end], + token=_match_token(match), ) - later = CONTRAST_PATTERN.search(suffix) + + +@_trace_scan_perf +def _match_is_negated(context: _SentenceMatchContext) -> bool: + """Return True when a negation phrase appears before the match in the same sentence.""" + return bool(NEGATION_BEFORE_MATCH_PATTERN.search(context.prefix)) + + +@_trace_scan_perf +def _match_has_negative_suffix(context: _SentenceMatchContext) -> bool: + """Return True when the match is immediately followed by rejecting language.""" + return bool(NEGATIVE_AFTER_OPTION_PATTERN.search(context.suffix)) + + +@_trace_scan_perf +def _match_is_contradicted(context: _SentenceMatchContext) -> bool: + """Return True when a later contrast in the sentence points to a different option.""" + if context.token is None: + return False + later = CONTRAST_PATTERN.search(context.suffix) if not later: return False contrasted = _norm_letter(later.group(1)) - return contrasted is not None and contrasted != current + return contrasted is not None and contrasted != context.token @_trace_scan_perf @@ -376,12 +406,15 @@ def _last_nonempty_line(text: str) -> str: return "" +@_trace_scan_perf def _option_candidate_invalid(text: str, match: re.Match) -> bool: """Return True if an option-like match is negated or contradicted in local context.""" - return _negated_near(text, match) or _negative_after_option(text, match) or _contradicted_by_later_option(text, match) + context = _sentence_match_context(text, match) + return _match_is_negated(context) or _match_has_negative_suffix(context) or _match_is_contradicted(context) -def _ignore_prior_option_like_token(prefix: str, prior_match: re.Match) -> bool: +@_trace_scan_perf +def _is_harmless_prefix_option_token(prefix: str, prior_match: re.Match) -> bool: """Ignore harmless single-letter artifacts before a terminal final-clause answer. This is limited to natural-language cases like: @@ -400,126 +433,123 @@ def _ignore_prior_option_like_token(prefix: str, prior_match: re.Match) -> bool: @_trace_scan_perf -def _extract_terminal_option_line(line: str) -> Optional[str]: - """Extract a standalone option token from the last line.""" - if not line: - return None - - match = TERMINAL_OPTION_LINE_PATTERN.fullmatch(line) - if match: - predicted = _norm_letter(match.group("opt")) - if predicted is None: - return None +def _has_connector_between(text: str, max_words: int = 5) -> bool: + """Return True when a short span looks like connector text between option tokens.""" + text = text.strip() + return bool(text) and len(text.split()) <= max_words and bool(_MULTI_OPTION_CONNECTOR_RE.search(text)) - tokens = list(TOKEN_PATTERN.finditer(line)) - if len(tokens) != 1: - return None - token_match = tokens[0] - if _option_candidate_invalid(line, token_match): - return None +@_trace_scan_perf +def _normalized_option_matches(text: str) -> list[tuple[str, re.Match]]: + """Return normalized option tokens paired with their regex matches in order.""" + matches: list[tuple[str, re.Match]] = [] + for match in TOKEN_PATTERN.finditer(text): + token = _norm_letter(match.group(1)) + if token is not None: + matches.append((token, match)) + return matches + + +@lru_cache(maxsize=64) +def _prefix_pattern(prefix_norm: str) -> re.Pattern: + """Compile and cache the anchored prefix regex for a normalized answer prefix.""" + flexible_prefix = re.escape(prefix_norm).replace(r"\ ", r"\s+") + return re.compile( + rf"{flexible_prefix}\s*[:\-–—]?\s*(?:is\s*)?(?Pnot\s+|isn['’]t\s+)?\(?\s*(?P[A-Za-z]|\d{{1,2}})\s*[\)\.:]?(?![\w+\-/])", + re.IGNORECASE, + ) - return predicted - leading_match = LEADING_OPTION_PATTERN.match(line) - if not leading_match or _is_compact_multi_option_list(line): +@_trace_scan_perf +def _extract_standalone_terminal_option(region: str) -> Optional[str]: + """Extract a standalone terminal token like ``C`` or ``(C)`` from a region.""" + match = TERMINAL_OPTION_LINE_PATTERN.fullmatch(region) + if not match: return None - predicted = _norm_letter(leading_match.group(1)) + predicted = _norm_letter(match.group("opt")) if predicted is None: return None - if _option_candidate_invalid(line, leading_match): + tokens = list(TOKEN_PATTERN.finditer(region)) + if len(tokens) != 1 or _option_candidate_invalid(region, tokens[0]): return None - return predicted @_trace_scan_perf -def _extract_short_final_clause_option(text: str, max_words: int = 12) -> Optional[str]: - """Extract a terminal option token from a short final clause like 'I think it's C'.""" - clause = _tail_region(text).strip() - if not clause or len(clause.split()) > max_words: - return None - if _is_compact_multi_option_list(clause): +def _extract_leading_terminal_option(region: str) -> Optional[str]: + """Extract a leading-option form like ``C. text`` from a region.""" + leading_match = LEADING_OPTION_PATTERN.match(region) + if not leading_match: return None - match = FINAL_CLAUSE_TERMINAL_OPTION_RE.search(clause) - if not match: + predicted = _norm_letter(leading_match.group(1)) + if predicted is None or _option_candidate_invalid(region, leading_match): return None + return predicted - token_match = match - if _option_candidate_invalid(clause, token_match): + +@_trace_scan_perf +def _extract_final_clause_terminal_option(region: str) -> Optional[str]: + """Extract a final-clause token like ``I think it's C`` from a short region.""" + match = FINAL_CLAUSE_TERMINAL_OPTION_RE.search(region) + if not match or _option_candidate_invalid(region, match): return None - # Reject short clauses that contain another meaningful option token before the final token. - prefix = clause[:token_match.start()] - for prior_match in TOKEN_PATTERN.finditer(prefix): - token = _norm_letter(prior_match.group(1)) - if token is None: + prefix = region[: match.start()] + for _token, prior_match in _normalized_option_matches(prefix): + if _is_harmless_prefix_option_token(prefix, prior_match): continue + return None - if _ignore_prior_option_like_token(prefix, prior_match): - continue + return _norm_letter(match.group("opt")) + +@_trace_scan_perf +def _extract_terminal_option_line(line: str) -> Optional[str]: + """Extract a standalone option token from the last line.""" + if not line or _is_compact_multi_option_list(line): return None + predicted = _extract_standalone_terminal_option(line) + if predicted is not None: + return predicted + return _extract_leading_terminal_option(line) - return _norm_letter(match.group("opt")) +@_trace_scan_perf +def _extract_short_final_clause_option(text: str, max_words: int = 12) -> Optional[str]: + """Extract a terminal option token from a short final clause like 'I think it's C'.""" + clause = _tail_region(text).strip() + if not clause or len(clause.split()) > max_words or _is_compact_multi_option_list(clause): + return None + return _extract_final_clause_terminal_option(clause) -# Connector words that join two option tokens into a multi-answer phrase. -# Catches "A and C", "A to C", "A through C", "neither A nor C", etc. -_MULTI_ANSWER_CONNECTOR_WORD_RE = re.compile( - r"\b(?:and|or|nor|to|through|then|plus)\b" - r"|\bas\s+well\s+as\b" - r"|\btogether\s+with\b" - r"|\bfollowed\s+by\b", - re.IGNORECASE, -) @_trace_scan_perf def _anchored_match_in_multi_answer_phrase(text: str, matches: list[re.Match], idx: int) -> bool: """Return True if anchored match *idx* is part of a local multi-answer phrase.""" match = matches[idx] - current = _norm_letter(match.group("opt")) + current = _match_token(match) if current is None: return False if idx > 0: between = text[matches[idx - 1].end() : match.start()] - if len(between.split()) <= 5 and _MULTI_ANSWER_CONNECTOR_WORD_RE.search(between): + if _has_connector_between(between): return True if idx < len(matches) - 1: between = text[match.end() : matches[idx + 1].start()] - if len(between.split()) <= 5 and _MULTI_ANSWER_CONNECTOR_WORD_RE.search(between): + if _has_connector_between(between): return True - pre_text = text[max(0, match.start() - 20) : match.start()] - if bool( - re.search( - r"(?]*" - r"[\s,;]*" - r"(?:and|or|nor|to|through|then|plus)\s*$", - pre_text, - re.IGNORECASE, - ) - ): - return True - sentence_start, sentence_end, match_start, match_end = _get_sentence_containing_match(text, match) sentence = text[sentence_start:sentence_end] local_match_start = match_start - sentence_start local_match_end = match_end - sentence_start - sentence_tokens = [] - for token_match in TOKEN_PATTERN.finditer(sentence): - token = _norm_letter(token_match.group(1)) - if token is None: - continue - sentence_tokens.append((token, token_match)) - - for token, token_match in sentence_tokens: + for token, token_match in _normalized_option_matches(sentence): if token == current: continue between = "" @@ -529,7 +559,7 @@ def _anchored_match_in_multi_answer_phrase(text: str, matches: list[re.Match], i between = sentence[local_match_end : token_match.start()] if not between: continue - if len(between.split()) <= 5 and _MULTI_ANSWER_CONNECTOR_WORD_RE.search(between): + if _has_connector_between(between): return True return False @@ -539,8 +569,7 @@ def _anchored_match_in_multi_answer_phrase(text: str, matches: list[re.Match], i def _is_compact_multi_option_list(text: str) -> bool: """Return True for short multi-option tails like 'A, C' or '> **A** and C'.""" text = (text or "").strip() - matches = list(TOKEN_PATTERN.finditer(text)) - if len(matches) < 2: + if len(list(TOKEN_PATTERN.finditer(text))) < 2: return False residue = TOKEN_PATTERN.sub(" ", text) @@ -629,33 +658,41 @@ def _result( llm_answer = _strip_tex(llm_answer) answer_text = _strip_tex(answer_text) - llm_answer_original = llm_answer - - # Normalize: casefold only (preserve whitespace structure for sentence detection) - llm_answer = normalize_for_structure(llm_answer) + # Keep two views of the response: + # - structural_text preserves original spacing for sentence/line-sensitive heuristics + # - normalized_answer casefolds and normalizes punctuation for anchor/text matching + structural_text = llm_answer + normalized_answer = normalize_for_structure(llm_answer) answer_letter = _norm_letter(answer_letter) answer_text = normalize_for_match(answer_text or "") if answer_letter is None: raise ValueError(f"Invalid answer_letter '{answer_letter=}'. Must be a single letter or digit string.") + # Once we see any explicit option selection of the right token kind, we stop lower-confidence + # fallbacks from overriding it with a tail token or answer-text mention. explicit_choice_found = False # Strategy 1: Only answer letter anywhere (without anchoring) - if answer_letter == _norm_letter(llm_answer): - return _result(True, "direct_answer", llm_answer, answer_letter, return_details) + if answer_letter == _norm_letter(normalized_answer): + return _result(True, "direct_answer", normalized_answer, answer_letter, return_details) + # A response that begins like "B. ..." gets special handling: we may disable both the leading + # shortcut and later tail/text fallbacks if it actually looks like multiple labeled options. + leading_match = LEADING_OPTION_PATTERN.match(structural_text) multiple_option_led_sentences = False - leading_match = LEADING_OPTION_PATTERN.match(llm_answer_original) + if leading_match: - # Only pay for the multi-sentence scan when the response actually starts like a - # leading-option answer. For very large payloads, disable the leading-option shortcut - # rather than scanning the whole response. - if len(llm_answer_original) <= _MULTIPLE_OPTION_LED_SCAN_MAX_CHARS: - multiple_option_led_sentences = _contains_multiple_option_led_sentences(llm_answer_original, answer_letter) - if multiple_option_led_sentences: - leading_match = None - else: + # Only pay for the additional answer scan when the payload actually starts with a leading + # option pattern; otherwise we leave this guard disabled for the cheaper later paths. + multiple_option_led_sentences = len( + structural_text + ) <= _MULTIPLE_OPTION_LED_SCAN_MAX_CHARS and _contains_multiple_option_led_sentences( + structural_text, answer_letter + ) + # If the response looks like multiple labeled answer statements, do not treat the first + # label as the chosen answer. + if multiple_option_led_sentences: leading_match = None # Strategy 2: Accept leading option token like "B. answer ..." @@ -671,25 +708,22 @@ def _result( if prefix: prefix_norm = normalize_for_structure(prefix).strip() if prefix_norm: - flexible_prefix = re.escape(prefix_norm).replace(r"\ ", r"\s+") - prefix_pattern = re.compile( - rf"{flexible_prefix}\s*[:\-–—]?\s*(?:is\s*)?(?Pnot\s+|isn['’]t\s+)?\(?\s*(?P[A-Za-z]|\d{{1,2}})\s*[\)\.:]?(?![\w+\-/])", - re.IGNORECASE, - ) - prefix_matches = list(prefix_pattern.finditer(llm_answer)) - - anchored_matches = prefix_matches if prefix_matches else list(ANCHOR_PATTERN.finditer(llm_answer)) + prefix_matches = list(_prefix_pattern(prefix_norm).finditer(normalized_answer)) + + anchored_matches = prefix_matches if prefix_matches else list(ANCHOR_PATTERN.finditer(normalized_answer)) if anchored_matches and answer_letter: + # Walk anchored matches from the end so later corrections like "Answer: B ... final answer: C" + # resolve to the last non-negated, non-multi-answer anchor. for idx in range(len(anchored_matches) - 1, -1, -1): match = anchored_matches[idx] - predicted = _norm_letter(match.group("opt")) + predicted = _match_token(match) if predicted is None: continue if match.group("neg") is not None: continue - if _contradicted_by_later_option(llm_answer, match): + if _match_is_contradicted(_sentence_match_context(normalized_answer, match)): continue - if _anchored_match_in_multi_answer_phrase(llm_answer, anchored_matches, idx): + if _anchored_match_in_multi_answer_phrase(normalized_answer, anchored_matches, idx): continue if _token_kind_matches_answer_letter(predicted, answer_letter): @@ -699,18 +733,21 @@ def _result( break # Strategy 4: Parse a terminal option line or short final clause near the end. + # Tail parsing is lower confidence than explicit anchors, so it only runs when no explicit + # option token has already been observed. if not explicit_choice_found and answer_letter and not multiple_option_led_sentences: - predicted = _extract_terminal_option_line(_last_nonempty_line(llm_answer)) + predicted = _extract_terminal_option_line(_last_nonempty_line(normalized_answer)) if predicted == answer_letter: return _result(True, "last_token", predicted, answer_letter, return_details) - predicted = _extract_short_final_clause_option(llm_answer) + predicted = _extract_short_final_clause_option(normalized_answer) if predicted == answer_letter: return _result(True, "last_token", predicted, answer_letter, return_details) # Strategy 5: Exact answer text match if there's no explicit choice found # Only search at beginning and end to avoid matching reasoning in the middle if accept_answer_text and answer_text and not explicit_choice_found: + # A multi-option-led payload is too ambiguous for answer-text fallback. if multiple_option_led_sentences: return _result(False, "none", None, None, return_details) @@ -718,7 +755,7 @@ def _result( answer_tokens = len(answer_text.split()) buffer_tokens = answer_tokens + 15 # Extra tokens for preamble like "The answer is:" - llm_tokens = llm_answer.split() + llm_tokens = normalized_answer.split() beginning_tokens = llm_tokens[:buffer_tokens] end_tokens = llm_tokens[-buffer_tokens:] if len(llm_tokens) > buffer_tokens else llm_tokens @@ -726,34 +763,29 @@ def _result( beginning_region = " ".join(beginning_tokens) end_region = " ".join(end_tokens) - # Make answer_text flexible for whitespace variations + # First try the normalized answer text directly, then a slightly looser punctuation-tolerant + # variant, but only in the beginning/end windows rather than the full reasoning trace. flexible_answer = re.escape(answer_text).replace(r"\ ", r"\s+") - pattern = re.compile(rf"(? Date: Sun, 19 Apr 2026 12:58:34 -0400 Subject: [PATCH 29/29] refactor again --- .../rewards/multiple_choice_accuracy.py | 1313 +++++++++-------- tests/test_mcq_accuracy.py | 103 +- 2 files changed, 771 insertions(+), 645 deletions(-) diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index e4a76b75..cdee4780 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -1,35 +1,37 @@ -""" -LLM multiple-choice question accuracy reward. +"""MCQ raw-text grading with tail-authoritative long-response handling.""" -Main use case: Handle models that either return the letter/number (preferred) -or return the entire answer text verbatim (fallback). - -Supports chain-of-thought by prioritizing anchored patterns like "answer is X" -before falling back to last token or text matching. Attempts to recognize -negations to avoid false positives (e.g., "the answer is not C"). -""" +from __future__ import annotations import re -import os -import sys -import time import unicodedata from dataclasses import dataclass from functools import lru_cache -from functools import wraps from typing import Optional +# Responses longer than this switch into tail long-mode behavior. +LONG_RESPONSE_THRESHOLD_CHARS = 4_000 +# Long-mode explicit-answer and answer-text scans are limited to this terminal slice. +TERMINAL_WINDOW_CHARS = 4_000 +# The looser last-token fallback only inspects this shorter tail inside the terminal slice. +STRONG_TAIL_WINDOW_CHARS = 2_000 +# Local ambiguity checks can look this far backward from a candidate. +LOCAL_CONTEXT_BEFORE_CHARS = 160 +# Local ambiguity checks can look this far forward from a candidate. +LOCAL_CONTEXT_AFTER_CHARS = 240 +# Tail-choice fallback is only allowed when the trailing segment is this short or shorter. +TAIL_CHOICE_MAX_WORDS = 16 + _UNICODE_PUNCT_TRANSLATIONS = str.maketrans( { - "\u00a0": " ", # no-break space - "\u2010": "-", # hyphen - "\u2011": "-", # non-breaking hyphen - "\u2012": "-", # figure dash - "\u2013": "-", # en dash - "\u2014": "-", # em dash - "\u2015": "-", # horizontal bar - "\u2212": "-", # minus sign + "\u00a0": " ", + "\u2010": "-", + "\u2011": "-", + "\u2012": "-", + "\u2013": "-", + "\u2014": "-", + "\u2015": "-", + "\u2212": "-", "\u2018": "'", "\u2019": "'", "\u201c": '"', @@ -38,122 +40,128 @@ ) _WHITESPACE_RE = re.compile(r"\s+") -_INNER_PUNCT_SPACING_RE = re.compile(r"\s*([()\[\]{}.,;:])\s*") -_TRAILING_TERMINAL_PUNCT_RE = re.compile(r"(?<=\w)[.!?,;:]+$") _LIKELY_TEX_RE = re.compile(r"\\[A-Za-z]+|\\[$\\()\\[\\]{}]|[$]") - - -def _mcq_perf_trace_enabled() -> bool: - """Return whether lightweight MCQ performance tracing is enabled.""" - return os.getenv("MEDARC_MCQ_PERF_TRACE", "").strip().lower() in {"1", "true", "yes", "on"} - - -def _mcq_perf_trace_min_seconds() -> float: - """Return the minimum elapsed time required before a helper emits a trace line.""" - raw = os.getenv("MEDARC_MCQ_PERF_TRACE_MIN_MS", "").strip() - if not raw: - return 0.0 - try: - return max(float(raw) / 1000.0, 0.0) - except ValueError: - return 0.0 - - -def _mcq_perf_trace_summary(args: tuple, kwargs: dict) -> str: - """Build a compact summary string for performance trace logging.""" - parts: list[str] = [] - for idx, value in enumerate(args[:3]): - if isinstance(value, str): - parts.append(f"arg{idx}_len={len(value)}") - elif isinstance(value, re.Match): - try: - start, end = value.span() - parts.append(f"arg{idx}_span={start}:{end}") - except Exception: - parts.append(f"arg{idx}=match") - elif isinstance(value, list): - parts.append(f"arg{idx}_len={len(value)}") - if "answer_letter" in kwargs and isinstance(kwargs["answer_letter"], str): - parts.append(f"answer_letter={kwargs['answer_letter']!r}") - return " ".join(parts) - - -def _trace_scan_perf(func): - """Wrap a helper so it can emit elapsed-time traces when tracing is enabled.""" - - @wraps(func) - def wrapper(*args, **kwargs): - """Execute the wrapped helper and optionally log its runtime.""" - if not _mcq_perf_trace_enabled(): - return func(*args, **kwargs) - - started = time.perf_counter() - try: - return func(*args, **kwargs) - finally: - elapsed = time.perf_counter() - started - if elapsed >= _mcq_perf_trace_min_seconds(): - summary = _mcq_perf_trace_summary(args, kwargs) - print( - f"[mcq-perf] {func.__name__} elapsed_ms={elapsed * 1000:.3f}" + (f" {summary}" if summary else ""), - file=sys.stderr, - flush=True, - ) - - return wrapper +_THINK_OPEN_RE = re.compile(r"<\s*think\b[^>]*>", re.IGNORECASE) +_THINK_CLOSE_RE = re.compile(r"", re.IGNORECASE) +_ANSWER_TAG_RE = re.compile(r"", re.IGNORECASE) + +# Any standalone option-like token. This is intentionally broad and gets filtered by +# local ambiguity checks before it can count as a chosen answer. +_OPTION_TOKEN_RE = re.compile(r"(?[A-Za-z]|\d{1,2})(?![\w+\-/])", re.IGNORECASE) +# Anchored cues that usually indicate the model is committing to a final answer. +_ANCHOR_RE = re.compile( + r"(?P, return everything after the last closing tag. - - If there is an unclosed with no closing tag, treat the output as missing - a final-answer region. - - Otherwise, return the full response. - """ - text = completion_text or "" +def _result( + is_correct: bool, + method: str, + predicted: Optional[str], + actual: Optional[str], + return_details: bool, +) -> bool | MCQAccuracyResult: + """Return either a bare boolean or the structured grading result.""" + if not return_details: + return is_correct + return MCQAccuracyResult( + is_correct=is_correct, + method=method, + matched_answer=predicted, + correct_answer=actual, + ) - closes = list(_THINK_CLOSE_RE.finditer(text)) - if closes: - return text[closes[-1].end() :].lstrip() +def _remove_think_tags(text: str) -> str: + """Drop internal reasoning and keep only the answer region after the last `` tag.""" + text = text or "" + last_close_end: Optional[int] = None + for match in _THINK_CLOSE_RE.finditer(text): + last_close_end = match.end() + if last_close_end is not None: + return text[last_close_end:].lstrip() if _THINK_OPEN_RE.search(text): return "" - return text -# Anchored patterns like "final answer: C" or "the answer is D" -ANCHOR_PATTERN = re.compile( - r"(?:\bfinal\s+answer\b|\banswer\b|\bans\b|\bchoice\b|\boption\b|\bselected\b|\bi\s+choose\b|\bi\s+pick\b|\btherefore\b|\bthus\b|\bso\b|\bconclusion\b|\bin\s+conclusion\b|\bmost\s+likely\b|\bbest[-\s]+supported\s+answer\b|)\s*" - r"[:\-–—]?\s*(?:is\s*)?(?Pnot\s+|isn['’]t\s+)?" - r"(?:[*_`~]+\s*)*" # allow markdown wrappers before the option - r"[\(\[\{<【]*\s*(?P[A-Za-z]|\d{1,2})\s*[\)\]\}>】]*" # option token, possibly wrapped - r"\s*[\)\.:]?\s*" # optional delimiter (e.g., 'B.' or 'B)') - r"(?:[*_`~]+\s*)*" # allow markdown wrappers after the option - r"(?![\w+\-/])", - re.IGNORECASE, -) - - -# Any letter/number token that looks like an option -TOKEN_PATTERN = re.compile( - r"(?】]*[\)\.:]?(?![\w+\-/])", - re.IGNORECASE, -) - -_LEADING_OPTION_PATTERN_BODY = ( - r"\s*(?:>\s*)?(?:(?:[-*+]\s+)|(?:\d{1,3}[.)]\s+))?\s*" - r"(?:[*_`~]+)?\s*\(?\s*([A-Za-z]|\d{1,2})\s*" - r"(?:[\)\.:]\s*\)?\s*(?:[*_`~]+)?\s*|(?=\s*(?:\(|[-–—])))(?!\w)" -) - -# Leading option token like "B. Answer text" or "C) ..." at the start of the response -LEADING_OPTION_PATTERN = re.compile(rf"^{_LEADING_OPTION_PATTERN_BODY}", re.IGNORECASE) -# Same pattern without ^ so we can match from sentence offsets without slicing text. -SENTENCE_LEADING_OPTION_PATTERN = re.compile(_LEADING_OPTION_PATTERN_BODY, re.IGNORECASE) - -# Standalone final-line option token like "C", "(C)", or "\boxed{C}". -TERMINAL_OPTION_LINE_PATTERN = re.compile( - r"^\s*(?:>\s*)?(?:(?:[-*+]\s+)|(?:\d{1,3}[.)]\s+))?\s*" - r"(?:\\boxed\{\s*)?(?:\s*)?" - r"[\(\[\{<【]*\s*[*_`~]*\s*(?P[A-Za-z]|\d{1,2})\s*[*_`~]*[\)\]\}>】]*" - r"\s*(?:\s*)?(?:\}\s*)?\s*[.!?]?\s*$", - re.IGNORECASE, -) +def _strip_outer_wrappers(text: str) -> str: + """Peel simple answer wrappers like markdown, quotes, brackets, or `` tags.""" + text = (text or "").strip() + changed = True + while text and changed: + changed = False + lowered = text.lower() + + # Strip explicit answer wrappers before more generic marker peeling. + if lowered[:8] == "" and lowered[-9:] == "": + text = text[8:-9].strip() + changed = True + continue -FINAL_CLAUSE_TERMINAL_OPTION_RE = re.compile( - r"(?[A-Za-z]|\d{1,2})\s*[*_`~]*[\)\]\}>】]*\s*[.!?]?\s*$", - re.IGNORECASE, -) + if lowered[:7] == "\\boxed{" and text.endswith("}"): + text = text[7:-1].strip() + changed = True + continue -# Negation/correction phrases that immediately precede an option or answer text -NEGATION_BEFORE_MATCH_PATTERN = re.compile( - r"(?:\bnot\b|\bisn['’]t\b|\baren['’]t\b|\bwasn['’]t\b|\bweren['’]t\b|\bincorrect\b|\bwrong\b|\bfalse\b|\bexcept(?:\s+for)?\b|\brather\s+than\b)(?:\W+\w+){0,3}\W*$", - re.IGNORECASE, -) + for marker in _OUTER_MARKERS: + if text.startswith(marker) and text.endswith(marker) and len(text) > len(marker) * 2: + text = text[len(marker) : -len(marker)].strip() + changed = True + break + if changed: + continue -# Negative-context phrases that indicate an option mention is NOT a selected answer -NEGATIVE_AFTER_OPTION_PATTERN = re.compile( - r"^\s*(?:is|are|was|were)\s+(?:incorrect|wrong|false|not\s+correct)\b|^\s*not\s+correct\b", - re.IGNORECASE, -) + for opener, closer in _OUTER_WRAPPER_PAIRS: + if text.startswith(opener) and text.endswith(closer) and len(text) > len(opener) + len(closer): + text = text[len(opener) : -len(closer)].strip() + changed = True + break -CONTRAST_PATTERN = re.compile( - r"\b(?:but|however|instead(?!\s+of\b))\b" - r".{0,40}?" - r"(? tuple[int, int]: + """Return the line boundaries that contain the span `[start, end)`.""" + line_start = text.rfind("\n", 0, start) + 1 + line_end = text.find("\n", end) + if line_end == -1: + line_end = len(text) + return line_start, line_end + + +def _previous_nonempty_line_start(text: str, line_start: int) -> int: + """Walk backward to the previous non-empty line start, if one exists.""" + cursor = line_start + while cursor > 0: + prev_end = cursor - 1 + prev_start = text.rfind("\n", 0, prev_end) + 1 + if text[prev_start:prev_end].strip(): + return prev_start + cursor = prev_start + return line_start + + +def _next_nonempty_line_end(text: str, line_end: int) -> int: + """Walk forward to the next non-empty line end, if one exists.""" + cursor = line_end + while cursor < len(text): + next_start = cursor + 1 if cursor < len(text) and text[cursor] == "\n" else cursor + next_end = text.find("\n", next_start) + if next_end == -1: + next_end = len(text) + if text[next_start:next_end].strip(): + return next_end + if next_end == len(text): + break + cursor = next_end + return line_end -_MULTIPLE_OPTION_LED_SCAN_MAX_CHARS = 10_000 +def _local_context(text: str, start: int, end: int) -> tuple[str, int, int]: + """Return a bounded local region around a candidate plus its relative offsets.""" + line_start, line_end = _line_bounds(text, start, end) + context_start = _previous_nonempty_line_start(text, line_start) + context_end = _next_nonempty_line_end(text, line_end) + # Prefer whole nearby lines, then cap to fixed windows so long CoTs stay cheap. + context_start = max(context_start, start - LOCAL_CONTEXT_BEFORE_CHARS) + context_end = min(context_end, end + LOCAL_CONTEXT_AFTER_CHARS) + return text[context_start:context_end], start - context_start, end - context_start -@_trace_scan_perf -def _get_sentence_containing_match(text: str, match: re.Match) -> tuple[int, int, int, int]: - """Return (sentence_start, sentence_end, match_start, match_end) in the original text.""" - if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex: - match_start, match_end = match.span("opt") - else: - try: - match_start, match_end = match.span(1) - except Exception: - match_start, match_end = match.span() - boundaries_before = [m.end() for m in SENTENCE_BOUNDARY.finditer(text[:match_start])] - boundaries_after = [m.start() for m in SENTENCE_BOUNDARY.finditer(text[match_end:])] +def _candidate_is_negated(context: str, rel_start: int, rel_end: int) -> bool: + """Detect local negation patterns that should invalidate a candidate option.""" + prefix = context[max(0, rel_start - 48) : rel_start] + suffix = context[rel_end : min(len(context), rel_end + 40)] + prefix = normalize_for_match(prefix).rstrip(" ([{<【") + suffix = normalize_for_match(suffix) - sentence_start = boundaries_before[-1] if boundaries_before else 0 - sentence_end = match_end + boundaries_after[0] if boundaries_after else len(text) - return sentence_start, sentence_end, match_start, match_end + if _NEGATION_PREFIX_RE.search(prefix): + return True + if prefix.endswith("rather than") or prefix.endswith("except"): + return True + if "wrong diagnosis is" in prefix[-32:] or "incorrect diagnosis is" in prefix[-32:]: + return True + for prefix_text in _AFTER_REJECTION_PREFIXES: + if suffix.startswith(prefix_text): + return True -@dataclass -class _SentenceMatchContext: - """Sentence-local context around a matched option or answer-text span.""" + return False - prefix: str - suffix: str - token: Optional[str] +def _looks_like_option_connector(between_norm: str) -> bool: + """Return True when the text between two options is just list/connector glue.""" + between_norm = between_norm.strip() + if not between_norm: + return True -@_trace_scan_perf -def _match_token(match: re.Match) -> Optional[str]: - """Extract and normalize the option token captured by a regex match.""" - if getattr(match.re, "groupindex", None) and "opt" in match.re.groupindex: - return _norm_letter(match.group("opt")) - try: - return _norm_letter(match.group(1)) - except Exception: - return None + between_norm = re.sub(r"\b(?:option|choice)\b", " ", between_norm).strip() + stripped = between_norm.strip(",;:./&+()[]{}<>-\\ ") + if not stripped: + return True + return stripped in _COMPACT_OPTION_CONNECTORS + + +def _is_harmless_option_match(text: str, match: re.Match[str]) -> bool: + """Ignore stray single-letter matches like pronoun `I` or apostrophe fragments.""" + token = match.group("opt").casefold() + start = match.start("opt") + end = match.end("opt") + + if token == "i": + before = text[start - 1] if start > 0 else " " + after = text[end] if end < len(text) else " " + if before in {" ", "\n", "\t", ",", ";", ".", "(", "["} and after in { + " ", + "\n", + "\t", + ",", + ";", + ".", + "!", + "?", + ")", + "]", + }: + return True + if token == "i" and start == 0: + return True + if start > 0 and text[start - 1] in {"'", "’"}: + return True + if end < len(text) and text[end] in {"'", "’"}: + return True + return False -@_trace_scan_perf -def _sentence_match_context(text: str, match: re.Match) -> _SentenceMatchContext: - """Return the same-sentence prefix, suffix, and normalized token for a regex match.""" - sentence_start, sentence_end, match_start, match_end = _get_sentence_containing_match(text, match) - return _SentenceMatchContext( - prefix=text[sentence_start:match_start], - suffix=text[match_end:sentence_end], - token=_match_token(match), - ) +def _candidate_has_local_competing_option( + context: str, rel_start: int, rel_end: int, token: str, answer_letter: str +) -> bool: + """Reject candidates that are locally entangled with another option token.""" + selected_span = (rel_start, rel_end) + for match in _OPTION_TOKEN_RE.finditer(context): + if _is_harmless_option_match(context, match): + continue + other = _norm_option(match.group("opt")) + if other is None or not _option_kind_matches(other, answer_letter) or other == token: + continue -@_trace_scan_perf -def _match_is_negated(context: _SentenceMatchContext) -> bool: - """Return True when a negation phrase appears before the match in the same sentence.""" - return bool(NEGATION_BEFORE_MATCH_PATTERN.search(context.prefix)) + if match.end() <= selected_span[0]: + between = context[match.end() : selected_span[0]] + elif selected_span[1] <= match.start(): + between = context[selected_span[1] : match.start()] + else: + continue + between_norm = normalize_for_match(between) + if len(between_norm) > 24: + continue + # Treat only very short glue like commas, "and", or "or" as true ambiguity. + if _looks_like_option_connector(between_norm): + return True -@_trace_scan_perf -def _match_has_negative_suffix(context: _SentenceMatchContext) -> bool: - """Return True when the match is immediately followed by rejecting language.""" - return bool(NEGATIVE_AFTER_OPTION_PATTERN.search(context.suffix)) + return False -@_trace_scan_perf -def _match_is_contradicted(context: _SentenceMatchContext) -> bool: - """Return True when a later contrast in the sentence points to a different option.""" - if context.token is None: - return False - later = CONTRAST_PATTERN.search(context.suffix) - if not later: +def _candidate_is_contradicted(context: str, rel_end: int, token: str, answer_letter: str) -> bool: + """Reject candidates that are immediately revised to a different option.""" + suffix = normalize_for_match(context[rel_end : min(len(context), rel_end + 80)]) + if not any(hint in suffix for hint in _CONTRAST_HINTS): return False - contrasted = _norm_letter(later.group(1)) - return contrasted is not None and contrasted != context.token - - -@_trace_scan_perf -def _tail_region(text: str, max_tokens: int = 64) -> str: - """Return a short tail slice (last sentence/line) to reduce option-token noise.""" - boundaries = list(SENTENCE_BOUNDARY.finditer(text)) - tail = text[boundaries[-1].end() :] if boundaries else text - tail = tail.strip() - - if not tail: - for line in reversed(text.splitlines()): - if line.strip(): - tail = line.strip() - break - - tokens = tail.split() - if len(tokens) > max_tokens: - tail = " ".join(tokens[-max_tokens:]) - return tail - - -@_trace_scan_perf -def _last_nonempty_line(text: str) -> str: - """Return the last non-empty line, if any.""" - for line in reversed((text or "").splitlines()): - if line.strip(): - return line.strip() - return "" + for match in _OPTION_TOKEN_RE.finditer(suffix): + other = _norm_option(match.group("opt")) + if other is None or not _option_kind_matches(other, answer_letter): + continue + if other != token: + return True + return False -@_trace_scan_perf -def _option_candidate_invalid(text: str, match: re.Match) -> bool: - """Return True if an option-like match is negated or contradicted in local context.""" - context = _sentence_match_context(text, match) - return _match_is_negated(context) or _match_has_negative_suffix(context) or _match_is_contradicted(context) +def _candidate_is_valid(text: str, candidate: _Candidate, answer_letter: str) -> bool: + """Apply the local negation, ambiguity, and contradiction filters to a candidate.""" + context, rel_start, rel_end = _local_context(text, candidate.start, candidate.end) + return not ( + _candidate_is_negated(context, rel_start, rel_end) + or _candidate_has_local_competing_option(context, rel_start, rel_end, candidate.token, answer_letter) + or _candidate_is_contradicted(context, rel_end, candidate.token, answer_letter) + ) -@_trace_scan_perf -def _is_harmless_prefix_option_token(prefix: str, prior_match: re.Match) -> bool: - """Ignore harmless single-letter artifacts before a terminal final-clause answer. - This is limited to natural-language cases like: - - leading pronoun "I" - - article "a" before a normal word - - trailing "'s" in contractions like "it's" - """ - raw = prior_match.group(1).casefold() - if raw == "i" and prior_match.start() == 0: - return True - if raw == "a" and re.match(r"\s+[a-z]{2,}\b", prefix[prior_match.end() :]): - return True - if raw == "s" and prior_match.start() > 0 and prefix[prior_match.start() - 1] in {"'", "’"}: - return True - return False +def _extract_exact_option(text: str, answer_letter: str) -> Optional[str]: + """Accept responses that are exactly one standalone option token.""" + stripped = _strip_outer_wrappers(text) + match = _EXACT_OPTION_RE.fullmatch(stripped) + if not match: + return None + predicted = _norm_option(match.group("opt")) + if predicted is None or not _option_kind_matches(predicted, answer_letter): + return None + return predicted -@_trace_scan_perf -def _has_connector_between(text: str, max_words: int = 5) -> bool: - """Return True when a short span looks like connector text between option tokens.""" - text = text.strip() - return bool(text) and len(text.split()) <= max_words and bool(_MULTI_OPTION_CONNECTOR_RE.search(text)) +def _extract_exact_answer_text(text: str, answer_text: str) -> Optional[str]: + """Accept responses that are exactly the answer text after wrapper normalization.""" + if not answer_text: + return None + stripped = _strip_outer_wrappers(text) + if normalize_for_answer_text_match(stripped) != answer_text: + return None + return answer_text -@_trace_scan_perf -def _normalized_option_matches(text: str) -> list[tuple[str, re.Match]]: - """Return normalized option tokens paired with their regex matches in order.""" - matches: list[tuple[str, re.Match]] = [] - for match in TOKEN_PATTERN.finditer(text): - token = _norm_letter(match.group(1)) - if token is not None: - matches.append((token, match)) - return matches +def _extract_exact_option_plus_text(text: str, answer_letter: str, answer_text: str) -> Optional[str]: + """Accept short option-led answers like `B. Correct answer text`.""" + stripped = _strip_outer_wrappers(text) + match = _LEADING_OPTION_RE.fullmatch(stripped) + if not match: + return None + predicted = _norm_option(match.group("opt")) + if predicted is None or not _option_kind_matches(predicted, answer_letter): + return None + if normalize_for_answer_text_match(match.group("rest")) != answer_text: + return None + return predicted @lru_cache(maxsize=64) -def _prefix_pattern(prefix_norm: str) -> re.Pattern: - """Compile and cache the anchored prefix regex for a normalized answer prefix.""" +def _prefix_pattern(prefix_norm: str) -> re.Pattern[str]: + """Compile the caller-provided anchor prefix into the same option-capture shape.""" flexible_prefix = re.escape(prefix_norm).replace(r"\ ", r"\s+") return re.compile( - rf"{flexible_prefix}\s*[:\-–—]?\s*(?:is\s*)?(?Pnot\s+|isn['’]t\s+)?\(?\s*(?P[A-Za-z]|\d{{1,2}})\s*[\)\.:]?(?![\w+\-/])", + rf"(?:^|(?not\s+|isn't\s+|isnt\s+)?" + rf"(?:(?:option|choice)\s+)?" + rf"(?:[*_`~]+\s*)*(?:\\boxed\{{\s*)?[\(\[\{{<【]*\s*(?P[A-Za-z]|\d{{1,2}})\s*" + rf"[\)\]\}}>】]*\s*(?:\}}\s*)?(?:[*_`~]+\s*)?(?![\w+\-/])", re.IGNORECASE, ) -@_trace_scan_perf -def _extract_standalone_terminal_option(region: str) -> Optional[str]: - """Extract a standalone terminal token like ``C`` or ``(C)`` from a region.""" - match = TERMINAL_OPTION_LINE_PATTERN.fullmatch(region) +def _latest_explicit_candidate(text: str, answer_letter: str, prefix: Optional[str]) -> Optional[_Candidate]: + """Return the latest valid anchored candidate, preferring a caller-specified prefix.""" + if prefix: + prefix_norm = normalize_for_match(prefix) + if prefix_norm: + saw_prefix_match = False + latest_valid: Optional[_Candidate] = None + for match in _prefix_pattern(prefix_norm).finditer(text): + if not _prefix_match_has_standalone_start(text, match.start()): + continue + saw_prefix_match = True + if match.groupdict().get("neg"): + continue + token = _norm_option(match.group("opt")) + if token is None or not _option_kind_matches(token, answer_letter): + continue + candidate = _Candidate( + token=token, + start=match.start("opt"), + end=match.end("opt"), + method="anchored_token", + ) + if _candidate_is_valid(text, candidate, answer_letter): + latest_valid = candidate + # If the caller supplied an explicit prefix, do not fall back to generic anchors + # once that prefix appears at all. + if saw_prefix_match: + return latest_valid + + latest_valid = None + for match in _ANCHOR_RE.finditer(text): + if match.groupdict().get("neg"): + continue + token = _norm_option(match.group("opt")) + if token is None or not _option_kind_matches(token, answer_letter): + continue + candidate = _Candidate(token=token, start=match.start("opt"), end=match.end("opt"), method="anchored_token") + if _candidate_is_valid(text, candidate, answer_letter): + latest_valid = candidate + + return latest_valid + + +def _prefix_match_has_standalone_start(text: str, start: int) -> bool: + """Require prefix matches to start at a token boundary rather than inside a word.""" + cursor = start - 1 + while cursor >= 0 and text[cursor].isspace(): + cursor -= 1 + return cursor < 0 or not text[cursor].isalnum() + + +def _leading_option_candidate(text: str, answer_letter: str, answer_text: str) -> Optional[_Candidate]: + """Parse a short option-led answer that starts with the selected option token.""" + source = text + offset = 0 + if "\n" in text: + # For multi-line responses, only trust the final non-empty line as a leading-option answer. + source = _last_nonempty_line(text) + if not source: + return None + offset = text.rfind(source) + match = _LEADING_OPTION_RE.match(source) + else: + match = _LEADING_OPTION_RE.match(source) + if not match: + source = _last_nonempty_line(text) + if not source: + return None + offset = text.rfind(source) + match = _LEADING_OPTION_RE.match(source) if not match: return None - predicted = _norm_letter(match.group("opt")) - if predicted is None: + token = _norm_option(match.group("opt")) + if token is None or not _option_kind_matches(token, answer_letter): return None - tokens = list(TOKEN_PATTERN.finditer(region)) - if len(tokens) != 1 or _option_candidate_invalid(region, tokens[0]): + # Plain prose like "I think B works" should not be treated as an option-led format. + separator = source[match.end("opt") : match.start("rest")] + rest = match.group("rest").lstrip() + if not any(char in separator for char in ")]}>】.:-*_`~\\") and not rest.startswith( + ("(", "[", "{", "<", "【", '"', "'", "\\boxed{") + ): return None - return predicted - -@_trace_scan_perf -def _extract_leading_terminal_option(region: str) -> Optional[str]: - """Extract a leading-option form like ``C. text`` from a region.""" - leading_match = LEADING_OPTION_PATTERN.match(region) - if not leading_match: + # Reject enumerated multi-option payloads like "A. ...\nD. ...". + if _contains_multiple_option_led_sentences(text, answer_letter): return None - predicted = _norm_letter(leading_match.group(1)) - if predicted is None or _option_candidate_invalid(region, leading_match): - return None - return predicted - - -@_trace_scan_perf -def _extract_final_clause_terminal_option(region: str) -> Optional[str]: - """Extract a final-clause token like ``I think it's C`` from a short region.""" - match = FINAL_CLAUSE_TERMINAL_OPTION_RE.search(region) - if not match or _option_candidate_invalid(region, match): + candidate = _Candidate( + token=token, + start=offset + match.start("opt"), + end=offset + match.end("opt"), + method="anchored_token", + ) + if not _candidate_is_valid(text, candidate, answer_letter): return None + return candidate - prefix = region[: match.start()] - for _token, prior_match in _normalized_option_matches(prefix): - if _is_harmless_prefix_option_token(prefix, prior_match): - continue - return None - return _norm_letter(match.group("opt")) +def _last_nonempty_line(text: str) -> str: + """Return the final non-empty line from the response, if any.""" + for line in reversed((text or "").splitlines()): + if line.strip(): + return line.strip() + return "" -@_trace_scan_perf -def _extract_terminal_option_line(line: str) -> Optional[str]: - """Extract a standalone option token from the last line.""" - if not line or _is_compact_multi_option_list(line): - return None - predicted = _extract_standalone_terminal_option(line) - if predicted is not None: - return predicted - return _extract_leading_terminal_option(line) +def _is_compact_multi_option_list(text: str, answer_letter: str) -> bool: + """Detect short tails like `A, C` or `B and D` that should fail closed.""" + matches = [ + match + for match in _OPTION_TOKEN_RE.finditer(text) + if _option_kind_matches(_norm_option(match.group("opt")), answer_letter) + ] + if len(matches) < 2: + return False + if len(text.strip()) > 40: + return False -@_trace_scan_perf -def _extract_short_final_clause_option(text: str, max_words: int = 12) -> Optional[str]: - """Extract a terminal option token from a short final clause like 'I think it's C'.""" - clause = _tail_region(text).strip() - if not clause or len(clause.split()) > max_words or _is_compact_multi_option_list(clause): - return None - return _extract_final_clause_terminal_option(clause) + for idx in range(len(matches) - 1): + between = normalize_for_match(text[matches[idx].end() : matches[idx + 1].start()]) + if not _looks_like_option_connector(between): + return False + return True -@_trace_scan_perf -def _anchored_match_in_multi_answer_phrase(text: str, matches: list[re.Match], idx: int) -> bool: - """Return True if anchored match *idx* is part of a local multi-answer phrase.""" - match = matches[idx] - current = _match_token(match) - if current is None: - return False - if idx > 0: - between = text[matches[idx - 1].end() : match.start()] - if _has_connector_between(between): - return True +def _tail_choice_text(text: str) -> str: + """Extract the short trailing segment that feeds the tail-choice fallback.""" + region = (text or "").strip() + if not region: + return "" - if idx < len(matches) - 1: - between = text[match.end() : matches[idx + 1].start()] - if _has_connector_between(between): - return True + parts = re.split(r"\n+|[.!?]\s+", region) + tail_choice = parts[-1].strip() if parts else region + if not tail_choice: + tail_choice = _last_nonempty_line(region) + # Long trailing prose is too ambiguous for the tail-choice heuristic. + if len(tail_choice.split()) > TAIL_CHOICE_MAX_WORDS: + return "" + return tail_choice - sentence_start, sentence_end, match_start, match_end = _get_sentence_containing_match(text, match) - sentence = text[sentence_start:sentence_end] - local_match_start = match_start - sentence_start - local_match_end = match_end - sentence_start - for token, token_match in _normalized_option_matches(sentence): - if token == current: +def _contains_multiple_option_led_sentences(text: str, answer_letter: str) -> bool: + """Detect multi-line or multi-sentence payloads that enumerate different option labels.""" + distinct: set[str] = set() + # Newline-separated enumerations are common in model outputs, so keep lines intact in that case. + chunks = (text or "").splitlines() if "\n" in (text or "") else re.split(r"[.!?]\s+", text or "") + for chunk in chunks: + match = _SENTENCE_OPTION_START_RE.match(chunk.strip()) + if not match: continue - between = "" - if token_match.end() <= local_match_start: - between = sentence[token_match.end() : local_match_start] - elif local_match_end <= token_match.start(): - between = sentence[local_match_end : token_match.start()] - if not between: + token = _norm_option(match.group("opt")) + if token is None or not _option_kind_matches(token, answer_letter): continue - if _has_connector_between(between): + distinct.add(token) + if len(distinct) > 1: return True - return False -@_trace_scan_perf -def _is_compact_multi_option_list(text: str) -> bool: - """Return True for short multi-option tails like 'A, C' or '> **A** and C'.""" - text = (text or "").strip() - if len(list(TOKEN_PATTERN.finditer(text))) < 2: - return False +def _tail_candidate(region: str, answer_letter: str) -> Optional[_Candidate]: + """Extract a last-line or tail-choice option token from the terminal region.""" + line = _last_nonempty_line(region) + # Prefer an exact last-line option like "(C)" before falling back to a looser tail-choice scan. + if line and not _is_compact_multi_option_list(line, answer_letter): + match = _TERMINAL_OPTION_LINE_RE.fullmatch(line) + if match: + token = _norm_option(match.group("opt")) + if token is not None and _option_kind_matches(token, answer_letter): + line_offset = region.rfind(line) + start = line_offset + match.start("opt") + end = line_offset + match.end("opt") + candidate = _Candidate(token=token, start=start, end=end, method="last_token") + if _candidate_is_valid(region, candidate, answer_letter): + return candidate + + tail_choice = _tail_choice_text(region) + if not tail_choice or _is_compact_multi_option_list(tail_choice, answer_letter): + return None - residue = TOKEN_PATTERN.sub(" ", text) - residue = COMPACT_MULTI_OPTION_GLUE_PATTERN.sub(" ", residue) - residue = re.sub(r"[\s\[\]\(\)\{\}<>*_`~.!?]+", " ", residue) - return residue.strip() == "" + match = _TAIL_CHOICE_OPTION_RE.search(tail_choice) + if not match: + return None + token = _norm_option(match.group("opt")) + if token is None or not _option_kind_matches(token, answer_letter): + return None -@_trace_scan_perf -def _contains_multiple_option_led_sentences(text: str, answer_letter: str) -> bool: - """Return True when different sentences/lines each start with different option labels. + tail_choice_offset = region.rfind(tail_choice) + candidate = _Candidate( + token=token, + start=tail_choice_offset + match.start("opt"), + end=tail_choice_offset + match.end("opt"), + method="last_token", + ) + if not _candidate_is_valid(region, candidate, answer_letter): + return None + return candidate - This catches payloads like "(A) ... . (D) ..." or "A. ...\\nD. ...", which should - not be accepted for a single-answer MCQ unless a later anchored final answer overrides - them. - """ - text = text or "" - distinct: set[str] = set() - starts = [0] - starts.extend(match.end() for match in SENTENCE_BOUNDARY.finditer(text)) - for start in starts: - match = SENTENCE_LEADING_OPTION_PATTERN.match(text, pos=start) - if not match: - continue - token = _norm_letter(match.group(1)) - if token is None or not _token_kind_matches_answer_letter(token, answer_letter): - continue - distinct.add(token) - if len(distinct) > 1: - return True - return False +def _answer_text_pattern(answer_text: str) -> re.Pattern[str]: + """Compile a whitespace-tolerant exact-answer-text regex.""" + flexible_answer = re.escape(answer_text).replace(r"\ ", r"\s+") + return re.compile(rf"(? Optional[str]: + """Return the latest valid exact answer-text match inside a search region.""" + region_struct = normalize_for_structure(region) + if not answer_text or not region_struct: + return None + + latest_valid: Optional[str] = None + for match in _answer_text_pattern(answer_text).finditer(region_struct): + if _answer_text_match_is_valid(region_struct, match.start(), match.end(), answer_letter): + latest_valid = answer_text + + return latest_valid + + +def _answer_text_match_is_valid(region_struct: str, start: int, end: int, answer_letter: str) -> bool: + """Reject answer-text matches that sit inside obvious negation or option-list structure.""" + prefix = region_struct[max(0, start - 64) : start].rstrip() + if _NEGATION_PREFIX_RE.search(prefix): + return False + if prefix.endswith("rather than") or prefix.endswith("except"): + return False + if "wrong diagnosis is" in prefix[-40:] or "incorrect diagnosis is" in prefix[-40:]: + return False + + line_start, line_end = _line_bounds(region_struct, start, end) + raw_line = region_struct[line_start:line_end] + rel_start = start - line_start + rel_end = end - line_start + leading_match = _LEADING_OPTION_RE.match(raw_line.strip()) + if leading_match is not None: + token = _norm_option(leading_match.group("opt")) + if token is not None and _option_kind_matches(token, answer_letter): + return False + + # Bulleted or numbered option-analysis lines often mention distractor answer text verbatim. + if _BULLET_OR_LIST_LINE_RE.match(raw_line): + before_match = raw_line[:rel_start] + after_match = raw_line[rel_end:].lstrip(" *_`~)]}>】") + if ":" in before_match or any(marker in before_match for marker in (" - ", " – ", " — ")): + return False + if after_match.startswith((":", "-", "–", "—")): + return False + + return True + + +def _answer_text_regions(text: str, answer_text: str, is_long: bool) -> list[str]: + """Choose the bounded regions where answer-text fallback is allowed to search.""" + if is_long: + # In long mode, the tail is authoritative because earlier reasoning is frequently revised. + return [text[-TERMINAL_WINDOW_CHARS:]] + + if len(text) <= 800: + return [text] + + # For shorter responses, search bounded tail/head windows but align them to line + # boundaries so local validation still sees bullet markers and nearby list structure. + window = max(600, min(1_400, len(answer_text) + 400)) + line_slack = 200 + + # Only stretch to the next line break when it is still close to the window edge. + head_end = text.find("\n", window, min(len(text), window + line_slack + 1)) + if head_end == -1: + head_end = min(len(text), window) + head = text[:head_end] + + tail_start = max(0, len(text) - window) + aligned_tail_start = text.rfind("\n", max(0, tail_start - line_slack), tail_start) + if aligned_tail_start != -1: + tail_start = aligned_tail_start + 1 + tail = text[tail_start:] + + if head == tail: + return [head] + return [tail, head] -@_trace_scan_perf def multiple_choice_accuracy( llm_answer: str, answer_letter: str, @@ -614,178 +771,112 @@ def multiple_choice_accuracy( strip_tex: bool = True, return_details: bool = False, ) -> bool | MCQAccuracyResult: - """ - Grade a multiple-choice answer with layered strategies: - - 1. Direct answer: Response is just the option letter/number - 2. Anchored token: Use the last occurrence of a provided prefix, otherwise general anchor phrases - 3. Last token: Parse a terminal option line or short final clause near the end - 4. Answer text: Match the full answer text (if long enough) - - Args: - llm_answer: The model's response text - answer_letter: The correct answer letter/number (e.g., "C" or "3") - answer_text: The full correct answer text - prefix: Optional prefix to strip (e.g., "The answer is: ") - accept_answer_text: Whether to fall back to text matching - strip_tex: Whether to strip LaTeX formatting - return_details: If True, return MCQAccuracyResult dataclass instead of bool - - Returns: - bool (if return_details=False) or MCQAccuracyResult (if return_details=True) - """ - - def _result( - is_correct: bool, method: str, predicted: str | None, actual: str | None, return_details: bool - ) -> bool | MCQAccuracyResult: - """Helper to format return value.""" - if not return_details: - return is_correct - return MCQAccuracyResult( - is_correct=is_correct, - method=method, - matched_answer=predicted, - correct_answer=actual, - ) + """Grade an MCQ answer using short-mode scans and tail-authoritative long-mode scans.""" if not llm_answer: return _result(False, "none", None, None, return_details) - # Normalize the response - llm_answer = _remove_think_tags(llm_answer) - + # Strip reasoning wrappers and normalize before any extraction logic runs. + processed_answer = _remove_think_tags(llm_answer) + processed_answer = _ANSWER_TAG_RE.sub(" ", processed_answer) if strip_tex: - llm_answer = _strip_tex(llm_answer) - answer_text = _strip_tex(answer_text) + processed_answer = _strip_tex(processed_answer) + answer_text = _strip_tex(answer_text or "") - # Keep two views of the response: - # - structural_text preserves original spacing for sentence/line-sensitive heuristics - # - normalized_answer casefolds and normalizes punctuation for anchor/text matching - structural_text = llm_answer - normalized_answer = normalize_for_structure(llm_answer) + structural_text = normalize_for_structure(processed_answer).strip() + answer_letter = _norm_option(answer_letter) + answer_text = normalize_for_answer_text_match(answer_text or "") + exact_answer_text_allowed = accept_answer_text and bool(answer_text) + answer_text_fallback_allowed = accept_answer_text and _answer_text_supports_fallback(answer_text) - answer_letter = _norm_letter(answer_letter) - answer_text = normalize_for_match(answer_text or "") if answer_letter is None: raise ValueError(f"Invalid answer_letter '{answer_letter=}'. Must be a single letter or digit string.") - # Once we see any explicit option selection of the right token kind, we stop lower-confidence - # fallbacks from overriding it with a tail token or answer-text mention. - explicit_choice_found = False - - # Strategy 1: Only answer letter anywhere (without anchoring) - if answer_letter == _norm_letter(normalized_answer): - return _result(True, "direct_answer", normalized_answer, answer_letter, return_details) - - # A response that begins like "B. ..." gets special handling: we may disable both the leading - # shortcut and later tail/text fallbacks if it actually looks like multiple labeled options. - leading_match = LEADING_OPTION_PATTERN.match(structural_text) - multiple_option_led_sentences = False - - if leading_match: - # Only pay for the additional answer scan when the payload actually starts with a leading - # option pattern; otherwise we leave this guard disabled for the cheaper later paths. - multiple_option_led_sentences = len( - structural_text - ) <= _MULTIPLE_OPTION_LED_SCAN_MAX_CHARS and _contains_multiple_option_led_sentences( - structural_text, answer_letter + if not structural_text: + return _result(False, "none", None, None, return_details) + + # Strategy 1: exact standalone option, e.g. "C" or "(2)". + direct_option = _extract_exact_option(structural_text, answer_letter) + if direct_option == answer_letter: + return _result( + True, + "direct_answer", + direct_option.casefold(), + answer_letter, + return_details, + ) + + # Strategy 2: exact answer text after wrapper normalization. This remains allowed + # even for numeric answer text, so parsed outputs like "\boxed{4}" can still match + # the gold content answer text before a mismatched standalone numeral fails closed. + if exact_answer_text_allowed: + direct_text = _extract_exact_answer_text(structural_text, answer_text) + if direct_text is not None: + return _result(True, "answer_text", direct_text, answer_text, return_details) + + if direct_option is not None: + return _result( + False, + "direct_answer", + direct_option.casefold(), + answer_letter, + return_details, + ) + + # Strategy 3: short option-led answer that also includes the answer text. + option_plus_text = _extract_exact_option_plus_text(structural_text, answer_letter, answer_text) + if option_plus_text is not None: + return _result( + option_plus_text == answer_letter, + "anchored_token", + option_plus_text, + answer_letter, + return_details, + ) + + is_long = len(structural_text) > LONG_RESPONSE_THRESHOLD_CHARS + terminal_region = structural_text[-TERMINAL_WINDOW_CHARS:] if is_long else structural_text + strong_tail_region = terminal_region[-STRONG_TAIL_WINDOW_CHARS:] if is_long else structural_text + + # Strategy 4: anchored commitments like "final answer: C". + explicit_candidate = _latest_explicit_candidate(terminal_region, answer_letter, prefix) + if explicit_candidate is not None: + return _result( + explicit_candidate.token == answer_letter, + explicit_candidate.method, + explicit_candidate.token, + answer_letter, + return_details, + ) + + # Strategy 5: leading-option forms are only trusted in short responses. + if not is_long: + leading_candidate = _leading_option_candidate(structural_text, answer_letter, answer_text) + if leading_candidate is not None: + return _result( + leading_candidate.token == answer_letter, + leading_candidate.method, + leading_candidate.token, + answer_letter, + return_details, + ) + + # Strategy 6: tail-only token fallback from the last line or short tail choice text. + tail_candidate = _tail_candidate(strong_tail_region, answer_letter) + if tail_candidate is not None: + return _result( + tail_candidate.token == answer_letter, + tail_candidate.method, + tail_candidate.token, + answer_letter, + return_details, ) - # If the response looks like multiple labeled answer statements, do not treat the first - # label as the chosen answer. - if multiple_option_led_sentences: - leading_match = None - - # Strategy 2: Accept leading option token like "B. answer ..." - if leading_match and answer_letter: - predicted = _norm_letter(leading_match.group(1)) - if _token_kind_matches_answer_letter(predicted, answer_letter): - explicit_choice_found = True - if predicted == answer_letter: - return _result(True, "anchored_token", predicted, answer_letter, return_details) - - # Strategy 3: Anchored token (prefix matches first, fallback to generic anchors) - prefix_matches = [] - if prefix: - prefix_norm = normalize_for_structure(prefix).strip() - if prefix_norm: - prefix_matches = list(_prefix_pattern(prefix_norm).finditer(normalized_answer)) - - anchored_matches = prefix_matches if prefix_matches else list(ANCHOR_PATTERN.finditer(normalized_answer)) - if anchored_matches and answer_letter: - # Walk anchored matches from the end so later corrections like "Answer: B ... final answer: C" - # resolve to the last non-negated, non-multi-answer anchor. - for idx in range(len(anchored_matches) - 1, -1, -1): - match = anchored_matches[idx] - predicted = _match_token(match) - if predicted is None: - continue - if match.group("neg") is not None: - continue - if _match_is_contradicted(_sentence_match_context(normalized_answer, match)): - continue - if _anchored_match_in_multi_answer_phrase(normalized_answer, anchored_matches, idx): - continue - - if _token_kind_matches_answer_letter(predicted, answer_letter): - explicit_choice_found = True - if predicted == answer_letter: - return _result(True, "anchored_token", predicted, answer_letter, return_details) - break - # Strategy 4: Parse a terminal option line or short final clause near the end. - # Tail parsing is lower confidence than explicit anchors, so it only runs when no explicit - # option token has already been observed. - if not explicit_choice_found and answer_letter and not multiple_option_led_sentences: - predicted = _extract_terminal_option_line(_last_nonempty_line(normalized_answer)) - if predicted == answer_letter: - return _result(True, "last_token", predicted, answer_letter, return_details) - - predicted = _extract_short_final_clause_option(normalized_answer) - if predicted == answer_letter: - return _result(True, "last_token", predicted, answer_letter, return_details) - - # Strategy 5: Exact answer text match if there's no explicit choice found - # Only search at beginning and end to avoid matching reasoning in the middle - if accept_answer_text and answer_text and not explicit_choice_found: - # A multi-option-led payload is too ambiguous for answer-text fallback. - if multiple_option_led_sentences: - return _result(False, "none", None, None, return_details) - - # Calculate search regions based on token count - answer_tokens = len(answer_text.split()) - buffer_tokens = answer_tokens + 15 # Extra tokens for preamble like "The answer is:" - - llm_tokens = normalized_answer.split() - - beginning_tokens = llm_tokens[:buffer_tokens] - end_tokens = llm_tokens[-buffer_tokens:] if len(llm_tokens) > buffer_tokens else llm_tokens - - beginning_region = " ".join(beginning_tokens) - end_region = " ".join(end_tokens) - - # First try the normalized answer text directly, then a slightly looser punctuation-tolerant - # variant, but only in the beginning/end windows rather than the full reasoning trace. - flexible_answer = re.escape(answer_text).replace(r"\ ", r"\s+") - normed_answer_text = normalize_for_answer_text_match(answer_text) - pattern = re.compile(rf"(?