diff --git a/docs/medarc-eval-process.md b/docs/medarc-eval-process.md index 7e3ca3c6..1b57c7ba 100644 --- a/docs/medarc-eval-process.md +++ b/docs/medarc-eval-process.md @@ -5,7 +5,7 @@ Convert raw benchmark outputs into analysis-ready parquet files. This step prepa ## Quick Start ```bash -# Process all completed runs (uses defaults) +# Process all completed jobs (uses defaults) medarc-eval process # Specify directories explicitly @@ -17,10 +17,10 @@ medarc-eval process --dry-run ## What Processing Does -1. **Discovers** completed jobs in `runs/raw/` +1. **Discovers** jobs in `runs/raw/` and filters by manifest status (default: `completed`) 2. **Extracts** results from each job's output files -3. **Normalizes** data into a consistent schema -4. **Writes** parquet files organized by environment and model +3. **Normalizes** data into a fixed output schema +4. **Writes** parquet files organized by model and environment 5. **Creates** an index (`env_index.json`) for downstream tools ### Output Structure @@ -28,22 +28,24 @@ medarc-eval process --dry-run ``` runs/processed/ ├── env_index.json # Dataset inventory for winrate/analysis -├── medqa/ -│ ├── gpt-4o.parquet -│ └── gpt-4o-mini.parquet -├── pubmedqa/ -│ ├── gpt-4o.parquet -│ └── gpt-4o-mini.parquet +├── gpt-4o/ +│ ├── medqa.parquet +│ └── pubmedqa.parquet +├── gpt-4o-mini/ +│ ├── medqa.parquet +│ └── pubmedqa.parquet └── ... ``` +On-disk model and env path components are slugified, so filenames may not exactly match raw ids. + ## Common Options | Flag | Description | Default | |------|-------------|---------| | `--runs-dir PATH` | Directory containing raw runs | `runs/raw` | | `--output-dir PATH` | Where to write processed files | `runs/processed` | -| `--max-workers N` | Parallel processing threads | 4 | +| `--max-workers N` | Parallel worker processes | 4 | | `--dry-run` | Show what would be processed | - | | `--yes` | Skip confirmation prompts | - | | `--exclude-dataset NAME` | Skip processing specific datasets/env ids (repeatable) | - | @@ -53,16 +55,35 @@ runs/processed/ ### By Completion Status -By default, only completed jobs are processed: +By default, `medarc-eval process` only selects jobs whose manifest status is `completed`. -```bash -# Include incomplete runs -medarc-eval process --process-incomplete +Note: successful jobs are written to `run_manifest.json` with `status: completed`. -# Filter by specific status +To override that default, pass one or more explicit status filters: + +```bash medarc-eval process --status completed --status failed ``` +You can also gate partially complete outputs by missing `results.jsonl` rows: + +```bash +# Default tolerance is 2.5 percent missing +medarc-eval process --max-results-missing-pct 2.5 + +# Effectively disable the gate +medarc-eval process --max-results-missing-pct 100 +``` + +This gate uses manifest job metadata only: + +- `expected_rows = num_examples * rollouts_per_example` +- `observed_rows = row_count` + +It is computed per selected job record and enforced only on the latest selected run for each processed model/environment output. It does not use manifest `summary.completed` / `summary.total`, and it does not fall back to older runs if the latest one is too incomplete. + +Selected records with missing `results.jsonl` fail processing immediately. + ### Latest Runs Only When multiple runs exist for the same (model, environment) pair, processing uses the latest by default. @@ -86,13 +107,19 @@ Store common options in a YAML file: ```yaml # process-config.yaml runs_dir: runs/raw -output_dir: runs/processed -max_workers: 8 -process_incomplete: false -exclude_datasets: - - med_dialog -exclude_models: - - deprecated-v1 + +process: + dir: processed + max_workers: 8 + max_results_missing_pct: 2.5 + exclude_datasets: + - med_dialog + exclude_models: + - deprecated-v1 + +winrate: + enabled: true + dir: winrate ``` ```bash @@ -101,6 +128,35 @@ medarc-eval process --config process-config.yaml CLI flags override config values. +Supported config schema for `medarc-eval process`: + +- Top-level `runs_dir`: raw run root. +- Top-level `process:`: process-specific defaults. +- Optional top-level `winrate:`: embedded post-process winrate step. +- Optional top-level `hf:`: shared HF settings. For embedded winrate uploads, use `hf.winrate_dir`. + +Path shortcuts: + +- `process.dir` is shorthand for `process.output_dir`, resolved relative to the parent of `runs_dir`. +- `winrate.dir` is shorthand for the embedded winrate output directory, resolved under the processed output dir. + +Example: + +```yaml +runs_dir: runs/raw + +process: + dir: processed + max_workers: 8 + +winrate: + dir: scorecards + +hf: + repo: your-org/medical-benchmarks + winrate_dir: scorecards/latest +``` + ## Hugging Face Integration Sync processed datasets to/from the Hugging Face Hub: @@ -108,7 +164,8 @@ Sync processed datasets to/from the Hugging Face Hub: ```yaml # process-config.yaml runs_dir: runs/raw -output_dir: runs/processed +process: + dir: processed hf: repo: your-org/medical-benchmarks @@ -117,6 +174,8 @@ hf: private: true ``` +`hf.token` accepts either a literal token string or an environment reference like `$HF_TOKEN` / `${HF_TOKEN}`. + ### Pull Before Processing ```bash @@ -128,8 +187,24 @@ medarc-eval process --hf-repo your-org/data --hf-pull-policy pull # Start fresh (ignore remote) medarc-eval process --hf-repo your-org/data --hf-pull-policy clean + +# Resume a previously failed HF upload without pulling or cleaning +medarc-eval process --hf-repo your-org/data --hf-pull-policy continue-upload ``` +`prompt` only prompts when the local processed dir is already non-empty. If the output dir is empty, process pulls the HF baseline immediately. + +When `prompt` is used with a non-empty local processed dir, the menu may show: + +- `pull`: download missing baseline data without deleting local files +- `clean`: redownload everything after deleting local files +- `upload`: keep local processed outputs and resume/upload pending HF artifacts + +`upload` is shown only when local parquet files appear to be missing remotely or have a different remote `lfs.sha256`. Recovery uploads the union of: + +- parquet files that were already pending before the current run started +- files touched by the current process run, including `env_index.json` and `dataset_infos.json` when rewritten + ### Push After Processing When `--hf-repo` is set, processed files are automatically uploaded after completion. @@ -139,10 +214,10 @@ When `--hf-repo` is set, processed files are automatically uploaded after comple Process and compute win rates in one step: ```bash -medarc-eval process --winrate winrate-config.yaml +medarc-eval process --config process-config.yaml ``` -This runs `medarc-eval winrate` automatically after processing completes. +This runs `medarc-eval winrate` automatically after processing completes when the config contains a `winrate:` section. ## Example Workflows @@ -180,18 +255,65 @@ medarc-eval process # env_index.json tracks what's already processed ``` +Incremental skipping only reuses an existing parquet when its footer metadata `source_runs` still matches the newly selected run ids and the existing row count still matches `env_index.json`. + +### Replace Existing Outputs + +Rebuild existing outputs for specific models or datasets without using `--clean`: + +```bash +# Rebuild every processed dataset for one model +medarc-eval process --replace-model gpt-4o + +# Rebuild every model for one dataset +medarc-eval process --replace-env medqa + +# Rebuild only the intersection +medarc-eval process --replace-model gpt-4o --replace-env medqa +``` + +When both flags are present, processing only rebuilds outputs that match both filters. + ## Troubleshooting ### "No runs found" Check that: 1. `--runs-dir` points to the correct location -2. Runs have completed (check `run_manifest.json` status) -3. Use `--process-incomplete` if runs are still in progress +2. Runs have completed (check `run_manifest.json` `jobs[*].status`) +3. Use `--status pending` or `--status running` to include non-completed jobs ### Missing data in output -By default, only jobs with `completed` status are included. Use `--process-incomplete` to include partial results. +By default, only jobs with `completed` status are included. In addition, `--max-results-missing-pct` fails if a selected latest job record is missing more than 2.5% of its expected `results.jsonl` rows, using manifest job fields: + +- `row_count` +- `num_examples` +- `rollouts_per_example` + +The gate is per selected record, not per whole run manifest. If the latest selected run for a model/dataset is too incomplete, processing fails fast instead of silently falling back to an older run. Records with unknown expected rows or unknown `row_count` are not gated. + +Use `--max-results-missing-pct 100` to disable the gate, or pass explicit `--status` values to include other statuses. + +### Integrity-check failures for existing parquet files + +If processing stops with an error like: + +```text +Existing processed output ... has N parquet rows but env_index.json records M. +``` + +the local processed snapshot is inconsistent. Fix it by rebuilding the affected output: + +```bash +medarc-eval process --replace-model gpt-4o --replace-env medqa +``` + +Or rebuild everything: + +```bash +medarc-eval process --clean --yes +``` ## Next Steps diff --git a/docs/medarc-eval-winrate.md b/docs/medarc-eval-winrate.md index d1f50e99..47c28f92 100644 --- a/docs/medarc-eval-winrate.md +++ b/docs/medarc-eval-winrate.md @@ -12,7 +12,7 @@ medarc-eval winrate medarc-eval winrate --list-models # Specify directories -medarc-eval winrate --processed-dir runs/processed --output-dir runs/winrate +medarc-eval winrate --processed-dir runs/processed --output-dir runs/processed/winrate ``` ## Prerequisites @@ -27,30 +27,35 @@ medarc-eval process ## How Win Rates Work For each pair of models (A, B) on each benchmark: -1. Find questions both models answered -2. Compare scores on each question -3. Count: A wins, B wins, ties -4. Win rate = (A wins + 0.5 × ties) / total +1. Average rollouts per `(example_id, model_id)` +2. Compare questions where at least one model has a reward +3. If one side is missing, fill it according to `--missing-policy` (`neg-inf` or `zero`) +4. Count: A wins, B wins, ties +5. Win rate = (A wins + 0.5 × ties) / total used questions The final win rate aggregates across all benchmarks using configurable weighting. +Winrate also emits a missingness summary so partial dataset coverage is visible. The report counts missing +`(dataset, model)` pairs after rollout averaging, including both absent rows and null reward values. + ## Output Files ``` -runs/winrate/ -├── winrates-2026-01-14T12-00-00.json # Timestamped results -├── winrates-2026-01-14T12-00-00.csv # Spreadsheet-friendly +runs/processed/winrate/ +├── winrates-20260114T120000Z.json # Timestamped results +├── winrates-20260114T120000Z.csv # Spreadsheet-friendly ├── latest.json # Always points to newest └── latest.csv ``` +If you pass `--output /path/to/file.json`, winrate writes only that JSON file and skips `latest.json` plus all CSV outputs. + ### Output Format The JSON output includes: - Per-model aggregate win rates -- Pairwise comparison matrices -- Per-benchmark breakdowns -- Computation metadata +- Per-opponent `vs` breakdowns +- Per-dataset average rewards and question counts ## Common Options @@ -92,33 +97,40 @@ The JSON output includes: | `--partial-datasets strict` | When `--include-model` is set, drop datasets missing any included model | | `--partial-datasets include` | When `--include-model` is set, keep datasets and treat missing models as all-missing | +`--partial-datasets include` is usually paired with `--dataset-coverage per-model`. With the default `all-models` coverage, datasets missing any required model are still dropped later. + ## Using a Config File ```yaml -# winrate-config.yaml -processed_dir: runs/processed -output_dir: runs/winrate - -# Calculation settings -missing_policy: neg-inf -epsilon: 1.0e-9 -min_common: 10 -weight_policy: ln - -# Model filtering -exclude_model: - - baseline-model - - deprecated-v1 - -# Dataset filtering -exclude_datasets: - - med_dialog +# process-config.yaml +runs_dir: runs/raw + +process: + dir: processed + +winrate: + dir: winrate + missing_policy: neg-inf + epsilon: 1.0e-9 + min_common: 10 + weight_policy: ln + exclude_model: + - baseline-model + - deprecated-v1 + exclude_datasets: + - med_dialog ``` ```bash -medarc-eval winrate --config winrate-config.yaml +medarc-eval winrate --config process-config.yaml ``` +Supported config schema for `medarc-eval winrate`: + +- Top-level `process:` can provide `dir` or `output_dir`; this becomes the default `processed_dir`. +- Top-level `winrate:` provides winrate-specific defaults. +- Top-level `hf:` provides shared HF settings. Use `hf.winrate_dir` to control where winrate artifacts upload inside the repo. + ## Example Workflows ### Compare Specific Models @@ -161,7 +173,7 @@ medarc-eval winrate --weight-policy ln ```bash medarc-eval winrate \ - --hf-processed-repo your-org/processed-benchmarks \ + --hf-repo your-org/processed-benchmarks \ --hf-processed-pull \ --hf-token $HF_TOKEN ``` @@ -170,7 +182,8 @@ medarc-eval winrate \ ```bash medarc-eval winrate \ - --hf-winrate-repo your-org/winrate-results \ + --hf-repo your-org/processed-benchmarks \ + --hf-winrate-dir winrate \ --hf-token $HF_TOKEN \ --hf-private ``` @@ -178,52 +191,81 @@ medarc-eval winrate \ ### Full Config with HF ```yaml -# winrate-config.yaml -processed_dir: runs/processed -output_dir: runs/winrate +# process-config.yaml +runs_dir: runs/raw + +process: + dir: processed -missing_policy: neg-inf -weight_policy: ln +winrate: + dir: winrate + missing_policy: neg-inf + weight_policy: ln hf: - repo: your-org/processed-data # Pull processed from here - winrate_repo: your-org/winrate-results # Upload results here + repo: your-org/processed-data # Pull processed from here; upload winrate here + winrate_dir: winrate # Subdirectory in repo for winrate artifacts (default: winrate) branch: main token: ${HF_TOKEN} private: true ``` +`hf.token` accepts either a literal token string or an environment reference like `$HF_TOKEN` / `${HF_TOKEN}`. + +`hf.winrate_dir` and `--hf-winrate-dir` both set the path inside the HF repo where `latest.json`, `latest.csv`, and timestamped winrate outputs are uploaded. + ## Interpreting Results ### Win Rate Table (CSV) -| model | win_rate | vs_gpt-4o | vs_gpt-4o-mini | vs_claude | -|-------|----------|-----------|----------------|-----------| -| gpt-4o | 0.72 | - | 0.85 | 0.58 | -| gpt-4o-mini | 0.45 | 0.15 | - | 0.32 | -| claude-3-5-sonnet | 0.68 | 0.42 | 0.68 | - | +| model | weighted_winrate | simple_winrate | medqa | pubmedqa | num_datasets | +|-------|------------------|----------------|-------|-----------|--------------| +| gpt-4o | 0.72 | 0.70 | 0.84 | 0.77 | 2 | +| gpt-4o-mini | 0.45 | 0.43 | 0.61 | 0.39 | 2 | -- **win_rate**: Aggregate win rate across all models -- **vs_X columns**: Pairwise win rate against model X -- Values > 0.5 mean the row model wins more often +- **weighted_winrate** / **simple_winrate**: Aggregate mean winrate across retained datasets +- Dataset columns: Average reward on that dataset, not pairwise winrate columns +- `num_datasets`: Number of datasets retained for that model after filtering/coverage rules ### JSON Structure ```json { - "models": ["gpt-4o", "gpt-4o-mini", "claude-3-5-sonnet"], - "aggregate_winrates": { - "gpt-4o": 0.72, - "gpt-4o-mini": 0.45, - "claude-3-5-sonnet": 0.68 - }, - "pairwise": { + "models": { "gpt-4o": { - "gpt-4o-mini": {"win_rate": 0.85, "wins": 850, "losses": 150, "ties": 0}, - "claude-3-5-sonnet": {"win_rate": 0.58, ...} + "mean_winrate": { + "simple_mean": 0.72, + "weighted_mean": 0.74, + "n_datasets": 2 + }, + "vs": { + "gpt-4o-mini": { + "mean_winrate": { + "simple_mean": 0.85, + "weighted_mean": 0.84 + }, + "per_dataset": { + "medqa": 0.90, + "pubmedqa": 0.80 + }, + "n_datasets": 2 + } + }, + "avg_reward_per_dataset": { + "medqa": 0.84, + "pubmedqa": 0.77 + } + } + }, + "datasets": { + "medqa": { + "avg_reward_per_model": { + "gpt-4o": 0.84, + "gpt-4o-mini": 0.61 + }, + "n_questions": 1273 } }, - "per_benchmark": { ... } } ``` @@ -239,3 +281,4 @@ hf: - Check `--min-common` isn't filtering out comparisons - Review `--missing-policy` (use `neg-inf` to penalize missing answers) - Verify models were evaluated on the same benchmark variants +- If using `--partial-datasets include`, also consider `--dataset-coverage per-model` diff --git a/docs/medarc-eval.md b/docs/medarc-eval.md index a9e48a39..395d251f 100644 --- a/docs/medarc-eval.md +++ b/docs/medarc-eval.md @@ -27,7 +27,7 @@ medarc-eval winrate (bench or single) (process) (winrate) | | | v v v - runs/raw/ runs/processed/ runs/winrate/ + runs/raw/ runs/processed/ runs/processed/winrate/ ``` ## Commands diff --git a/docs/medarc-verifiers-architecture.md b/docs/medarc-verifiers-architecture.md index 7eddf092..d9f25cd2 100644 --- a/docs/medarc-verifiers-architecture.md +++ b/docs/medarc-verifiers-architecture.md @@ -16,7 +16,7 @@ At a high level, everything funnels into a three-stage workflow: 1. **Run** evals (single or batch) → `runs/raw//...` 2. **Process** raw outputs → `runs/processed//.parquet` + `env_index.json` -3. **Winrate** on processed outputs → `runs/winrate/*.json` and `*.csv` +3. **Winrate** on processed outputs → `runs/processed/winrate/*.json` and `*.csv` ## Important side effects (auto-installed patches) @@ -173,7 +173,7 @@ Entry point: `medarc_verifiers/cli/process/pipeline.py` (via `run_process()`). - This suffix-derived rollout index is only used when rollouts are faked this way. Native verifiers rollouts (below) use the per-row JSONL field. - `medarc_verifiers/cli/process/rollout.py` 4. **Load rows from `results.jsonl`**: - - Drops large fields (`prompt`, `completion`) by default. + - Always drops large fields (`prompt`, `completion`). - Allows selecting extra per-env columns into a JSON-encoded `extras` column. - If the JSONL provides a per-row `rollout_index` (native verifiers multi-rollout runs), it is treated as authoritative and preserved. - If `rollout_index` is missing but the JSONL contains multiple rows per `example_id`, computes a data-driven `rollout_index` based on occurrence count. @@ -184,7 +184,8 @@ Entry point: `medarc_verifiers/cli/process/pipeline.py` (via `run_process()`). - When aggregating fake rollouts (manifest env ids include rollout suffixes), ensures every row has a `rollout_index` (derived from the suffix if missing) and normalizes indices to `0..K-1` within the dataset. - When aggregating native verifiers rollouts (no rollout suffixes), preserves `rollout_index` values as provided by `results.jsonl` (no normalization). 6. **Write Parquet**: - - Output path is `//.parquet`. + - Output path is `//.parquet`. + - Output columns are restricted to a fixed allowlist schema for downstream compatibility. - Adds exporter metadata under a Parquet schema metadata key. - Writes `env_index.json` (v2) and `dataset_infos.json` for HF datasets UX. - `medarc_verifiers/cli/process/writer.py`, `medarc_verifiers/cli/process/env_index.py` @@ -200,13 +201,15 @@ Processing can use `env_index.json` to do incremental updates (delta processing) Docs: `docs/medarc-eval-winrate.md`. -`medarc-eval winrate` reads dataset inventory from `env_index.json`, then computes pairwise model comparisons. +`medarc-eval winrate` reads dataset inventory from `env_index.json`, averages rollouts per `(example_id, model_id)`, then computes pairwise model comparisons. - Dataset discovery via `env_index.json`: `medarc_verifiers/cli/winrate/runner.py` - Core math + weighting policies: `medarc_verifiers/cli/winrate/api.py` - Outputs: - timestamped `winrates-.json` and `.csv` - `latest.json` and `latest.csv` + - JSON shape is model-centric: top-level `models` and `datasets` + - CSV contains aggregate winrates plus per-dataset average rewards, not pairwise `vs_*` columns ## Shared building blocks used by environments diff --git a/medarc_verifiers/cli/_constants.py b/medarc_verifiers/cli/_constants.py index 41a840dd..a466e47b 100644 --- a/medarc_verifiers/cli/_constants.py +++ b/medarc_verifiers/cli/_constants.py @@ -20,4 +20,4 @@ DEFAULT_ENV_CONFIG_ROOT = Path("configs") / "envs" DEFAULT_RUNS_RAW_DIR = Path("runs") / "raw" DEFAULT_PROCESSED_DIR = Path("runs") / "processed" -DEFAULT_WINRATE_DIR = Path("runs") / "winrate" +DEFAULT_WINRATE_DIR = DEFAULT_PROCESSED_DIR / "winrate" diff --git a/medarc_verifiers/cli/_manifest_tools.py b/medarc_verifiers/cli/_manifest_tools.py index 5ba9effc..836fd9d2 100644 --- a/medarc_verifiers/cli/_manifest_tools.py +++ b/medarc_verifiers/cli/_manifest_tools.py @@ -2,14 +2,16 @@ from __future__ import annotations +import os import json import logging +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path -from typing import Sequence +from typing import Any, Mapping, Sequence from medarc_verifiers.cli._manifest import MANIFEST_FILENAME, RunManifestModel, SUPPORTED_MANIFEST_VERSIONS -from medarc_verifiers.cli.utils.shared import count_jsonl_rows logger = logging.getLogger(__name__) @@ -41,90 +43,179 @@ def validate_manifests_in_runs(runs_dir: Path | str, *, strict: bool = False) -> if not runs_path.exists(): return ManifestValidationResult(manifests_checked=0, jobs_checked=0, issues=[]) - for run_dir in sorted(path for path in runs_path.iterdir() if path.is_dir()): - manifest_path = run_dir / MANIFEST_FILENAME - if not manifest_path.exists(): - continue - manifests_checked += 1 - try: - payload = json.loads(manifest_path.read_text(encoding="utf-8")) - except Exception as exc: # noqa: BLE001 - issues.append( + run_dirs = sorted(path for path in runs_path.iterdir() if path.is_dir()) + logger.info("Scanning manifests under %s...", runs_path) + + manifest_run_dirs = [run_dir for run_dir in run_dirs if (run_dir / MANIFEST_FILENAME).exists()] + if not manifest_run_dirs: + return ManifestValidationResult(manifests_checked=0, jobs_checked=0, issues=[]) + + max_workers = min(len(manifest_run_dirs), max(1, (os.cpu_count() or 4) * 4)) + if max_workers <= 1: + results = [_validate_run_dir(run_dir, strict=strict) for run_dir in manifest_run_dirs] + else: + results = list(_validate_run_dirs_parallel(manifest_run_dirs, strict=strict, max_workers=max_workers)) + + for result in results: + manifests_checked += result.manifests_checked + jobs_checked += result.jobs_checked + issues.extend(result.issues) + + issues.sort(key=lambda item: (item.run_id, item.job_id, item.kind, item.message)) + return ManifestValidationResult(manifests_checked=manifests_checked, jobs_checked=jobs_checked, issues=issues) + + +def _validate_run_dirs_parallel( + run_dirs: Sequence[Path], + *, + strict: bool, + max_workers: int, +) -> list[ManifestValidationResult]: + results: list[ManifestValidationResult] = [] + progress, task_id = _create_manifest_scan_progress(len(run_dirs)) + executor: ThreadPoolExecutor | None = None + futures = [] + try: + executor = ThreadPoolExecutor(max_workers=max_workers) + futures = [executor.submit(_validate_run_dir, run_dir, strict=strict) for run_dir in run_dirs] + if progress is not None and task_id is not None: + with progress: + for future in as_completed(futures): + results.append(future.result()) + progress.update(task_id, advance=1) + else: + for future in as_completed(futures): + results.append(future.result()) + except KeyboardInterrupt: + logger.warning("Manifest scanning interrupted; cancelling validation workers.") + for future in futures: + future.cancel() + if executor is not None: + executor.shutdown(wait=False, cancel_futures=True) + executor = None + raise + finally: + if executor is not None: + executor.shutdown(wait=True, cancel_futures=False) + return results + + +def _create_manifest_scan_progress(total: int) -> tuple[object | None, object | None]: + if total <= 0 or not sys.stderr.isatty(): + return None, None + try: + from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn, TimeElapsedColumn + + progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + transient=True, + ) + task_id = progress.add_task("Scanning manifests", total=total) + return progress, task_id + except Exception: + return None, None + + +def _validate_run_dir(run_dir: Path, *, strict: bool) -> ManifestValidationResult: + issues: list[ManifestValidationIssue] = [] + manifest_path = run_dir / MANIFEST_FILENAME + if not manifest_path.exists(): + return ManifestValidationResult(manifests_checked=0, jobs_checked=0, issues=[]) + + try: + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + except Exception as exc: # noqa: BLE001 + return ManifestValidationResult( + manifests_checked=1, + jobs_checked=0, + issues=[ ManifestValidationIssue( run_id=run_dir.name, job_id="", kind="error", message=f"Failed to parse manifest: {exc}", ) - ) - continue + ], + ) - version = payload.get("version") - if version not in SUPPORTED_MANIFEST_VERSIONS: - issues.append( + version = payload.get("version") + if version not in SUPPORTED_MANIFEST_VERSIONS: + return ManifestValidationResult( + manifests_checked=1, + jobs_checked=0, + issues=[ ManifestValidationIssue( run_id=run_dir.name, job_id="", kind="error", message=f"Unsupported manifest version: {version}", ) + ], + ) + + model = RunManifestModel.model_validate(payload) + artifacts_root = str(getattr(model, "artifacts_root", ".") or ".") + jobs_checked = 0 + + for entry in model.jobs: + jobs_checked += 1 + results_path, metadata_path, used_fallback = _resolve_job_artifact_paths( + run_dir=run_dir, + artifacts_root=artifacts_root, + job_id=entry.job_id, + results_relpath=entry.results_relpath, + metadata_relpath=entry.metadata_relpath, + ) + if used_fallback: + issues.append( + ManifestValidationIssue( + run_id=model.run_id, + job_id=entry.job_id, + kind="warning", + message="Manifest artifact path missing; fallback to run-relative job directory would be used.", + ) ) - continue - model = RunManifestModel.model_validate(payload) - artifacts_root = str(getattr(model, "artifacts_root", ".") or ".") - - for entry in model.jobs: - jobs_checked += 1 - results_path, metadata_path, used_fallback = _resolve_job_artifact_paths( - run_dir=run_dir, - artifacts_root=artifacts_root, - job_id=entry.job_id, - results_relpath=entry.results_relpath, - metadata_relpath=entry.metadata_relpath, - ) - if used_fallback: - issues.append( - ManifestValidationIssue( - run_id=model.run_id, - job_id=entry.job_id, - kind="warning", - message="Manifest artifact path missing; fallback to run-relative job directory would be used.", - ) + if not results_path.exists(): + kind = "error" if strict else "warning" + issues.append( + ManifestValidationIssue( + run_id=model.run_id, + job_id=entry.job_id, + kind=kind, + message=f"Missing results.jsonl at {results_path}", ) - if not results_path.exists(): + ) + if results_path.exists(): + for message in _quick_validate_results_jsonl( + results_path, + num_examples=entry.num_examples, + rollouts_per_example=entry.rollouts_per_example, + ): kind = "error" if strict else "warning" issues.append( ManifestValidationIssue( run_id=model.run_id, job_id=entry.job_id, kind=kind, - message=f"Missing results.jsonl at {results_path}", + message=message, ) ) - if entry.row_count is not None and results_path.exists(): - row_count = count_jsonl_rows(results_path) - if row_count is not None and int(row_count) != int(entry.row_count): - kind = "error" if strict else "warning" - issues.append( - ManifestValidationIssue( - run_id=model.run_id, - job_id=entry.job_id, - kind=kind, - message=f"row_count mismatch: manifest={entry.row_count} actual={row_count}", - ) - ) - # metadata is optional; only flag when declared explicitly in v3. - if entry.metadata_relpath and not metadata_path.exists(): - kind = "error" if strict else "warning" - issues.append( - ManifestValidationIssue( - run_id=model.run_id, - job_id=entry.job_id, - kind=kind, - message=f"Missing metadata.json at {metadata_path}", - ) + if entry.metadata_relpath and not metadata_path.exists(): + kind = "error" if strict else "warning" + issues.append( + ManifestValidationIssue( + run_id=model.run_id, + job_id=entry.job_id, + kind=kind, + message=f"Missing metadata.json at {metadata_path}", ) - return ManifestValidationResult(manifests_checked=manifests_checked, jobs_checked=jobs_checked, issues=issues) + ) + + return ManifestValidationResult(manifests_checked=1, jobs_checked=jobs_checked, issues=issues) def _resolve_job_artifact_paths( @@ -153,6 +244,132 @@ def _resolve_job_artifact_paths( return results_path, metadata_path, used_fallback +def _quick_validate_results_jsonl( + path: Path, + *, + num_examples: int | None, + rollouts_per_example: int | None, +) -> list[str]: + first_line = _read_first_nonempty_line(path) + last_line = _read_last_nonempty_line(path) + if first_line is None or last_line is None: + return [f"results.jsonl at {path} is empty"] + + issues: list[str] = [] + first_payload = _decode_probe_line(first_line, path=path, position="first", issues=issues) + last_payload = _decode_probe_line(last_line, path=path, position="last", issues=issues) + if first_payload is None or last_payload is None: + return issues + + for position, payload in (("first", first_payload), ("last", last_payload)): + if "example_id" not in payload: + issues.append(f"{position} JSONL row in {path} is missing example_id") + _validate_rollout_index( + first_payload, + path=path, + position="first", + rollouts_per_example=rollouts_per_example, + issues=issues, + ) + _validate_rollout_index( + last_payload, + path=path, + position="last", + rollouts_per_example=rollouts_per_example, + issues=issues, + ) + + return issues + + +def _decode_probe_line( + raw_line: str, + *, + path: Path, + position: str, + issues: list[str], +) -> Mapping[str, Any] | None: + try: + payload = json.loads(raw_line) + except json.JSONDecodeError as exc: + issues.append(f"failed to parse {position} JSONL row in {path}: {exc.msg}") + return None + if not isinstance(payload, Mapping): + issues.append(f"{position} JSONL row in {path} is not a JSON object") + return None + return payload + + +def _read_first_nonempty_line(path: Path) -> str | None: + with path.open("r", encoding="utf-8") as handle: + for line in handle: + candidate = line.strip() + if candidate: + return candidate + return None + + +def _read_last_nonempty_line(path: Path) -> str | None: + with path.open("rb") as handle: + handle.seek(0, os.SEEK_END) + file_size = handle.tell() + if file_size <= 0: + return None + + chunk_size = 8192 + buffer = b"" + position = file_size + while position > 0: + read_size = min(chunk_size, position) + position -= read_size + handle.seek(position) + buffer = handle.read(read_size) + buffer + lines = buffer.splitlines() + for raw_line in reversed(lines): + candidate = raw_line.strip() + if candidate: + return candidate.decode("utf-8") + return None + + +def _validate_rollout_index( + payload: Mapping[str, Any], + *, + path: Path, + position: str, + rollouts_per_example: int | None, + issues: list[str], +) -> None: + rollout_index = _coerce_int(payload.get("rollout_index")) + if rollout_index is None: + return + if rollout_index < 0: + issues.append(f"{position} JSONL row in {path} has negative rollout_index={payload.get('rollout_index')!r}") + return + if rollouts_per_example and rollout_index >= rollouts_per_example: + issues.append( + f"{position} JSONL row in {path} has out-of-range rollout_index={payload.get('rollout_index')!r}; " + f"expected < {rollouts_per_example}" + ) + + +def _coerce_int(value: Any) -> int | None: + if value is None or isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + if value.is_integer(): + return int(value) + return None + if isinstance(value, str): + try: + return int(value.strip()) + except ValueError: + return None + return None + + def format_validation_issues(issues: Sequence[ManifestValidationIssue]) -> list[str]: lines: list[str] = [] for issue in issues: diff --git a/medarc_verifiers/cli/hf/__init__.py b/medarc_verifiers/cli/hf/__init__.py index 11e0a6aa..47009eb4 100644 --- a/medarc_verifiers/cli/hf/__init__.py +++ b/medarc_verifiers/cli/hf/__init__.py @@ -3,6 +3,8 @@ from .sync import ( # noqa: F401 HFSyncConfig, HFSyncSummary, + collect_changed_output_files, + compute_pending_parquet_uploads, download_hf_repo, sync_files_to_hub, sync_to_hub, @@ -11,6 +13,8 @@ __all__ = [ "HFSyncConfig", "HFSyncSummary", + "collect_changed_output_files", + "compute_pending_parquet_uploads", "sync_files_to_hub", "sync_to_hub", "download_hf_repo", diff --git a/medarc_verifiers/cli/hf/sync.py b/medarc_verifiers/cli/hf/sync.py index 9f462f9d..44db7314 100644 --- a/medarc_verifiers/cli/hf/sync.py +++ b/medarc_verifiers/cli/hf/sync.py @@ -2,12 +2,15 @@ from __future__ import annotations +import hashlib import logging import tempfile import time from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence + +from medarc_verifiers.utils.pathing import resolve_under if TYPE_CHECKING: from medarc_verifiers.cli.process.writer import EnvWriteSummary @@ -60,6 +63,29 @@ def _is_repo_not_found_error(exc: BaseException) -> bool: return False +def _status_code_from_exc(exc: BaseException) -> int | None: + response = getattr(exc, "response", None) + status_code = getattr(response, "status_code", None) + if status_code is None: + status_code = getattr(exc, "status_code", None) + try: + return int(status_code) if status_code is not None else None + except Exception: + return None + + +def _is_transient_hf_error(exc: BaseException) -> bool: + status_code = _status_code_from_exc(exc) + if status_code == 429 or (status_code is not None and 500 <= status_code < 600): + return True + try: + import httpx # type: ignore[import-not-found] + + return isinstance(exc, (httpx.TimeoutException, httpx.TransportError)) + except Exception: + return False + + def _confirm_create_repo( *, repo_id: str, @@ -153,6 +179,181 @@ class HFSyncSummary: files: Sequence[str] +def _local_sha256(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as handle: + for chunk in iter(lambda: handle.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def _repo_tree_entry_path(entry: Any) -> str | None: + for attr in ("path", "rfilename"): + value = getattr(entry, attr, None) + if isinstance(value, str) and value.strip(): + return Path(value).as_posix() + if isinstance(entry, dict): + value = entry.get("path") or entry.get("rfilename") + if isinstance(value, str) and value.strip(): + return Path(value).as_posix() + return None + + +def _repo_tree_entry_lfs_sha256(entry: Any) -> str | None: + lfs = getattr(entry, "lfs", None) + if lfs is None and isinstance(entry, dict): + lfs = entry.get("lfs") + if isinstance(lfs, dict): + sha256 = lfs.get("sha256") + return str(sha256) if sha256 else None + sha256 = getattr(lfs, "sha256", None) + return str(sha256) if sha256 else None + + +def _normalize_output_files(output_dir: Path, files: Iterable[str | Path]) -> list[str]: + normalized: list[str] = [] + for path in files: + candidate = Path(path) + if candidate.is_absolute(): + try: + rel_path = candidate.relative_to(output_dir) + except ValueError: + continue + else: + # Accept caller inputs like "runs/processed/foo.parquet" when output_dir is also relative. + output_parts = output_dir.parts + if output_parts and candidate.parts[: len(output_parts)] == output_parts: + try: + rel_path = candidate.relative_to(output_dir) + except ValueError: + continue + else: + rel_path = candidate + rel_text = rel_path.as_posix() + if rel_text: + normalized.append(rel_text) + return sorted(set(normalized)) + + +def _prepare_upload_file_entries(output_dir: Path, files: Sequence[str | Path]) -> list[tuple[str, Path]]: + output_dir = output_dir.resolve() + prepared: list[tuple[str, Path]] = [] + seen: set[str] = set() + for path in files: + candidate = Path(path) + raw_text = candidate.as_posix() + if not raw_text: + continue + if candidate.is_absolute(): + try: + rel_path = candidate.resolve().relative_to(output_dir).as_posix() + except ValueError as exc: + raise ValueError(f"Upload file path must be under output_dir: {candidate}") from exc + else: + resolved = resolve_under(output_dir, raw_text) + if resolved is None: + raise ValueError(f"Upload file path must be relative to output_dir without traversal: {raw_text!r}") + try: + rel_path = resolved.resolve().relative_to(output_dir).as_posix() + except ValueError as exc: + raise ValueError(f"Upload file path resolves outside output_dir: {raw_text!r}") from exc + local_path = (output_dir / rel_path).resolve() + try: + local_path.relative_to(output_dir) + except ValueError as exc: + raise ValueError(f"Upload file path resolves outside output_dir: {raw_text!r}") from exc + if rel_path in seen: + continue + prepared.append((rel_path, local_path)) + seen.add(rel_path) + return prepared + + +def collect_changed_output_files( + env_summaries: Sequence[EnvWriteSummary], + *, + output_dir: Path, + metadata_paths: Sequence[Path] | None = None, +) -> list[str]: + changed_paths = {summary.output_path for summary in env_summaries if summary.changed} + if metadata_paths: + for path in metadata_paths: + candidate = Path(path) + if not candidate.is_absolute(): + output_parts = output_dir.parts + if output_parts and candidate.parts[: len(output_parts)] != output_parts: + candidate = output_dir / candidate + changed_paths.add(candidate) + return _normalize_output_files(output_dir, changed_paths) + + +def _collect_changed_output_files( + env_summaries: Sequence[EnvWriteSummary], + *, + output_dir: Path, + metadata_paths: Sequence[Path] | None = None, +) -> list[str]: + return collect_changed_output_files(env_summaries, output_dir=output_dir, metadata_paths=metadata_paths) + + +def compute_pending_parquet_uploads( + output_dir: Path, + repo_id: str, + branch: str | None, + token: str | None, +) -> set[str]: + """Return local parquet paths that are missing remotely or differ from remote lfs.sha256.""" + output_dir = Path(output_dir) + local_parquets = sorted(path for path in output_dir.rglob("*.parquet") if path.is_file()) + if not local_parquets: + return set() + + try: + from huggingface_hub import HfApi # type: ignore[import-not-found] + except Exception as exc: # noqa: BLE001 + raise ImportError("huggingface_hub is required for HF upload recovery.") from exc + + api = HfApi(token=token) + list_kwargs = { + "repo_id": repo_id, + "repo_type": "dataset", + "revision": branch, + "recursive": True, + "expand": True, + } + try: + try: + tree_entries = list(api.list_repo_tree(**list_kwargs)) + except TypeError as exc: + if "expand" not in str(exc): + raise + list_kwargs.pop("expand", None) + tree_entries = list(api.list_repo_tree(**list_kwargs)) + except Exception as exc: # noqa: BLE001 + if _is_repo_not_found_error(exc): + tree_entries = [] + else: + raise + + remote_parquets: dict[str, str | None] = {} + for entry in tree_entries: + rel_path = _repo_tree_entry_path(entry) + if not rel_path or not rel_path.endswith(".parquet"): + continue + remote_parquets[rel_path] = _repo_tree_entry_lfs_sha256(entry) + + pending: set[str] = set() + for parquet_path in local_parquets: + rel_path = parquet_path.relative_to(output_dir).as_posix() + if rel_path not in remote_parquets: + pending.add(rel_path) + continue + remote_sha256 = remote_parquets[rel_path] + if remote_sha256 is None or remote_sha256 != _local_sha256(parquet_path): + pending.add(rel_path) + return pending + + def sync_files_to_hub( *, repo_id: str, @@ -166,25 +367,27 @@ def sync_files_to_hub( request_timeout_s: float | None = None, retries: int = 3, max_files_per_commit: int | None = None, + path_in_repo_prefix: str | None = None, is_tty: bool = False, assume_yes: bool = False, prompt_func: Callable[[str], str] | None = None, -) -> None: - """Upload explicit file paths from output_dir to a HF dataset repo.""" +) -> bool: + """Upload explicit file paths from output_dir to a HF dataset repo. + + Returns False only when upload is skipped because repo creation was declined. + """ if not repo_id: logger.debug("HF sync skipped: no repo_id provided.") - return - file_list = [] - for path in files: - rel_path = Path(path).as_posix() if not isinstance(path, str) else Path(path).as_posix() - if rel_path: - file_list.append(rel_path) + return True + output_dir = Path(output_dir) + prepared_files = _prepare_upload_file_entries(output_dir, files) + file_list = [rel_path for rel_path, _ in prepared_files] if not file_list: logger.debug("HF sync skipped: no files provided.") - return + return True if dry_run: logger.debug("HF sync dry-run; skipping push.") - return + return True try: from huggingface_hub import CommitOperationAdd, HfApi # type: ignore[import-not-found] @@ -195,6 +398,9 @@ def sync_files_to_hub( _configure_hf_http_timeout(float(request_timeout_s)) api = HfApi(token=token) + repo_prefix = _normalize_repo_path_prefix(path_in_repo_prefix) + + file_map = dict(prepared_files) if max_files_per_commit is None or max_files_per_commit <= 0: batches = [file_list] @@ -203,11 +409,12 @@ def sync_files_to_hub( file_list[index : index + max_files_per_commit] for index in range(0, len(file_list), max_files_per_commit) ] - output_dir = Path(output_dir) - for batch_index, batch_files in enumerate(batches, start=1): operations = [ - CommitOperationAdd(path_in_repo=rel_path, path_or_fileobj=str(output_dir / rel_path)) + CommitOperationAdd( + path_in_repo=_join_repo_path(repo_prefix, rel_path), + path_or_fileobj=str(file_map[rel_path]), + ) for rel_path in batch_files ] commit_message = message @@ -234,9 +441,11 @@ def sync_files_to_hub( prompt_func=prompt_func, ) if not should_create: - raise RuntimeError( - f"HF dataset repo '{repo_id}' not found. Create it on the Hub or re-run with --yes to allow creation." - ) from exc + logger.warning( + "HF dataset repo '%s' not found; skipping upload because repo creation was declined.", + repo_id, + ) + return False api.create_repo( repo_id=repo_id, repo_type="dataset", @@ -245,13 +454,7 @@ def sync_files_to_hub( ) # Retry the commit immediately after repo creation. continue - try: - import httpx # type: ignore[import-not-found] - - is_retryable = isinstance(exc, (httpx.TimeoutException, httpx.TransportError)) - except Exception: - is_retryable = False - if not is_retryable or attempt >= int(retries): + if not _is_transient_hf_error(exc) or attempt >= int(retries): raise delay = _sleep_backoff_seconds(attempt) logger.warning( @@ -262,6 +465,27 @@ def sync_files_to_hub( delay, ) time.sleep(delay) + return True + + +def _normalize_repo_path_prefix(value: str | None) -> str | None: + if value is None: + return None + raw = str(value).strip().replace("\\", "/").strip("/") + if not raw: + return None + candidate = resolve_under(Path("."), raw) + if candidate is None: + raise ValueError(f"Invalid path_in_repo_prefix: {value!r}") + normalized = candidate.as_posix().lstrip("./") + return normalized or None + + +def _join_repo_path(prefix: str | None, rel_path: str) -> str: + rel = rel_path.strip().replace("\\", "/").lstrip("/") + if not prefix: + return rel + return f"{prefix}/{rel}" if rel else prefix def sync_to_hub( @@ -270,6 +494,7 @@ def sync_to_hub( *, output_dir: Path, metadata_paths: Sequence[Path] | None = None, + files: Sequence[str | Path] | None = None, is_tty: bool = False, assume_yes: bool = False, prompt_func: Callable[[str], str] | None = None, @@ -278,37 +503,27 @@ def sync_to_hub( if not config.repo_id: logger.debug("HF sync skipped: no repo_id provided.") return None - if not env_summaries: - logger.debug("HF sync skipped: no environment summaries available.") - return None - if all(summary.dry_run for summary in env_summaries): - logger.debug("HF sync skipped: only dry-run summaries available.") + if config.dry_run: + logger.debug("HF sync dry-run; skipping summary generation and upload.") return None + output_dir = Path(output_dir) changed = [summary for summary in env_summaries if summary.changed] - if not changed: - logger.debug("HF sync skipped: no changed outputs.") - return None + if files is None: + if not env_summaries: + logger.debug("HF sync skipped: no environment summaries available.") + return None + if all(summary.dry_run for summary in env_summaries): + logger.debug("HF sync skipped: only dry-run summaries available.") + return None + files = collect_changed_output_files(env_summaries, output_dir=output_dir, metadata_paths=metadata_paths) + else: + files = _normalize_output_files(output_dir, files) - output_dir = Path(output_dir) - changed_paths = {summary.output_path for summary in changed} - if metadata_paths: - for path in metadata_paths: - candidate = Path(path) - if not candidate.is_absolute(): - output_parts = output_dir.parts - if output_parts and candidate.parts[: len(output_parts)] != output_parts: - candidate = output_dir / candidate - changed_paths.add(candidate) + if not files: + logger.debug("HF sync skipped: no files selected for upload.") + return None - files = [] - for path in changed_paths: - try: - rel_path = path.relative_to(output_dir) - except ValueError: - continue - files.append(rel_path.as_posix()) - files = sorted(set(files)) summary = HFSyncSummary( repo_id=config.repo_id, strategy="file", @@ -318,7 +533,7 @@ def sync_to_hub( ) message = f"Update {summary.total_files} file(s) from medarc-eval process" - sync_files_to_hub( + uploaded = sync_files_to_hub( repo_id=config.repo_id, output_dir=output_dir, files=files, @@ -334,6 +549,8 @@ def sync_to_hub( assume_yes=assume_yes, prompt_func=prompt_func, ) + if not uploaded: + return None return summary @@ -383,6 +600,8 @@ def download_hf_repo( __all__ = [ "HFSyncSummary", "HFSyncConfig", + "collect_changed_output_files", + "compute_pending_parquet_uploads", "sync_files_to_hub", "sync_to_hub", ] diff --git a/medarc_verifiers/cli/main.py b/medarc_verifiers/cli/main.py index 72d0a20e..97ca6e50 100644 --- a/medarc_verifiers/cli/main.py +++ b/medarc_verifiers/cli/main.py @@ -4,6 +4,7 @@ import argparse import logging +import os import sys from pathlib import Path from textwrap import dedent @@ -25,7 +26,6 @@ DEFAULT_ENV_DIR, DEFAULT_PROCESSED_DIR, DEFAULT_RUNS_RAW_DIR, - DEFAULT_WINRATE_DIR, PROCESS_COMMAND, WINRATE_COMMAND, ) @@ -33,11 +33,10 @@ from medarc_verifiers.cli._job_executor import ExecutorSettings, JobExecutionResult, execute_jobs from medarc_verifiers.cli._manifest import MANIFEST_FILENAME, ManifestJobEntry, RunManifest, compute_snapshot_checksum from medarc_verifiers.cli._manifest_planner import ManifestPlanner -from medarc_verifiers.cli._manifest_tools import format_validation_issues, validate_manifests_in_runs from medarc_verifiers.cli._schemas import EnvironmentConfigSchema, EnvironmentExportConfig from medarc_verifiers.cli._single_run import run_single_mode from medarc_verifiers.cli.hf import HFSyncConfig, sync_files_to_hub -from medarc_verifiers.cli.process import ProcessOptions, ProcessResult, run_process +from medarc_verifiers.cli.process import PROCESS_DEFAULT_STATUS_FILTER, ProcessOptions, ProcessResult, run_process from medarc_verifiers.cli.utils.config_io import load_mapping_file from medarc_verifiers.cli.utils.overrides import build_cli_override from medarc_verifiers.cli.utils.shared import ( @@ -47,6 +46,7 @@ slugify, validate_simple_name, ) +from medarc_verifiers.utils.pathing import resolve_under from medarc_verifiers.cli.winrate import ( WinrateConfig, _resolve_source, @@ -287,29 +287,33 @@ def build_process_parser() -> argparse.ArgumentParser: parser.add_argument("--processed-at", default=None, help="Override processed_at timestamp (ISO8601).") parser.add_argument("--dry-run", action="store_true", default=None, help="Plan processing without writing outputs.") parser.add_argument( - "--validate-manifest", - action=argparse.BooleanOptionalAction, + "--replace-model", + action="append", default=None, - help="Validate run manifests before processing (default: enabled).", + help="Rebuild existing processed outputs for these model ids (repeatable; comma-separated values allowed).", ) parser.add_argument( - "--strict-manifest", - action="store_true", + "--replace-env", + action="append", default=None, - help="Treat manifest validation problems as errors.", + help="Rebuild existing processed outputs for these env ids (repeatable; comma-separated values allowed).", ) parser.add_argument( - "--process-incomplete", - dest="process_incomplete", - action="store_true", + "--max-results-missing-pct", + type=float, default=None, - help="Include runs where run_manifest.json summary has completed < total.", + help=( + "Fail if a selected latest job record is missing more than this percentage of expected results.jsonl rows " + "based on manifest job fields (row_count, num_examples, rollouts_per_example). " + "Computed per selected job record and enforced only on the latest selected run; does not use " + "manifest summary.completed/summary.total or fall back to older runs (default: 2.5)." + ), ) parser.add_argument( "--winrate", type=Path, default=None, - help="Run winrate after processing using the provided winrate config file.", + help="Run winrate after processing using the provided config file. If omitted, an embedded winrate section in --config is used.", ) parser.add_argument( "--max-workers", @@ -321,7 +325,7 @@ def build_process_parser() -> argparse.ArgumentParser: parser.add_argument("--hf-repo", default=None, help="Hugging Face repo id for dataset sync.") parser.add_argument( "--hf-pull-policy", - choices=("prompt", "pull", "clean"), + choices=("prompt", "pull", "clean", "continue-upload"), default=None, help="Baseline policy when output dir is non-empty in HF mode.", ) @@ -376,7 +380,7 @@ def build_winrate_parser() -> argparse.ArgumentParser: "--output-dir", type=Path, default=None, - help=f"Directory to store winrate outputs (default: {DEFAULT_WINRATE_DIR}).", + help="Directory to store winrate outputs (default: /winrate).", ) parser.add_argument( "--output", @@ -465,7 +469,7 @@ def build_winrate_parser() -> argparse.ArgumentParser: "per-model uses the legacy behavior where each model may be averaged over a different dataset set." ), ) - parser.add_argument("--hf-processed-repo", help="Hugging Face repo id for processed dataset download.") + parser.add_argument("--hf-repo", help="Hugging Face repo id used for processed download and winrate upload.") parser.add_argument( "--hf-processed-pull", action="store_true", @@ -474,7 +478,11 @@ def build_winrate_parser() -> argparse.ArgumentParser: ) parser.add_argument("--hf-branch", help="Target HF branch or revision for processed download.") parser.add_argument("--hf-token", help="Auth token for HF operations.") - parser.add_argument("--hf-winrate-repo", help="Hugging Face repo id for winrate artifact upload.") + parser.add_argument( + "--hf-winrate-dir", + default=None, + help="Path under the HF repo where winrate artifacts are uploaded (default: winrate).", + ) parser.add_argument( "--hf-private", action=argparse.BooleanOptionalAction, @@ -561,33 +569,61 @@ def _run_batch_mode(argv: Sequence[str]) -> int: def _run_process_mode(argv: Sequence[str]) -> int: + parser, args = _resolve_process_args(argv) + winrate_args = _resolve_embedded_winrate(args, parser=parser) + + try: + env_export_map = _load_env_export_map(args.env_config_root) + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to load environment export configs: %s", exc) + env_export_map = {} + + options = _build_process_options(args) + + try: + result = run_process(options, env_export_map=env_export_map) + except Exception as exc: # noqa: BLE001 + logger.exception("Process pipeline failed: %s", exc) + return 1 + + _log_process_result(result) + return _run_process_post_steps(args, parser=parser, options=options, winrate_args=winrate_args) + + +def _resolve_process_args(argv: Sequence[str]) -> tuple[argparse.ArgumentParser, argparse.Namespace]: parser = build_process_parser() args = parser.parse_args(argv) if args.config: _load_and_apply_config(args, args.config, mode="process", parser=parser) _finalize_config_args(args, mode="process") + _validate_process_args(args, argv=argv, parser=parser) + return parser, args + + +def _validate_process_args( + args: argparse.Namespace, + *, + argv: Sequence[str], + parser: argparse.ArgumentParser, +) -> None: + for flag, attr in (("--replace-model", "replace_model"), ("--replace-env", "replace_env")): + if _option_was_provided(argv, flag) and not getattr(args, attr, None): + parser.error(f"{flag} requires at least one non-empty value.") try: if args.exclude_dataset: normalize_dataset_ids(args.exclude_dataset, label="process exclude dataset") if args.exclude_model: normalize_model_ids(args.exclude_model, label="process exclude model") + if args.max_results_missing_pct is not None: + value = float(args.max_results_missing_pct) + if value < 0: + parser.error("--max-results-missing-pct must be non-negative.") except ValueError as exc: parser.error(str(exc)) - winrate_args: argparse.Namespace | None = None - if args.winrate: - winrate_path = Path(args.winrate).expanduser() - if not winrate_path.exists(): - parser.error(f"Winrate config path '{winrate_path}' does not exist.") - args.winrate = winrate_path - winrate_args = _build_winrate_args_from_config(winrate_path, parser=parser) - try: - env_export_map = _load_env_export_map(args.env_config_root) - except Exception as exc: # noqa: BLE001 - logger.warning("Failed to load environment export configs: %s", exc) - env_export_map = {} +def _build_process_options(args: argparse.Namespace) -> ProcessOptions: hf_config = HFSyncConfig.from_cli( repo=args.hf_repo, branch=args.hf_branch, @@ -598,16 +634,18 @@ def _run_process_mode(argv: Sequence[str]) -> int: retries=args.hf_retries, max_files_per_commit=args.hf_max_files_per_commit, ) - + status_values = list(args.status or []) + status_filter = tuple(status_values) if status_values else PROCESS_DEFAULT_STATUS_FILTER + max_results_missing_pct = float(args.max_results_missing_pct) if args.max_results_missing_pct is not None else 2.5 processed_with_args = { - "status": args.status or [], + "status": list(status_filter), + "max_results_missing_pct": max_results_missing_pct, "exclude_datasets": args.exclude_dataset or [], "exclude_models": args.exclude_model or [], + "replace_models": args.replace_model or [], + "replace_envs": args.replace_env or [], "dry_run": bool(args.dry_run), "clean": bool(args.clean), - "validate_manifest": bool(args.validate_manifest), - "strict_manifest": bool(args.strict_manifest), - "only_complete_runs": not bool(args.process_incomplete), "hf_repo": args.hf_repo, "hf_pull_policy": args.hf_pull_policy, "hf_request_timeout": args.hf_request_timeout, @@ -615,16 +653,17 @@ def _run_process_mode(argv: Sequence[str]) -> int: "hf_max_files_per_commit": args.hf_max_files_per_commit, "max_workers": args.max_workers, } - - options = ProcessOptions( + return ProcessOptions( runs_dir=args.runs_dir, output_dir=args.output_dir, exclude_datasets=tuple(args.exclude_dataset or ()), exclude_models=tuple(args.exclude_model or ()), + replace_models=tuple(args.replace_model or ()), + replace_envs=tuple(args.replace_env or ()), processed_at=args.processed_at, processed_with_args=processed_with_args, - status_filter=args.status or (), - only_complete_runs=not bool(args.process_incomplete), + status_filter=status_filter, + max_results_missing_pct=max_results_missing_pct, dry_run=bool(args.dry_run), clean=bool(args.clean), assume_yes=bool(args.yes), @@ -633,80 +672,94 @@ def _run_process_mode(argv: Sequence[str]) -> int: max_workers=args.max_workers, ) - if args.validate_manifest: - validation = validate_manifests_in_runs(options.runs_dir, strict=bool(args.strict_manifest)) - for line in format_validation_issues(validation.issues): - if line.startswith("[ERROR]"): - logger.error("%s", line) - else: - logger.warning("%s", line) - logger.info( - "Manifest preflight: checked %d manifest(s), %d job(s), %d issue(s).", - validation.manifests_checked, - validation.jobs_checked, - len(validation.issues), - ) - if validation.has_errors: - logger.error("Manifest validation failed in strict mode; aborting process.") - return 1 +def _resolve_embedded_winrate( + args: argparse.Namespace, + *, + parser: argparse.ArgumentParser, +) -> argparse.Namespace | None: + embedded_winrate = False + if args.config and args.winrate is None: + try: + embedded_winrate = _config_has_embedded_winrate(Path(args.config).expanduser()) + except (FileNotFoundError, ValueError) as exc: + parser.error(str(exc)) + + if args.winrate: + winrate_path = Path(args.winrate).expanduser() + if not winrate_path.exists(): + parser.error(f"Winrate config path '{winrate_path}' does not exist.") + args.winrate = winrate_path + return _build_winrate_args_from_config(winrate_path, parser=parser) + + if embedded_winrate: + args.winrate = Path(args.config).expanduser() + return _build_winrate_args_from_config(Path(args.config).expanduser(), parser=parser) + return None + + +def _run_process_post_steps( + args: argparse.Namespace, + *, + parser: argparse.ArgumentParser, + options: ProcessOptions, + winrate_args: argparse.Namespace | None, +) -> int: + if not args.winrate: + return 0 + if options.dry_run: + logger.info("Skipping winrate post-step for dry-run process.") + return 0 + + if winrate_args is None: + winrate_args = _build_winrate_args_from_config(Path(args.winrate), parser=parser) + winrate_args.processed_dir = options.output_dir + if not getattr(winrate_args, "_output_dir_explicit", False): + winrate_args.output_dir = _default_winrate_output_dir(options.output_dir) + winrate_args.hf_repo = None + winrate_args.hf_processed_pull = False + + winrate_cfg = WinrateConfig( + missing_policy=winrate_args.missing_policy, + epsilon=winrate_args.epsilon, + min_common=winrate_args.min_common, + weight_policy=winrate_args.weight_policy, + weight_cap=winrate_args.weight_cap, + dataset_coverage=winrate_args.dataset_coverage, + include_models=tuple(winrate_args.include_model or ()), + exclude_models=tuple(winrate_args.exclude_model or ()), + exclude_datasets=tuple(winrate_args.exclude_dataset or ()), + partial_datasets=winrate_args.partial_datasets, + ) try: - result = run_process(options, env_export_map=env_export_map) + winrate_result = run_winrate( + processed_dir=options.output_dir, + output_dir=winrate_args.output_dir, + output_path=winrate_args.output, + output_name=winrate_args.output_name, + config=winrate_cfg, + processed_at=winrate_args.processed_at, + hf_config=None, + hf_processed_pull=False, + ) except Exception as exc: # noqa: BLE001 - logger.exception("Process pipeline failed: %s", exc) + logger.exception("Win rate computation failed: %s", exc) return 1 - _log_process_result(result) + logger.info("Computed win rates for %d dataset(s): %s", len(winrate_result.datasets), winrate_result.output_path) + print_winrate_summary_markdown(winrate_result.result) - if args.winrate: - if options.dry_run: - logger.info("Skipping winrate post-step for dry-run process.") - return 0 - if winrate_args is None: - winrate_args = _build_winrate_args_from_config(Path(args.winrate), parser=parser) - winrate_args.processed_dir = options.output_dir - winrate_args.hf_processed_repo = None - winrate_args.hf_processed_pull = False - winrate_cfg = WinrateConfig( - missing_policy=winrate_args.missing_policy, - epsilon=winrate_args.epsilon, - min_common=winrate_args.min_common, - weight_policy=winrate_args.weight_policy, - weight_cap=winrate_args.weight_cap, - dataset_coverage=winrate_args.dataset_coverage, - include_models=tuple(winrate_args.include_model or ()), - exclude_models=tuple(winrate_args.exclude_model or ()), - exclude_datasets=tuple(winrate_args.exclude_dataset or ()), - partial_datasets=winrate_args.partial_datasets, - ) - try: - winrate_result = run_winrate( - processed_dir=options.output_dir, - output_dir=winrate_args.output_dir, - output_path=winrate_args.output, - output_name=winrate_args.output_name, - config=winrate_cfg, - processed_at=winrate_args.processed_at, - hf_config=None, - hf_processed_pull=False, - ) - except Exception as exc: # noqa: BLE001 - logger.exception("Win rate computation failed: %s", exc) - return 1 - logger.info( - "Computed win rates for %d dataset(s): %s", len(winrate_result.datasets), winrate_result.output_path + if options.hf_config and options.hf_config.repo_id: + _upload_winrate_outputs( + output_dir=winrate_args.output_dir, + output_paths=winrate_result.output_paths, + repo_id=options.hf_config.repo_id, + token=options.hf_config.token, + branch=options.hf_config.branch, + private=bool(options.hf_config.private), + winrate_dir=winrate_args.hf_winrate_dir, + assume_yes=bool(args.yes), ) - print_winrate_summary_markdown(winrate_result.result) - if winrate_args.hf_winrate_repo: - _upload_winrate_outputs( - output_dir=winrate_args.output_dir, - output_paths=winrate_result.output_paths, - repo_id=winrate_args.hf_winrate_repo, - token=winrate_args.hf_token, - private=bool(winrate_args.hf_private), - assume_yes=bool(args.yes), - ) - return 0 @@ -744,18 +797,199 @@ def _set_if_unset(args: argparse.Namespace, attr: str, value: Any) -> None: setattr(args, attr, value) +def _resolve_config_string_value(key: str, value: Any) -> str: + resolved = str(value) + if key != "hf_token": + return resolved + + trimmed = resolved.strip() + env_var: str | None = None + if trimmed.startswith("${") and trimmed.endswith("}") and len(trimmed) > 3: + env_var = trimmed[2:-1].strip() + elif trimmed.startswith("$") and len(trimmed) > 1: + env_var = trimmed[1:].strip() + + if not env_var: + return resolved + + env_value = os.getenv(env_var) + if env_value is None: + raise ValueError(f"Config field 'hf.token' references unset environment variable '{env_var}'.") + return env_value + + def _load_config_payload(path: Path, *, mode: Literal["process", "winrate"]) -> dict[str, Any]: label = "Process config" if mode == "process" else "Winrate config" - return dict(load_mapping_file(path, label=label)) + raw_payload = dict(load_mapping_file(path, label=label)) + if mode == "process": + _reject_removed_process_config_keys(raw_payload) + return _expand_embedded_pipeline_config(raw_payload, mode=mode) + + +def _reject_removed_process_config_keys(payload: Mapping[str, Any]) -> None: + if "max_run_missing_pct" in payload: + raise ValueError("Process config field 'max_run_missing_pct' was removed; use 'max_results_missing_pct'.") + process_section = payload.get("process") + if isinstance(process_section, Mapping) and "max_run_missing_pct" in process_section: + raise ValueError( + "Process config field 'process.max_run_missing_pct' was removed; use 'process.max_results_missing_pct'." + ) + + +def _expand_embedded_pipeline_config(payload: dict[str, Any], *, mode: Literal["process", "winrate"]) -> dict[str, Any]: + expanded = dict(payload) + process_section = payload.get("process") + if isinstance(process_section, Mapping): + _merge_process_section(expanded, process_section, mode=mode) + + process_output_dir = _resolve_processed_dir_from_payload(expanded, mode=mode) + + winrate_section = payload.get("winrate") + if isinstance(winrate_section, Mapping): + if mode == "process": + expanded.pop("winrate", None) + if mode == "winrate": + _merge_winrate_section(expanded, winrate_section, process_output_dir=process_output_dir) + elif isinstance(winrate_section, bool) and mode == "process": + expanded.pop("winrate", None) + + if mode == "winrate" and "processed_dir" not in expanded and process_output_dir is not None: + expanded["processed_dir"] = process_output_dir + + return expanded + + +def _merge_process_section( + expanded: dict[str, Any], + process_section: Mapping[str, Any], + *, + mode: Literal["process", "winrate"], +) -> None: + resolved = None + if "dir" in process_section: + resolved = _resolve_process_dir_value(process_section["dir"], runs_dir=expanded.get("runs_dir")) + if mode == "process" and "output_dir" not in expanded and resolved is not None: + expanded["output_dir"] = resolved + if mode == "winrate" and "processed_dir" not in expanded and resolved is not None: + expanded["processed_dir"] = resolved + if mode == "winrate" and "processed_dir" not in expanded and "output_dir" in process_section: + expanded["processed_dir"] = process_section["output_dir"] + key_map = {"runs_dir": "runs_dir"} + if mode == "process": + key_map.update( + { + "output_dir": "output_dir", + "env_config_root": "env_config_root", + "processed_at": "processed_at", + "status": "status", + "exclude_datasets": "exclude_datasets", + "exclude_models": "exclude_models", + "replace_models": "replace_models", + "replace_envs": "replace_envs", + "dry_run": "dry_run", + "clean": "clean", + "yes": "yes", + "max_workers": "max_workers", + "max_results_missing_pct": "max_results_missing_pct", + } + ) + for key, target in key_map.items(): + if key in process_section and target not in expanded: + expanded[target] = process_section[key] + + +def _merge_winrate_section( + expanded: dict[str, Any], + winrate_section: Mapping[str, Any], + *, + process_output_dir: Path | None, +) -> None: + if "dir" in winrate_section and "output_dir" not in expanded: + resolved = _resolve_winrate_dir_value(winrate_section["dir"], process_output_dir=process_output_dir) + if resolved is not None: + expanded["output_dir"] = resolved + key_map = { + "processed_dir": "processed_dir", + "output_dir": "output_dir", + "output_name": "output_name", + "processed_at": "processed_at", + "missing_policy": "missing_policy", + "epsilon": "epsilon", + "min_common": "min_common", + "weight_policy": "weight_policy", + "weight_cap": "weight_cap", + "dataset_coverage": "dataset_coverage", + "include_model": "include_models", + "include_models": "include_models", + "exclude_model": "exclude_models", + "exclude_models": "exclude_models", + "exclude_dataset": "exclude_datasets", + "exclude_datasets": "exclude_datasets", + "partial_datasets": "partial_datasets", + "hf_processed_pull": "hf_processed_pull", + "hf_winrate_dir": "hf_winrate_dir", + } + for key, target in key_map.items(): + if key in winrate_section and target not in expanded: + expanded[target] = winrate_section[key] + + +def _resolve_processed_dir_from_payload( + payload: Mapping[str, Any], *, mode: Literal["process", "winrate"] +) -> Path | None: + if "processed_dir" in payload and payload["processed_dir"] is not None: + return Path(str(payload["processed_dir"])) + if mode == "process" and "output_dir" in payload and payload["output_dir"] is not None: + return Path(str(payload["output_dir"])) + process_section = payload.get("process") + if isinstance(process_section, Mapping) and "dir" in process_section: + return _resolve_process_dir_value(process_section["dir"], runs_dir=payload.get("runs_dir")) + return None + + +def _resolve_process_dir_value(value: Any, *, runs_dir: Any | None) -> Path | None: + raw = str(value).strip() + if not raw: + return None + candidate = Path(raw) + if candidate.is_absolute(): + return candidate + runs_base = Path(str(runs_dir)).parent if runs_dir is not None else DEFAULT_RUNS_RAW_DIR.parent + return runs_base / candidate + + +def _resolve_winrate_dir_value(value: Any, *, process_output_dir: Path | None) -> Path | None: + raw = str(value).strip() + if not raw: + return None + candidate = Path(raw) + if candidate.is_absolute(): + return candidate + base = process_output_dir if process_output_dir is not None else DEFAULT_PROCESSED_DIR + return base / candidate + + +def _config_has_embedded_winrate(path: Path) -> bool: + payload = dict(load_mapping_file(path, label="Process config")) + winrate_payload = payload.get("winrate") + if isinstance(winrate_payload, Mapping): + return bool(winrate_payload.get("enabled", True)) + return bool(winrate_payload) if isinstance(winrate_payload, bool) else False def _normalize_mode_payload(payload: dict[str, Any], *, mode: Literal["process", "winrate"]) -> None: + if mode == "winrate": + if "hf_processed_repo" in payload and "hf_repo" not in payload: + payload["hf_repo"] = payload["hf_processed_repo"] + if "hf_winrate_repo" in payload: + raise ValueError("Winrate config field 'hf_winrate_repo' was removed; use 'hf.repo' and 'hf.winrate_dir'.") + hf_payload = payload.get("hf") if isinstance(hf_payload, Mapping): for key, value in hf_payload.items(): if mode == "winrate": if key == "repo": - payload.setdefault("hf_processed_repo", value) + payload.setdefault("hf_repo", value) continue if key == "branch": payload.setdefault("hf_branch", value) @@ -766,6 +1000,10 @@ def _normalize_mode_payload(payload: dict[str, Any], *, mode: Literal["process", if key == "private": payload.setdefault("hf_private", value) continue + if key == "winrate_repo": + raise ValueError( + "Winrate config field 'hf.winrate_repo' was removed; use 'hf.repo' and 'hf.winrate_dir'." + ) payload.setdefault(f"hf_{key}", value) if "exclude_datasets" not in payload and "exclude_dataset" in payload: @@ -783,9 +1021,9 @@ def _load_and_apply_config( ) -> None: try: payload = _load_config_payload(path, mode=mode) + _normalize_mode_payload(payload, mode=mode) except (FileNotFoundError, ValueError) as exc: parser.error(str(exc)) - _normalize_mode_payload(payload, mode=mode) path_fields = { "process": { @@ -811,8 +1049,8 @@ def _load_and_apply_config( "weight_policy": "weight_policy", "partial_datasets": "partial_datasets", "dataset_coverage": "dataset_coverage", - "hf_processed_repo": "hf_processed_repo", - "hf_winrate_repo": "hf_winrate_repo", + "hf_repo": "hf_repo", + "hf_winrate_dir": "hf_winrate_dir", "hf_branch": "hf_branch", "hf_token": "hf_token", }, @@ -822,9 +1060,6 @@ def _load_and_apply_config( "dry_run": "dry_run", "clean": "clean", "yes": "yes", - "process_incomplete": "process_incomplete", - "validate_manifest": "validate_manifest", - "strict_manifest": "strict_manifest", "hf_private": "hf_private", }, "winrate": {"hf_processed_pull": "hf_processed_pull", "hf_private": "hf_private"}, @@ -838,11 +1073,20 @@ def _load_and_apply_config( "winrate": {"min_common": "min_common", "weight_cap": "weight_cap"}, }[mode] float_fields = { - "process": {"hf_request_timeout": "hf_request_timeout"}, + "process": { + "hf_request_timeout": "hf_request_timeout", + "max_results_missing_pct": "max_results_missing_pct", + }, "winrate": {"epsilon": "epsilon"}, }[mode] repeatable_fields = { - "process": {"status": "status", "exclude_datasets": "exclude_dataset", "exclude_models": "exclude_model"}, + "process": { + "status": "status", + "exclude_datasets": "exclude_dataset", + "exclude_models": "exclude_model", + "replace_models": "replace_model", + "replace_envs": "replace_env", + }, "winrate": { "include_models": "include_model", "exclude_models": "exclude_model", @@ -855,7 +1099,11 @@ def _load_and_apply_config( _set_if_unset(args, attr, Path(str(payload[key]))) for key, attr in string_fields.items(): if key in payload and _is_unset(args, attr): - _set_if_unset(args, attr, str(payload[key])) + try: + resolved = _resolve_config_string_value(key, payload[key]) + except ValueError as exc: + parser.error(str(exc)) + _set_if_unset(args, attr, resolved) for key, attr in boolean_fields.items(): if key in payload and _is_unset(args, attr): _set_if_unset(args, attr, bool(payload[key])) @@ -891,14 +1139,15 @@ def _build_winrate_args_from_config(path: Path, *, parser: argparse.ArgumentPars exclude_model=None, exclude_dataset=None, partial_datasets=None, - hf_processed_repo=None, + hf_repo=None, hf_processed_pull=None, - hf_winrate_repo=None, + hf_winrate_dir=None, hf_branch=None, hf_token=None, hf_private=None, ) _load_and_apply_config(args, path, mode="winrate", parser=parser) + args._output_dir_explicit = args.output_dir is not None _finalize_config_args(args, mode="winrate") return args @@ -915,15 +1164,14 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", "dry_run": False, "clean": False, "yes": False, - "process_incomplete": False, - "validate_manifest": True, - "strict_manifest": False, + "max_results_missing_pct": 2.5, "exclude_dataset": [], "exclude_model": [], + "replace_model": [], + "replace_env": [], }, "winrate": { "processed_dir": DEFAULT_PROCESSED_DIR, - "output_dir": DEFAULT_WINRATE_DIR, "missing_policy": "neg-inf", "epsilon": 1e-9, "min_common": 0, @@ -935,6 +1183,7 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", "exclude_dataset": [], "partial_datasets": "strict", "hf_processed_pull": False, + "hf_winrate_dir": "winrate", "hf_private": False, "yes": False, }, @@ -942,11 +1191,21 @@ def _finalize_config_args(args: argparse.Namespace, *, mode: Literal["process", for attr, default in defaults.items(): if getattr(args, attr, None) is None: setattr(args, attr, default) + if mode == "winrate" and getattr(args, "output_dir", None) is None: + args.output_dir = _default_winrate_output_dir(Path(args.processed_dir)) if hasattr(args, "exclude_dataset"): args.exclude_dataset = _parse_repeatable_csv(args.exclude_dataset) if mode == "process" and hasattr(args, "exclude_model"): args.exclude_model = _parse_repeatable_csv(args.exclude_model) + if mode == "process" and hasattr(args, "replace_model"): + args.replace_model = _parse_repeatable_csv(args.replace_model) + if mode == "process" and hasattr(args, "replace_env"): + args.replace_env = _parse_repeatable_csv(args.replace_env) + + +def _default_winrate_output_dir(processed_dir: Path) -> Path: + return Path(processed_dir) / "winrate" def _upload_winrate_outputs( @@ -955,11 +1214,19 @@ def _upload_winrate_outputs( output_paths: Sequence[Path], repo_id: str, token: str | None, + branch: str | None, private: bool, + winrate_dir: str | None, assume_yes: bool = False, ) -> None: if not output_paths: return + raw_dir = "winrate" if winrate_dir is None else str(winrate_dir).strip() + if not raw_dir: + raw_dir = "winrate" + if resolve_under(Path("."), raw_dir) is None: + logger.error("Invalid winrate_dir '%s'; skipping upload.", winrate_dir) + return output_dir = Path(output_dir) files: list[str] = [] for path in output_paths: @@ -981,6 +1248,8 @@ def _upload_winrate_outputs( token=token, private=private, message=message, + branch=branch, + path_in_repo_prefix=raw_dir, is_tty=sys.stdin.isatty(), assume_yes=assume_yes, prompt_func=input, @@ -993,20 +1262,21 @@ def _run_winrate_mode(argv: Sequence[str]) -> int: if args.config: _load_and_apply_config(args, args.config, mode="winrate", parser=parser) + args._output_dir_explicit = args.output_dir is not None _finalize_config_args(args, mode="winrate") hf_config = HFSyncConfig.from_cli( - repo=args.hf_processed_repo, + repo=args.hf_repo, branch=args.hf_branch, token=args.hf_token, - private=False, + private=bool(args.hf_private), dry_run=False, ) if args.list_models: source_dir, datasets, source_desc = _resolve_source( args.processed_dir, - hf_config=hf_config if args.hf_processed_repo else None, + hf_config=hf_config if args.hf_repo else None, hf_processed_pull=bool(args.hf_processed_pull), ) if args.exclude_dataset: @@ -1054,13 +1324,15 @@ def _run_winrate_mode(argv: Sequence[str]) -> int: logger.info("Computed win rates for %d dataset(s): %s", len(winrate_result.datasets), winrate_result.output_path) print_winrate_summary_markdown(winrate_result.result) - if args.hf_winrate_repo: + if args.hf_repo: _upload_winrate_outputs( output_dir=args.output_dir, output_paths=winrate_result.output_paths, - repo_id=args.hf_winrate_repo, + repo_id=args.hf_repo, token=args.hf_token, + branch=args.hf_branch, private=bool(args.hf_private), + winrate_dir=args.hf_winrate_dir, assume_yes=bool(args.yes), ) return 0 diff --git a/medarc_verifiers/cli/process/__init__.py b/medarc_verifiers/cli/process/__init__.py index 6c20133e..35cb601d 100644 --- a/medarc_verifiers/cli/process/__init__.py +++ b/medarc_verifiers/cli/process/__init__.py @@ -1,5 +1,5 @@ """Process command pipeline for exporting MedARC runs.""" -from .pipeline import ProcessOptions, ProcessResult, run_process +from .pipeline import PROCESS_DEFAULT_STATUS_FILTER, ProcessOptions, ProcessResult, run_process -__all__ = ["ProcessOptions", "ProcessResult", "run_process"] +__all__ = ["PROCESS_DEFAULT_STATUS_FILTER", "ProcessOptions", "ProcessResult", "run_process"] diff --git a/medarc_verifiers/cli/process/aggregate.py b/medarc_verifiers/cli/process/aggregate.py index b00d5ff2..f6a25966 100644 --- a/medarc_verifiers/cli/process/aggregate.py +++ b/medarc_verifiers/cli/process/aggregate.py @@ -6,7 +6,8 @@ from dataclasses import dataclass from typing import Any, Iterable, Mapping -from medarc_verifiers.cli.process.rollout import derive_base_env_id +from medarc_verifiers.cli.process.metadata import RunIdentity +from medarc_verifiers.cli.process.rollout import extract_rollout_index logger = logging.getLogger(__name__) @@ -25,9 +26,17 @@ class AggregatedEnvRows: def aggregate_rows_by_env( rows: Iterable[Mapping[str, Any]], + *, + identities: Iterable[RunIdentity] | None = None, ) -> list[AggregatedEnvRows]: """Group enriched rows by (model_id, base_env_id), capturing unioned schemas.""" groups: dict[tuple[str, str], dict[str, Any]] = {} + identity_list = list(identities or ()) + fake_rollout_groups = { + (identity.model_id, identity.output_env_id) + for identity in identity_list + if identity.rollout_index is not None + } for row in rows: base_env_id = str(row.get("base_env_id") or row.get("env_id") or "") @@ -68,7 +77,15 @@ def aggregate_rows_by_env( # processing "fake rollouts" that are created by running separate jobs with rollout suffixes # (e.g., env-a-rollout7) and then combining them under a shared base_env_id. normalized_rows: list[Mapping[str, Any]] = list(group["rows"]) # shallow copy - if _group_uses_rollout_suffixes(normalized_rows, base_env_id=group["base_env_id"] or key[1]): + if key in fake_rollout_groups: + _ensure_rollout_index_from_identities( + normalized_rows, + identities=identity_list, + model_id=group["model_id"], + base_env_id=group["base_env_id"] or key[1], + ) + _normalize_rollout_indices(normalized_rows) + elif _group_uses_rollout_suffixes(normalized_rows, base_env_id=group["base_env_id"] or key[1]): _ensure_rollout_index_from_suffix(normalized_rows, base_env_id=group["base_env_id"] or key[1]) _normalize_rollout_indices(normalized_rows) candidate_env_id = group["env_id"] or group["base_env_id"] or "" @@ -85,13 +102,47 @@ def aggregate_rows_by_env( return aggregated +def _ensure_rollout_index_from_identities( + rows: list[Mapping[str, Any]], + *, + identities: list[RunIdentity], + model_id: str, + base_env_id: str, +) -> None: + rollout_by_manifest_env: dict[str, int] = {} + for identity in identities: + if identity.model_id != model_id or identity.output_env_id != base_env_id: + continue + if identity.rollout_index is None: + continue + rollout_by_manifest_env[identity.manifest_env_id] = identity.rollout_index + + if not rollout_by_manifest_env: + return + + for row in rows: + value = row.get("rollout_index") + if _coerce_rollout_index(value) is not None: + continue + manifest_env_id = row.get("manifest_env_id") + if not isinstance(manifest_env_id, str): + continue + resolved = rollout_by_manifest_env.get(manifest_env_id) + if resolved is None: + continue + try: + row["rollout_index"] = resolved + except TypeError: + continue + + def _group_uses_rollout_suffixes(rows: list[Mapping[str, Any]], *, base_env_id: str) -> bool: for row in rows: manifest_env_id = row.get("manifest_env_id") if not isinstance(manifest_env_id, str) or not manifest_env_id: continue - derived_base, _ = derive_base_env_id(manifest_env_id) - if derived_base and derived_base == base_env_id and manifest_env_id != derived_base: + row_base_env_id = str(row.get("base_env_id") or base_env_id or "") + if row_base_env_id and manifest_env_id != row_base_env_id: return True return False @@ -104,8 +155,11 @@ def _ensure_rollout_index_from_suffix(rows: list[Mapping[str, Any]], *, base_env manifest_env_id = row.get("manifest_env_id") if not isinstance(manifest_env_id, str) or not manifest_env_id: continue - derived_base, derived_index = derive_base_env_id(manifest_env_id) - if not derived_base or derived_base != base_env_id: + row_base_env_id = str(row.get("base_env_id") or base_env_id or "") + if not row_base_env_id or manifest_env_id == row_base_env_id: + continue + derived_index = extract_rollout_index(manifest_env_id) + if derived_index <= 0: continue try: row["rollout_index"] = derived_index diff --git a/medarc_verifiers/cli/process/discovery.py b/medarc_verifiers/cli/process/discovery.py index fc583f10..7aba00f8 100644 --- a/medarc_verifiers/cli/process/discovery.py +++ b/medarc_verifiers/cli/process/discovery.py @@ -20,7 +20,6 @@ logger = logging.getLogger(__name__) DEFAULT_STATUS = "unknown" -_COMPLETED_STATUSES = {"completed", "succeeded", "success"} @dataclass(frozen=True, slots=True) @@ -66,8 +65,10 @@ class RunRecord: reason: str | None started_at: str | None ended_at: str | None + avg_reward: float | None num_examples: int | None rollouts_per_example: int | None + row_count: int | None env_args: Mapping[str, Any] sampling_args: Mapping[str, Any] env_config: Mapping[str, Any] | None @@ -78,17 +79,15 @@ def discover_run_records( runs_dir: Path | str, *, filter_status: Sequence[str] | None = None, - only_complete_runs: bool = False, ) -> list[RunRecord]: """Return all discovered run records within the provided runs directory.""" - return list(iter_run_records(runs_dir, filter_status=filter_status, only_complete_runs=only_complete_runs)) + return list(iter_run_records(runs_dir, filter_status=filter_status)) def iter_run_records( runs_dir: Path | str, *, filter_status: Sequence[str] | None = None, - only_complete_runs: bool = False, ) -> Iterator[RunRecord]: """Yield run records for each job entry found under the runs directory.""" runs_path = Path(runs_dir) @@ -108,13 +107,6 @@ def iter_run_records( manifest_info, job_entries = _load_manifest(run_dir) if manifest_info is None: continue - if ( - only_complete_runs - and manifest_info.summary_total_known - and manifest_info.summary_completed != manifest_info.summary_total - ): - # Skip entire run if not fully completed - continue summary_map = _load_run_summary(run_dir) for job_entry in job_entries: summary_entry = summary_map.get(job_entry.job_id or "") @@ -194,8 +186,10 @@ def _build_run_record( reason=reason or job_entry.reason, started_at=job_entry.started_at, ended_at=job_entry.ended_at, + avg_reward=job_entry.avg_reward, num_examples=job_entry.num_examples, rollouts_per_example=job_entry.rollouts_per_example, + row_count=job_entry.row_count, env_args=env_args, sampling_args=sampling_args, env_config=env_config, diff --git a/medarc_verifiers/cli/process/env_index.py b/medarc_verifiers/cli/process/env_index.py index 89c85c37..86fecd50 100644 --- a/medarc_verifiers/cli/process/env_index.py +++ b/medarc_verifiers/cli/process/env_index.py @@ -54,21 +54,9 @@ def read_env_index_inventory(processed_dir: Path) -> EnvIndexInventory: """Read env_index.json and return a dataset inventory.""" index_path = processed_dir / "env_index.json" payload = load_env_index(index_path) - version = payload.get("version") if isinstance(payload, Mapping) else None - if version == 2: + if isinstance(payload, Mapping) and int(payload.get("version") or 0) == 2: return _inventory_from_v2(payload, processed_dir) - return EnvIndexInventory(env_paths={}, version=int(version or 1)) - - -def read_env_index_runs(processed_dir: Path) -> tuple[int, dict[str, Mapping[str, Any]]]: - """Return env_index version and run metadata map.""" - index_path = processed_dir / "env_index.json" - payload = load_env_index(index_path) - version = int(payload.get("version") or 1) if isinstance(payload, Mapping) else 1 - runs = payload.get("runs") if isinstance(payload, Mapping) else None - if version != 2 or not isinstance(runs, Mapping): - return version, {} - return version, {str(k): v for k, v in runs.items() if isinstance(v, Mapping)} + return EnvIndexInventory(env_paths={}, version=0) def read_env_index_files(processed_dir: Path) -> dict[str, Mapping[str, Any]]: @@ -118,7 +106,6 @@ def read_env_index_models(processed_dir: Path) -> set[str]: __all__ = [ "EnvIndexInventory", "read_env_index_inventory", - "read_env_index_runs", "read_env_index_files", "read_env_index_models", ] diff --git a/medarc_verifiers/cli/process/metadata.py b/medarc_verifiers/cli/process/metadata.py index 6bfae643..118e63f8 100644 --- a/medarc_verifiers/cli/process/metadata.py +++ b/medarc_verifiers/cli/process/metadata.py @@ -4,6 +4,7 @@ import json import logging +import math from dataclasses import dataclass from pathlib import Path from typing import Any, Mapping, MutableMapping @@ -21,6 +22,7 @@ class _MetadataPayload(BaseModel): env_id: str | None = None model: str | None = None + avg_reward: float | None = None version_info: dict[str, str | None] | None = None env_args: dict[str, Any] = Field(default_factory=dict) num_examples: int | None = None @@ -32,6 +34,7 @@ class _MetadataPayload(BaseModel): class NormalizedMetadata: """Normalized view of metadata.json merged with manifest discovery data.""" + identity: "RunIdentity" record: RunRecord metadata_path: Path | None raw_metadata: Mapping[str, Any] @@ -47,13 +50,111 @@ class NormalizedMetadata: rollouts_per_example: int | None +@dataclass(frozen=True, slots=True) +class RunIdentity: + """Canonical identity for selecting and exporting a discovered run record.""" + + model_id: str + manifest_env_id: str + base_env_id: str + rollout_index: int | None + job_run_id: str + output_env_id: str + + +@dataclass(frozen=True, slots=True) +class ResolvedRunIdentity: + """Selection-time identity that tolerates missing model ids.""" + + model_id: str | None + manifest_env_id: str + base_env_id: str + rollout_index: int | None + job_run_id: str + output_env_id: str + + +@dataclass(frozen=True, slots=True) +class _ResolvedMetadataContext: + raw_metadata: Mapping[str, Any] + manifest_env_id: str + metadata_env_id: str | None + base_env_id: str + rollout_index: int + model_id: str | None + metadata_model: str | None + env_args: Mapping[str, Any] + sampling_args: Mapping[str, Any] + num_examples: int | None + rollouts_per_example: int | None + + +def resolve_run_identity( + record: RunRecord, + *, + combine_rollouts: bool = True, +) -> ResolvedRunIdentity: + """Resolve a run identity for selection without requiring model_id.""" + context = _resolve_metadata_context(record, combine_rollouts=combine_rollouts) + resolved_rollout_index = ( + context.rollout_index if context.rollout_index != 0 or context.manifest_env_id != context.base_env_id else None + ) + return ResolvedRunIdentity( + model_id=context.model_id, + manifest_env_id=context.manifest_env_id, + base_env_id=context.base_env_id, + rollout_index=resolved_rollout_index, + job_run_id=record.manifest.job_run_id, + output_env_id=context.base_env_id or context.manifest_env_id or record.job_id, + ) + + def load_normalized_metadata( record: RunRecord, *, combine_rollouts: bool = True, ) -> NormalizedMetadata: """Merge manifest fields with metadata.json (when present).""" + context = _resolve_metadata_context(record, combine_rollouts=combine_rollouts) + if not context.model_id: + raise RuntimeError(format_missing_model_id_error(record)) + resolved_rollout_index = ( + context.rollout_index if context.rollout_index != 0 or context.manifest_env_id != context.base_env_id else None + ) + identity = RunIdentity( + model_id=context.model_id, + manifest_env_id=context.manifest_env_id, + base_env_id=context.base_env_id, + rollout_index=resolved_rollout_index, + job_run_id=record.manifest.job_run_id, + output_env_id=context.base_env_id or context.manifest_env_id or record.job_id, + ) + + return NormalizedMetadata( + identity=identity, + record=record, + metadata_path=record.metadata_path if record.has_metadata else None, + raw_metadata=context.raw_metadata, + manifest_env_id=context.manifest_env_id, + metadata_env_id=context.metadata_env_id, + base_env_id=context.base_env_id, + rollout_index=identity.rollout_index or 0, + model_id=identity.model_id, + metadata_model=context.metadata_model, + env_args=context.env_args, + sampling_args=context.sampling_args, + num_examples=context.num_examples, + rollouts_per_example=context.rollouts_per_example, + ) + + +def _resolve_metadata_context( + record: RunRecord, + *, + combine_rollouts: bool, +) -> _ResolvedMetadataContext: metadata_payload, raw_metadata = _load_metadata(record) + _warn_manifest_metadata_result_mismatch(record, metadata_payload) metadata_env_id = metadata_payload.env_id if metadata_payload else None metadata_model = metadata_payload.model if metadata_payload else None env_args = _merge_mappings( @@ -64,7 +165,6 @@ def load_normalized_metadata( primary=record.sampling_args, fallback=metadata_payload.sampling_args if metadata_payload else None, ) - manifest_env_id = ( _extract_env_config_id(record.env_config) or record.manifest_env_id or metadata_env_id or record.job_id ) @@ -72,34 +172,36 @@ def load_normalized_metadata( manifest_env_id, combine_rollouts=combine_rollouts, ) - # If we didn't capture a rollout index from the manifest env id, - # try to derive it from the results directory name (common when - # manifests keep base env id, but the on-disk folder encodes the rollout). if rollout_index == 0 and record.results_dir_name: alt_index = extract_rollout_index(record.results_dir_name) if alt_index: rollout_index = alt_index - - model_id = record.model_id or metadata_model - num_examples = record.num_examples or (metadata_payload.num_examples if metadata_payload else None) - rollouts_per_example = record.rollouts_per_example or ( - metadata_payload.rollouts_per_example if metadata_payload else None - ) - - return NormalizedMetadata( - record=record, - metadata_path=record.metadata_path if record.has_metadata else None, + return _ResolvedMetadataContext( raw_metadata=raw_metadata, manifest_env_id=manifest_env_id, metadata_env_id=metadata_env_id, base_env_id=base_env_id, rollout_index=rollout_index, - model_id=model_id, + model_id=record.model_id or metadata_model, metadata_model=metadata_model, env_args=env_args, sampling_args=sampling_args, - num_examples=num_examples, - rollouts_per_example=rollouts_per_example, + num_examples=_prefer_manifest_value( + record.num_examples, + metadata_payload.num_examples if metadata_payload else None, + ), + rollouts_per_example=_prefer_manifest_value( + record.rollouts_per_example, + metadata_payload.rollouts_per_example if metadata_payload else None, + ), + ) + + +def format_missing_model_id_error(record: RunRecord) -> str: + return ( + "Missing model_id for run " + f"(job_run_id={record.manifest.job_run_id}, job_id={record.job_id}, " + f"results_dir={record.results_dir}, manifest={record.manifest.manifest_path})" ) @@ -153,6 +255,50 @@ def _merge_mappings( return result +def _prefer_manifest_value(primary: int | None, fallback: int | None) -> int | None: + if primary is not None: + return primary + return fallback + + +def _warn_manifest_metadata_result_mismatch(record: RunRecord, metadata_payload: _MetadataPayload | None) -> None: + if metadata_payload is None: + return + + mismatches: list[str] = [] + if _has_float_mismatch(record.avg_reward, metadata_payload.avg_reward): + mismatches.append( + f"avg_reward manifest={record.avg_reward!r} metadata={metadata_payload.avg_reward!r}" + ) + if _has_int_mismatch(record.num_examples, metadata_payload.num_examples): + mismatches.append( + f"num_examples manifest={record.num_examples!r} metadata={metadata_payload.num_examples!r}" + ) + if not mismatches: + return + + logger.warning( + "Manifest/metadata result mismatch for process input " + "(job_run_id=%s, job_id=%s, metadata=%s): %s", + record.manifest.job_run_id, + record.job_id, + record.metadata_path, + "; ".join(mismatches), + ) + + +def _has_float_mismatch(left: float | None, right: float | None) -> bool: + if left is None or right is None: + return False + return not math.isclose(left, right, rel_tol=1e-9, abs_tol=1e-9) + + +def _has_int_mismatch(left: int | None, right: int | None) -> bool: + if left is None or right is None: + return False + return left != right + + def _extract_env_config_id(env_config: Mapping[str, Any] | None) -> str | None: if not env_config: return None @@ -164,4 +310,11 @@ def _extract_env_config_id(env_config: Mapping[str, Any] | None) -> str | None: return None -__all__ = ["NormalizedMetadata", "load_normalized_metadata"] +__all__ = [ + "NormalizedMetadata", + "ResolvedRunIdentity", + "RunIdentity", + "format_missing_model_id_error", + "load_normalized_metadata", + "resolve_run_identity", +] diff --git a/medarc_verifiers/cli/process/pipeline.py b/medarc_verifiers/cli/process/pipeline.py index 36609ae5..b23fc16d 100644 --- a/medarc_verifiers/cli/process/pipeline.py +++ b/medarc_verifiers/cli/process/pipeline.py @@ -1,31 +1,27 @@ -"""Top-level pipeline wiring discovery, row loading, aggregation, and writing.""" +"""Top-level pipeline wiring discovery, selection, row loading, aggregation, and writing.""" from __future__ import annotations +import json import logging import sys from concurrent.futures import ProcessPoolExecutor, as_completed from dataclasses import dataclass, field from datetime import UTC, datetime from pathlib import Path -from typing import Any, Callable, Iterable, Mapping, Sequence +from typing import Any, Iterable, Mapping, Sequence + +import pyarrow.parquet as pq -from medarc_verifiers.cli._schemas import EnvironmentExportConfig from medarc_verifiers.cli import hf as hf_sync -from medarc_verifiers.cli.process import ( - aggregate, - discovery, - env_index, - metadata, - rows, - rollout, - writer, - workspace, -) -from medarc_verifiers.cli.process.aggregate import AggregatedEnvRows +from medarc_verifiers.cli._schemas import EnvironmentExportConfig from medarc_verifiers.cli.hf import HFSyncConfig, HFSyncSummary -from medarc_verifiers.cli.process.writer import EnvWriteSummary, WriterConfig +from medarc_verifiers.cli.process import aggregate, discovery, env_index, metadata, rollout, rows, workspace, writer +from medarc_verifiers.cli.process.aggregate import AggregatedEnvRows +from medarc_verifiers.cli.process.metadata import RunIdentity +from medarc_verifiers.cli.process.writer import EXPORTER_METADATA_KEY, EnvWriteSummary, WriterConfig from medarc_verifiers.cli.utils.shared import ( + count_jsonl_rows, dataset_is_excluded, model_is_excluded, normalize_dataset_ids, @@ -33,6 +29,7 @@ ) logger = logging.getLogger(__name__) +PROCESS_DEFAULT_STATUS_FILTER: tuple[str, ...] = ("completed",) @dataclass(slots=True) @@ -41,12 +38,14 @@ class ProcessOptions: runs_dir: Path output_dir: Path - only_complete_runs: bool = True + max_results_missing_pct: float = 2.5 exclude_datasets: Sequence[str] = field(default_factory=tuple) exclude_models: Sequence[str] = field(default_factory=tuple) + replace_models: Sequence[str] = field(default_factory=tuple) + replace_envs: Sequence[str] = field(default_factory=tuple) processed_at: str | None = None processed_with_args: Mapping[str, Any] = field(default_factory=dict) - status_filter: Sequence[str] = field(default_factory=tuple) + status_filter: Sequence[str] = field(default_factory=lambda: PROCESS_DEFAULT_STATUS_FILTER) dry_run: bool = False clean: bool = False assume_yes: bool = False @@ -57,12 +56,15 @@ class ProcessOptions: def __post_init__(self) -> None: self.runs_dir = Path(self.runs_dir) self.output_dir = Path(self.output_dir) + self.max_results_missing_pct = float(self.max_results_missing_pct) self.max_workers = max(1, int(self.max_workers)) if not self.processed_at: self.processed_at = datetime.now(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z") self.status_filter = tuple(str(status) for status in self.status_filter) self.exclude_datasets = tuple(str(value) for value in self.exclude_datasets if str(value).strip()) self.exclude_models = tuple(str(value) for value in self.exclude_models if str(value).strip()) + self.replace_models = tuple(str(value) for value in self.replace_models if str(value).strip()) + self.replace_envs = tuple(str(value) for value in self.replace_envs if str(value).strip()) @dataclass(slots=True) @@ -76,8 +78,8 @@ class ProcessResult: hf_summary: HFSyncSummary | None -@dataclass(slots=True) -class _RecordWork: +@dataclass(frozen=True, slots=True) +class PlannedRecord: """Per-record settings for row loading.""" normalized: metadata.NormalizedMetadata @@ -86,26 +88,42 @@ class _RecordWork: answer_column: str | None -@dataclass(slots=True) -class _NormalizedRecord: +@dataclass(frozen=True, slots=True) +class PlannedWorkItem: + """A single selected (model, env) output to process.""" + + identity: RunIdentity + records: list[PlannedRecord] + + +@dataclass(frozen=True, slots=True) +class SelectionRecord: + """Selection-time record settings before full normalization.""" + record: discovery.RunRecord - normalized: metadata.NormalizedMetadata + identity: metadata.ResolvedRunIdentity + combine_rollouts: bool extra_columns: Sequence[str] drop_columns: Sequence[str] answer_column: str | None - model_key: str - env_key: str - job_run_id: str - run_timestamp: str -@dataclass(slots=True) -class _EnvGroupSelection: - model_key: str - env_key: str - job_run_id: str - run_timestamp: str - records: list[_NormalizedRecord] +@dataclass(frozen=True, slots=True) +class SelectionWorkItem: + """A selected work item before metadata normalization.""" + + identity: metadata.ResolvedRunIdentity + records: list[SelectionRecord] + + +@dataclass(frozen=True, slots=True) +class SelectionResult: + """Complete output of the selection phase.""" + + work_items: list[PlannedWorkItem] + skipped_by_delta: int + skipped_by_exclusion: int + total_discovered: int def run_process( @@ -117,101 +135,54 @@ def run_process( env_export_map = env_export_map or {} def _run_pipeline() -> ProcessResult: - if not options.dry_run and options.clean: - _confirm_clean_process( - options.output_dir, - assume_yes=options.assume_yes, - is_tty=sys.stdin.isatty(), - prompt_func=input, - ) - workspace.clear_output_dir(options.output_dir) - if not options.dry_run and options.hf_config and options.hf_config.repo_id and not options.clean: - workspace.prepare_hf_baseline( + baseline_result: workspace.BaselineResult | None = None + if not options.dry_run: + preparation = workspace.prepare_output_workspace( output_dir=options.output_dir, hf_config=options.hf_config, pull_policy=options.hf_pull_policy, + clean=options.clean, + assume_yes=options.assume_yes, is_tty=sys.stdin.isatty(), prompt_func=input, ) + if preparation is not None: + baseline_result = preparation.baseline_result - index_version, index_runs = env_index.read_env_index_runs(options.output_dir) - index_files = env_index.read_env_index_files(options.output_dir) - if options.clean: - index_version = 0 - index_runs = {} - index_files = {} - + index_files = {} if options.clean else env_index.read_env_index_files(options.output_dir) discovered = discovery.discover_run_records( options.runs_dir, filter_status=options.status_filter or None, - only_complete_runs=False, ) - - use_delta = index_version == 2 and not options.clean - if index_version != 2 and not options.clean: - logger.info("Delta processing disabled: missing or legacy env_index.json; running full reprocess.") - records: list[discovery.RunRecord] = list(discovered) - if options.only_complete_runs: - records = [ - record - for record in records - if not ( - record.manifest.summary_total_known - and record.manifest.summary_completed != record.manifest.summary_total - ) - ] - normalized_records = _normalize_records(records, env_export_map) - env_groups = _select_latest_env_groups(normalized_records) - if use_delta: - env_groups = _filter_env_groups_by_delta( - env_groups, - index_runs, - index_files, - output_dir=options.output_dir, - ) - if options.exclude_datasets: - env_groups = _filter_env_groups_by_exclusion(env_groups, options.exclude_datasets) - if options.exclude_models: - env_groups = _filter_env_groups_by_model_exclusion(env_groups, options.exclude_models) - records = [item.record for group in env_groups for item in group.records] - + selection = select_work_items( + discovered, + options=options, + env_export_map=env_export_map, + index_files=index_files, + ) + selected_records = [planned.normalized.record for item in selection.work_items for planned in item.records] _print_records_table( discovered, - records, - options.only_complete_runs, + selected_records, + options.max_results_missing_pct, exclude_datasets=options.exclude_datasets, exclude_models=options.exclude_models, + skipped_by_delta=selection.skipped_by_delta, + skipped_by_exclusion=selection.skipped_by_exclusion, ) - grouped: dict[tuple[str, str], list[_RecordWork]] = {} run_metadata: dict[str, dict[str, Any]] = {} - record_items = [item for group in env_groups for item in group.records] - record_iter: Iterable[_NormalizedRecord] = record_items - try: - from rich.progress import track - - record_iter = track(record_items, description="Reading run outputs", transient=True) - except Exception: - pass - - for record in record_iter: - normalized = record.normalized - grouped.setdefault((record.model_key, record.env_key), []).append( - _RecordWork( - normalized=normalized, - extra_columns=record.extra_columns, - drop_columns=record.drop_columns, - answer_column=record.answer_column, + for item in selection.work_items: + for planned in item.records: + record = planned.normalized.record + run_metadata.setdefault( + record.manifest.job_run_id, + { + "created_at": record.manifest.created_at, + "updated_at": _source_updated_at(record), + "config_checksum": record.manifest.config_checksum, + }, ) - ) - run_metadata.setdefault( - record.job_run_id, - { - "created_at": record.record.manifest.created_at, - "updated_at": _source_updated_at(record.record), - "config_checksum": record.record.manifest.config_checksum, - }, - ) writer_config = WriterConfig( output_dir=options.output_dir, @@ -223,20 +194,22 @@ def _run_pipeline() -> ProcessResult: env_groups: list[AggregatedEnvRows] = [] env_summaries: list[EnvWriteSummary] = [] rows_processed = 0 + work_items = sorted( + selection.work_items, key=lambda item: (item.identity.model_id, item.identity.output_env_id) + ) - env_items = sorted(grouped.items()) try: - if options.max_workers <= 1 or len(env_items) <= 1: - env_iter: Iterable[tuple[tuple[str, str], list[_RecordWork]]] = env_items + if options.max_workers <= 1 or len(work_items) <= 1: + work_iter: Iterable[PlannedWorkItem] = work_items try: from rich.progress import track - env_iter = track(env_items, description="Processing datasets", transient=True) + work_iter = track(work_items, description="Processing datasets", transient=True) except Exception: - env_iter = env_items + work_iter = work_items - for _, work_items in env_iter: - aggregated, row_count = _process_env_group(work_items) + for item in work_iter: + aggregated, row_count = _process_env_group(item) rows_processed += row_count env_groups.extend(aggregated) summaries = writer.write_env_groups(aggregated, writer_config, write_index=False) @@ -249,8 +222,8 @@ def _run_pipeline() -> ProcessResult: futures = [] try: executor = ProcessPoolExecutor(max_workers=options.max_workers) - for _, work_items in env_items: - futures.append(executor.submit(_process_env_group, work_items)) + for item in work_items: + futures.append(executor.submit(_process_env_group, item)) future_iter: Iterable[Any] = as_completed(futures) try: @@ -273,8 +246,8 @@ def _run_pipeline() -> ProcessResult: group.rows.clear() except KeyboardInterrupt: logger.warning("Processing cancelled by user; shutting down workers.") - for f in futures: - f.cancel() + for future in futures: + future.cancel() if executor is not None: executor.shutdown(cancel_futures=True) raise @@ -296,11 +269,20 @@ def _run_pipeline() -> ProcessResult: hf_summary: HFSyncSummary | None = None if options.hf_config: + files_to_upload: list[str] | None = None + if baseline_result is not None and baseline_result.policy == "continue-upload": + touched_files = hf_sync.collect_changed_output_files( + env_summaries, + output_dir=options.output_dir, + metadata_paths=metadata_paths, + ) + files_to_upload = sorted(set(baseline_result.pending_parquet_uploads) | set(touched_files)) hf_summary = hf_sync.sync_to_hub( env_summaries, options.hf_config, output_dir=options.output_dir, metadata_paths=metadata_paths, + files=files_to_upload, is_tty=sys.stdin.isatty(), assume_yes=options.assume_yes, prompt_func=input, @@ -310,7 +292,7 @@ def _run_pipeline() -> ProcessResult: env_groups = [_strip_env_group_rows(group) for group in env_groups] return ProcessResult( - records_processed=len(records), + records_processed=len(selected_records), rows_processed=rows_processed, env_groups=env_groups, env_summaries=env_summaries, @@ -323,31 +305,329 @@ def _run_pipeline() -> ProcessResult: return _run_pipeline() +def select_work_items( + discovered: Sequence[discovery.RunRecord], + *, + options: ProcessOptions, + env_export_map: Mapping[str, EnvironmentExportConfig], + index_files: Mapping[str, Mapping[str, Any]], +) -> SelectionResult: + """Filter discovered runs down to selected work items before row loading begins.""" + planned_records = [_plan_selection_record(record, env_export_map) for record in discovered] + _raise_for_latest_invalid_selection(planned_records) + work_items = _materialize_work_items( + _select_latest_work_items([record for record in planned_records if record.identity.model_id]) + ) + + work_items, skipped_by_exclusion = _apply_exclusions( + work_items, + exclude_datasets=options.exclude_datasets, + exclude_models=options.exclude_models, + ) + _validate_replace_targets(work_items, options) + work_items, skipped_by_delta = _apply_additive_delta(work_items, options=options, index_files=index_files) + _validate_selected_results_completeness(work_items, max_results_missing_pct=options.max_results_missing_pct) + + return SelectionResult( + work_items=work_items, + skipped_by_delta=skipped_by_delta, + skipped_by_exclusion=skipped_by_exclusion, + total_discovered=len(discovered), + ) + + def _resolve_env_export( manifest_env_id: str | None, env_export_map: Mapping[str, EnvironmentExportConfig], -) -> EnvironmentExportConfig | None: +) -> EnvironmentExportConfig: if not manifest_env_id: - return None + return EnvironmentExportConfig() if manifest_env_id in env_export_map: return env_export_map[manifest_env_id] base_env_id, _ = rollout.derive_base_env_id(manifest_env_id) if base_env_id and base_env_id in env_export_map: return env_export_map[base_env_id] - return None + return EnvironmentExportConfig() def _resolve_columns(env_columns: Sequence[str]) -> Sequence[str]: return tuple(str(column).strip() for column in env_columns if str(column).strip()) +def _plan_selection_record( + record: discovery.RunRecord, + env_export_map: Mapping[str, EnvironmentExportConfig], +) -> SelectionRecord: + env_export = _resolve_env_export(record.manifest_env_id, env_export_map) + combine_rollouts = bool(env_export.combine_rollouts) + identity = metadata.resolve_run_identity(record, combine_rollouts=combine_rollouts) + return SelectionRecord( + record=record, + identity=identity, + combine_rollouts=combine_rollouts, + extra_columns=_resolve_columns(env_export.extra_columns), + drop_columns=_resolve_columns(env_export.drop_columns), + answer_column=env_export.answer_column, + ) + + +def _raise_for_latest_invalid_selection(records: Sequence[SelectionRecord]) -> None: + latest_by_target: dict[tuple[str, str], SelectionRecord] = {} + for planned in records: + selection_key = (planned.identity.output_env_id, planned.record.job_id) + current = latest_by_target.get(selection_key) + if current is None or _run_sort_key( + _source_updated_at(planned.record), + planned.record.manifest.job_run_id, + ) > _run_sort_key(_source_updated_at(current.record), current.record.manifest.job_run_id): + latest_by_target[selection_key] = planned + + invalid_latest = [planned for planned in latest_by_target.values() if not planned.identity.model_id] + if not invalid_latest: + return + + failing = sorted( + invalid_latest, + key=lambda planned: ( + planned.identity.output_env_id, + _run_sort_key(_source_updated_at(planned.record), planned.record.manifest.job_run_id), + ), + )[-1] + raise RuntimeError(metadata.format_missing_model_id_error(failing.record)) + + +def _select_latest_work_items(records: Sequence[SelectionRecord]) -> list[SelectionWorkItem]: + grouped: dict[tuple[str, str], dict[str, list[SelectionRecord]]] = {} + run_timestamps: dict[str, str] = {} + + for planned in records: + identity = planned.identity + if not identity.model_id: + continue + group_key = (identity.model_id, identity.output_env_id) + grouped.setdefault(group_key, {}).setdefault(identity.job_run_id, []).append(planned) + run_timestamps.setdefault(identity.job_run_id, _source_updated_at(planned.record)) + + selected: list[SelectionWorkItem] = [] + for _, run_groups in grouped.items(): + latest_run_id = max(run_groups.keys(), key=lambda run_id: _run_sort_key(run_timestamps.get(run_id, ""), run_id)) + latest_records = run_groups[latest_run_id] + representative = latest_records[0] + selected.append( + SelectionWorkItem( + identity=representative.identity, + records=list(latest_records), + ) + ) + return selected + + +def _materialize_work_items(items: Sequence[SelectionWorkItem]) -> list[PlannedWorkItem]: + materialized: list[PlannedWorkItem] = [] + for item in items: + records: list[PlannedRecord] = [] + for selected in item.records: + normalized = metadata.load_normalized_metadata( + selected.record, + combine_rollouts=selected.combine_rollouts, + ) + records.append( + PlannedRecord( + normalized=normalized, + extra_columns=selected.extra_columns, + drop_columns=selected.drop_columns, + answer_column=selected.answer_column, + ) + ) + materialized.append(PlannedWorkItem(identity=records[0].normalized.identity, records=records)) + return materialized + + +def _apply_exclusions( + work_items: Sequence[PlannedWorkItem], + *, + exclude_datasets: Sequence[str], + exclude_models: Sequence[str], +) -> tuple[list[PlannedWorkItem], int]: + exclude_dataset_set = normalize_dataset_ids(exclude_datasets, label="process exclude dataset") + exclude_model_set = normalize_model_ids(exclude_models, label="process exclude model") + filtered: list[PlannedWorkItem] = [] + skipped = 0 + for item in work_items: + if exclude_dataset_set and _env_is_excluded(item.identity.output_env_id, exclude_dataset_set): + skipped += 1 + continue + if exclude_model_set and model_is_excluded(item.identity.model_id, exclude_model_set): + skipped += 1 + continue + filtered.append(item) + return filtered, skipped + + +def _validate_replace_targets(work_items: Sequence[PlannedWorkItem], options: ProcessOptions) -> None: + if not options.replace_models and not options.replace_envs: + return + + if options.replace_models: + matched_models = { + item.identity.model_id for item in work_items if item.identity.model_id in options.replace_models + } + if not matched_models: + raise RuntimeError( + "No selected processed outputs match --replace-model values: " + f"{', '.join(sorted(options.replace_models))}." + ) + if options.replace_envs: + matched_envs = { + item.identity.output_env_id for item in work_items if item.identity.output_env_id in options.replace_envs + } + if not matched_envs: + raise RuntimeError( + f"No selected processed outputs match --replace-env values: {', '.join(sorted(options.replace_envs))}." + ) + if options.replace_models and options.replace_envs: + intersection = [ + item + for item in work_items + if item.identity.model_id in options.replace_models and item.identity.output_env_id in options.replace_envs + ] + if not intersection: + raise RuntimeError( + "No selected processed outputs match the intersection of --replace-model and --replace-env." + ) + + +def _apply_additive_delta( + work_items: Sequence[PlannedWorkItem], + *, + options: ProcessOptions, + index_files: Mapping[str, Mapping[str, Any]], +) -> tuple[list[PlannedWorkItem], int]: + if options.clean: + return list(work_items), 0 + + filtered: list[PlannedWorkItem] = [] + skipped = 0 + for item in work_items: + output_path = writer.build_output_path( + options.output_dir, + model_id=item.identity.model_id, + env_id=item.identity.output_env_id, + ) + if not output_path.exists(): + filtered.append(item) + continue + if _should_replace_existing_output(item.identity, options): + filtered.append(item) + continue + parquet_metadata = _read_existing_output_metadata(output_path) + _validate_existing_output_integrity( + output_path, + output_dir=options.output_dir, + index_files=index_files, + parquet_metadata=parquet_metadata, + ) + if not _existing_output_matches_selected_runs(item, parquet_metadata): + filtered.append(item) + continue + skipped += 1 + return filtered, skipped + + +def _should_replace_existing_output(identity: RunIdentity, options: ProcessOptions) -> bool: + if options.clean: + return True + has_model_filter = bool(options.replace_models) + has_env_filter = bool(options.replace_envs) + if not has_model_filter and not has_env_filter: + return False + if has_model_filter and has_env_filter: + return identity.model_id in options.replace_models and identity.output_env_id in options.replace_envs + if has_model_filter: + return identity.model_id in options.replace_models + return identity.output_env_id in options.replace_envs + + +def _read_existing_output_metadata(output_path: Path) -> pq.FileMetaData: + try: + metadata_obj = pq.ParquetFile(output_path).metadata + except Exception as exc: # noqa: BLE001 + raise RuntimeError( + f"Existing processed output {output_path} is unreadable. " + "Rebuild it with --replace-model/--replace-env or re-run with --clean." + ) from exc + + if metadata_obj is None: + raise RuntimeError( + f"Existing processed output {output_path} is missing parquet footer metadata. " + "Rebuild it with --replace-model/--replace-env or re-run with --clean." + ) + return metadata_obj + + +def _validate_existing_output_integrity( + output_path: Path, + *, + output_dir: Path, + index_files: Mapping[str, Mapping[str, Any]], + parquet_metadata: pq.FileMetaData | None = None, +) -> None: + metadata_obj = parquet_metadata or _read_existing_output_metadata(output_path) + + rel_key = output_path.relative_to(output_dir).as_posix() + index_entry = index_files.get(rel_key) + if not isinstance(index_entry, Mapping): + return + expected_row_count = index_entry.get("row_count") + if expected_row_count is None: + return + try: + expected = int(expected_row_count) + except (TypeError, ValueError): + return + actual = int(metadata_obj.num_rows) + if actual != expected: + raise RuntimeError( + f"Existing processed output {output_path} has {actual} parquet rows but env_index.json records {expected}. " + "Rebuild it with --replace-model/--replace-env or re-run with --clean." + ) + + +def _existing_output_matches_selected_runs(item: PlannedWorkItem, parquet_metadata: pq.FileMetaData) -> bool: + existing_run_ids = _extract_exporter_source_runs(parquet_metadata) + if existing_run_ids is None: + return False + selected_run_ids = {planned.normalized.record.manifest.job_run_id for planned in item.records} + return existing_run_ids == selected_run_ids + + +def _extract_exporter_source_runs(parquet_metadata: pq.FileMetaData) -> set[str] | None: + metadata_map = parquet_metadata.metadata + if not metadata_map: + return None + payload = metadata_map.get(EXPORTER_METADATA_KEY) + if not payload: + return None + try: + exporter_metadata = json.loads(payload.decode("utf-8")) + except Exception: # noqa: BLE001 + return None + source_runs = exporter_metadata.get("source_runs") + if not isinstance(source_runs, list): + return None + run_ids = {str(run_id).strip() for run_id in source_runs if str(run_id).strip()} + return run_ids or None + + def _print_records_table( discovered: Sequence[discovery.RunRecord], selected: Sequence[discovery.RunRecord], - only_complete_runs: bool, + max_results_missing_pct: float, *, exclude_datasets: Sequence[str] = (), exclude_models: Sequence[str] = (), + skipped_by_delta: int = 0, + skipped_by_exclusion: int = 0, ) -> None: """Pretty-print job discovery vs planned processing.""" exclude_set = normalize_dataset_ids(exclude_datasets, label="process exclude dataset") @@ -355,71 +635,74 @@ def _print_records_table( eligible_discovered = [ rec for rec in discovered - if (not only_complete_runs or _manifest_is_complete(rec.manifest)) - and not (exclude_set and _record_is_excluded(rec, exclude_set)) + if not (exclude_set and _record_is_excluded(rec, exclude_set)) and not (exclude_model_set and _record_model_is_excluded(rec, exclude_model_set)) ] total_by_model: dict[str, int] = {} completed_by_model: dict[str, int] = {} selected_by_model: dict[str, int] = {} - completed_statuses = {"completed", "succeeded", "success"} for rec in eligible_discovered: model_id = rec.model_id or "unknown" total_by_model[model_id] = total_by_model.get(model_id, 0) + 1 - if (rec.status or "").lower() in completed_statuses: + if (rec.status or "").lower() in PROCESS_DEFAULT_STATUS_FILTER: completed_by_model[model_id] = completed_by_model.get(model_id, 0) + 1 for rec in selected: model_id = rec.model_id or "unknown" selected_by_model[model_id] = selected_by_model.get(model_id, 0) + 1 models = sorted(set(total_by_model.keys()) | set(selected_by_model.keys())) - selected_models = sorted(m for m, c in selected_by_model.items() if c > 0) - discovered_jobs_total = sum(total_by_model.get(m, 0) for m in models) - selected_jobs_total = sum(selected_by_model.get(m, 0) for m in models) + selected_models = sorted(model_id for model_id, count in selected_by_model.items() if count > 0) + discovered_jobs_total = sum(total_by_model.get(model_id, 0) for model_id in models) + selected_jobs_total = sum(selected_by_model.get(model_id, 0) for model_id in models) try: from rich.console import Console - from rich.table import Table from rich.markup import escape + from rich.table import Table except Exception: - suffix = " (complete runs only)" if only_complete_runs else "" logger.info( - "Processing %d job(s) across %d model(s)%s (found %d job(s) across %d model(s)).", + "Processing %d job(s) across %d model(s) (max_results_missing_pct=%s; found %d eligible job(s) across %d model(s)); " + "excluded=%d existing=%d.", selected_jobs_total, len(selected_models), - suffix, + _format_missing_pct(max_results_missing_pct), discovered_jobs_total, len(models), + skipped_by_exclusion, + skipped_by_delta, ) for model_id in models: - comp = completed_by_model.get(model_id, 0) - tot = total_by_model.get(model_id, 0) - sel = selected_by_model.get(model_id, 0) - logger.info(" - %s: selected=%d; %d/%d completed", model_id, sel, comp, tot) + completed = completed_by_model.get(model_id, 0) + total = total_by_model.get(model_id, 0) + selected_count = selected_by_model.get(model_id, 0) + logger.info(" - %s: selected=%d; %d/%d completed", model_id, selected_count, completed, total) return console = Console() - title = f"Processing {selected_jobs_total} job(s) across {len(selected_models)} model(s)" - if only_complete_runs: - title += " (complete runs only)" - found_suffix = "after filters" if (exclude_set or only_complete_runs) else "pre-aggregation" - title += f" [dim](found {discovered_jobs_total} job(s) across {len(models)} model(s); {found_suffix})[/dim]" + title = ( + f"Processing {selected_jobs_total} job(s) across {len(selected_models)} model(s) " + f"[dim](max_results_missing_pct={_format_missing_pct(max_results_missing_pct)})[/dim]" + ) + title += ( + f" [dim](found {discovered_jobs_total} eligible job(s); excluded={skipped_by_exclusion}, " + f"existing={skipped_by_delta})[/dim]" + ) table = Table(title=title, show_header=True, header_style="bold cyan", caption=None) table.add_column("Model", style="magenta") table.add_column("Jobs (completed/total)", style="green", justify="right") table.add_column("Selected", style="cyan", justify="right") for model_id in models: - comp = completed_by_model.get(model_id, 0) - tot = total_by_model.get(model_id, 0) - sel = selected_by_model.get(model_id, 0) - table.add_row(escape(str(model_id)), f"{comp}/{tot}", str(sel)) + completed = completed_by_model.get(model_id, 0) + total = total_by_model.get(model_id, 0) + selected_count = selected_by_model.get(model_id, 0) + table.add_row(escape(str(model_id)), f"{completed}/{total}", str(selected_count)) console.print(table) -def _manifest_is_complete(manifest: discovery.RunManifestInfo) -> bool: - return not (manifest.summary_total_known and manifest.summary_completed != manifest.summary_total) +def _format_missing_pct(value: float) -> str: + return f"{float(value):g}" def _record_is_excluded(record: discovery.RunRecord, exclude_set: set[str]) -> bool: @@ -434,90 +717,152 @@ def _record_is_excluded(record: discovery.RunRecord, exclude_set: set[str]) -> b def _record_model_is_excluded(record: discovery.RunRecord, exclude_model_set: set[str]) -> bool: - model_id = str(record.model_id or "").strip() - return model_is_excluded(model_id, exclude_model_set) + return model_is_excluded(str(record.model_id or "").strip(), exclude_model_set) -__all__ = ["ProcessOptions", "ProcessResult", "run_process"] +def _validate_selected_results_completeness( + work_items: Sequence[PlannedWorkItem], + *, + max_results_missing_pct: float, +) -> None: + missing_files: list[str] = [] + violations: list[str] = [] + ungateable = 0 + + for item in work_items: + for planned in item.records: + normalized = planned.normalized + record = normalized.record + if not record.results_path.exists(): + missing_files.append( + "model_id={model_id} output_env_id={output_env_id} manifest_env_id={manifest_env_id} " + "job_run_id={job_run_id} job_id={job_id} results_path={results_path}".format( + model_id=item.identity.model_id, + output_env_id=item.identity.output_env_id, + manifest_env_id=normalized.manifest_env_id, + job_run_id=record.manifest.job_run_id, + job_id=record.job_id, + results_path=record.results_path, + ) + ) + continue + + expected_rows = _expected_results_rows(normalized) + observed_rows = _completeness_observed_rows(record, expected_rows=expected_rows, threshold=max_results_missing_pct) + if expected_rows is None or observed_rows is None: + ungateable += 1 + continue + + missing_pct = _results_missing_pct(expected_rows=expected_rows, observed_rows=observed_rows) + if missing_pct > max_results_missing_pct: + violations.append( + "model_id={model_id} output_env_id={output_env_id} manifest_env_id={manifest_env_id} " + "job_run_id={job_run_id} job_id={job_id} expected_rows={expected_rows} " + "observed_rows={observed_rows} missing_pct={missing_pct:.2f} threshold={threshold:g}".format( + model_id=item.identity.model_id, + output_env_id=item.identity.output_env_id, + manifest_env_id=normalized.manifest_env_id, + job_run_id=record.manifest.job_run_id, + job_id=record.job_id, + expected_rows=expected_rows, + observed_rows=observed_rows, + missing_pct=missing_pct, + threshold=float(max_results_missing_pct), + ) + ) + if ungateable: + logger.warning( + "Results row completeness gate could not be applied to %d selected record(s) because expected_rows " + "(num_examples * rollouts_per_example) or manifest row_count was unknown.", + ungateable, + ) -def _process_env_group( - work_items: Sequence[_RecordWork], -) -> tuple[list[AggregatedEnvRows], int]: - """Load and aggregate all rows for a single environment.""" - row_buffer: list[dict[str, Any]] = [] - for work in work_items: - row_batch = rows.load_rows( - work.normalized, - extra_columns=work.extra_columns, - drop_columns=work.drop_columns, - answer_column=work.answer_column, + if not missing_files and not violations: + return + + message_parts: list[str] = [] + if missing_files: + missing_lines = "\n".join(f" - {line}" for line in missing_files) + message_parts.append("Selected records are missing results.jsonl files:\n" + missing_lines) + if violations: + violation_lines = "\n".join(f" - {line}" for line in violations) + message_parts.append( + "Selected records exceeded --max-results-missing-pct based on manifest row_count and expected rows:\n" + + violation_lines ) - row_buffer.extend(row_batch) - aggregated = aggregate.aggregate_rows_by_env( - row_buffer, - ) - return aggregated, len(row_buffer) + raise RuntimeError("\n\n".join(message_parts)) -def _source_updated_at(record: discovery.RunRecord) -> str: - return record.manifest.updated_at or record.manifest.created_at or "" +def _expected_results_rows(normalized: metadata.NormalizedMetadata) -> int | None: + num_examples = normalized.num_examples + rollouts_per_example = normalized.rollouts_per_example + if num_examples is None or rollouts_per_example is None: + return None + if num_examples == -1: + return None + if num_examples <= 0 or rollouts_per_example <= 0: + return None + return int(num_examples) * int(rollouts_per_example) -def _filter_env_groups_by_delta( - env_groups: Sequence[_EnvGroupSelection], - index_runs: Mapping[str, Mapping[str, Any]], - index_files: Mapping[str, Mapping[str, Any]], +def _results_missing_pct(*, expected_rows: int, observed_rows: int) -> float: + if expected_rows <= 0: + return 0.0 + missing_rows = max(int(expected_rows) - max(int(observed_rows), 0), 0) + return 100.0 * missing_rows / int(expected_rows) + + +def _completeness_observed_rows( + record: discovery.RunRecord, *, - output_dir: Path, -) -> list[_EnvGroupSelection]: - filtered: list[_EnvGroupSelection] = [] - for group in env_groups: - expected_path = writer.build_output_path(output_dir, model_id=group.model_key, env_id=group.env_key) - expected_rel = expected_path.relative_to(output_dir).as_posix() - prior_file = index_files.get(expected_rel, {}) - if not prior_file: - filtered.append(group) - continue - prior_updated_at = str(prior_file.get("updated_at") or prior_file.get("created_at") or "") - if group.job_run_id not in index_runs: - filtered.append(group) - continue - if _is_newer_timestamp(group.run_timestamp, prior_updated_at): - filtered.append(group) - continue - return filtered + expected_rows: int | None, + threshold: float, +) -> int | None: + observed_rows = record.row_count + if expected_rows is None or observed_rows is None: + return observed_rows + + missing_pct = _results_missing_pct(expected_rows=expected_rows, observed_rows=observed_rows) + if missing_pct <= threshold: + return observed_rows + + actual_rows = count_jsonl_rows(record.results_path) + if actual_rows is None or actual_rows == observed_rows: + return observed_rows + + logger.warning( + "Manifest row_count mismatch for process input " + "(job_run_id=%s, job_id=%s, results_path=%s): manifest row_count=%s actual_rows=%s. " + "Using actual_rows for completeness validation.", + record.manifest.job_run_id, + record.job_id, + record.results_path, + observed_rows, + actual_rows, + ) + return actual_rows -def _filter_env_groups_by_exclusion( - env_groups: Sequence[_EnvGroupSelection], - exclude_datasets: Sequence[str], -) -> list[_EnvGroupSelection]: - exclude_set = normalize_dataset_ids(exclude_datasets, label="process exclude dataset") - if not exclude_set: - return list(env_groups) - filtered: list[_EnvGroupSelection] = [] - for group in env_groups: - if _env_is_excluded(str(group.env_key or ""), exclude_set): - continue - filtered.append(group) - return filtered +def _process_env_group(item: PlannedWorkItem) -> tuple[list[AggregatedEnvRows], int]: + """Load and aggregate all rows for a single selected dataset.""" + row_buffer: list[dict[str, Any]] = [] + identities: list[RunIdentity] = [] + for planned in item.records: + row_batch = rows.load_rows( + planned.normalized, + extra_columns=planned.extra_columns, + drop_columns=planned.drop_columns, + answer_column=planned.answer_column, + ) + row_buffer.extend(row_batch) + identities.append(planned.normalized.identity) + aggregated = aggregate.aggregate_rows_by_env(row_buffer, identities=identities) + return aggregated, len(row_buffer) -def _filter_env_groups_by_model_exclusion( - env_groups: Sequence[_EnvGroupSelection], - exclude_models: Sequence[str], -) -> list[_EnvGroupSelection]: - exclude_set = normalize_model_ids(exclude_models, label="process exclude model") - if not exclude_set: - return list(env_groups) - filtered: list[_EnvGroupSelection] = [] - for group in env_groups: - model_id = str(group.model_key or "").strip() - if model_is_excluded(model_id, exclude_set): - continue - filtered.append(group) - return filtered +def _source_updated_at(record: discovery.RunRecord) -> str: + return record.manifest.updated_at or record.manifest.created_at or "" def _env_is_excluded(env_id: str, exclude_set: set[str]) -> bool: @@ -526,19 +871,6 @@ def _env_is_excluded(env_id: str, exclude_set: set[str]) -> bool: return dataset_is_excluded(env_identifier, exclude_set, base_dataset_id=base_env_id) -def _is_newer_timestamp(current: str, prior: str) -> bool: - if not prior: - return True if current else False - if not current: - return False - try: - current_dt = datetime.fromisoformat(current.replace("Z", "+00:00")) - prior_dt = datetime.fromisoformat(prior.replace("Z", "+00:00")) - except Exception: - return current != prior - return current_dt > prior_dt - - def _strip_env_group_rows(group: AggregatedEnvRows) -> AggregatedEnvRows: return AggregatedEnvRows( env_id=group.env_id, @@ -550,72 +882,6 @@ def _strip_env_group_rows(group: AggregatedEnvRows) -> AggregatedEnvRows: ) -def _normalize_records( - records: Sequence[discovery.RunRecord], - env_export_map: Mapping[str, EnvironmentExportConfig], -) -> list[_NormalizedRecord]: - normalized_records: list[_NormalizedRecord] = [] - for record in records: - env_export = _resolve_env_export(record.manifest_env_id, env_export_map) - extra_columns = _resolve_columns(env_export.extra_columns if env_export else ()) - drop_columns = _resolve_columns(env_export.drop_columns if env_export else ()) - answer_column = env_export.answer_column if env_export else None - - normalized = metadata.load_normalized_metadata(record) - model_id = normalized.model_id - if not model_id: - raise RuntimeError( - "Missing model_id for run " - f"(job_run_id={record.manifest.job_run_id}, job_id={record.job_id}, " - f"results_dir={record.results_dir}, manifest={record.manifest.manifest_path})" - ) - - env_key = normalized.base_env_id or normalized.manifest_env_id or record.manifest_env_id or record.job_id - normalized_records.append( - _NormalizedRecord( - record=record, - normalized=normalized, - extra_columns=extra_columns, - drop_columns=drop_columns, - answer_column=answer_column, - model_key=model_id, - env_key=env_key, - job_run_id=record.manifest.job_run_id, - run_timestamp=_source_updated_at(record), - ) - ) - return normalized_records - - -def _select_latest_env_groups( - records: Sequence[_NormalizedRecord], -) -> list[_EnvGroupSelection]: - env_groups: dict[tuple[str, str], dict[str, list[_NormalizedRecord]]] = {} - run_timestamps: dict[str, str] = {} - for record in records: - env_groups.setdefault((record.model_key, record.env_key), {}).setdefault(record.job_run_id, []).append(record) - run_timestamps.setdefault(record.job_run_id, record.run_timestamp) - - selected: list[_EnvGroupSelection] = [] - for (model_key, env_key), run_groups in env_groups.items(): - if not run_groups: - continue - latest_run_id = max( - run_groups.keys(), - key=lambda run_id: _run_sort_key(run_timestamps.get(run_id, ""), run_id), - ) - selected.append( - _EnvGroupSelection( - model_key=model_key, - env_key=env_key, - job_run_id=latest_run_id, - run_timestamp=run_timestamps.get(latest_run_id, ""), - records=run_groups[latest_run_id], - ) - ) - return selected - - def _run_sort_key(timestamp: str, job_run_id: str) -> tuple[int, datetime, str]: if not timestamp: return (0, datetime.min.replace(tzinfo=UTC), job_run_id) @@ -626,21 +892,13 @@ def _run_sort_key(timestamp: str, job_run_id: str) -> tuple[int, datetime, str]: return (0, datetime.min.replace(tzinfo=UTC), job_run_id) -def _confirm_clean_process( - output_dir: Path, - *, - assume_yes: bool, - is_tty: bool, - prompt_func: Callable[[str], str] | None, -) -> None: - if assume_yes: - return - if not is_tty or prompt_func is None: - raise RuntimeError("Refusing to clean processed outputs without confirmation. Re-run with --yes to confirm.") - prompt = f"--clean will delete all contents of {output_dir} and rebuild from runs. Type 'clean' to continue: " - try: - response = prompt_func(prompt).strip().lower() - except (EOFError, KeyboardInterrupt): # noqa: PERF203 - raise RuntimeError("Aborted clean process.") from None - if response != "clean": - raise RuntimeError("Aborted clean process.") +__all__ = [ + "PROCESS_DEFAULT_STATUS_FILTER", + "PlannedRecord", + "PlannedWorkItem", + "ProcessOptions", + "ProcessResult", + "SelectionResult", + "run_process", + "select_work_items", +] diff --git a/medarc_verifiers/cli/process/rows.py b/medarc_verifiers/cli/process/rows.py index d06cb1e4..e27896a7 100644 --- a/medarc_verifiers/cli/process/rows.py +++ b/medarc_verifiers/cli/process/rows.py @@ -28,84 +28,118 @@ def load_rows( """Load results.jsonl rows and attach manifest metadata.""" record = metadata.record if not record.has_results: - logger.debug("Run %s missing results.jsonl; skipping.", record.job_id) - return [] + raise FileNotFoundError( + "Missing results.jsonl for selected run " + f"(job_run_id={record.manifest.job_run_id}, job_id={record.job_id}, path={record.results_path})" + ) results_path = record.results_path extras_keys = {column for column in extra_columns or () if column} drop = {column for column in drop_columns or () if column} drop.update(DEFAULT_DROP_COLUMNS) drop.update(PROMPT_COMPLETION_COLUMNS) + decoded_rows, example_counts = _decode_results_jsonl(results_path) + multi_rollout = _detect_multi_rollout_shape(example_counts) + version_info_json = _encode_metadata_json_column(metadata.raw_metadata.get("version_info")) + + rows: list[dict[str, Any]] = [] + seen_per_example: dict[Any, int] = {} + for line_number, payload in decoded_rows: + cleaned, extras = _clean_payload_row( + payload, + extras_keys=extras_keys, + drop=drop, + answer_column=answer_column, + ) + rollout_index = _resolve_rollout_index( + payload, + metadata, + multi_rollout=multi_rollout, + seen_per_example=seen_per_example, + ) + if extras_keys and extras: + cleaned["extras"] = json.dumps(extras, sort_keys=True) + else: + cleaned["extras"] = None + enriched = _attach_row_metadata( + cleaned, + metadata, + line_number=line_number, + rollout_index=rollout_index, + version_info_json=version_info_json, + ) + rows.append(enriched) - # First pass: decode and clean rows, and count example_id occurrences to - # detect multiple rollouts within a single JSONL (example_id repetition). + return rows + + +def _decode_results_jsonl(path: Path) -> tuple[list[tuple[int, Mapping[str, Any]]], dict[Any, int]]: + """Decode results.jsonl and count example_id occurrences for rollout detection.""" decoded_rows: list[tuple[int, Mapping[str, Any]]] = [] example_counts: dict[Any, int] = {} try: - with results_path.open("r", encoding="utf-8") as handle: + with path.open("r", encoding="utf-8") as handle: for line_number, raw_line in enumerate(handle, start=1): line = raw_line.strip() if not line: continue - payload = _decode_line(line, results_path, line_number) + payload = _decode_line(line, path, line_number) decoded_rows.append((line_number, payload)) ex_id = payload.get("example_id") - # Count occurrences to infer intra-file rollout structure. try: example_counts[ex_id] = example_counts.get(ex_id, 0) + 1 except TypeError: - # Non-hashable example_id shouldn't happen (schema requires - # primitive), but guard just in case. pass except ValueError: raise except OSError as exc: # noqa: FBT003 - logger.warning("Failed to read %s: %s", results_path, exc) - return [] + logger.warning("Failed to read %s: %s", path, exc) + return [], {} + return decoded_rows, example_counts - multi_rollout = any(count > 1 for count in example_counts.values()) - version_info_json = _encode_metadata_json_column(metadata.raw_metadata.get("version_info")) - # Second pass: enrich rows. If the file contains multiple rollouts, compute - # a data-driven rollout_index by counting seen occurrences per example_id. - # Otherwise, retain the suffix/dir-derived rollout_index from metadata. - rows: list[dict[str, Any]] = [] - seen_per_example: dict[Any, int] = {} - for line_number, payload in decoded_rows: - extras = _extract_extras(payload, extras_keys=extras_keys) - cleaned = _clean_row(payload, drop=drop, extras_keys=extras_keys) - cleaned.pop("rollout_index", None) - _map_answer_column(cleaned, payload, answer_column=answer_column) - _flatten_token_usage(cleaned) - payload_rollout_index = _coerce_rollout_index(payload.get("rollout_index")) - if payload_rollout_index is not None: - rollout_index = payload_rollout_index - cleaned["rollout_index"] = payload_rollout_index - elif multi_rollout: - ex_id = payload.get("example_id") - try: - seen = seen_per_example.get(ex_id, 0) - rollout_index = seen # 0-based occurrence index - seen_per_example[ex_id] = seen + 1 - except TypeError: - # Fallback to metadata rollout_index if example_id is unusable as key - rollout_index = metadata.rollout_index - else: - rollout_index = metadata.rollout_index - if extras_keys and extras: - cleaned["extras"] = json.dumps(extras, sort_keys=True) - else: - cleaned["extras"] = None - enriched = _attach_metadata( - cleaned, - metadata, - line_number=line_number, - rollout_index=rollout_index, - version_info_json=version_info_json, - ) - rows.append(enriched) +def _detect_multi_rollout_shape(example_counts: Mapping[Any, int]) -> bool: + return any(count > 1 for count in example_counts.values()) - return rows + +def _clean_payload_row( + payload: Mapping[str, Any], + *, + extras_keys: set[str], + drop: set[str], + answer_column: str | None, +) -> tuple[MutableMapping[str, Any], Mapping[str, Any]]: + extras = _extract_extras(payload, extras_keys=extras_keys) + cleaned = _clean_row(payload, drop=drop, extras_keys=extras_keys) + cleaned.pop("rollout_index", None) + _map_answer_column(cleaned, payload, answer_column=answer_column) + _normalize_token_usage(cleaned) + payload_rollout_index = _coerce_rollout_index(payload.get("rollout_index")) + if payload_rollout_index is not None: + cleaned["rollout_index"] = payload_rollout_index + return cleaned, extras + + +def _resolve_rollout_index( + payload: Mapping[str, Any], + metadata: NormalizedMetadata, + *, + multi_rollout: bool, + seen_per_example: MutableMapping[Any, int], +) -> int: + payload_rollout_index = _coerce_rollout_index(payload.get("rollout_index")) + if payload_rollout_index is not None: + return payload_rollout_index + if not multi_rollout: + return metadata.rollout_index + + ex_id = payload.get("example_id") + try: + seen = seen_per_example.get(ex_id, 0) + seen_per_example[ex_id] = seen + 1 + return seen + except TypeError: + return metadata.rollout_index def _map_answer_column( @@ -202,7 +236,7 @@ def _coerce_rollout_index(value: Any) -> int | None: return None -def _attach_metadata( +def _attach_row_metadata( row: MutableMapping[str, Any], metadata: NormalizedMetadata, *, @@ -211,19 +245,18 @@ def _attach_metadata( version_info_json: str | None, ) -> MutableMapping[str, Any]: record = metadata.record + identity = metadata.identity error_value = record.reason if record.status == "failed" else None - env_identifier = metadata.base_env_id or metadata.manifest_env_id - row.update( { - "env_id": env_identifier, - "manifest_env_id": metadata.manifest_env_id, - "base_env_id": metadata.base_env_id, + "env_id": identity.output_env_id, + "manifest_env_id": identity.manifest_env_id, + "base_env_id": identity.base_env_id, "job_run_id": record.manifest.job_run_id, "run_id": record.job_id, - "model_id": metadata.model_id, + "model_id": identity.model_id, "version_info": version_info_json, "status": record.status, "error": error_value, @@ -236,7 +269,7 @@ def _attach_metadata( return row -def _flatten_token_usage(row: MutableMapping[str, Any]) -> None: +def _normalize_token_usage(row: MutableMapping[str, Any]) -> None: """Flatten token_usage dict into explicit columns and drop the original field.""" if "token_usage" not in row: return diff --git a/medarc_verifiers/cli/process/workspace.py b/medarc_verifiers/cli/process/workspace.py index 20254104..d5669ff5 100644 --- a/medarc_verifiers/cli/process/workspace.py +++ b/medarc_verifiers/cli/process/workspace.py @@ -3,14 +3,17 @@ from __future__ import annotations import json +import logging import shutil from dataclasses import dataclass, field from pathlib import Path from typing import Callable, Iterable, Sequence -from medarc_verifiers.cli.hf import HFSyncConfig, download_hf_repo +from medarc_verifiers.cli.hf import HFSyncConfig, compute_pending_parquet_uploads, download_hf_repo from medarc_verifiers.utils.pathing import resolve_under +logger = logging.getLogger(__name__) + @dataclass(slots=True) class BaselineResult: @@ -18,9 +21,16 @@ class BaselineResult: files_copied: list[Path] = field(default_factory=list) files_overwritten: list[Path] = field(default_factory=list) files_skipped: list[Path] = field(default_factory=list) + pending_parquet_uploads: set[str] = field(default_factory=set) snapshot_dir: Path | None = None +@dataclass(slots=True) +class WorkspacePreparationResult: + cleaned: bool = False + baseline_result: BaselineResult | None = None + + def ensure_output_dir(output_dir: Path) -> None: output_dir.mkdir(parents=True, exist_ok=True) @@ -33,6 +43,37 @@ def is_nonempty_dir(path: Path) -> bool: return False +def prepare_output_workspace( + *, + output_dir: Path, + hf_config: HFSyncConfig | None, + pull_policy: str | None, + clean: bool, + assume_yes: bool, + is_tty: bool, + prompt_func: Callable[[str], str] | None = None, +) -> WorkspacePreparationResult: + """Prepare local processed outputs before selection reads local inventory state.""" + ensure_output_dir(output_dir) + + if clean: + confirm_clean_output_dir(output_dir, assume_yes=assume_yes, is_tty=is_tty, prompt_func=prompt_func) + clear_output_dir(output_dir) + return WorkspacePreparationResult(cleaned=True) + + if hf_config and hf_config.repo_id: + baseline_result = prepare_hf_baseline( + output_dir=output_dir, + hf_config=hf_config, + pull_policy=pull_policy, + is_tty=is_tty, + prompt_func=prompt_func, + ) + return WorkspacePreparationResult(cleaned=False, baseline_result=baseline_result) + + return WorkspacePreparationResult(cleaned=False) + + def prepare_hf_baseline( *, output_dir: Path, @@ -47,8 +88,10 @@ def prepare_hf_baseline( return BaselineResult(policy="local") policy = _resolve_pull_policy(pull_policy, is_tty=is_tty) - result = BaselineResult(policy=policy) if not is_nonempty_dir(output_dir): + if policy == "continue-upload": + logger.warning("HF continue-upload requested with an empty output dir; falling back to pull.") + result = BaselineResult(policy="pull" if policy == "continue-upload" else policy) snapshot_dir = download_hf_repo( repo_id=hf_config.repo_id, branch=hf_config.branch, @@ -61,10 +104,36 @@ def prepare_hf_baseline( _copy_snapshot(snapshot_dir, output_dir, result, overwrite=True) return result + result = BaselineResult(policy=policy) + if policy == "prompt": + try: + result.pending_parquet_uploads = compute_pending_parquet_uploads( + output_dir=output_dir, + repo_id=hf_config.repo_id, + branch=hf_config.branch, + token=hf_config.token, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("HF upload recovery check failed before prompt; hiding upload option: %s", exc) + elif policy == "continue-upload": + try: + result.pending_parquet_uploads = compute_pending_parquet_uploads( + output_dir=output_dir, + repo_id=hf_config.repo_id, + branch=hf_config.branch, + token=hf_config.token, + ) + except Exception as exc: # noqa: BLE001 + logger.warning("HF upload recovery check failed for continue-upload; uploading only current touched files: %s", exc) + prompt_conflicts = False if policy == "prompt": - choice = _prompt_baseline_choice(prompt_func, is_tty=is_tty) - policy = choice + choice = _prompt_baseline_choice( + prompt_func, + is_tty=is_tty, + show_upload=bool(result.pending_parquet_uploads), + ) + policy = "continue-upload" if choice == "upload" else choice result.policy = policy prompt_conflicts = policy == "pull" @@ -104,26 +173,59 @@ def prepare_hf_baseline( ) return result + if policy == "continue-upload": + return result + raise ValueError(f"Unsupported HF pull policy: {policy}") +def confirm_clean_output_dir( + output_dir: Path, + *, + assume_yes: bool, + is_tty: bool, + prompt_func: Callable[[str], str] | None, +) -> None: + if assume_yes: + return + if not is_tty or prompt_func is None: + raise RuntimeError("Refusing to clean processed outputs without confirmation. Re-run with --yes to confirm.") + prompt = f"--clean will delete all contents of {output_dir} and rebuild from runs. Type 'clean' to continue: " + try: + response = prompt_func(prompt).strip().lower() + except (EOFError, KeyboardInterrupt): # noqa: PERF203 + raise RuntimeError("Aborted clean process.") from None + if response != "clean": + raise RuntimeError("Aborted clean process.") + + def _resolve_pull_policy(pull_policy: str | None, *, is_tty: bool) -> str: if pull_policy: return pull_policy return "prompt" if is_tty else "pull" -def _prompt_baseline_choice(prompt_func: Callable[[str], str] | None, *, is_tty: bool) -> str: +def _prompt_baseline_choice( + prompt_func: Callable[[str], str] | None, + *, + is_tty: bool, + show_upload: bool = False, +) -> str: if not is_tty or prompt_func is None: return "pull" + choices = ["pull", "clean"] + if show_upload: + choices.append("upload") if prompt_func is not input: - prompt = ( + prompt_lines = [ "HF baseline exists locally.\n" - " pull -> download missing data without deleting local files\n" - " clean -> redownload everything after deleting local files\n" - "Choose [pull/clean]: " - ) - return _read_choice(prompt_func, prompt, {"pull", "clean"}) + " pull -> download missing data without deleting local files\n" + " clean -> redownload everything after deleting local files\n" + ] + if show_upload: + prompt_lines.append(" upload -> keep local files and resume/upload pending HF artifacts\n") + prompt_lines.append(f"Choose [{'/'.join(choices)}]: ") + return _read_choice(prompt_func, "".join(prompt_lines), choices) from rich.console import Console from rich.prompt import Prompt @@ -131,7 +233,12 @@ def _prompt_baseline_choice(prompt_func: Callable[[str], str] | None, *, is_tty: console.print("[bold yellow]HF baseline exists locally.[/bold yellow]") console.print(" [cyan]pull[/cyan] -> download missing data without deleting local files") console.print(" [cyan]clean[/cyan] -> redownload everything after deleting local files") - return Prompt.ask("Choose", choices=["pull", "clean"], default="pull") + if show_upload: + console.print(" [cyan]upload[/cyan] -> keep local files and resume/upload pending HF artifacts") + try: + return Prompt.ask("Choose", choices=choices, default="pull") + except (EOFError, KeyboardInterrupt): # noqa: PERF203 + raise RuntimeError("Aborted HF baseline selection.") from None def _prompt_overwrite_file(prompt_func: Callable[[str], str] | None, *, path: Path, is_tty: bool) -> bool: @@ -147,7 +254,7 @@ def _read_choice(prompt_func: Callable[[str], str], prompt: str, choices: Sequen while True: try: response = prompt_func(prompt).strip().lower() - except EOFError: # noqa: PERF203 + except (EOFError, KeyboardInterrupt): # noqa: PERF203 raise RuntimeError("Aborted HF baseline selection.") from None if response in choices_set: return response @@ -228,8 +335,11 @@ def clear_output_dir(output_dir: Path) -> None: __all__ = [ "BaselineResult", + "WorkspacePreparationResult", "clear_output_dir", + "confirm_clean_output_dir", "ensure_output_dir", "is_nonempty_dir", + "prepare_output_workspace", "prepare_hf_baseline", ] diff --git a/medarc_verifiers/cli/process/writer.py b/medarc_verifiers/cli/process/writer.py index 1b9d4d55..a9256cdb 100644 --- a/medarc_verifiers/cli/process/writer.py +++ b/medarc_verifiers/cli/process/writer.py @@ -51,7 +51,7 @@ EXPECTED_POLARS_DTYPES: dict[str, pl.DataType] = { "env_id": pl.String, "error": pl.String, - "example_id": pl.Int64, + "example_id": pl.String, "answer": pl.String, "extras": pl.String, "generation_ms": pl.Float64, @@ -79,7 +79,7 @@ [ pa.field("env_id", pa.large_string()), pa.field("error", pa.large_string()), - pa.field("example_id", pa.int64()), + pa.field("example_id", pa.large_string()), pa.field("answer", pa.large_string()), pa.field("extras", pa.large_string()), pa.field("generation_ms", pa.float64()), diff --git a/medarc_verifiers/cli/winrate/api.py b/medarc_verifiers/cli/winrate/api.py index d44ad2df..88f375cb 100644 --- a/medarc_verifiers/cli/winrate/api.py +++ b/medarc_verifiers/cli/winrate/api.py @@ -64,6 +64,28 @@ class ModelCentricResult: datasets: dict[str, dict[str, Any]] +@dataclass(slots=True) +class DatasetModelMissingness: + """Missing reward coverage for one (dataset, model) pair.""" + + dataset: str + model: str + expected_n: int + present_nonnull_n: int + missing_count: int + missing_pct: float + + +@dataclass(slots=True) +class MissingnessSummary: + """Aggregate missingness summary across retained datasets.""" + + n_pairs_total: int + n_pairs_with_missing: int + missing_cells_total: int + worst_offenders: list[DatasetModelMissingness] + + def read_dataset_lazy( parquet_path: Path | str | Sequence[Path | str | PLDataFrame | PLLazyFrame] | PLDataFrame | PLLazyFrame, ) -> pl.LazyFrame: @@ -288,6 +310,7 @@ def compute_winrates( n_questions_by_ds: dict[str, int] = {} models_by_ds: dict[str, list[str]] = {} models_present_by_ds: dict[str, set[str]] = {} + missingness_by_ds: dict[str, list[DatasetModelMissingness]] = {} seen_models: set[str] = set() seen_model_case_map: dict[str, str] = {} @@ -300,7 +323,7 @@ def compute_winrates( dataset_iter = datasets for dataset_name, parquet_path in dataset_iter: - stats, models_present = _process_dataset( + stats, models_present, missingness = _process_dataset( dataset_name, parquet_path, cfg, @@ -329,6 +352,7 @@ def compute_winrates( avg_rewards_by_dataset[dataset_name] = stats.avg_reward_per_model n_questions_by_ds[dataset_name] = stats.n_questions models_by_ds[dataset_name] = stats.models + missingness_by_ds[dataset_name] = missingness if not known_model_set: if include_set: @@ -349,6 +373,7 @@ def compute_winrates( per_dataset_model_means=per_dataset_model_means, avg_rewards_by_dataset=avg_rewards_by_dataset, models_by_ds=models_by_ds, + missingness_by_ds=missingness_by_ds, include_map=include_map, seen_model_case_map=seen_model_case_map, ) @@ -373,6 +398,7 @@ def compute_winrates( avg_rewards_by_dataset.pop(dataset_name, None) n_questions_by_ds.pop(dataset_name, None) models_by_ds.pop(dataset_name, None) + missingness_by_ds.pop(dataset_name, None) if not per_dataset_pairwise: _raise_user_error( "No datasets remain after enforcing dataset_coverage=all-models. " @@ -385,6 +411,8 @@ def compute_winrates( coverage=dataset_coverage, ) + _emit_missingness_report(_summarize_missingness(missingness_by_ds)) + return build_model_centric_result( per_dataset_pairwise=per_dataset_pairwise, per_dataset_model_means=per_dataset_model_means, @@ -583,7 +611,7 @@ def _process_dataset( include_map: Mapping[str, str], seen_model_case_map: Mapping[str, str], partial_datasets: str, -) -> tuple[DatasetStats | None, list[str]]: +) -> tuple[DatasetStats | None, list[str], list[DatasetModelMissingness]]: """Read and process a dataset, raising on failure and honoring selection policies.""" try: lf = read_dataset_lazy(parquet_path) @@ -599,7 +627,7 @@ def _process_dataset( if missing_required and partial_datasets == "strict": missing_labels = [include_map.get(model, model) for model in missing_required] _emit_note(f"Dropping dataset {dataset_name} (missing include models: {missing_labels}).") - return None, models_present + return None, models_present, [] if include_set: models_filtered = [models_present_map[model] for model in target_models if model in models_present_map] @@ -648,6 +676,7 @@ def canonical_label(normalized_id: str) -> str: else: pairwise[key] = (1.0 - wr, n_used) avg_reward_per_model = _mean_reward_per_model(df_avg, allowed=models) + missingness = _compute_dataset_missingness(dataset_name, df_filtered, models) return ( DatasetStats( pairwise=pairwise, @@ -656,12 +685,54 @@ def canonical_label(normalized_id: str) -> str: avg_reward_per_model=avg_reward_per_model, ), models_present, + missingness, ) except Exception as exc: # noqa: BLE001 message = f"Failed to process dataset {dataset_name} at {_format_parquet_source(parquet_path)}: {exc}" _raise_user_error(message, exc) +def _compute_dataset_missingness( + dataset_name: str, + df_avg: pl.DataFrame, + models: Sequence[str], +) -> list[DatasetModelMissingness]: + deduped_models = list(dict.fromkeys(str(model) for model in models)) + if not deduped_models: + return [] + + expected_n = 0 + present_nonnull_by_model: dict[str, int] = {} + if not df_avg.is_empty() and EXAMPLE_ID_COL in df_avg.columns: + expected_n = int(df_avg.select(pl.col(EXAMPLE_ID_COL).n_unique()).item()) + if MODEL_COL in df_avg.columns: + grouped = ( + df_avg.filter(pl.col("reward_mean").is_not_null()) + .group_by(MODEL_COL) # type: ignore[arg-type] + .agg(pl.col(EXAMPLE_ID_COL).n_unique().alias("present_nonnull_n")) + ) + present_nonnull_by_model = { + str(model): int(present_nonnull or 0) for model, present_nonnull in grouped.iter_rows() + } + + missingness: list[DatasetModelMissingness] = [] + for model in deduped_models: + present_nonnull_n = max(present_nonnull_by_model.get(model, 0), 0) + missing_count = max(expected_n - present_nonnull_n, 0) + missing_pct = (100.0 * missing_count / expected_n) if expected_n > 0 else 0.0 + missingness.append( + DatasetModelMissingness( + dataset=dataset_name, + model=model, + expected_n=expected_n, + present_nonnull_n=present_nonnull_n, + missing_count=missing_count, + missing_pct=missing_pct, + ) + ) + return missingness + + def _mean_reward_per_model(df_avg: pl.DataFrame, allowed: Sequence[str] | None = None) -> dict[str, float | None]: """Average reward_mean per model inside a dataset.""" if df_avg.is_empty() or MODEL_COL not in df_avg.columns: @@ -745,6 +816,7 @@ def _canonicalize_dataset_model_labels( per_dataset_model_means: dict[str, dict[str, float]], avg_rewards_by_dataset: dict[str, dict[str, float | None]], models_by_ds: dict[str, list[str]], + missingness_by_ds: dict[str, list[DatasetModelMissingness]], include_map: Mapping[str, str], seen_model_case_map: Mapping[str, str], ) -> None: @@ -806,6 +878,70 @@ def canonical(value: str) -> str: deduped.append(canonical_model) models_by_ds[dataset] = deduped + for dataset, rows in list(missingness_by_ds.items()): + canonical_rows: list[DatasetModelMissingness] = [] + for row in rows: + canonical_rows.append( + DatasetModelMissingness( + dataset=row.dataset, + model=canonical(row.model), + expected_n=row.expected_n, + present_nonnull_n=row.present_nonnull_n, + missing_count=row.missing_count, + missing_pct=row.missing_pct, + ) + ) + missingness_by_ds[dataset] = canonical_rows + + +def _summarize_missingness( + missingness_by_ds: Mapping[str, Sequence[DatasetModelMissingness]], +) -> MissingnessSummary: + rows = [row for dataset_rows in missingness_by_ds.values() for row in dataset_rows] + rows_with_missing = [row for row in rows if row.missing_count > 0] + worst_offenders = sorted( + rows_with_missing, + key=lambda row: (-row.missing_pct, -row.missing_count, row.dataset, row.model), + )[:10] + return MissingnessSummary( + n_pairs_total=len(rows), + n_pairs_with_missing=len(rows_with_missing), + missing_cells_total=sum(row.missing_count for row in rows), + worst_offenders=worst_offenders, + ) + + +def _emit_missingness_report(summary: MissingnessSummary) -> None: + logger.info( + "Winrate missingness summary: n_pairs_total=%d n_pairs_with_missing=%d missing_cells_total=%d", + summary.n_pairs_total, + summary.n_pairs_with_missing, + summary.missing_cells_total, + ) + console = _get_console() + if not console or not getattr(console, "is_terminal", False) or not summary.worst_offenders: + return + try: + from rich.table import Table + except Exception: + return + + table = Table(title="Winrate missingness (top offenders)") + table.add_column("dataset", style="cyan") + table.add_column("model", style="magenta") + table.add_column("missing", justify="right") + table.add_column("expected", justify="right") + table.add_column("missing %", justify="right") + for row in summary.worst_offenders: + table.add_row( + row.dataset, + row.model, + str(row.missing_count), + str(row.expected_n), + f"{row.missing_pct:.1f}", + ) + console.print(table) + def _format_parquet_source( parquet_path: Path | str | Sequence[Path | str] | PLDataFrame | PLLazyFrame, diff --git a/medarc_verifiers/parsers/xml_parser.py b/medarc_verifiers/parsers/xml_parser.py index 6a1176fc..6eb1f2e6 100644 --- a/medarc_verifiers/parsers/xml_parser.py +++ b/medarc_verifiers/parsers/xml_parser.py @@ -61,6 +61,25 @@ def parse(self, completion: Messages | str, strip: bool = True, last: bool = Fal return parsed return None + def parse_answer(self, completion: Messages | str) -> str | None: + """Extract the last answer field from a completion.""" + if isinstance(completion, str): + parsed = self.parse(completion, last=True) + if parsed is not None and hasattr(parsed, self.answer_field): + value = getattr(parsed, self.answer_field) + if value is not None: + return value + return None + + for msg in reversed(self.get_assistant_messages(completion)): + content = str(msg.get("content", "")) + parsed = self.parse(content, last=True) + if parsed is not None and hasattr(parsed, self.answer_field): + value = getattr(parsed, self.answer_field) + if value is not None: + return value + return None + def _has_any_field(self, parsed: Any) -> bool: for _, alternatives in self._fields: for alt in alternatives: diff --git a/medarc_verifiers/rewards/multiple_choice_accuracy.py b/medarc_verifiers/rewards/multiple_choice_accuracy.py index 71e123a8..cdee4780 100644 --- a/medarc_verifiers/rewards/multiple_choice_accuracy.py +++ b/medarc_verifiers/rewards/multiple_choice_accuracy.py @@ -1,76 +1,190 @@ -""" -LLM multiple-choice question accuracy reward. +"""MCQ raw-text grading with tail-authoritative long-response handling.""" -Main use case: Handle models that either return the letter/number (preferred) -or return the entire answer text verbatim (fallback). - -Supports chain-of-thought by prioritizing anchored patterns like "answer is X" -before falling back to last token or text matching. Attempts to recognize -negations to avoid false positives (e.g., "the answer is not C"). -""" +from __future__ import annotations import re import unicodedata from dataclasses import dataclass +from functools import lru_cache from typing import Optional +# Responses longer than this switch into tail long-mode behavior. +LONG_RESPONSE_THRESHOLD_CHARS = 4_000 +# Long-mode explicit-answer and answer-text scans are limited to this terminal slice. +TERMINAL_WINDOW_CHARS = 4_000 +# The looser last-token fallback only inspects this shorter tail inside the terminal slice. +STRONG_TAIL_WINDOW_CHARS = 2_000 +# Local ambiguity checks can look this far backward from a candidate. +LOCAL_CONTEXT_BEFORE_CHARS = 160 +# Local ambiguity checks can look this far forward from a candidate. +LOCAL_CONTEXT_AFTER_CHARS = 240 +# Tail-choice fallback is only allowed when the trailing segment is this short or shorter. +TAIL_CHOICE_MAX_WORDS = 16 + +_UNICODE_PUNCT_TRANSLATIONS = str.maketrans( + { + "\u00a0": " ", + "\u2010": "-", + "\u2011": "-", + "\u2012": "-", + "\u2013": "-", + "\u2014": "-", + "\u2015": "-", + "\u2212": "-", + "\u2018": "'", + "\u2019": "'", + "\u201c": '"', + "\u201d": '"', + } +) + +_WHITESPACE_RE = re.compile(r"\s+") +_LIKELY_TEX_RE = re.compile(r"\\[A-Za-z]+|\\[$\\()\\[\\]{}]|[$]") +_THINK_OPEN_RE = re.compile(r"<\s*think\b[^>]*>", re.IGNORECASE) +_THINK_CLOSE_RE = re.compile(r"", re.IGNORECASE) +_ANSWER_TAG_RE = re.compile(r"", re.IGNORECASE) + +# Any standalone option-like token. This is intentionally broad and gets filtered by +# local ambiguity checks before it can count as a chosen answer. +_OPTION_TOKEN_RE = re.compile(r"(?[A-Za-z]|\d{1,2})(?![\w+\-/])", re.IGNORECASE) +# Anchored cues that usually indicate the model is committing to a final answer. +_ANCHOR_RE = re.compile( + r"(?P