From 8e6602d28361058d1fe15b95686fd587032577f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 27 May 2026 14:10:56 +0200 Subject: [PATCH 1/9] Fix Excel-serial parsing in load_questions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit bioscancast_questions.csv stores created_date as an Excel serial day number (e.g. 45712). pd.to_datetime without unit=D + origin=1899-12-30 treated those integers as nanoseconds past 1970, yielding garbage dates like 1970-01-01 00:00:00.000045712. The bug was latent — no caller had yet relied on the parsed date — but the new orchestrator's build_forecast_question factory needs an accurate created_at. After the fix, q7 resolves to 2025-02-24 as expected. --- bioscancast/stages/eval_stage/loaders.py | 220 ++++++++++++----------- 1 file changed, 113 insertions(+), 107 deletions(-) diff --git a/bioscancast/stages/eval_stage/loaders.py b/bioscancast/stages/eval_stage/loaders.py index 024d670..92da2ea 100644 --- a/bioscancast/stages/eval_stage/loaders.py +++ b/bioscancast/stages/eval_stage/loaders.py @@ -1,108 +1,114 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Union - -import pandas as pd - - -PathLike = Union[str, Path] - - -def _read_csv(path: PathLike) -> pd.DataFrame: - """ - Read one of the BioScanCast CSV files with the correct separator, - encoding, and decimal format. - """ - return pd.read_csv( - path, - sep=";", - encoding="cp1252", - decimal=",", - ) - - -def _clean_text_columns(df: pd.DataFrame) -> pd.DataFrame: - """ - Normalize common text fields so matching is stable across files. - - This mainly helps with: - - spacing - - dash variants - - accidental surrounding whitespace - """ - text_columns = df.select_dtypes(include="object").columns - - for col in text_columns: - df[col] = ( - df[col] - .astype(str) - .str.replace("\u2013", "-", regex=False) # en dash -> hyphen - .str.replace("\u2014", "-", regex=False) # em dash -> hyphen - .str.strip() - ) - - return df - - -def load_questions(path: PathLike) -> pd.DataFrame: - """ - Load the question metadata CSV. - - Expected columns: - - question_id - - topic - - question_text - - question_type - - resolution_criteria - - created_date - - question_status - - resolved_option - - comparison_to_outcome - - takeaways - - relevant_links - """ - df = _read_csv(path) - df = _clean_text_columns(df) - - if "created_date" in df.columns: - df["created_date"] = pd.to_datetime(df["created_date"], errors="coerce") - - if "question_status" in df.columns: - df["question_status"] = df["question_status"].str.lower() - - return df - - -def load_forecasts(path: PathLike) -> pd.DataFrame: - """ - Load the forecasts CSV. - - Expected columns: - - question_id - - forecast_version - - option - - probability - """ - df = _read_csv(path) - df = _clean_text_columns(df) - - if "probability" not in df.columns: - raise ValueError("Forecast file must contain a 'probability' column.") - - df["probability"] = pd.to_numeric(df["probability"], errors="coerce") - - if df["probability"].isna().any(): - bad_rows = df[df["probability"].isna()] - raise ValueError( - "Some forecast probabilities could not be parsed as numbers. " - f"Problematic rows: {bad_rows.index.tolist()}" - ) - - if "forecast_version" in df.columns: - df["forecast_version"] = df["forecast_version"].str.strip() - - if "question_id" in df.columns: - df["question_id"] = df["question_id"].str.strip() - +from __future__ import annotations + +from pathlib import Path +from typing import Union + +import pandas as pd + + +PathLike = Union[str, Path] + + +def _read_csv(path: PathLike) -> pd.DataFrame: + """ + Read one of the BioScanCast CSV files with the correct separator, + encoding, and decimal format. + """ + return pd.read_csv( + path, + sep=";", + encoding="cp1252", + decimal=",", + ) + + +def _clean_text_columns(df: pd.DataFrame) -> pd.DataFrame: + """ + Normalize common text fields so matching is stable across files. + + This mainly helps with: + - spacing + - dash variants + - accidental surrounding whitespace + """ + text_columns = df.select_dtypes(include="object").columns + + for col in text_columns: + df[col] = ( + df[col] + .astype(str) + .str.replace("\u2013", "-", regex=False) # en dash -> hyphen + .str.replace("\u2014", "-", regex=False) # em dash -> hyphen + .str.strip() + ) + + return df + + +def load_questions(path: PathLike) -> pd.DataFrame: + """ + Load the question metadata CSV. + + Expected columns: + - question_id + - topic + - question_text + - question_type + - resolution_criteria + - created_date + - question_status + - resolved_option + - comparison_to_outcome + - takeaways + - relevant_links + """ + df = _read_csv(path) + df = _clean_text_columns(df) + + if "created_date" in df.columns: + # The CSV stores created_date as an Excel serial day (e.g. 45712 → + # 2025-02-19). Without unit="D" + origin="1899-12-30", pandas treats + # the integer as nanoseconds past 1970 and produces garbage dates + # like 1970-01-01 00:00:00.000045712. + df["created_date"] = pd.to_datetime( + df["created_date"], unit="D", origin="1899-12-30", errors="coerce", + ) + + if "question_status" in df.columns: + df["question_status"] = df["question_status"].str.lower() + + return df + + +def load_forecasts(path: PathLike) -> pd.DataFrame: + """ + Load the forecasts CSV. + + Expected columns: + - question_id + - forecast_version + - option + - probability + """ + df = _read_csv(path) + df = _clean_text_columns(df) + + if "probability" not in df.columns: + raise ValueError("Forecast file must contain a 'probability' column.") + + df["probability"] = pd.to_numeric(df["probability"], errors="coerce") + + if df["probability"].isna().any(): + bad_rows = df[df["probability"].isna()] + raise ValueError( + "Some forecast probabilities could not be parsed as numbers. " + f"Problematic rows: {bad_rows.index.tolist()}" + ) + + if "forecast_version" in df.columns: + df["forecast_version"] = df["forecast_version"].str.strip() + + if "question_id" in df.columns: + df["question_id"] = df["question_id"].str.strip() + return df \ No newline at end of file From e5f2efd584bef193ac376e2229e55476beb6d8f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 27 May 2026 14:14:40 +0200 Subject: [PATCH 2/9] Add build_forecast_question factory and load_question_by_id The orchestrator (next commit) needs to turn a CSV row into a typed ForecastQuestion. Maps: - created_date -> tz-aware UTC datetime (already parsed by load_questions) - topic "Pathogen (Region)" -> lowercased pathogen + region - question_text "by Month day, year" -> target_date via regex; falls back to "by Month year" giving the first of next month - question_type + keyword hints in text -> event_type (case_count / death_count / outbreak_declared / None) - resolution_criteria passes through - as_of_date is a factory kwarg, not a CSV column; orchestrator passes it from --as-of-date Tested against all 11 rows of bioscancast_questions.csv; q7 produces ForecastQuestion(id=q7, pathogen=mpox, region=world, target_date=2025-02-28, event_type=case_count, ...). --- bioscancast/stages/eval_stage/loaders.py | 174 ++++++++++++++++++++++- 1 file changed, 172 insertions(+), 2 deletions(-) diff --git a/bioscancast/stages/eval_stage/loaders.py b/bioscancast/stages/eval_stage/loaders.py index 92da2ea..ee1f5bb 100644 --- a/bioscancast/stages/eval_stage/loaders.py +++ b/bioscancast/stages/eval_stage/loaders.py @@ -1,13 +1,107 @@ from __future__ import annotations +import logging +import re +from datetime import datetime, timezone from pathlib import Path -from typing import Union +from typing import Optional, Union import pandas as pd +from bioscancast.filtering.models import ForecastQuestion + PathLike = Union[str, Path] +logger = logging.getLogger(__name__) + + +_MONTH_NUMBERS: dict[str, int] = { + "january": 1, "february": 2, "march": 3, "april": 4, "may": 5, "june": 6, + "july": 7, "august": 8, "september": 9, "october": 10, "november": 11, + "december": 12, +} + +# "by February 28, 2025" / "by Feb 28th 2025" / "by February 2025" +_TARGET_DATE_RE = re.compile( + r"by\s+(?P" + + "|".join(_MONTH_NUMBERS.keys()) + + r"|" + + "|".join(m[:3] for m in _MONTH_NUMBERS.keys()) + + r")" + + r"(?:\s+(?P\d{1,2})(?:st|nd|rd|th)?)?" + + r"(?:,)?\s+(?P\d{4})", + re.IGNORECASE, +) + + +def _parse_target_date(text: str) -> Optional[datetime]: + """Extract a target_date from a question's natural-language text. + + Matches phrasings like "by February 28, 2025", "by Feb 28 2025", and the + month-only fallback "by February 2025" (resolves to the 1st of the + following month, conservatively interpreting "by" as inclusive of the + named month). Returns None if no clear pattern matches. + """ + m = _TARGET_DATE_RE.search(text) + if not m: + return None + month_name = m.group("month").lower() + if len(month_name) == 3: + # Map 3-letter abbreviation back to canonical month + for full, num in _MONTH_NUMBERS.items(): + if full.startswith(month_name): + month = num + break + else: + return None + else: + month = _MONTH_NUMBERS[month_name] + year = int(m.group("year")) + day_str = m.group("day") + if day_str: + day = int(day_str) + try: + return datetime(year, month, day, tzinfo=timezone.utc) + except ValueError: + return None + # Month-only fallback: anchor to the last day of the month would require + # a calendar lookup; the simpler "1st of next month" gives the same + # cutoff semantics for a "by " question. + next_month = month + 1 + next_year = year + if next_month == 13: + next_month = 1 + next_year += 1 + return datetime(next_year, next_month, 1, tzinfo=timezone.utc) + + +def _split_topic(topic: str) -> tuple[Optional[str], Optional[str]]: + """Split a topic like "Mpox (World)" → ("mpox", "world"). + + Returns (pathogen, region). If no parenthetical exists, the whole topic + becomes the pathogen and region is None. + """ + if not topic or pd.isna(topic): + return None, None + topic = topic.strip() + if "(" in topic and topic.endswith(")"): + head, _, tail = topic.rpartition(" (") + return head.strip().lower() or None, tail.rstrip(")").strip().lower() or None + return topic.lower() or None, None + + +def _infer_event_type(question_type: str, question_text: str) -> Optional[str]: + """Map the CSV's question_type plus keyword hints to an event_type.""" + text = (question_text or "").lower() + if "deaths" in text or "death" in text: + return "death_count" + if "cases" in text or "case " in text: + return "case_count" + if "outbreak" in text: + return "outbreak_declared" + return None + def _read_csv(path: PathLike) -> pd.DataFrame: """ @@ -111,4 +205,80 @@ def load_forecasts(path: PathLike) -> pd.DataFrame: if "question_id" in df.columns: df["question_id"] = df["question_id"].str.strip() - return df \ No newline at end of file + return df + + +def build_forecast_question( + row: pd.Series, + *, + as_of_date: Optional[datetime] = None, +) -> ForecastQuestion: + """Convert one row of the question CSV into a ForecastQuestion. + + Used by the orchestrator (`bioscancast.main`) to turn a CSV row into the + typed object the search/filter/insight stages expect. ``as_of_date`` is + passed through verbatim and is the historical-replay cutoff; ``None`` + means live mode. + """ + qid = str(row["question_id"]).strip() + text = str(row["question_text"]).strip() + + created_value = row.get("created_date") + if isinstance(created_value, pd.Timestamp) and not pd.isna(created_value): + created_at = created_value.to_pydatetime() + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + elif isinstance(created_value, datetime): + created_at = ( + created_value if created_value.tzinfo + else created_value.replace(tzinfo=timezone.utc) + ) + else: + logger.warning( + "Question %s has unparseable created_date %r; defaulting to now()", + qid, created_value, + ) + created_at = datetime.now(timezone.utc) + + topic = row.get("topic", "") + pathogen, region = _split_topic(str(topic) if not pd.isna(topic) else "") + + target_date = _parse_target_date(text) + event_type = _infer_event_type(str(row.get("question_type", "")), text) + + resolution_criteria_val = row.get("resolution_criteria") + resolution_criteria = ( + str(resolution_criteria_val).strip() + if resolution_criteria_val is not None and not pd.isna(resolution_criteria_val) + else None + ) + + return ForecastQuestion( + id=qid, + text=text, + created_at=created_at, + target_date=target_date, + region=region, + pathogen=pathogen, + event_type=event_type, + resolution_criteria=resolution_criteria, + as_of_date=as_of_date, + ) + + +def load_question_by_id( + path: PathLike, + question_id: str, + *, + as_of_date: Optional[datetime] = None, +) -> ForecastQuestion: + """Load a single question from the CSV by its question_id.""" + df = load_questions(path) + matches = df[df["question_id"].astype(str).str.strip() == question_id.strip()] + if matches.empty: + available = sorted(df["question_id"].astype(str).str.strip().tolist()) + raise KeyError( + f"question_id {question_id!r} not found in {path}. " + f"Available: {available}" + ) + return build_forecast_question(matches.iloc[0], as_of_date=as_of_date) \ No newline at end of file From 996963b48e92aa2bcefed800497d5f8d20d4fe48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 27 May 2026 14:15:52 +0200 Subject: [PATCH 3/9] Add orchestration/test_questions.csv with q7 + q12 Branch-local question fixture for the new end-to-end orchestrator's live smoke tests. Two rows: - q7: verbatim copy of the row in bioscancast/stages/eval_stage/ bioscancast_questions.csv. Resolved at 126,441 mpox cases globally by Feb 28 2025. Run with --as-of-date 2025-02-28 to exercise historical replay. - q12: new live question on the current East Africa Ebola outbreak, target_date 2026-06-30. Run with no --as-of-date for live mode. Kept separate from bioscancast_questions.csv so the canonical CSV stays an unmodified record of what human forecasters actually evaluated. --- bioscancast/orchestration/test_questions.csv | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 bioscancast/orchestration/test_questions.csv diff --git a/bioscancast/orchestration/test_questions.csv b/bioscancast/orchestration/test_questions.csv new file mode 100644 index 0000000..84e3ff5 --- /dev/null +++ b/bioscancast/orchestration/test_questions.csv @@ -0,0 +1,3 @@ +question_id;topic;question_text;question_type;resolution_criteria;created_date;question_status;resolved_option;comparison_to_outcome;takeaways;relevant_links +q7;Mpox (World);How many confirmed cumulative Mpox cases will be reported globally by February 28, 2025?;range;The question resolves based on cumulative confirmed Mpox cases.;45712;resolved;126,001-128,500;The final case count was 126,441.;Community adjustment over time demonstrated responsiveness to new data.;https://ourworldindata.org/mpox +q12;Ebola (East Africa);How many cumulative confirmed Ebola cases will be reported in the current East Africa outbreak by June 30, 2026?;range;The number of cumulative confirmed cases reported by Africa CDC and/or the WHO Disease Outbreak News for the current East Africa Ebola outbreak as of June 30, 2026.;46169;unresolved;TBD;TBD;TBD;https://africacdc.org/ From 174b3eadfce8dad4e4352d4b343a49e7e0b5c68a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 27 May 2026 14:17:18 +0200 Subject: [PATCH 4/9] Add OpenAI price table and cost estimator bioscancast/llm/pricing.py introduces: - MODEL_PRICES: USD/1M-token snapshot dated 2026-05-27 for the models actually used by stage configs (gpt-4o-mini, gpt-4o, text-embedding-3- small/large) plus a date-pinned gpt-4o-2024-08-06 alias. - estimate_cost(model, input_tokens, output_tokens, cached_input_tokens): computes USD spend with a 50% discount on cached prefix per OpenAI's standard prompt-cache pricing. - estimate_cost_from_summary(): consumes the dict shape that InsightRunResult.budget_summary already produces. Sources cited in the module docstring. Unknown model raises UnknownModelError so the orchestrator surfaces stale price tables loudly rather than under-reporting cost. --- bioscancast/llm/pricing.py | 110 +++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 bioscancast/llm/pricing.py diff --git a/bioscancast/llm/pricing.py b/bioscancast/llm/pricing.py new file mode 100644 index 0000000..d0ae65a --- /dev/null +++ b/bioscancast/llm/pricing.py @@ -0,0 +1,110 @@ +"""USD price table for OpenAI models and a per-call cost estimator. + +The orchestrator uses this to surface estimated cost per pipeline run. +Prices are a point-in-time snapshot — refresh when OpenAI changes rates or +when new model identifiers enter the stage configs. + +Snapshot taken 2026-05-27 from OpenAI's public API pricing pages +(https://devtk.ai/en/models/gpt-4o-mini/ and +https://www.cloudzero.com/blog/openai-pricing/). All numbers are USD per +1,000,000 tokens. The cached-input rate is OpenAI's standard 50% discount +on the cached prefix of an input; embedding models have no separate cached +rate. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ModelPrice: + """USD per 1,000,000 tokens for one model.""" + + input: float + cached_input: float + output: float + + +MODEL_PRICES: dict[str, ModelPrice] = { + # Cheap chat workhorse — used by search (query decomposition + filter + # rescue) and insight (chunk extraction). + "gpt-4o-mini": ModelPrice(input=0.15, cached_input=0.075, output=0.60), + # Strong model — scaffolded for issue #26 refinement but not in + # production use as of 2026-05-27. + "gpt-4o": ModelPrice(input=2.50, cached_input=1.25, output=10.00), + "gpt-4o-2024-08-06": ModelPrice(input=2.50, cached_input=1.25, output=10.00), + # Embeddings (insight retrieval). + "text-embedding-3-small": ModelPrice(input=0.02, cached_input=0.02, output=0.0), + "text-embedding-3-large": ModelPrice(input=0.13, cached_input=0.13, output=0.0), +} + + +class UnknownModelError(KeyError): + """Raised when a model name is not in MODEL_PRICES.""" + + +def estimate_cost( + model: str, + input_tokens: int, + output_tokens: int, + cached_input_tokens: int = 0, +) -> float: + """Estimate USD cost of an LLM call. + + Args: + model: Identifier matching a key in MODEL_PRICES. + input_tokens: Total input tokens (including any cached portion). + output_tokens: Output tokens generated. + cached_input_tokens: Subset of input_tokens that hit the prompt + cache; must be <= input_tokens. The non-cached remainder is + billed at the full input rate. + + Raises: + UnknownModelError: If ``model`` is not in MODEL_PRICES — refresh + this module when adding a new model to a stage config. + ValueError: If cached_input_tokens > input_tokens or any token + count is negative. + """ + if input_tokens < 0 or output_tokens < 0 or cached_input_tokens < 0: + raise ValueError("Token counts must be non-negative") + if cached_input_tokens > input_tokens: + raise ValueError( + f"cached_input_tokens ({cached_input_tokens}) exceeds " + f"input_tokens ({input_tokens})" + ) + try: + price = MODEL_PRICES[model] + except KeyError as exc: + raise UnknownModelError( + f"Model {model!r} is not in MODEL_PRICES; refresh " + f"bioscancast/llm/pricing.py with current rates." + ) from exc + + fresh_input = input_tokens - cached_input_tokens + return ( + fresh_input * price.input + + cached_input_tokens * price.cached_input + + output_tokens * price.output + ) / 1_000_000.0 + + +def estimate_cost_from_summary(summary: dict) -> float: + """Estimate USD cost from an InsightRunResult.budget_summary-style dict. + + Expects a dict with a ``per_model`` key whose value is + ``{model: {input_tokens, output_tokens, [cached_input_tokens]}}`` — + the shape the insight pipeline already produces. Unknown models are + skipped with a noisy KeyError so callers can decide whether to + suppress or surface them. + """ + per_model = summary.get("per_model") or {} + total = 0.0 + for model, counts in per_model.items(): + total += estimate_cost( + model, + input_tokens=int(counts.get("input_tokens", 0)), + output_tokens=int(counts.get("output_tokens", 0)), + cached_input_tokens=int(counts.get("cached_input_tokens", 0)), + ) + return total From bb8d6d8d1886fae82509c2b54ff999e70e379df7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 27 May 2026 14:18:28 +0200 Subject: [PATCH 5/9] Add bioscancast.orchestration package and persistence helpers New module with the run-directory layout (data/runs/{question_id}/{run_id}/) and per-stage JSON dump helpers: save_question / save_search / save_filtered / save_documents / save_insight / save_manifest. _json_default and the asdict pattern are lifted from scripts/eval_insight_on_real_docs.py so the orchestrator and the eval harness share serialization conventions. --- bioscancast/orchestration/__init__.py | 7 ++ bioscancast/orchestration/persistence.py | 83 ++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 bioscancast/orchestration/__init__.py create mode 100644 bioscancast/orchestration/persistence.py diff --git a/bioscancast/orchestration/__init__.py b/bioscancast/orchestration/__init__.py new file mode 100644 index 0000000..23cc5aa --- /dev/null +++ b/bioscancast/orchestration/__init__.py @@ -0,0 +1,7 @@ +"""End-to-end pipeline orchestration helpers. + +The main orchestrator lives in :mod:`bioscancast.main`. This package +holds the persistence layer (run-directory layout, per-stage JSON dumps, +manifest write/append) so :mod:`bioscancast.main` stays focused on +stage composition. +""" diff --git a/bioscancast/orchestration/persistence.py b/bioscancast/orchestration/persistence.py new file mode 100644 index 0000000..f0d8f3a --- /dev/null +++ b/bioscancast/orchestration/persistence.py @@ -0,0 +1,83 @@ +"""Run-directory layout and per-stage JSON dump helpers. + +The orchestrator writes each stage's output and a running manifest to +``data/runs/{question_id}/{run_id}/`` so a crashed or interrupted run +still has partial artifacts for debugging. +""" + +from __future__ import annotations + +import json +from dataclasses import asdict, is_dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Iterable + +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.insight.pipeline import InsightRunResult + + +def _json_default(obj: Any) -> Any: + """Default JSON encoder for datetimes and dataclasses. + + Lifted from scripts/eval_insight_on_real_docs.py so the orchestrator + and the eval harness use the same conventions for run artifacts. + """ + if isinstance(obj, datetime): + return obj.isoformat() + if is_dataclass(obj): + return asdict(obj) + raise TypeError(f"not serializable: {type(obj).__name__}") + + +def _dump(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, default=_json_default) + + +def make_run_dir(out_root: Path, question_id: str, run_id: str) -> Path: + """Create and return ``out_root/question_id/run_id/``.""" + run_dir = Path(out_root) / question_id / run_id + run_dir.mkdir(parents=True, exist_ok=True) + return run_dir + + +def save_question(run_dir: Path, question: ForecastQuestion) -> None: + _dump(run_dir / "question.json", asdict(question)) + + +def save_search(run_dir: Path, results: Iterable[Any]) -> None: + _dump(run_dir / "search.json", [asdict(r) for r in results]) + + +def save_filtered(run_dir: Path, docs: Iterable[Any]) -> None: + _dump(run_dir / "filtered.json", [asdict(d) for d in docs]) + + +def save_documents(run_dir: Path, docs: Iterable[Any]) -> None: + _dump(run_dir / "documents.json", [asdict(d) for d in docs]) + + +def save_insight(run_dir: Path, result: InsightRunResult) -> None: + """Serialize an InsightRunResult to insight.json. + + The dataclass nests InsightRecord and ChunkReference dataclasses; + asdict handles both. budget_summary is already a dict. + """ + payload = { + "records": [asdict(r) for r in result.records], + "budget_summary": result.budget_summary, + "documents_processed": result.documents_processed, + "documents_skipped": result.documents_skipped, + "notes": list(result.notes), + } + _dump(run_dir / "insight.json", payload) + + +def save_manifest(run_dir: Path, manifest: dict) -> None: + """Write the manifest. Called repeatedly — once before each stage and + again after every stage completes — so a crashed run keeps the + partial timings/config that did make it through. + """ + _dump(run_dir / "manifest.json", manifest) From 30e02514ec5ae92e5f9446eab3f0f7368f8da160 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 27 May 2026 14:21:22 +0200 Subject: [PATCH 6/9] Implement end-to-end pipeline orchestrator in bioscancast/main.py Replaces the 14-line commented sketch with a real argparse-driven orchestrator that chains all four stages for a single ForecastQuestion: python -m bioscancast.main q7 --as-of-date 2025-02-28 -v Pipeline: 1. load_question_by_id reads the CSV row and builds a ForecastQuestion via the new factory (applying any CLI overrides). 2. SearchStagePipeline runs with a usage-tracking LLM wrapper so search + filter token usage is accumulated for cost reporting. 3. FilteringPipeline reuses the same wrapped client. 4. ExtractionPipeline gets as_of_date=question.as_of_date so the fetcher uses Wayback snapshots in historical-replay mode. 5. InsightPipeline receives the raw (unwrapped) client; its BudgetTracker already tracks usage, so wrapping would double-count. After all stages, search/filter usage and insight per_model usage are merged and fed to bioscancast.llm.pricing.estimate_cost for a single USD figure in the final epilogue and manifest. Persistence: data/runs/{question_id}/{run_id}/ question.json, search.json, filtered.json, documents.json, insight.json, manifest.json The manifest is rewritten after every stage so a crashed run keeps partial timings + config. On any stage exception the manifest pins the failing stage and re-raises wrapped in PipelineError; main() exits 1. Empty intermediate output is not an error - logged and passed through (the insight stage already handles zero documents). --- bioscancast/main.py | 424 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 413 insertions(+), 11 deletions(-) diff --git a/bioscancast/main.py b/bioscancast/main.py index 6dd54a9..68e5c80 100644 --- a/bioscancast/main.py +++ b/bioscancast/main.py @@ -1,14 +1,416 @@ -# question = ForecastQuestion( -# id="Q123", -# text="Will country X report more than 50 confirmed human cases of pathogen Y by 30 June 2026?", -# created_at=now() -# ) +"""End-to-end pipeline orchestrator: search -> filter -> extract -> insight. -# search_results = run_search_stage(question, config) +Runs all four stages against a single forecast question, persisting each +stage's output and a running manifest under ``data/runs/{qid}/{run_id}/`` +so a crashed run still has partial artifacts for debugging. -# filtered_docs = run_filtering_stage(question, search_results, config) +Usage: -# downstream -# extracted_texts = run_extraction_stage(filtered_docs, config) -# insights_df = run_insight_stage(question, extracted_texts, config) -# forecast = run_forecasting_stage(question, insights_df, config) \ No newline at end of file + python -m bioscancast.main QUESTION_ID [--csv PATH] [--as-of-date Y-M-D] ... + +See ``--help`` for the full flag list. +""" + +from __future__ import annotations + +import argparse +import logging +import os +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import asdict +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Iterator, Optional + +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass + +from bioscancast.extraction.config import ExtractionConfig +from bioscancast.extraction.pipeline import ExtractionPipeline +from bioscancast.filtering.config import FILTER_CONFIG +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.filtering.pipeline import FilteringPipeline +from bioscancast.insight.config import InsightConfig +from bioscancast.insight.pipeline import InsightPipeline, InsightRunResult +from bioscancast.llm.base import LLMResponse +from bioscancast.llm.openai_client import OpenAILLMClient +from bioscancast.llm.pricing import ( + UnknownModelError, + estimate_cost, +) +from bioscancast.orchestration import persistence +from bioscancast.stages.eval_stage.loaders import load_question_by_id +from bioscancast.stages.search_stage.backends.tavily_backend import TavilyBackend +from bioscancast.stages.search_stage.cache import SearchCache +from bioscancast.stages.search_stage.pipeline import SearchStagePipeline + + +DEFAULT_CSV = "bioscancast/stages/eval_stage/bioscancast_questions.csv" +DEFAULT_OUT_ROOT = "data/runs" + +logger = logging.getLogger("bioscancast.main") + + +class PipelineError(RuntimeError): + """Raised when a stage fails; wraps the underlying exception with stage name.""" + + def __init__(self, stage: str, original: BaseException) -> None: + super().__init__(f"Pipeline failed in stage {stage!r}: {original}") + self.stage = stage + self.original = original + + +class _UsageTrackingClient: + """Wraps an LLMClient and accumulates per-model token usage. + + Used for the search and filter stages so the orchestrator can include + their cost in the final estimate. The insight stage already tracks its + own usage via BudgetTracker — passing the raw client to insight avoids + double-counting. + """ + + def __init__(self, inner) -> None: + self._inner = inner + self.per_model: dict[str, dict[str, int]] = defaultdict( + lambda: {"input_tokens": 0, "output_tokens": 0, "calls": 0} + ) + + def generate_json(self, **kwargs) -> LLMResponse: + response = self._inner.generate_json(**kwargs) + bucket = self.per_model[response.model] + bucket["input_tokens"] += response.input_tokens + bucket["output_tokens"] += response.output_tokens + bucket["calls"] += 1 + return response + + def embed(self, texts: list[str], *, model: str) -> list[list[float]]: + # The OpenAI embed() call doesn't expose usage today; the insight + # pipeline doesn't currently track embedding cost either. If + # embeddings become a material cost line, hook tokenizer-based + # estimation in here. + return self._inner.embed(texts, model=model) + + +# ---------------------------------------------------------------------------- +# CLI parsing and question construction +# ---------------------------------------------------------------------------- + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser( + prog="bioscancast.main", + description="Run the BioScanCast pipeline end-to-end for one question.", + ) + parser.add_argument( + "question_id", + help="ID of the question in the CSV (e.g. q7).", + ) + parser.add_argument( + "--csv", + default=DEFAULT_CSV, + help=f"Path to the question CSV. Default: {DEFAULT_CSV}", + ) + parser.add_argument( + "--out-root", + default=DEFAULT_OUT_ROOT, + help=f"Run-artifact root directory. Default: {DEFAULT_OUT_ROOT}", + ) + parser.add_argument( + "--run-id", + default=None, + help="Override the UTC-timestamp run directory name.", + ) + parser.add_argument( + "--as-of-date", + default=None, + help="Historical-replay cutoff (YYYY-MM-DD). Omit for live mode.", + ) + parser.add_argument( + "--target-date", + default=None, + help="Override CSV-derived target_date (YYYY-MM-DD).", + ) + parser.add_argument("--region", default=None, help="Override region field.") + parser.add_argument("--pathogen", default=None, help="Override pathogen field.") + parser.add_argument( + "--event-type", default=None, help="Override event_type field." + ) + parser.add_argument( + "--no-cache", + action="store_true", + help="Disable the search-stage cache.", + ) + parser.add_argument( + "--max-input-tokens", + type=int, + default=None, + help="Override InsightConfig.max_input_tokens_per_run.", + ) + parser.add_argument( + "-v", "--verbose", + action="store_true", + help="Set log level to INFO.", + ) + return parser.parse_args(argv) + + +def _parse_date(arg: Optional[str]) -> Optional[datetime]: + if arg is None: + return None + return datetime.strptime(arg, "%Y-%m-%d").replace(tzinfo=timezone.utc) + + +def _apply_overrides(q: ForecastQuestion, args: argparse.Namespace) -> ForecastQuestion: + target_date = _parse_date(args.target_date) + return ForecastQuestion( + id=q.id, + text=q.text, + created_at=q.created_at, + target_date=target_date if target_date is not None else q.target_date, + region=args.region or q.region, + pathogen=args.pathogen or q.pathogen, + event_type=args.event_type or q.event_type, + resolution_criteria=q.resolution_criteria, + as_of_date=q.as_of_date, + ) + + +# ---------------------------------------------------------------------------- +# Stage timing and console output +# ---------------------------------------------------------------------------- + + +@contextmanager +def _stage_timer(manifest: dict, stage: str) -> Iterator[None]: + manifest["stage_timings"].setdefault(stage, None) + manifest["current_stage"] = stage + print(f"[{stage}] starting...", flush=True) + t0 = time.perf_counter() + try: + yield + finally: + elapsed = time.perf_counter() - t0 + manifest["stage_timings"][stage] = round(elapsed, 3) + + +def _log_summary(stage: str, summary: str, elapsed: Optional[float]) -> None: + tail = f"({elapsed:.1f}s)" if elapsed is not None else "" + print(f"[{stage}] {summary} {tail}", flush=True) + + +# ---------------------------------------------------------------------------- +# Cost estimation +# ---------------------------------------------------------------------------- + + +def _merge_usage(*usage_dicts: dict[str, dict[str, int]]) -> dict[str, dict[str, int]]: + """Combine per-model usage dicts from different stages.""" + merged: dict[str, dict[str, int]] = defaultdict( + lambda: {"input_tokens": 0, "output_tokens": 0, "calls": 0} + ) + for usage in usage_dicts: + for model, counts in (usage or {}).items(): + for k in ("input_tokens", "output_tokens", "calls"): + merged[model][k] += int(counts.get(k, 0)) + return dict(merged) + + +def _estimate_total_cost(per_model: dict[str, dict[str, int]]) -> tuple[float, list[str]]: + """Return (usd_total, list_of_warnings).""" + total = 0.0 + warnings: list[str] = [] + for model, counts in per_model.items(): + try: + total += estimate_cost( + model, + input_tokens=int(counts.get("input_tokens", 0)), + output_tokens=int(counts.get("output_tokens", 0)), + ) + except UnknownModelError as exc: + warnings.append(str(exc)) + return total, warnings + + +# ---------------------------------------------------------------------------- +# Main pipeline +# ---------------------------------------------------------------------------- + + +def run_pipeline(args: argparse.Namespace) -> InsightRunResult: + csv_path = Path(args.csv) + if not csv_path.exists(): + raise FileNotFoundError(f"Question CSV not found: {csv_path}") + + for var in ("OPENAI_API_KEY", "TAVILY_API_KEY"): + if not os.environ.get(var): + raise RuntimeError( + f"Missing required environment variable {var}. " + f"Set it in your shell or in a .env file." + ) + + question = load_question_by_id( + csv_path, + args.question_id, + as_of_date=_parse_date(args.as_of_date), + ) + question = _apply_overrides(question, args) + + run_id = args.run_id or datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + out_root = Path(args.out_root) + run_dir = persistence.make_run_dir(out_root, question.id, run_id) + + print(f"=== Pipeline run: {question.id} ({run_id}) ===") + print( + f"question: {question.text}\n" + f"pathogen: {question.pathogen!r}\n" + f"region: {question.region!r}\n" + f"target: {question.target_date.date() if question.target_date else None}\n" + f"as_of: {question.as_of_date.date() if question.as_of_date else None}\n" + f"artifacts: {run_dir}\n", + ) + + persistence.save_question(run_dir, question) + + insight_config = InsightConfig() + if args.max_input_tokens: + insight_config.max_input_tokens_per_run = args.max_input_tokens + + manifest: dict[str, Any] = { + "run_id": run_id, + "question_id": question.id, + "csv_path": str(csv_path), + "started_at": datetime.now(timezone.utc).isoformat(), + "completed_at": None, + "stage_timings": {}, + "current_stage": None, + "errored_stage": None, + "error_message": None, + "config": { + "filter": dict(FILTER_CONFIG), + "extraction": asdict(ExtractionConfig()), + "insight": asdict(insight_config), + }, + } + persistence.save_manifest(run_dir, manifest) + + shared_llm_raw = OpenAILLMClient() + shared_llm = _UsageTrackingClient(shared_llm_raw) + + try: + with _stage_timer(manifest, "search"): + backend = TavilyBackend() + cache = None if args.no_cache else SearchCache() + search_pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=shared_llm, + cache=cache, + backend_name="tavily", + ) + search_results = search_pipeline.run(question) + persistence.save_search(run_dir, search_results) + if cache: + cache.close() + _log_summary( + "search", f"{len(search_results)} results", + manifest["stage_timings"]["search"], + ) + persistence.save_manifest(run_dir, manifest) + + with _stage_timer(manifest, "filter"): + filter_pipeline = FilteringPipeline(llm_client=shared_llm) + filtered_docs = filter_pipeline.run(question, search_results) + persistence.save_filtered(run_dir, filtered_docs) + _log_summary( + "filter", f"{len(filtered_docs)} docs", + manifest["stage_timings"]["filter"], + ) + persistence.save_manifest(run_dir, manifest) + + with _stage_timer(manifest, "extract"): + extraction_pipeline = ExtractionPipeline(as_of_date=question.as_of_date) + documents = extraction_pipeline.run(filtered_docs) + persistence.save_documents(run_dir, documents) + n_ok = sum(1 for d in documents if d.status == "success") + _log_summary( + "extract", f"{n_ok}/{len(documents)} success", + manifest["stage_timings"]["extract"], + ) + persistence.save_manifest(run_dir, manifest) + + with _stage_timer(manifest, "insight"): + insight_pipeline = InsightPipeline( + llm_client=shared_llm_raw, + config=insight_config, + ) + insight_result = insight_pipeline.run(question, documents) + persistence.save_insight(run_dir, insight_result) + budget = insight_result.budget_summary + _log_summary( + "insight", + f"{len(insight_result.records)} records | " + f"in={budget.get('total_input_tokens', 0):,} " + f"out={budget.get('total_output_tokens', 0):,}", + manifest["stage_timings"]["insight"], + ) + + except Exception as exc: + manifest["errored_stage"] = manifest.get("current_stage") + manifest["error_message"] = f"{type(exc).__name__}: {exc}" + persistence.save_manifest(run_dir, manifest) + raise PipelineError(manifest["current_stage"] or "unknown", exc) from exc + + # Combine usage across stages and estimate cost. + combined_usage = _merge_usage( + dict(shared_llm.per_model), + insight_result.budget_summary.get("per_model") or {}, + ) + cost_usd, cost_warnings = _estimate_total_cost(combined_usage) + + manifest["current_stage"] = None + manifest["completed_at"] = datetime.now(timezone.utc).isoformat() + manifest["combined_usage"] = combined_usage + manifest["estimated_cost_usd"] = round(cost_usd, 6) + if cost_warnings: + manifest["cost_estimate_warnings"] = cost_warnings + persistence.save_manifest(run_dir, manifest) + + print() + print(f"=== Pipeline complete: {question.id} ===") + for stage in ("search", "filter", "extract", "insight"): + elapsed = manifest["stage_timings"].get(stage) + if elapsed is not None: + print(f" {stage:<8} {elapsed:>7.2f}s") + print(f" estimated cost: ${cost_usd:.4f}") + if cost_warnings: + for w in cost_warnings: + print(f" ! {w}") + print(f" artifacts: {run_dir}") + + return insight_result + + +def main(argv: Optional[list[str]] = None) -> int: + args = _parse_args(argv if argv is not None else sys.argv[1:]) + logging.basicConfig( + level=logging.INFO if args.verbose else logging.WARNING, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + try: + run_pipeline(args) + except PipelineError as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 1 + except (FileNotFoundError, KeyError, RuntimeError) as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 2 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From 23cf6fc567651d5c61bca6ceef7bd7e1372e778d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 27 May 2026 14:22:27 +0200 Subject: [PATCH 7/9] Add scripts/run_pipeline.py wrapper and gitignore data/runs/ Thin scripts/run_pipeline.py wrapper around bioscancast.main:main so the new orchestrator follows the same `scripts/run_*.py` convention as the existing per-stage runners. Both invocations are equivalent: python -m bioscancast.main q7 python scripts/run_pipeline.py q7 data/runs/ added to .gitignore so per-run artifacts (some quite large - documents.json includes every chunk text) don't pollute the repo. --- .gitignore | 1 + scripts/run_pipeline.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) create mode 100644 scripts/run_pipeline.py diff --git a/.gitignore b/.gitignore index ccb93a7..43b0233 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ build/ # Data / cache data/cache/ +data/runs/ *.sqlite # Docling eval — keep FINDINGS.md and sources/, ignore generated run artifacts diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py new file mode 100644 index 0000000..d76fecd --- /dev/null +++ b/scripts/run_pipeline.py @@ -0,0 +1,21 @@ +"""Thin wrapper around ``python -m bioscancast.main`` so the orchestrator +matches the existing per-stage runner convention (scripts/run_*.py). + +Usage: + python scripts/run_pipeline.py q7 --as-of-date 2025-02-28 -v +""" + +from __future__ import annotations + +import os +import sys + +# Add project root to path so `bioscancast` imports work when run from +# anywhere. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from bioscancast.main import main # noqa: E402 + + +if __name__ == "__main__": + sys.exit(main()) From 5de0d46bd3c92a641c2f6ae643c1d50cd47abb52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Wed, 27 May 2026 14:33:22 +0200 Subject: [PATCH 8/9] Handle sets and dated OpenAI model aliases (surfaced by live runs) Two small fixes uncovered by the first live runs of the orchestrator: 1. persistence._json_default crashed on FILTER_CONFIG's set values (blocked_domains, low_value_url_keywords, etc.) when serializing the manifest. Now sorts sets to lists. 2. pricing.MODEL_PRICES needs the dated aliases OpenAI returns in response.model. A request to "gpt-4o-mini" comes back tagged "gpt-4o-mini-2024-07-18", which was missing from the table and produced a $0 cost estimate with a noisy warning. Added that alias plus a couple of known gpt-4o dated variants. q7 historical-replay run subsequently cost $0.0030, q12 live run cost $0.0049, both correctly reported in the manifest. --- bioscancast/llm/pricing.py | 15 +++++++++++---- bioscancast/orchestration/persistence.py | 8 +++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/bioscancast/llm/pricing.py b/bioscancast/llm/pricing.py index d0ae65a..56ef62a 100644 --- a/bioscancast/llm/pricing.py +++ b/bioscancast/llm/pricing.py @@ -26,14 +26,21 @@ class ModelPrice: output: float +_GPT_4O_MINI = ModelPrice(input=0.15, cached_input=0.075, output=0.60) +_GPT_4O = ModelPrice(input=2.50, cached_input=1.25, output=10.00) + MODEL_PRICES: dict[str, ModelPrice] = { # Cheap chat workhorse — used by search (query decomposition + filter - # rescue) and insight (chunk extraction). - "gpt-4o-mini": ModelPrice(input=0.15, cached_input=0.075, output=0.60), + # rescue) and insight (chunk extraction). OpenAI returns the dated + # alias in response.model even when the request used the floating + # name; keep both keyed to the same price. + "gpt-4o-mini": _GPT_4O_MINI, + "gpt-4o-mini-2024-07-18": _GPT_4O_MINI, # Strong model — scaffolded for issue #26 refinement but not in # production use as of 2026-05-27. - "gpt-4o": ModelPrice(input=2.50, cached_input=1.25, output=10.00), - "gpt-4o-2024-08-06": ModelPrice(input=2.50, cached_input=1.25, output=10.00), + "gpt-4o": _GPT_4O, + "gpt-4o-2024-08-06": _GPT_4O, + "gpt-4o-2024-05-13": _GPT_4O, # Embeddings (insight retrieval). "text-embedding-3-small": ModelPrice(input=0.02, cached_input=0.02, output=0.0), "text-embedding-3-large": ModelPrice(input=0.13, cached_input=0.13, output=0.0), diff --git a/bioscancast/orchestration/persistence.py b/bioscancast/orchestration/persistence.py index f0d8f3a..6118e7a 100644 --- a/bioscancast/orchestration/persistence.py +++ b/bioscancast/orchestration/persistence.py @@ -18,15 +18,17 @@ def _json_default(obj: Any) -> Any: - """Default JSON encoder for datetimes and dataclasses. + """Default JSON encoder for datetimes, dataclasses, and sets. - Lifted from scripts/eval_insight_on_real_docs.py so the orchestrator - and the eval harness use the same conventions for run artifacts. + Lifted from scripts/eval_insight_on_real_docs.py and extended to + handle the set/frozenset values that live in FILTER_CONFIG. """ if isinstance(obj, datetime): return obj.isoformat() if is_dataclass(obj): return asdict(obj) + if isinstance(obj, (set, frozenset)): + return sorted(obj) raise TypeError(f"not serializable: {type(obj).__name__}") From 5d7db792940f5ac0959c533c4ca25a332a64c111 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Mod=C3=A9e?= Date: Thu, 28 May 2026 08:09:08 +0200 Subject: [PATCH 9/9] Add per-stage cost line-items to orchestrator epilogue The epilogue previously printed a single total-cost figure. This splits cost by stage (search / filter / insight) so a cost spike in any one stage is visible at a glance during iteration. Mechanism: _UsageTrackingClient gains a snapshot() method; run_pipeline snapshots the shared tracker after search and after filter and diffs them (_usage_delta) to attribute usage to each stage. Insight reports its own budget_summary per_model as before. The extract stage makes no LLM calls, so it shows timing only. Manifest gains stage_usage and stage_costs_usd alongside the existing combined_usage and estimated_cost_usd. Epilogue now reads e.g.: search 23.87s $0.0009 filter 3.06s $0.0003 extract 59.69s insight 12.28s $0.0030 total cost: $0.0042 Implements item 11 from the roadmap. 447 tests still passing. --- bioscancast/main.py | 58 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/bioscancast/main.py b/bioscancast/main.py index 68e5c80..8c40c81 100644 --- a/bioscancast/main.py +++ b/bioscancast/main.py @@ -96,6 +96,11 @@ def embed(self, texts: list[str], *, model: str) -> list[list[float]]: # estimation in here. return self._inner.embed(texts, model=model) + def snapshot(self) -> dict[str, dict[str, int]]: + """Return a plain-dict copy of cumulative usage for delta-ing + between stages.""" + return {model: dict(counts) for model, counts in self.per_model.items()} + # ---------------------------------------------------------------------------- # CLI parsing and question construction @@ -221,6 +226,24 @@ def _merge_usage(*usage_dicts: dict[str, dict[str, int]]) -> dict[str, dict[str, return dict(merged) +def _usage_delta( + before: dict[str, dict[str, int]], + after: dict[str, dict[str, int]], +) -> dict[str, dict[str, int]]: + """Per-model usage that accrued between two snapshots of the same + tracker. Models with no change are omitted.""" + delta: dict[str, dict[str, int]] = {} + for model, counts in after.items(): + b = before.get(model, {}) + d = { + k: int(counts.get(k, 0)) - int(b.get(k, 0)) + for k in ("input_tokens", "output_tokens", "calls") + } + if any(d.values()): + delta[model] = d + return delta + + def _estimate_total_cost(per_model: dict[str, dict[str, int]]) -> tuple[float, list[str]]: """Return (usd_total, list_of_warnings).""" total = 0.0 @@ -302,6 +325,11 @@ def run_pipeline(args: argparse.Namespace) -> InsightRunResult: shared_llm_raw = OpenAILLMClient() shared_llm = _UsageTrackingClient(shared_llm_raw) + # Per-stage usage is captured by snapshotting the shared tracker + # before/after each stage that uses it. Search and filter share one + # client; insight reports its own budget_summary. + stage_usage: dict[str, dict[str, dict[str, int]]] = {} + try: with _stage_timer(manifest, "search"): backend = TavilyBackend() @@ -316,6 +344,8 @@ def run_pipeline(args: argparse.Namespace) -> InsightRunResult: persistence.save_search(run_dir, search_results) if cache: cache.close() + usage_after_search = shared_llm.snapshot() + stage_usage["search"] = usage_after_search _log_summary( "search", f"{len(search_results)} results", manifest["stage_timings"]["search"], @@ -326,6 +356,7 @@ def run_pipeline(args: argparse.Namespace) -> InsightRunResult: filter_pipeline = FilteringPipeline(llm_client=shared_llm) filtered_docs = filter_pipeline.run(question, search_results) persistence.save_filtered(run_dir, filtered_docs) + stage_usage["filter"] = _usage_delta(usage_after_search, shared_llm.snapshot()) _log_summary( "filter", f"{len(filtered_docs)} docs", manifest["stage_timings"]["filter"], @@ -350,6 +381,7 @@ def run_pipeline(args: argparse.Namespace) -> InsightRunResult: ) insight_result = insight_pipeline.run(question, documents) persistence.save_insight(run_dir, insight_result) + stage_usage["insight"] = insight_result.budget_summary.get("per_model") or {} budget = insight_result.budget_summary _log_summary( "insight", @@ -365,16 +397,29 @@ def run_pipeline(args: argparse.Namespace) -> InsightRunResult: persistence.save_manifest(run_dir, manifest) raise PipelineError(manifest["current_stage"] or "unknown", exc) from exc - # Combine usage across stages and estimate cost. + # Per-stage cost from the captured usage snapshots. + stage_costs: dict[str, float] = {} + cost_warnings: list[str] = [] + for stage in ("search", "filter", "insight"): + usage = stage_usage.get(stage) or {} + cost, warns = _estimate_total_cost(usage) + stage_costs[stage] = round(cost, 6) + cost_warnings.extend(warns) + + # Combine usage across stages and estimate total cost. combined_usage = _merge_usage( dict(shared_llm.per_model), insight_result.budget_summary.get("per_model") or {}, ) - cost_usd, cost_warnings = _estimate_total_cost(combined_usage) + cost_usd, _ = _estimate_total_cost(combined_usage) + # Dedup warnings (same unknown model can surface in multiple stages). + cost_warnings = list(dict.fromkeys(cost_warnings)) manifest["current_stage"] = None manifest["completed_at"] = datetime.now(timezone.utc).isoformat() manifest["combined_usage"] = combined_usage + manifest["stage_usage"] = stage_usage + manifest["stage_costs_usd"] = stage_costs manifest["estimated_cost_usd"] = round(cost_usd, 6) if cost_warnings: manifest["cost_estimate_warnings"] = cost_warnings @@ -384,9 +429,12 @@ def run_pipeline(args: argparse.Namespace) -> InsightRunResult: print(f"=== Pipeline complete: {question.id} ===") for stage in ("search", "filter", "extract", "insight"): elapsed = manifest["stage_timings"].get(stage) - if elapsed is not None: - print(f" {stage:<8} {elapsed:>7.2f}s") - print(f" estimated cost: ${cost_usd:.4f}") + if elapsed is None: + continue + stage_cost = stage_costs.get(stage) + cost_str = f" ${stage_cost:.4f}" if stage_cost is not None else "" + print(f" {stage:<8} {elapsed:>7.2f}s{cost_str}") + print(f" total cost: ${cost_usd:.4f}") if cost_warnings: for w in cost_warnings: print(f" ! {w}")