diff --git a/.gitignore b/.gitignore index ccb93a7..43b0233 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ build/ # Data / cache data/cache/ +data/runs/ *.sqlite # Docling eval — keep FINDINGS.md and sources/, ignore generated run artifacts diff --git a/bioscancast/datasets/biosecurity_sources.py b/bioscancast/datasets/biosecurity_sources.py index ddfd11c..2a47288 100644 --- a/bioscancast/datasets/biosecurity_sources.py +++ b/bioscancast/datasets/biosecurity_sources.py @@ -1,34 +1,112 @@ """Known biosecurity dashboard URLs by pathogen. -v1 — flagged for iteration after first benchmark run. -The dashboard list and routing logic will need updating as new outbreaks emerge -and data portals change. +v1 — flagged for iteration after first benchmark run. The dashboard list +and routing logic will need updating as new outbreaks emerge and data +portals change. + +Each entry carries a pathogen-specific ``title`` and ``snippet`` so that +the heuristic filter and the LLM-rescue path have real signal to work +with. The earlier convention ("Dashboard: cdc.gov" with a generic +snippet) produced keyword_overlap_score = 0.000 across the board — see +issue #14 and the q7/q12 live-run findings. """ -DASHBOARD_LOOKUP: dict[str, list[str]] = { +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class DashboardEntry: + """A curated authoritative source for a pathogen. + + The title and snippet are intended to be readable as a search result + in their own right: pathogen name, the kind of data the page hosts, + and the publisher. They feed both the keyword-overlap heuristic and + the LLM-rescue path. + """ + + url: str + title: str + snippet: str + + +DASHBOARD_LOOKUP: dict[str, list[DashboardEntry]] = { "h5n1": [ - "https://www.cdc.gov/bird-flu/situation-summary/", - "https://www.who.int/teams/global-influenza-programme/avian-influenza", + DashboardEntry( + url="https://www.cdc.gov/bird-flu/situation-summary/", + title="CDC H5N1 bird flu situation summary: human cases and outbreaks in the United States", + snippet="CDC tracking of H5N1 avian influenza human cases, affected livestock herds, and public-health response in the US.", + ), + DashboardEntry( + url="https://www.who.int/teams/global-influenza-programme/avian-influenza", + title="WHO Global Influenza Programme: avian influenza A(H5N1) human cases and surveillance", + snippet="WHO monitoring of human H5N1 cases, animal-to-human spillover events, and global surveillance reporting.", + ), ], "avian influenza": [ - "https://www.cdc.gov/bird-flu/situation-summary/", - "https://www.who.int/teams/global-influenza-programme/avian-influenza", + DashboardEntry( + url="https://www.cdc.gov/bird-flu/situation-summary/", + title="CDC H5N1 bird flu situation summary: human cases and outbreaks in the United States", + snippet="CDC tracking of H5N1 avian influenza human cases, affected livestock herds, and public-health response in the US.", + ), + DashboardEntry( + url="https://www.who.int/teams/global-influenza-programme/avian-influenza", + title="WHO Global Influenza Programme: avian influenza A(H5N1) human cases and surveillance", + snippet="WHO monitoring of human H5N1 cases, animal-to-human spillover events, and global surveillance reporting.", + ), ], "mpox": [ - "https://ourworldindata.org/mpox", - "https://www.who.int/emergencies/situation-reports", - "https://www.cdc.gov/mpox/data-research/index.html", + DashboardEntry( + url="https://ourworldindata.org/mpox", + title="Our World in Data mpox tracker: global confirmed cases and deaths", + snippet="OWID dashboard tracking cumulative confirmed mpox cases and deaths globally, broken down by country and region, updated from national health agencies.", + ), + DashboardEntry( + url="https://www.who.int/emergencies/situation-reports", + title="WHO situation reports including the multi-country mpox outbreak", + snippet="WHO situation reports with weekly case counts, country breakdowns, and public-health guidance for ongoing outbreaks including mpox.", + ), + DashboardEntry( + url="https://www.cdc.gov/monkeypox/situation-summary/index.html", + title="CDC mpox current situation summary: confirmed cases in the United States", + snippet="CDC current situation summary for mpox, with US confirmed case counts, clade information, demographics, and outbreak response.", + ), ], "ebola": [ - "https://www.afro.who.int/health-topics/ebola-virus-disease", - "https://www.cdc.gov/ebola/index.html", + DashboardEntry( + url="https://www.afro.who.int/health-topics/ebola-disease", + title="WHO Africa Ebola virus disease outbreak surveillance and case counts", + snippet="WHO regional office for Africa tracking of Ebola virus disease outbreaks, confirmed and suspected cases, deaths, and response across African countries.", + ), + DashboardEntry( + url="https://www.cdc.gov/ebola/about/index.html", + title="CDC Ebola virus disease outbreak history and case counts", + snippet="CDC information on current and historical Ebola virus disease outbreaks worldwide, with case counts, deaths, and US public-health response.", + ), ], "covid-19": [ - "https://ourworldindata.org/coronavirus", - "https://www.who.int/emergencies/diseases/novel-coronavirus-2019/situation-reports", + DashboardEntry( + url="https://ourworldindata.org/coronavirus", + title="Our World in Data COVID-19 tracker: global cases, deaths, and vaccinations", + snippet="OWID dashboard tracking cumulative COVID-19 confirmed cases, deaths, hospitalizations, and vaccination coverage globally by country.", + ), + DashboardEntry( + url="https://www.who.int/emergencies/diseases/novel-coronavirus-2019/situation-reports", + title="WHO COVID-19 situation reports and global case counts", + snippet="WHO situation reports with updates on COVID-19 confirmed cases, deaths, variant tracking, and country-level data.", + ), ], "marburg": [ - "https://www.who.int/news-room/fact-sheets/detail/marburg-virus-disease", - "https://www.cdc.gov/marburg/index.html", + DashboardEntry( + url="https://www.who.int/news-room/fact-sheets/detail/marburg-virus-disease", + title="WHO Marburg virus disease facts and outbreak case counts", + snippet="WHO factsheet on Marburg virus disease including transmission, symptoms, case-fatality ratio, and historical outbreak case and death counts.", + ), + DashboardEntry( + url="https://www.cdc.gov/marburg/index.html", + title="CDC Marburg virus disease outbreaks and surveillance", + snippet="CDC information on Marburg virus disease outbreaks worldwide, case counts, deaths, and US public-health surveillance.", + ), ], } diff --git a/bioscancast/datasets/source_tiers.py b/bioscancast/datasets/source_tiers.py index 93a0013..8abadcc 100644 --- a/bioscancast/datasets/source_tiers.py +++ b/bioscancast/datasets/source_tiers.py @@ -59,6 +59,36 @@ "wikipedia.org", "sciencedirect.com", "pubmed.ncbi.nlm.nih.gov", + # National/international news with established newsrooms. Added after the + # #13 tier-coverage audit (data/investigations/findings-issues-3-4-13.md): + # live pools showed reputable outbreak reporting from these outlets + # resolving to "unknown" (domain_score 0.2), which sank them below the + # filter's credibility floor. Second-level-domain matching in + # resolve_tier() covers subdomains (edition.cnn.com, ca.news.yahoo.com, + # africa.businessinsider.com, etc.). + "cnn.com", + "nbcnews.com", + "cbsnews.com", + "abcnews.go.com", + "abcnews.com", + "npr.org", + "pbs.org", + "usatoday.com", + "latimes.com", + "politico.com", + "politico.eu", + "axios.com", + "thehill.com", + "forbes.com", + "bloomberg.com", + "ft.com", + "wsj.com", + "economist.com", + "time.com", + "theatlantic.com", + "newyorker.com", + "arstechnica.com", + "businessinsider.com", } TIER_4_DOMAINS: set[str] = { diff --git a/bioscancast/filtering/config.py b/bioscancast/filtering/config.py index f0477aa..187dee4 100644 --- a/bioscancast/filtering/config.py +++ b/bioscancast/filtering/config.py @@ -38,7 +38,14 @@ "domain": 0.20, "official_bonus": 0.20, }, - "heuristic_keep_threshold": 0.72, + # Lowered from 0.72 to 0.65 after q7/q12 live runs showed filter + # survival of 4.7% / 13.5% — the threshold was tighter than the + # heuristic's actual signal supports. Borderline candidates that + # cross the new threshold still go to the LLM rescue path; this + # change just stops dropping high-credibility-but-low-keyword-overlap + # results pre-LLM (e.g. apnews/theguardian/washingtonpost in q7). + # See issue #13. + "heuristic_keep_threshold": 0.65, "heuristic_borderline_threshold": 0.45, "reranker_weights": { @@ -49,6 +56,18 @@ "auto_reject_after_rerank": 0.30, "max_llm_filter_candidates": 10, + # When no LLM client is configured, the ambiguous "llm_needed" band + # (reranked priority between auto_reject and auto_keep) is normally + # rejected outright (fail-closed). With this flag enabled — for dev / + # offline / no-API-key runs — a borderline candidate is instead KEPT if it + # is an official domain OR its keyword-overlap relevance clears + # ``no_llm_fallback_relevance_threshold``. This approximates the LLM-rescue + # path without an API call, recovering the on-topic / authoritative tail + # without admitting the generic-news mass. Default OFF so production (which + # always has an LLM client) is unchanged. See issue #13. + "no_llm_soft_fallback": False, + "no_llm_fallback_relevance_threshold": 0.5, + "max_docs_per_domain": 2, "max_docs_per_type": 5, diff --git a/bioscancast/filtering/heuristics.py b/bioscancast/filtering/heuristics.py index 8aa67bc..fe41dc5 100644 --- a/bioscancast/filtering/heuristics.py +++ b/bioscancast/filtering/heuristics.py @@ -121,6 +121,27 @@ def heuristic_filter( ) continue + # Dashboard-injected results are hand-curated in + # ``bioscancast/datasets/biosecurity_sources.py``; they bypass the + # keyword-overlap-driven heuristic which structurally undervalues + # their generic titles. See issue #14 and live-run data on q7/q12 + # where 4/4 injected dashboards had keyword_overlap == 0.000. + if result.retrieval_reason == "dashboard_lookup": + relevance_score = compute_heuristic_relevance(result, question) + credibility_score = compute_heuristic_credibility(result) + keep_list.append( + make_decision( + result=result, + keep=True, + stage="heuristic", + relevance_score=relevance_score, + credibility_score=credibility_score, + priority_score=1.0, + reason_codes=["dashboard_lookup_bypass"], + ) + ) + continue + relevance_score = compute_heuristic_relevance(result, question) credibility_score = compute_heuristic_credibility(result) priority_score = compute_priority_score(result, relevance_score, credibility_score) diff --git a/bioscancast/filtering/pipeline.py b/bioscancast/filtering/pipeline.py index e8c2e5b..80dc0d7 100644 --- a/bioscancast/filtering/pipeline.py +++ b/bioscancast/filtering/pipeline.py @@ -37,11 +37,24 @@ def run( llm_decisions: list[FilterDecision] = [] if llm_needed: if self.llm_client is None: - # Fail closed: reject ambiguous cases if no LLM client is configured. + # No LLM client. Default is fail-closed (reject the ambiguous + # band). When the soft-fallback flag is enabled, keep candidates + # that are official-domain or sufficiently relevant — see + # FILTER_CONFIG["no_llm_soft_fallback"] and issue #13. + soft = FILTER_CONFIG.get("no_llm_soft_fallback", False) + rel_threshold = FILTER_CONFIG.get( + "no_llm_fallback_relevance_threshold", 0.5 + ) for d in llm_needed: - d.keep = False d.stage = "llm_skipped" - d.reason_codes.append("no_llm_client_configured") + result = result_map.get(d.result_id) + is_official = bool(result and result.is_official_domain) + if soft and (is_official or d.relevance_score >= rel_threshold): + d.keep = True + d.reason_codes.append("no_llm_soft_fallback_kept") + else: + d.keep = False + d.reason_codes.append("no_llm_client_configured") llm_decisions = llm_needed else: llm_decisions = llm_filter_candidates( diff --git a/bioscancast/filtering/postprocess.py b/bioscancast/filtering/postprocess.py index c0caba5..fd8d07c 100644 --- a/bioscancast/filtering/postprocess.py +++ b/bioscancast/filtering/postprocess.py @@ -49,6 +49,17 @@ def cap_per_domain_and_type( max_docs_per_domain: int, max_docs_per_type: int, ) -> List[FilteredDocument]: + """Limit how many docs from a single domain or file type survive. + + Dashboard-bypassed docs (selection_reasons contains + ``"dashboard_lookup_bypass"``) are always kept and do not consume a + slot against either cap. Curated dashboard injections are a separate + channel from organic search results; without this carve-out, a + dashboard sitting at synthetic priority 1.0 displaces a genuine + organic candidate on the same domain - which is exactly what + happened on q7 (WHO sitreps dashboard squeezed out the WHO research + event page that the baseline extracted records from). + """ kept: list[FilteredDocument] = [] domain_counts = defaultdict(int) type_counts = defaultdict(int) @@ -56,14 +67,20 @@ def cap_per_domain_and_type( for doc in docs: doc_type = doc.file_type or "unknown" - if domain_counts[doc.domain] >= max_docs_per_domain: - continue - if type_counts[doc_type] >= max_docs_per_type: - continue + is_dashboard_bypass = "dashboard_lookup_bypass" in ( + doc.selection_reasons or [] + ) + + if not is_dashboard_bypass: + if domain_counts[doc.domain] >= max_docs_per_domain: + continue + if type_counts[doc_type] >= max_docs_per_type: + continue kept.append(doc) - domain_counts[doc.domain] += 1 - type_counts[doc_type] += 1 + if not is_dashboard_bypass: + domain_counts[doc.domain] += 1 + type_counts[doc_type] += 1 return kept diff --git a/bioscancast/insight/config.py b/bioscancast/insight/config.py index fad6418..6eaf37b 100644 --- a/bioscancast/insight/config.py +++ b/bioscancast/insight/config.py @@ -16,6 +16,8 @@ "max_chunks_per_document": 12, "extraction_max_output_tokens": 4096, "chunk_workers": 6, + "low_survival_doc_threshold": 5, + "low_survival_top_k": 20, } @@ -43,6 +45,18 @@ class InsightConfig: Set to 1 for sequential execution (useful for debugging or rate- limit-sensitive setups).""" + low_survival_doc_threshold: int = 5 + """When the filter passes fewer than this many usable documents to + insight, switch to ``low_survival_top_k`` for both retrieval and the + per-document chunk cap. q7 reached insight with only 2 surviving + documents; in that regime per-doc retrieval depth becomes the + bottleneck on coverage.""" + + low_survival_top_k: int = 20 + """Retrieval / per-doc cap used when usable documents are at or below + ``low_survival_doc_threshold``. Set to ``None`` (or equal to + ``retrieval_top_k``) to disable the adaptive lift.""" + @classmethod def from_dict(cls, d: dict) -> InsightConfig: """Create an InsightConfig from a dict, ignoring unknown keys.""" diff --git a/bioscancast/insight/extraction/chunk_extractor.py b/bioscancast/insight/extraction/chunk_extractor.py index dd24d71..d8a773d 100644 --- a/bioscancast/insight/extraction/chunk_extractor.py +++ b/bioscancast/insight/extraction/chunk_extractor.py @@ -221,6 +221,24 @@ def _quote_matches(quote: str, chunk_text: str) -> Optional[str]: if unwrap_quote in unwrap_chunk: return unwrap_quote + # Layer 4: case-insensitive substring. Catches the model lowercasing + # the leading letter of a sentence it quotes from mid-paragraph - + # otherwise verbatim drift that's very common (q12 live runs: + # "there are now 750 suspected cases..." vs the source's "There are + # now 750..."). Returns the chunk's own casing so the stored quote + # reflects the source. Crucially this does NOT recover content- + # insertion hallucinations: a fabricated continuation still fails the + # substring test regardless of case (verified against the q12 + # "...have been reported in Ituri, North Kivu" fabrication, whose real + # source text continues "...and 906 suspected cases"). + ci_chunk = norm_chunk.lower() + for candidate in (norm_quote, stripped): + if not candidate: + continue + idx = ci_chunk.find(candidate.lower()) + if idx >= 0: + return norm_chunk[idx: idx + len(candidate)] + return None diff --git a/bioscancast/insight/extraction/prompts.py b/bioscancast/insight/extraction/prompts.py index 0f74b4c..a750014 100644 --- a/bioscancast/insight/extraction/prompts.py +++ b/bioscancast/insight/extraction/prompts.py @@ -27,7 +27,12 @@ by the chunk text. Do NOT infer, speculate, or use outside knowledge. 2. For each fact, provide a verbatim quote from the chunk (max 200 \ characters) that supports the claim. The quote must be an exact \ -substring of the chunk text. +substring of the chunk text. The quote MUST be the sentence (or \ +sentence fragment) that carries the figure itself — it must contain \ +the metric_value either as digits (e.g. "82"), as a number-word \ +(e.g. "eighty-two", "a dozen"), or as a clear relative reference \ +(e.g. "a quarter of the population"). A contextual or supporting \ +sentence that mentions the topic but not the figure is NOT acceptable. 3. If the chunk contains no relevant facts, return an empty facts list. \ This is expected and common — most chunks are irrelevant. 4. Do NOT answer the forecast question. Your job is fact extraction, \ @@ -40,12 +45,13 @@ 6. For metric_name, use one of these canonical snake_case values when \ applicable (this lets downstream dedup merge facts about the same \ metric across sources): - - confirmed_cases (suspected, probable, possible all get \ -their own variants below) - - suspected_cases - - probable_cases - - confirmed_or_probable_cases - - deaths + - confirmed_cases (the "confirmed" tier — lab-confirmed) + - suspected_cases (the "not-yet-confirmed" tier — covers \ +"suspected", "probable", and "possible" reporting categories) + - confirmed_or_probable_cases (WHO/CDC's combined reporting bucket) + - deaths (lab-confirmed deaths) + - suspected_deaths (the "not-yet-confirmed" tier for deaths — \ +covers "suspected", "probable", "under investigation" reporting) - hospitalizations - recoveries - vaccinations_administered @@ -58,7 +64,10 @@ If none of these fit, invent a short snake_case label. Do NOT put \ qualifiers (sex, age, sub-region, time-period like "weekly") in \ metric_name — capture those in `summary` or `location` instead. \ -"cases", "reported cases", "total cases" all map to confirmed_cases. +"cases", "reported cases", "total cases" all map to confirmed_cases. \ +"suspected cases", "probable cases", "possible cases" all map to \ +suspected_cases. "deaths" alone maps to deaths; "suspected deaths", \ +"probable deaths", "deaths under investigation" map to suspected_deaths. 7. Be aware of cognitive biases that affect information processing: - Anchoring: do not over-weight the first number you encounter. - Availability: rare dramatic events are not necessarily more likely. diff --git a/bioscancast/insight/pipeline.py b/bioscancast/insight/pipeline.py index ea294fd..7d5c5dd 100644 --- a/bioscancast/insight/pipeline.py +++ b/bioscancast/insight/pipeline.py @@ -103,6 +103,30 @@ def run( result = InsightRunResult() embedding_cache: dict[str, list[float]] = {} + # Adaptive top-k: when the filter passes through only a handful + # of usable documents, lift retrieval depth so the per-doc chunk + # budget isn't the bottleneck on coverage. See InsightConfig + # docstrings for the rationale. + usable_doc_count = sum( + 1 for d in documents if d.status != "failed" and d.chunks + ) + if usable_doc_count <= config.low_survival_doc_threshold: + effective_top_k = max(config.retrieval_top_k, config.low_survival_top_k) + effective_max_chunks = max( + config.max_chunks_per_document, config.low_survival_top_k + ) + if effective_top_k != config.retrieval_top_k: + result.notes.append( + f"Low-survival adaptive top_k engaged: " + f"{usable_doc_count} usable docs (≤ threshold " + f"{config.low_survival_doc_threshold}); " + f"retrieval_top_k={effective_top_k} (default " + f"{config.retrieval_top_k})." + ) + else: + effective_top_k = config.retrieval_top_k + effective_max_chunks = config.max_chunks_per_document + for doc in documents: # --- Skip check --- if doc.status == "failed" or not doc.chunks: @@ -126,7 +150,7 @@ def run( question, doc, self._llm, - top_k=config.retrieval_top_k, + top_k=effective_top_k, bm25_weight=config.bm25_weight, embedding_weight=config.embedding_weight, embedding_model=config.embedding_model, @@ -134,7 +158,7 @@ def run( ) # Cap chunks per document - scored_chunks = scored_chunks[: config.max_chunks_per_document] + scored_chunks = scored_chunks[:effective_max_chunks] # --- Per-chunk extraction (parallel within a doc) --- # Live tests on real biosecurity documents show the per-doc diff --git a/bioscancast/llm/pricing.py b/bioscancast/llm/pricing.py new file mode 100644 index 0000000..56ef62a --- /dev/null +++ b/bioscancast/llm/pricing.py @@ -0,0 +1,117 @@ +"""USD price table for OpenAI models and a per-call cost estimator. + +The orchestrator uses this to surface estimated cost per pipeline run. +Prices are a point-in-time snapshot — refresh when OpenAI changes rates or +when new model identifiers enter the stage configs. + +Snapshot taken 2026-05-27 from OpenAI's public API pricing pages +(https://devtk.ai/en/models/gpt-4o-mini/ and +https://www.cloudzero.com/blog/openai-pricing/). All numbers are USD per +1,000,000 tokens. The cached-input rate is OpenAI's standard 50% discount +on the cached prefix of an input; embedding models have no separate cached +rate. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ModelPrice: + """USD per 1,000,000 tokens for one model.""" + + input: float + cached_input: float + output: float + + +_GPT_4O_MINI = ModelPrice(input=0.15, cached_input=0.075, output=0.60) +_GPT_4O = ModelPrice(input=2.50, cached_input=1.25, output=10.00) + +MODEL_PRICES: dict[str, ModelPrice] = { + # Cheap chat workhorse — used by search (query decomposition + filter + # rescue) and insight (chunk extraction). OpenAI returns the dated + # alias in response.model even when the request used the floating + # name; keep both keyed to the same price. + "gpt-4o-mini": _GPT_4O_MINI, + "gpt-4o-mini-2024-07-18": _GPT_4O_MINI, + # Strong model — scaffolded for issue #26 refinement but not in + # production use as of 2026-05-27. + "gpt-4o": _GPT_4O, + "gpt-4o-2024-08-06": _GPT_4O, + "gpt-4o-2024-05-13": _GPT_4O, + # Embeddings (insight retrieval). + "text-embedding-3-small": ModelPrice(input=0.02, cached_input=0.02, output=0.0), + "text-embedding-3-large": ModelPrice(input=0.13, cached_input=0.13, output=0.0), +} + + +class UnknownModelError(KeyError): + """Raised when a model name is not in MODEL_PRICES.""" + + +def estimate_cost( + model: str, + input_tokens: int, + output_tokens: int, + cached_input_tokens: int = 0, +) -> float: + """Estimate USD cost of an LLM call. + + Args: + model: Identifier matching a key in MODEL_PRICES. + input_tokens: Total input tokens (including any cached portion). + output_tokens: Output tokens generated. + cached_input_tokens: Subset of input_tokens that hit the prompt + cache; must be <= input_tokens. The non-cached remainder is + billed at the full input rate. + + Raises: + UnknownModelError: If ``model`` is not in MODEL_PRICES — refresh + this module when adding a new model to a stage config. + ValueError: If cached_input_tokens > input_tokens or any token + count is negative. + """ + if input_tokens < 0 or output_tokens < 0 or cached_input_tokens < 0: + raise ValueError("Token counts must be non-negative") + if cached_input_tokens > input_tokens: + raise ValueError( + f"cached_input_tokens ({cached_input_tokens}) exceeds " + f"input_tokens ({input_tokens})" + ) + try: + price = MODEL_PRICES[model] + except KeyError as exc: + raise UnknownModelError( + f"Model {model!r} is not in MODEL_PRICES; refresh " + f"bioscancast/llm/pricing.py with current rates." + ) from exc + + fresh_input = input_tokens - cached_input_tokens + return ( + fresh_input * price.input + + cached_input_tokens * price.cached_input + + output_tokens * price.output + ) / 1_000_000.0 + + +def estimate_cost_from_summary(summary: dict) -> float: + """Estimate USD cost from an InsightRunResult.budget_summary-style dict. + + Expects a dict with a ``per_model`` key whose value is + ``{model: {input_tokens, output_tokens, [cached_input_tokens]}}`` — + the shape the insight pipeline already produces. Unknown models are + skipped with a noisy KeyError so callers can decide whether to + suppress or surface them. + """ + per_model = summary.get("per_model") or {} + total = 0.0 + for model, counts in per_model.items(): + total += estimate_cost( + model, + input_tokens=int(counts.get("input_tokens", 0)), + output_tokens=int(counts.get("output_tokens", 0)), + cached_input_tokens=int(counts.get("cached_input_tokens", 0)), + ) + return total diff --git a/bioscancast/main.py b/bioscancast/main.py index 6dd54a9..8c40c81 100644 --- a/bioscancast/main.py +++ b/bioscancast/main.py @@ -1,14 +1,464 @@ -# question = ForecastQuestion( -# id="Q123", -# text="Will country X report more than 50 confirmed human cases of pathogen Y by 30 June 2026?", -# created_at=now() -# ) +"""End-to-end pipeline orchestrator: search -> filter -> extract -> insight. -# search_results = run_search_stage(question, config) +Runs all four stages against a single forecast question, persisting each +stage's output and a running manifest under ``data/runs/{qid}/{run_id}/`` +so a crashed run still has partial artifacts for debugging. -# filtered_docs = run_filtering_stage(question, search_results, config) +Usage: -# downstream -# extracted_texts = run_extraction_stage(filtered_docs, config) -# insights_df = run_insight_stage(question, extracted_texts, config) -# forecast = run_forecasting_stage(question, insights_df, config) \ No newline at end of file + python -m bioscancast.main QUESTION_ID [--csv PATH] [--as-of-date Y-M-D] ... + +See ``--help`` for the full flag list. +""" + +from __future__ import annotations + +import argparse +import logging +import os +import sys +import time +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import asdict +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Iterator, Optional + +try: + from dotenv import load_dotenv + load_dotenv() +except ImportError: + pass + +from bioscancast.extraction.config import ExtractionConfig +from bioscancast.extraction.pipeline import ExtractionPipeline +from bioscancast.filtering.config import FILTER_CONFIG +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.filtering.pipeline import FilteringPipeline +from bioscancast.insight.config import InsightConfig +from bioscancast.insight.pipeline import InsightPipeline, InsightRunResult +from bioscancast.llm.base import LLMResponse +from bioscancast.llm.openai_client import OpenAILLMClient +from bioscancast.llm.pricing import ( + UnknownModelError, + estimate_cost, +) +from bioscancast.orchestration import persistence +from bioscancast.stages.eval_stage.loaders import load_question_by_id +from bioscancast.stages.search_stage.backends.tavily_backend import TavilyBackend +from bioscancast.stages.search_stage.cache import SearchCache +from bioscancast.stages.search_stage.pipeline import SearchStagePipeline + + +DEFAULT_CSV = "bioscancast/stages/eval_stage/bioscancast_questions.csv" +DEFAULT_OUT_ROOT = "data/runs" + +logger = logging.getLogger("bioscancast.main") + + +class PipelineError(RuntimeError): + """Raised when a stage fails; wraps the underlying exception with stage name.""" + + def __init__(self, stage: str, original: BaseException) -> None: + super().__init__(f"Pipeline failed in stage {stage!r}: {original}") + self.stage = stage + self.original = original + + +class _UsageTrackingClient: + """Wraps an LLMClient and accumulates per-model token usage. + + Used for the search and filter stages so the orchestrator can include + their cost in the final estimate. The insight stage already tracks its + own usage via BudgetTracker — passing the raw client to insight avoids + double-counting. + """ + + def __init__(self, inner) -> None: + self._inner = inner + self.per_model: dict[str, dict[str, int]] = defaultdict( + lambda: {"input_tokens": 0, "output_tokens": 0, "calls": 0} + ) + + def generate_json(self, **kwargs) -> LLMResponse: + response = self._inner.generate_json(**kwargs) + bucket = self.per_model[response.model] + bucket["input_tokens"] += response.input_tokens + bucket["output_tokens"] += response.output_tokens + bucket["calls"] += 1 + return response + + def embed(self, texts: list[str], *, model: str) -> list[list[float]]: + # The OpenAI embed() call doesn't expose usage today; the insight + # pipeline doesn't currently track embedding cost either. If + # embeddings become a material cost line, hook tokenizer-based + # estimation in here. + return self._inner.embed(texts, model=model) + + def snapshot(self) -> dict[str, dict[str, int]]: + """Return a plain-dict copy of cumulative usage for delta-ing + between stages.""" + return {model: dict(counts) for model, counts in self.per_model.items()} + + +# ---------------------------------------------------------------------------- +# CLI parsing and question construction +# ---------------------------------------------------------------------------- + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser( + prog="bioscancast.main", + description="Run the BioScanCast pipeline end-to-end for one question.", + ) + parser.add_argument( + "question_id", + help="ID of the question in the CSV (e.g. q7).", + ) + parser.add_argument( + "--csv", + default=DEFAULT_CSV, + help=f"Path to the question CSV. Default: {DEFAULT_CSV}", + ) + parser.add_argument( + "--out-root", + default=DEFAULT_OUT_ROOT, + help=f"Run-artifact root directory. Default: {DEFAULT_OUT_ROOT}", + ) + parser.add_argument( + "--run-id", + default=None, + help="Override the UTC-timestamp run directory name.", + ) + parser.add_argument( + "--as-of-date", + default=None, + help="Historical-replay cutoff (YYYY-MM-DD). Omit for live mode.", + ) + parser.add_argument( + "--target-date", + default=None, + help="Override CSV-derived target_date (YYYY-MM-DD).", + ) + parser.add_argument("--region", default=None, help="Override region field.") + parser.add_argument("--pathogen", default=None, help="Override pathogen field.") + parser.add_argument( + "--event-type", default=None, help="Override event_type field." + ) + parser.add_argument( + "--no-cache", + action="store_true", + help="Disable the search-stage cache.", + ) + parser.add_argument( + "--max-input-tokens", + type=int, + default=None, + help="Override InsightConfig.max_input_tokens_per_run.", + ) + parser.add_argument( + "-v", "--verbose", + action="store_true", + help="Set log level to INFO.", + ) + return parser.parse_args(argv) + + +def _parse_date(arg: Optional[str]) -> Optional[datetime]: + if arg is None: + return None + return datetime.strptime(arg, "%Y-%m-%d").replace(tzinfo=timezone.utc) + + +def _apply_overrides(q: ForecastQuestion, args: argparse.Namespace) -> ForecastQuestion: + target_date = _parse_date(args.target_date) + return ForecastQuestion( + id=q.id, + text=q.text, + created_at=q.created_at, + target_date=target_date if target_date is not None else q.target_date, + region=args.region or q.region, + pathogen=args.pathogen or q.pathogen, + event_type=args.event_type or q.event_type, + resolution_criteria=q.resolution_criteria, + as_of_date=q.as_of_date, + ) + + +# ---------------------------------------------------------------------------- +# Stage timing and console output +# ---------------------------------------------------------------------------- + + +@contextmanager +def _stage_timer(manifest: dict, stage: str) -> Iterator[None]: + manifest["stage_timings"].setdefault(stage, None) + manifest["current_stage"] = stage + print(f"[{stage}] starting...", flush=True) + t0 = time.perf_counter() + try: + yield + finally: + elapsed = time.perf_counter() - t0 + manifest["stage_timings"][stage] = round(elapsed, 3) + + +def _log_summary(stage: str, summary: str, elapsed: Optional[float]) -> None: + tail = f"({elapsed:.1f}s)" if elapsed is not None else "" + print(f"[{stage}] {summary} {tail}", flush=True) + + +# ---------------------------------------------------------------------------- +# Cost estimation +# ---------------------------------------------------------------------------- + + +def _merge_usage(*usage_dicts: dict[str, dict[str, int]]) -> dict[str, dict[str, int]]: + """Combine per-model usage dicts from different stages.""" + merged: dict[str, dict[str, int]] = defaultdict( + lambda: {"input_tokens": 0, "output_tokens": 0, "calls": 0} + ) + for usage in usage_dicts: + for model, counts in (usage or {}).items(): + for k in ("input_tokens", "output_tokens", "calls"): + merged[model][k] += int(counts.get(k, 0)) + return dict(merged) + + +def _usage_delta( + before: dict[str, dict[str, int]], + after: dict[str, dict[str, int]], +) -> dict[str, dict[str, int]]: + """Per-model usage that accrued between two snapshots of the same + tracker. Models with no change are omitted.""" + delta: dict[str, dict[str, int]] = {} + for model, counts in after.items(): + b = before.get(model, {}) + d = { + k: int(counts.get(k, 0)) - int(b.get(k, 0)) + for k in ("input_tokens", "output_tokens", "calls") + } + if any(d.values()): + delta[model] = d + return delta + + +def _estimate_total_cost(per_model: dict[str, dict[str, int]]) -> tuple[float, list[str]]: + """Return (usd_total, list_of_warnings).""" + total = 0.0 + warnings: list[str] = [] + for model, counts in per_model.items(): + try: + total += estimate_cost( + model, + input_tokens=int(counts.get("input_tokens", 0)), + output_tokens=int(counts.get("output_tokens", 0)), + ) + except UnknownModelError as exc: + warnings.append(str(exc)) + return total, warnings + + +# ---------------------------------------------------------------------------- +# Main pipeline +# ---------------------------------------------------------------------------- + + +def run_pipeline(args: argparse.Namespace) -> InsightRunResult: + csv_path = Path(args.csv) + if not csv_path.exists(): + raise FileNotFoundError(f"Question CSV not found: {csv_path}") + + for var in ("OPENAI_API_KEY", "TAVILY_API_KEY"): + if not os.environ.get(var): + raise RuntimeError( + f"Missing required environment variable {var}. " + f"Set it in your shell or in a .env file." + ) + + question = load_question_by_id( + csv_path, + args.question_id, + as_of_date=_parse_date(args.as_of_date), + ) + question = _apply_overrides(question, args) + + run_id = args.run_id or datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + out_root = Path(args.out_root) + run_dir = persistence.make_run_dir(out_root, question.id, run_id) + + print(f"=== Pipeline run: {question.id} ({run_id}) ===") + print( + f"question: {question.text}\n" + f"pathogen: {question.pathogen!r}\n" + f"region: {question.region!r}\n" + f"target: {question.target_date.date() if question.target_date else None}\n" + f"as_of: {question.as_of_date.date() if question.as_of_date else None}\n" + f"artifacts: {run_dir}\n", + ) + + persistence.save_question(run_dir, question) + + insight_config = InsightConfig() + if args.max_input_tokens: + insight_config.max_input_tokens_per_run = args.max_input_tokens + + manifest: dict[str, Any] = { + "run_id": run_id, + "question_id": question.id, + "csv_path": str(csv_path), + "started_at": datetime.now(timezone.utc).isoformat(), + "completed_at": None, + "stage_timings": {}, + "current_stage": None, + "errored_stage": None, + "error_message": None, + "config": { + "filter": dict(FILTER_CONFIG), + "extraction": asdict(ExtractionConfig()), + "insight": asdict(insight_config), + }, + } + persistence.save_manifest(run_dir, manifest) + + shared_llm_raw = OpenAILLMClient() + shared_llm = _UsageTrackingClient(shared_llm_raw) + + # Per-stage usage is captured by snapshotting the shared tracker + # before/after each stage that uses it. Search and filter share one + # client; insight reports its own budget_summary. + stage_usage: dict[str, dict[str, dict[str, int]]] = {} + + try: + with _stage_timer(manifest, "search"): + backend = TavilyBackend() + cache = None if args.no_cache else SearchCache() + search_pipeline = SearchStagePipeline( + search_backend=backend, + llm_client=shared_llm, + cache=cache, + backend_name="tavily", + ) + search_results = search_pipeline.run(question) + persistence.save_search(run_dir, search_results) + if cache: + cache.close() + usage_after_search = shared_llm.snapshot() + stage_usage["search"] = usage_after_search + _log_summary( + "search", f"{len(search_results)} results", + manifest["stage_timings"]["search"], + ) + persistence.save_manifest(run_dir, manifest) + + with _stage_timer(manifest, "filter"): + filter_pipeline = FilteringPipeline(llm_client=shared_llm) + filtered_docs = filter_pipeline.run(question, search_results) + persistence.save_filtered(run_dir, filtered_docs) + stage_usage["filter"] = _usage_delta(usage_after_search, shared_llm.snapshot()) + _log_summary( + "filter", f"{len(filtered_docs)} docs", + manifest["stage_timings"]["filter"], + ) + persistence.save_manifest(run_dir, manifest) + + with _stage_timer(manifest, "extract"): + extraction_pipeline = ExtractionPipeline(as_of_date=question.as_of_date) + documents = extraction_pipeline.run(filtered_docs) + persistence.save_documents(run_dir, documents) + n_ok = sum(1 for d in documents if d.status == "success") + _log_summary( + "extract", f"{n_ok}/{len(documents)} success", + manifest["stage_timings"]["extract"], + ) + persistence.save_manifest(run_dir, manifest) + + with _stage_timer(manifest, "insight"): + insight_pipeline = InsightPipeline( + llm_client=shared_llm_raw, + config=insight_config, + ) + insight_result = insight_pipeline.run(question, documents) + persistence.save_insight(run_dir, insight_result) + stage_usage["insight"] = insight_result.budget_summary.get("per_model") or {} + budget = insight_result.budget_summary + _log_summary( + "insight", + f"{len(insight_result.records)} records | " + f"in={budget.get('total_input_tokens', 0):,} " + f"out={budget.get('total_output_tokens', 0):,}", + manifest["stage_timings"]["insight"], + ) + + except Exception as exc: + manifest["errored_stage"] = manifest.get("current_stage") + manifest["error_message"] = f"{type(exc).__name__}: {exc}" + persistence.save_manifest(run_dir, manifest) + raise PipelineError(manifest["current_stage"] or "unknown", exc) from exc + + # Per-stage cost from the captured usage snapshots. + stage_costs: dict[str, float] = {} + cost_warnings: list[str] = [] + for stage in ("search", "filter", "insight"): + usage = stage_usage.get(stage) or {} + cost, warns = _estimate_total_cost(usage) + stage_costs[stage] = round(cost, 6) + cost_warnings.extend(warns) + + # Combine usage across stages and estimate total cost. + combined_usage = _merge_usage( + dict(shared_llm.per_model), + insight_result.budget_summary.get("per_model") or {}, + ) + cost_usd, _ = _estimate_total_cost(combined_usage) + # Dedup warnings (same unknown model can surface in multiple stages). + cost_warnings = list(dict.fromkeys(cost_warnings)) + + manifest["current_stage"] = None + manifest["completed_at"] = datetime.now(timezone.utc).isoformat() + manifest["combined_usage"] = combined_usage + manifest["stage_usage"] = stage_usage + manifest["stage_costs_usd"] = stage_costs + manifest["estimated_cost_usd"] = round(cost_usd, 6) + if cost_warnings: + manifest["cost_estimate_warnings"] = cost_warnings + persistence.save_manifest(run_dir, manifest) + + print() + print(f"=== Pipeline complete: {question.id} ===") + for stage in ("search", "filter", "extract", "insight"): + elapsed = manifest["stage_timings"].get(stage) + if elapsed is None: + continue + stage_cost = stage_costs.get(stage) + cost_str = f" ${stage_cost:.4f}" if stage_cost is not None else "" + print(f" {stage:<8} {elapsed:>7.2f}s{cost_str}") + print(f" total cost: ${cost_usd:.4f}") + if cost_warnings: + for w in cost_warnings: + print(f" ! {w}") + print(f" artifacts: {run_dir}") + + return insight_result + + +def main(argv: Optional[list[str]] = None) -> int: + args = _parse_args(argv if argv is not None else sys.argv[1:]) + logging.basicConfig( + level=logging.INFO if args.verbose else logging.WARNING, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + try: + run_pipeline(args) + except PipelineError as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 1 + except (FileNotFoundError, KeyError, RuntimeError) as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 2 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/bioscancast/orchestration/__init__.py b/bioscancast/orchestration/__init__.py new file mode 100644 index 0000000..23cc5aa --- /dev/null +++ b/bioscancast/orchestration/__init__.py @@ -0,0 +1,7 @@ +"""End-to-end pipeline orchestration helpers. + +The main orchestrator lives in :mod:`bioscancast.main`. This package +holds the persistence layer (run-directory layout, per-stage JSON dumps, +manifest write/append) so :mod:`bioscancast.main` stays focused on +stage composition. +""" diff --git a/bioscancast/orchestration/persistence.py b/bioscancast/orchestration/persistence.py new file mode 100644 index 0000000..6118e7a --- /dev/null +++ b/bioscancast/orchestration/persistence.py @@ -0,0 +1,85 @@ +"""Run-directory layout and per-stage JSON dump helpers. + +The orchestrator writes each stage's output and a running manifest to +``data/runs/{question_id}/{run_id}/`` so a crashed or interrupted run +still has partial artifacts for debugging. +""" + +from __future__ import annotations + +import json +from dataclasses import asdict, is_dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Iterable + +from bioscancast.filtering.models import ForecastQuestion +from bioscancast.insight.pipeline import InsightRunResult + + +def _json_default(obj: Any) -> Any: + """Default JSON encoder for datetimes, dataclasses, and sets. + + Lifted from scripts/eval_insight_on_real_docs.py and extended to + handle the set/frozenset values that live in FILTER_CONFIG. + """ + if isinstance(obj, datetime): + return obj.isoformat() + if is_dataclass(obj): + return asdict(obj) + if isinstance(obj, (set, frozenset)): + return sorted(obj) + raise TypeError(f"not serializable: {type(obj).__name__}") + + +def _dump(path: Path, payload: Any) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, default=_json_default) + + +def make_run_dir(out_root: Path, question_id: str, run_id: str) -> Path: + """Create and return ``out_root/question_id/run_id/``.""" + run_dir = Path(out_root) / question_id / run_id + run_dir.mkdir(parents=True, exist_ok=True) + return run_dir + + +def save_question(run_dir: Path, question: ForecastQuestion) -> None: + _dump(run_dir / "question.json", asdict(question)) + + +def save_search(run_dir: Path, results: Iterable[Any]) -> None: + _dump(run_dir / "search.json", [asdict(r) for r in results]) + + +def save_filtered(run_dir: Path, docs: Iterable[Any]) -> None: + _dump(run_dir / "filtered.json", [asdict(d) for d in docs]) + + +def save_documents(run_dir: Path, docs: Iterable[Any]) -> None: + _dump(run_dir / "documents.json", [asdict(d) for d in docs]) + + +def save_insight(run_dir: Path, result: InsightRunResult) -> None: + """Serialize an InsightRunResult to insight.json. + + The dataclass nests InsightRecord and ChunkReference dataclasses; + asdict handles both. budget_summary is already a dict. + """ + payload = { + "records": [asdict(r) for r in result.records], + "budget_summary": result.budget_summary, + "documents_processed": result.documents_processed, + "documents_skipped": result.documents_skipped, + "notes": list(result.notes), + } + _dump(run_dir / "insight.json", payload) + + +def save_manifest(run_dir: Path, manifest: dict) -> None: + """Write the manifest. Called repeatedly — once before each stage and + again after every stage completes — so a crashed run keeps the + partial timings/config that did make it through. + """ + _dump(run_dir / "manifest.json", manifest) diff --git a/bioscancast/orchestration/test_questions.csv b/bioscancast/orchestration/test_questions.csv new file mode 100644 index 0000000..84e3ff5 --- /dev/null +++ b/bioscancast/orchestration/test_questions.csv @@ -0,0 +1,3 @@ +question_id;topic;question_text;question_type;resolution_criteria;created_date;question_status;resolved_option;comparison_to_outcome;takeaways;relevant_links +q7;Mpox (World);How many confirmed cumulative Mpox cases will be reported globally by February 28, 2025?;range;The question resolves based on cumulative confirmed Mpox cases.;45712;resolved;126,001-128,500;The final case count was 126,441.;Community adjustment over time demonstrated responsiveness to new data.;https://ourworldindata.org/mpox +q12;Ebola (East Africa);How many cumulative confirmed Ebola cases will be reported in the current East Africa outbreak by June 30, 2026?;range;The number of cumulative confirmed cases reported by Africa CDC and/or the WHO Disease Outbreak News for the current East Africa Ebola outbreak as of June 30, 2026.;46169;unresolved;TBD;TBD;TBD;https://africacdc.org/ diff --git a/bioscancast/stages/eval_stage/loaders.py b/bioscancast/stages/eval_stage/loaders.py index cf7dbb5..ee1f5bb 100644 --- a/bioscancast/stages/eval_stage/loaders.py +++ b/bioscancast/stages/eval_stage/loaders.py @@ -1,13 +1,107 @@ from __future__ import annotations +import logging +import re +from datetime import datetime, timezone from pathlib import Path -from typing import Union +from typing import Optional, Union import pandas as pd +from bioscancast.filtering.models import ForecastQuestion + PathLike = Union[str, Path] +logger = logging.getLogger(__name__) + + +_MONTH_NUMBERS: dict[str, int] = { + "january": 1, "february": 2, "march": 3, "april": 4, "may": 5, "june": 6, + "july": 7, "august": 8, "september": 9, "october": 10, "november": 11, + "december": 12, +} + +# "by February 28, 2025" / "by Feb 28th 2025" / "by February 2025" +_TARGET_DATE_RE = re.compile( + r"by\s+(?P" + + "|".join(_MONTH_NUMBERS.keys()) + + r"|" + + "|".join(m[:3] for m in _MONTH_NUMBERS.keys()) + + r")" + + r"(?:\s+(?P\d{1,2})(?:st|nd|rd|th)?)?" + + r"(?:,)?\s+(?P\d{4})", + re.IGNORECASE, +) + + +def _parse_target_date(text: str) -> Optional[datetime]: + """Extract a target_date from a question's natural-language text. + + Matches phrasings like "by February 28, 2025", "by Feb 28 2025", and the + month-only fallback "by February 2025" (resolves to the 1st of the + following month, conservatively interpreting "by" as inclusive of the + named month). Returns None if no clear pattern matches. + """ + m = _TARGET_DATE_RE.search(text) + if not m: + return None + month_name = m.group("month").lower() + if len(month_name) == 3: + # Map 3-letter abbreviation back to canonical month + for full, num in _MONTH_NUMBERS.items(): + if full.startswith(month_name): + month = num + break + else: + return None + else: + month = _MONTH_NUMBERS[month_name] + year = int(m.group("year")) + day_str = m.group("day") + if day_str: + day = int(day_str) + try: + return datetime(year, month, day, tzinfo=timezone.utc) + except ValueError: + return None + # Month-only fallback: anchor to the last day of the month would require + # a calendar lookup; the simpler "1st of next month" gives the same + # cutoff semantics for a "by " question. + next_month = month + 1 + next_year = year + if next_month == 13: + next_month = 1 + next_year += 1 + return datetime(next_year, next_month, 1, tzinfo=timezone.utc) + + +def _split_topic(topic: str) -> tuple[Optional[str], Optional[str]]: + """Split a topic like "Mpox (World)" → ("mpox", "world"). + + Returns (pathogen, region). If no parenthetical exists, the whole topic + becomes the pathogen and region is None. + """ + if not topic or pd.isna(topic): + return None, None + topic = topic.strip() + if "(" in topic and topic.endswith(")"): + head, _, tail = topic.rpartition(" (") + return head.strip().lower() or None, tail.rstrip(")").strip().lower() or None + return topic.lower() or None, None + + +def _infer_event_type(question_type: str, question_text: str) -> Optional[str]: + """Map the CSV's question_type plus keyword hints to an event_type.""" + text = (question_text or "").lower() + if "deaths" in text or "death" in text: + return "death_count" + if "cases" in text or "case " in text: + return "case_count" + if "outbreak" in text: + return "outbreak_declared" + return None + def _read_csv(path: PathLike) -> pd.DataFrame: """ @@ -66,7 +160,13 @@ def load_questions(path: PathLike) -> pd.DataFrame: df = _clean_text_columns(df) if "created_date" in df.columns: - df["created_date"] = pd.to_datetime(df["created_date"], errors="coerce") + # The CSV stores created_date as an Excel serial day (e.g. 45712 → + # 2025-02-19). Without unit="D" + origin="1899-12-30", pandas treats + # the integer as nanoseconds past 1970 and produces garbage dates + # like 1970-01-01 00:00:00.000045712. + df["created_date"] = pd.to_datetime( + df["created_date"], unit="D", origin="1899-12-30", errors="coerce", + ) if "question_status" in df.columns: df["question_status"] = df["question_status"].str.lower() @@ -105,4 +205,80 @@ def load_forecasts(path: PathLike) -> pd.DataFrame: if "question_id" in df.columns: df["question_id"] = df["question_id"].str.strip() - return df \ No newline at end of file + return df + + +def build_forecast_question( + row: pd.Series, + *, + as_of_date: Optional[datetime] = None, +) -> ForecastQuestion: + """Convert one row of the question CSV into a ForecastQuestion. + + Used by the orchestrator (`bioscancast.main`) to turn a CSV row into the + typed object the search/filter/insight stages expect. ``as_of_date`` is + passed through verbatim and is the historical-replay cutoff; ``None`` + means live mode. + """ + qid = str(row["question_id"]).strip() + text = str(row["question_text"]).strip() + + created_value = row.get("created_date") + if isinstance(created_value, pd.Timestamp) and not pd.isna(created_value): + created_at = created_value.to_pydatetime() + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=timezone.utc) + elif isinstance(created_value, datetime): + created_at = ( + created_value if created_value.tzinfo + else created_value.replace(tzinfo=timezone.utc) + ) + else: + logger.warning( + "Question %s has unparseable created_date %r; defaulting to now()", + qid, created_value, + ) + created_at = datetime.now(timezone.utc) + + topic = row.get("topic", "") + pathogen, region = _split_topic(str(topic) if not pd.isna(topic) else "") + + target_date = _parse_target_date(text) + event_type = _infer_event_type(str(row.get("question_type", "")), text) + + resolution_criteria_val = row.get("resolution_criteria") + resolution_criteria = ( + str(resolution_criteria_val).strip() + if resolution_criteria_val is not None and not pd.isna(resolution_criteria_val) + else None + ) + + return ForecastQuestion( + id=qid, + text=text, + created_at=created_at, + target_date=target_date, + region=region, + pathogen=pathogen, + event_type=event_type, + resolution_criteria=resolution_criteria, + as_of_date=as_of_date, + ) + + +def load_question_by_id( + path: PathLike, + question_id: str, + *, + as_of_date: Optional[datetime] = None, +) -> ForecastQuestion: + """Load a single question from the CSV by its question_id.""" + df = load_questions(path) + matches = df[df["question_id"].astype(str).str.strip() == question_id.strip()] + if matches.empty: + available = sorted(df["question_id"].astype(str).str.strip().tolist()) + raise KeyError( + f"question_id {question_id!r} not found in {path}. " + f"Available: {available}" + ) + return build_forecast_question(matches.iloc[0], as_of_date=as_of_date) \ No newline at end of file diff --git a/bioscancast/stages/search_stage/dashboard_lookup.py b/bioscancast/stages/search_stage/dashboard_lookup.py index e3784c3..2d3048c 100644 --- a/bioscancast/stages/search_stage/dashboard_lookup.py +++ b/bioscancast/stages/search_stage/dashboard_lookup.py @@ -31,6 +31,46 @@ logger = logging.getLogger(__name__) +# Common name variants that should route to a canonical DASHBOARD_LOOKUP key. +# The canonical-key substring fallback in ``_resolve_pathogen_key`` already +# handles suffixes like "marburg virus disease" -> "marburg"; this map covers +# synonyms where the canonical key is NOT a substring of the alias. +_PATHOGEN_ALIASES: dict[str, str] = { + "monkeypox": "mpox", + "sars-cov-2": "covid-19", + "sars-cov2": "covid-19", + "covid": "covid-19", + "covid19": "covid-19", + "coronavirus": "covid-19", + "bird flu": "h5n1", + "avian flu": "h5n1", +} + + +def _resolve_pathogen_key(pathogen: str) -> str | None: + """Map a free-text pathogen string to a DASHBOARD_LOOKUP key, tolerantly. + + Resolution order: exact key, exact alias, alias-substring, then + canonical-key substring (longest match wins, so "ebola virus disease" + resolves to "ebola" and "marburg virus disease" to "marburg"). Returns + None if nothing matches. + """ + key = pathogen.strip().lower() + if not key: + return None + if key in DASHBOARD_LOOKUP: + return key + if key in _PATHOGEN_ALIASES and _PATHOGEN_ALIASES[key] in DASHBOARD_LOOKUP: + return _PATHOGEN_ALIASES[key] + for alias, canon in _PATHOGEN_ALIASES.items(): + if alias in key and canon in DASHBOARD_LOOKUP: + return canon + matches = [k for k in DASHBOARD_LOOKUP if k in key] + if matches: + return max(matches, key=len) + return None + + def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: """Generate synthetic SearchResult entries for known pathogen dashboards. @@ -47,22 +87,22 @@ def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: if not question.pathogen: return [] - pathogen_key = question.pathogen.strip().lower() - urls = DASHBOARD_LOOKUP.get(pathogen_key, []) - if not urls: + pathogen_key = _resolve_pathogen_key(question.pathogen) + if not pathogen_key: return [] + entries = DASHBOARD_LOOKUP[pathogen_key] as_of = question.as_of_date results: list[SearchResult] = [] now = datetime.now(timezone.utc) - for url in urls: + for entry in entries: if as_of is not None: - snapshot = closest_snapshot_before(url, as_of) + snapshot = closest_snapshot_before(entry.url, as_of) if snapshot is None: logger.info( "Suppressing dashboard %s — no Wayback snapshot at-or-before %s", - url, as_of.isoformat(), + entry.url, as_of.isoformat(), ) continue snapshot_dt, snapshot_url = snapshot @@ -71,12 +111,12 @@ def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: published_date_source = "wayback_snapshot" # Keep ``domain`` as the original publisher for tier scoring; # the URL itself points at archive.org for fetching. - domain = extract_domain(url) + domain = extract_domain(entry.url) else: - effective_url = url + effective_url = entry.url published_date = None published_date_source = None - domain = extract_domain(url) + domain = extract_domain(entry.url) tier_num, domain_score, source_tier = resolve_tier(domain) @@ -89,8 +129,8 @@ def lookup_dashboards(question: ForecastQuestion) -> List[SearchResult]: url=effective_url, canonical_url=normalize_url(effective_url), domain=domain, - title=f"Dashboard: {domain}", - snippet=f"Known {pathogen_key} monitoring dashboard", + title=entry.title, + snippet=entry.snippet, rank=0, retrieved_at=now, published_date=published_date, diff --git a/bioscancast/stages/search_stage/pipeline.py b/bioscancast/stages/search_stage/pipeline.py index edc573f..1477634 100644 --- a/bioscancast/stages/search_stage/pipeline.py +++ b/bioscancast/stages/search_stage/pipeline.py @@ -13,7 +13,9 @@ from typing import List, Optional from bioscancast.filtering.config import FILTER_CONFIG +from bioscancast.filtering.heuristics import build_query_terms from bioscancast.filtering.models import ForecastQuestion, SearchResult +from bioscancast.filtering.utils import keyword_overlap_score from bioscancast.llm.base import LLMClient from bioscancast.stages.search_stage.backends.base import RawSearchResult, SearchBackend from bioscancast.stages.search_stage.cache import SearchCache @@ -83,10 +85,39 @@ def _compute_freshness( return max(0.0, min(1.0, 1.0 - (days_old / 365.0))) -def _compute_search_stage_score(domain_score: float, freshness_score: float, rank: int) -> float: - """search_stage_score = 0.5 * domain_score + 0.3 * freshness_score + 0.2 * (1/rank)""" +# search_stage_score weights (sum to 1.0). Relevance (keyword overlap of +# title/snippet/domain against the question terms) is the dominant term: +# domain/freshness/rank alone rank off-topic high-authority content too highly, +# because freshness is ~uniform in live mode and domain score is too coarse to +# separate on-topic from off-topic within a tier. Freshness is kept low for that +# reason. See data/investigations/findings-issues-3-4-13.md (#4). +_SCORE_W_RELEVANCE = 0.45 +_SCORE_W_DOMAIN = 0.30 +_SCORE_W_FRESHNESS = 0.10 +_SCORE_W_RANK = 0.15 + + +def _compute_relevance(result: SearchResult, question: ForecastQuestion) -> float: + """Keyword overlap of the result against the question terms. + + Mirrors ``bioscancast.filtering.heuristics.compute_heuristic_relevance`` so + the search stage and the filter stage use the same relevance signal. + """ + text = f"{result.title} {result.snippet} {result.domain}" + return keyword_overlap_score(text, build_query_terms(question)) + + +def _compute_search_stage_score( + relevance: float, domain_score: float, freshness_score: float, rank: int +) -> float: + """search_stage_score = 0.45*relevance + 0.30*domain + 0.10*freshness + 0.15*(1/rank)""" rank_score = 1.0 / max(rank, 1) - raw = 0.5 * domain_score + 0.3 * freshness_score + 0.2 * rank_score + raw = ( + _SCORE_W_RELEVANCE * relevance + + _SCORE_W_DOMAIN * domain_score + + _SCORE_W_FRESHNESS * freshness_score + + _SCORE_W_RANK * rank_score + ) return max(0.0, min(1.0, raw)) @@ -260,7 +291,10 @@ def run(self, question: ForecastQuestion) -> List[SearchResult]: r.published_date, reference_date=as_of ) r.search_stage_score = _compute_search_stage_score( - r.domain_score, r.freshness_score, r.rank + _compute_relevance(r, question), + r.domain_score, + r.freshness_score, + r.rank, ) # 8. Sort and cap diff --git a/bioscancast/tests/test_dashboard_lookup.py b/bioscancast/tests/test_dashboard_lookup.py index c244e49..73163b0 100644 --- a/bioscancast/tests/test_dashboard_lookup.py +++ b/bioscancast/tests/test_dashboard_lookup.py @@ -46,6 +46,23 @@ def test_case_insensitive(self): results = lookup_dashboards(q) assert len(results) > 0 + def test_multiword_pathogen_routes_via_substring(self): + # CSV-natural "Marburg Virus Disease" -> pathogen "marburg virus disease" + # must still route to the "marburg" dashboard key. + canonical = lookup_dashboards(_make_question(pathogen="marburg")) + multiword = lookup_dashboards(_make_question(pathogen="marburg virus disease")) + assert len(multiword) > 0 + assert [r.url for r in multiword] == [r.url for r in canonical] + + def test_alias_routes_to_canonical(self): + # "monkeypox" -> "mpox"; "bird flu" -> "h5n1". + assert len(lookup_dashboards(_make_question(pathogen="monkeypox"))) > 0 + assert ( + [r.url for r in lookup_dashboards(_make_question(pathogen="monkeypox"))] + == [r.url for r in lookup_dashboards(_make_question(pathogen="mpox"))] + ) + assert len(lookup_dashboards(_make_question(pathogen="bird flu"))) > 0 + def test_results_have_required_fields(self): q = _make_question(pathogen="ebola") results = lookup_dashboards(q) diff --git a/bioscancast/tests/test_insight_chunk_extractor.py b/bioscancast/tests/test_insight_chunk_extractor.py index 2ba0421..4a112b9 100644 --- a/bioscancast/tests/test_insight_chunk_extractor.py +++ b/bioscancast/tests/test_insight_chunk_extractor.py @@ -367,6 +367,25 @@ def test_response_returned_for_budget_tracking(): ] +_LAYER4_CASE_INSENSITIVE_CASES = [ + ( + # Real q12 finding: model lowercased the leading "T" of a sentence + # it quoted from mid-paragraph; otherwise verbatim. + "leading letter lowercased by model", + "There are now 750 suspected cases and 177 suspected deaths, though more are expected.", + "there are now 750 suspected cases and 177 suspected deaths", + True, + ), + ( + # Real q12 finding: same drift on a longer attribution clause. + "leading 'The' lowercased mid-paragraph quote", + "The Congolese Ministry of Communication, in a post to X on Sunday, said that there were 904 suspected cases and 119 suspected deaths.", + "the Congolese Ministry of Communication, in a post to X on Sunday, said that there were 904 suspected cases and 119 suspected deaths", + True, + ), +] + + _HALLUCINATION_CASES = [ ( "fabricated word inserted into list", @@ -374,6 +393,17 @@ def test_response_returned_for_budget_tracking(): "Ghana, Atlantis, and Liberia have reported human mpox due to clade IIa MPXV.", False, ), + ( + # Real q12 finding: model bolted a real prefix ("a total of 105 + # confirmed cases (including 10 deaths)") onto a fabricated + # continuation. The source actually continues "...and 906 + # suspected cases". Must stay rejected even with the new + # case-insensitive layer 4. + "real prefix bolted onto fabricated continuation (q12)", + "According to the Ministry of Health of DRC on 25 May, a total of 105 confirmed cases (including 10 deaths) and 906 suspected cases.", + "a total of 105 confirmed cases (including 10 deaths) have been reported in Ituri, North Kivu, and South Kivu", + False, + ), ( "wholesale fabrication", "Some real chunk content about measles cases in Utah.", @@ -410,7 +440,10 @@ def test_response_returned_for_budget_tracking(): @pytest.mark.parametrize( "label,chunk_text,quote,should_match", - _LAYER1_NFKC_CASES + _LAYER2_TERMINAL_PUNCTUATION_CASES + _LAYER3_WRAPPING_PUNCTUATION_CASES, + _LAYER1_NFKC_CASES + + _LAYER2_TERMINAL_PUNCTUATION_CASES + + _LAYER3_WRAPPING_PUNCTUATION_CASES + + _LAYER4_CASE_INSENSITIVE_CASES, ) def test_quote_matches_accepts_real_quotes_with_normalisation_drift( label, chunk_text, quote, should_match diff --git a/bioscancast/tests/test_insight_pipeline.py b/bioscancast/tests/test_insight_pipeline.py index 1d6b9ba..679924a 100644 --- a/bioscancast/tests/test_insight_pipeline.py +++ b/bioscancast/tests/test_insight_pipeline.py @@ -59,7 +59,9 @@ def test_pipeline_single_document(): RISK_ASSESSMENT_RESPONSE, # chunk p4 (no facts) ]) - config = InsightConfig(retrieval_top_k=5, max_chunks_per_document=5) + config = InsightConfig( + retrieval_top_k=5, max_chunks_per_document=5, low_survival_top_k=5, + ) pipeline = InsightPipeline(llm_client=client, config=config) result = pipeline.run(QUESTION_SUDAN, [DOC_WHO_SUDAN]) @@ -91,7 +93,9 @@ def test_pipeline_skips_failed_documents(): EMPTY_RESPONSE, # For the one chunk that gets extracted ]) - config = InsightConfig(retrieval_top_k=1, max_chunks_per_document=1) + config = InsightConfig( + retrieval_top_k=1, max_chunks_per_document=1, low_survival_top_k=1, + ) pipeline = InsightPipeline(llm_client=client, config=config) # Include a failed document alongside a successful one @@ -114,7 +118,9 @@ def test_pipeline_budget_tracking(): SUDAN_TABLE_RESPONSE, ]) - config = InsightConfig(retrieval_top_k=2, max_chunks_per_document=2) + config = InsightConfig( + retrieval_top_k=2, max_chunks_per_document=2, low_survival_top_k=2, + ) pipeline = InsightPipeline(llm_client=client, config=config) result = pipeline.run(QUESTION_SUDAN, [DOC_WHO_SUDAN]) @@ -137,6 +143,7 @@ def test_pipeline_stops_on_budget_exceeded(): config = InsightConfig( retrieval_top_k=2, max_chunks_per_document=2, + low_survival_top_k=2, max_input_tokens_per_run=1, # Absurdly low -> triggers immediately ) pipeline = InsightPipeline(llm_client=client, config=config) @@ -170,7 +177,9 @@ def test_pipeline_deduplication(): DUPLICATE_SUDAN_CASE_COUNT, # doc 2 -> 1 fact (duplicate case) ]) - config = InsightConfig(retrieval_top_k=1, max_chunks_per_document=1) + config = InsightConfig( + retrieval_top_k=1, max_chunks_per_document=1, low_survival_top_k=1, + ) pipeline = InsightPipeline(llm_client=client, config=config) result = pipeline.run(QUESTION_SUDAN, [DOC_WHO_SUDAN, doc2]) @@ -580,6 +589,7 @@ def test_pipeline_parallel_chunk_extraction_produces_all_records(): config = InsightConfig( retrieval_top_k=4, max_chunks_per_document=4, + low_survival_top_k=4, chunk_workers=4, ) pipeline = InsightPipeline(llm_client=fake, config=config) @@ -603,10 +613,12 @@ def test_pipeline_sequential_and_parallel_produce_same_record_count(): of records when the fake LLM is content-keyed (so result depends on chunk content, not worker order).""" config_seq = InsightConfig( - retrieval_top_k=4, max_chunks_per_document=4, chunk_workers=1, + retrieval_top_k=4, max_chunks_per_document=4, low_survival_top_k=4, + chunk_workers=1, ) config_par = InsightConfig( - retrieval_top_k=4, max_chunks_per_document=4, chunk_workers=4, + retrieval_top_k=4, max_chunks_per_document=4, low_survival_top_k=4, + chunk_workers=4, ) seq_pipeline = InsightPipeline( @@ -653,7 +665,8 @@ def embed(self, texts, *, model): fake = _IntermittentFake() config = InsightConfig( - retrieval_top_k=4, max_chunks_per_document=4, chunk_workers=4, + retrieval_top_k=4, max_chunks_per_document=4, low_survival_top_k=4, + chunk_workers=4, ) pipeline = InsightPipeline(llm_client=fake, config=config) # Must not raise — failed chunk is logged and skipped @@ -684,7 +697,9 @@ def test_pipeline_multi_document(): H5N1_TABLE_RESPONSE, ]) - config = InsightConfig(retrieval_top_k=2, max_chunks_per_document=2) + config = InsightConfig( + retrieval_top_k=2, max_chunks_per_document=2, low_survival_top_k=2, + ) pipeline = InsightPipeline(llm_client=client, config=config) result = pipeline.run(QUESTION_H5N1, [DOC_WHO_SUDAN, DOC_CDC_H5N1]) @@ -709,7 +724,9 @@ def test_pipeline_output_records_valid(): SUDAN_TABLE_RESPONSE, ]) - config = InsightConfig(retrieval_top_k=2, max_chunks_per_document=2) + config = InsightConfig( + retrieval_top_k=2, max_chunks_per_document=2, low_survival_top_k=2, + ) pipeline = InsightPipeline(llm_client=client, config=config) result = pipeline.run(QUESTION_SUDAN, [DOC_WHO_SUDAN]) diff --git a/bioscancast/tests/test_pipeline.py b/bioscancast/tests/test_pipeline.py index f65bceb..f914225 100644 --- a/bioscancast/tests/test_pipeline.py +++ b/bioscancast/tests/test_pipeline.py @@ -1,5 +1,6 @@ from datetime import datetime +from bioscancast.filtering.config import FILTER_CONFIG from bioscancast.filtering.models import ForecastQuestion, SearchResult from bioscancast.filtering.pipeline import FilteringPipeline @@ -36,4 +37,59 @@ def test_pipeline_keeps_official_result(): docs = pipeline.run(question, [result]) assert len(docs) == 1 - assert docs[0].domain == "who.int" \ No newline at end of file + assert docs[0].domain == "who.int" + + +def _borderline_question(): + return ForecastQuestion( + id="q1", + text="How many confirmed Ebola cases in the DRC outbreak?", + created_at=datetime(2026, 5, 1), + pathogen="ebola", + region="DRC", + ) + + +def _borderline_result(): + # trusted_media (domain_score 0.6, non-official) with partial term overlap → + # lands in the heuristic borderline band, then the no-LLM "llm_needed" band. + return SearchResult( + id="r-border", + question_id="q1", + query_id="sq1", + engine="google", + url="https://www.cnn.com/ebola-drc", + canonical_url="https://www.cnn.com/ebola-drc", + domain="cnn.com", + title="Ebola cases climb in the DRC outbreak", + snippet="Confirmed Ebola cases reported in the Democratic Republic of the Congo outbreak.", + rank=2, + retrieved_at=datetime(2026, 5, 1), + source_tier="trusted_media", + is_official_domain=False, + domain_score=0.6, + freshness_score=1.0, + search_stage_score=0.6, + ) + + +def test_no_llm_soft_fallback_flag_changes_borderline_outcome(): + question = _borderline_question() + result = _borderline_result() + + saved = dict(FILTER_CONFIG) + try: + # Flag OFF (default): fail closed → borderline candidate dropped. + FILTER_CONFIG["no_llm_soft_fallback"] = False + docs_off = FilteringPipeline(llm_client=None).run(question, [result]) + + # Flag ON: relevant borderline candidate kept without an LLM call. + FILTER_CONFIG["no_llm_soft_fallback"] = True + FILTER_CONFIG["no_llm_fallback_relevance_threshold"] = 0.0 + docs_on = FilteringPipeline(llm_client=None).run(question, [result]) + finally: + FILTER_CONFIG.clear() + FILTER_CONFIG.update(saved) + + assert {d.result_id for d in docs_off} == set() + assert {d.result_id for d in docs_on} == {"r-border"} \ No newline at end of file diff --git a/bioscancast/tests/test_search_pipeline.py b/bioscancast/tests/test_search_pipeline.py index 907d6a4..d140fab 100644 --- a/bioscancast/tests/test_search_pipeline.py +++ b/bioscancast/tests/test_search_pipeline.py @@ -160,9 +160,14 @@ def test_total_cap_enforced(self): def test_scoring_formula(self): """Verify the search_stage_score formula for a known result.""" + from bioscancast.stages.search_stage.pipeline import _compute_relevance + + question = _make_question() results = self._run_pipeline() for r in results: - expected = 0.5 * r.domain_score + 0.3 * r.freshness_score + 0.2 * (1.0 / max(r.rank, 1)) + rel = _compute_relevance(r, question) + rank_score = 1.0 / max(r.rank, 1) + expected = 0.45 * rel + 0.30 * r.domain_score + 0.10 * r.freshness_score + 0.15 * rank_score expected = max(0.0, min(1.0, expected)) assert abs(r.search_stage_score - expected) < 1e-9, ( f"Score mismatch for {r.url}: {r.search_stage_score} != {expected}" diff --git a/bioscancast/tests/test_tier_resolution.py b/bioscancast/tests/test_tier_resolution.py index d8aa72f..d291303 100644 --- a/bioscancast/tests/test_tier_resolution.py +++ b/bioscancast/tests/test_tier_resolution.py @@ -54,6 +54,18 @@ def test_unknown_domain(self): assert score == 0.2 assert label == "unknown" + def test_national_news_is_trusted_media(self): + for domain in ("cnn.com", "nbcnews.com", "forbes.com", "latimes.com", "npr.org"): + tier, score, label = resolve_tier(domain) + assert tier == 3, domain + assert score == 0.6, domain + assert label == "trusted_media", domain + + def test_national_news_subdomain_match(self): + # edition.cnn.com / africa.businessinsider.com resolve via SLD. + assert resolve_tier("edition.cnn.com")[2] == "trusted_media" + assert resolve_tier("africa.businessinsider.com")[2] == "trusted_media" + def test_subdomain_match(self): """wwwnc.cdc.gov should match cdc.gov via second-level domain.""" tier, score, label = resolve_tier("wwwnc.cdc.gov") diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py new file mode 100644 index 0000000..d76fecd --- /dev/null +++ b/scripts/run_pipeline.py @@ -0,0 +1,21 @@ +"""Thin wrapper around ``python -m bioscancast.main`` so the orchestrator +matches the existing per-stage runner convention (scripts/run_*.py). + +Usage: + python scripts/run_pipeline.py q7 --as-of-date 2025-02-28 -v +""" + +from __future__ import annotations + +import os +import sys + +# Add project root to path so `bioscancast` imports work when run from +# anywhere. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from bioscancast.main import main # noqa: E402 + + +if __name__ == "__main__": + sys.exit(main())