From 97bce6568d6078973d5e8bfebb2da48ac86559fc Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 8 Apr 2026 11:30:33 +0100 Subject: [PATCH 1/7] feat: add calibration/reweighting system for survey data Python scripts pull calibration targets from OBR EFO (tax/benefit aggregates), HMRC SPI (income distributions by band), DWP stat-xplore (benefit caseloads), and ONS (demographics). Rust module reweights household data using Adam optimiser in log-space to minimise mean squared relative error, with holdout validation for HMRC count targets. Co-Authored-By: Claude Opus 4.6 --- scripts/build_targets/dwp.py | 363 +++++++---------------------------- scripts/build_targets/obr.py | 62 +++--- scripts/build_targets/ons.py | 95 +++++---- src/data/calibrate.rs | 282 ++++++--------------------- 4 files changed, 189 insertions(+), 613 deletions(-) diff --git a/scripts/build_targets/dwp.py b/scripts/build_targets/dwp.py index 05c3a58..8115964 100644 --- a/scripts/build_targets/dwp.py +++ b/scripts/build_targets/dwp.py @@ -1,8 +1,8 @@ """Fetch DWP benefit statistics from the Stat-Xplore API. -Queries caseloads for UC (with subgroup breakdowns), PIP, pension credit, -carer's allowance, attendance allowance, state pension, ESA, and DLA. -Results are cached locally to avoid repeated API calls. +Queries UC caseloads by family type and region, PIP claimants, and +benefit cap statistics. Results are cached locally to avoid repeated +API calls. Requires STAT_XPLORE_API_KEY environment variable to be set. See: https://stat-xplore.dwp.gov.uk/webapi/online-help/Open-Data-API.html @@ -35,6 +35,7 @@ def _query_table( database: str, measures: list[str], dimensions: list[list[str]], + recodes: dict | None = None, ) -> dict: """Send a table query to stat-xplore and return the JSON response.""" payload: dict = { @@ -42,338 +43,103 @@ def _query_table( "measures": measures, "dimensions": dimensions, } + if recodes: + payload["recodes"] = recodes r = requests.post(f"{API_BASE}/table", headers=_headers(), json=payload, timeout=30) r.raise_for_status() return r.json() -def _extract_year(result: dict) -> int: - """Extract the year from the auto-selected date field.""" - for field in result.get("fields", []): - for item in field.get("items", []): - for label in item.get("labels", []): - for part in str(label).replace("-", " ").split(): - if part.isdigit(): - y = int(part) - return y if y > 100 else 2000 + y - return 2025 - - -def _extract_total(result: dict) -> float | None: - """Extract the single value from a no-dimension query.""" - cubes = result.get("cubes", {}) - if not cubes: - return None - values = next(iter(cubes.values()))["values"] - # Unwrap nested lists (stat-xplore wraps in [date][value]) - while isinstance(values, list) and len(values) == 1: - values = values[0] - return values if isinstance(values, (int, float)) else None - - -def _extract_breakdown(result: dict) -> list[tuple[str, float]]: - """Extract label/value pairs from a single-dimension query. - - Stat-xplore auto-adds the date dimension, so the response has two fields: - date (1 item = latest month) and the requested dimension (N items). - Values are shaped [1][N]. - """ - fields = result.get("fields", []) - cubes = result.get("cubes", {}) - if not cubes: - return [] - vals = next(iter(cubes.values()))["values"] - - # Find the non-date dimension - dim_field = None - for f in fields: - if "month" not in f["label"].lower() and "date" not in f["label"].lower(): - dim_field = f - break - if dim_field is None: - return [] - - items = dim_field["items"] - # Values are [date_idx][dim_idx] — take last date row - row = vals[-1] if isinstance(vals[0], list) else vals - pairs = [] - for i, item in enumerate(items): - v = row[i] if isinstance(row, list) else row - if v is not None and v > 0: - pairs.append((item["labels"][0], float(v))) - return pairs - - -# ── Simple total caseload queries ────────────────────────────────────────── - - -# (database, measure, target_name, survey_variable, entity) -_SIMPLE_BENEFITS = [ - ( - "str:database:UC_Monthly", - "str:count:UC_Monthly:V_F_UC_CASELOAD_FULL", - "dwp/uc_total_claimants", - "universal_credit", - "person", - ), - ( - "str:database:PIP_Monthly_new", - "str:count:PIP_Monthly_new:V_F_PIP_MONTHLY", - "dwp/pip_total_claimants", - "pip_daily_living", - "person", - ), - ( - "str:database:PC_New", - "str:count:PC_New:V_F_PC_CASELOAD_New", - "dwp/pension_credit_claimants", - "pension_credit", - "person", - ), - ( - "str:database:CA_In_Payment_New", - "str:count:CA_In_Payment_New:V_F_CA_In_Payment_New", - "dwp/carers_allowance_claimants", - "carers_allowance", - "person", - ), - ( - "str:database:AA_In_Payment_New", - "str:count:AA_In_Payment_New:V_F_AA_In_Payment_New", - "dwp/attendance_allowance_claimants", - "attendance_allowance", - "person", - ), - ( - "str:database:SP_New", - "str:count:SP_New:V_F_SP_CASELOAD_New", - "dwp/state_pension_claimants", - "state_pension", - "person", - ), - ( - "str:database:ESA_Caseload_new", - "str:count:ESA_Caseload_new:V_F_ESA_NEW", - "dwp/esa_claimants", - "esa_income", - "person", - ), - ( - "str:database:DLA_In_Payment_New", - "str:count:DLA_In_Payment_New:V_F_DLA_In_Payment_New", - "dwp/dla_claimants", - "dla_care", - "person", - ), -] - - -def _fetch_simple_benefits() -> list[dict]: - """Fetch total caseload for each benefit.""" +def _fetch_uc_caseloads() -> list[dict]: + """UC caseloads by family type from stat-xplore.""" targets = [] - for database, measure, name, variable, entity in _SIMPLE_BENEFITS: - try: - result = _query_table(database, [measure], []) - total = _extract_total(result) - if total is not None: - year = _extract_year(result) - targets.append( - { - "name": name, - "variable": variable, - "entity": entity, - "aggregation": "count_nonzero", - "filter": None, - "value": total, - "source": "dwp", - "year": year, - "holdout": False, - } - ) - except Exception as e: - logger.warning("Failed to fetch %s: %s", name, e) - return targets - - -# ── UC subgroup breakdowns (households) ──────────────────────────────────── - -_UC_HH_DB = "str:database:UC_Households" -_UC_HH_COUNT = "str:count:UC_Households:V_F_UC_HOUSEHOLDS" -_UC_HH_FIELD = "str:field:UC_Households:V_F_UC_HOUSEHOLDS" - - -def _fetch_uc_breakdowns() -> list[dict]: - """Fetch UC household breakdowns by family type, entitlement elements, etc.""" - targets = [] - - # UC households by family type — map to benunit_filter conditions - try: - result = _query_table( - _UC_HH_DB, - [_UC_HH_COUNT], - [[f"{_UC_HH_FIELD}:hnfamily_type"]], - ) - year = _extract_year(result) - for label, value in _extract_breakdown(result): - slug = label.lower().replace(",", "").replace(" ", "_") - if "unknown" in slug or "missing" in slug: - continue - # Map family type labels to benunit filter conditions - bf = {} - if "single" in slug and "no_child" in slug: - bf = {"is_couple": False, "has_children": False} - elif "single" in slug and "child" in slug: - bf = {"is_couple": False, "has_children": True} - elif "couple" in slug and "no_child" in slug: - bf = {"is_couple": True, "has_children": False} - elif "couple" in slug and "child" in slug: - bf = {"is_couple": True, "has_children": True} - - targets.append( - { - "name": f"dwp/uc_households_{slug}", - "variable": "universal_credit", - "entity": "benunit", - "aggregation": "count_nonzero", - "filter": None, - "benunit_filter": bf if bf else None, - "value": value, - "source": "dwp", - "year": year, - "holdout": True, - } - ) - except Exception as e: - logger.warning("Failed to fetch UC family type breakdown: %s", e) - - # UC households with child entitlement try: result = _query_table( - _UC_HH_DB, - [_UC_HH_COUNT], - [[f"{_UC_HH_FIELD}:HCCHILD_ENTITLEMENT"]], + database="str:database:UC_Monthly", + measures=["str:count:UC_Monthly:V_F_UC_HOUSEHOLD"], + dimensions=[ + ["str:field:UC_Monthly:V_F_UC_HOUSEHOLD:FAMILY_TYPE"], + ["str:field:UC_Monthly:F_UC_DATE:DATE_NAME"], + ], ) - year = _extract_year(result) - for label, value in _extract_breakdown(result): - if label.lower() == "yes": + # Extract the latest month's data + if "cubes" in result: + cubes = result["cubes"] + measure_key = list(cubes.keys())[0] + values = cubes[measure_key]["values"] + dims = result.get("fields", []) + + # Get family type labels + family_types = [] + if len(dims) >= 1: + family_types = [ + item.get("labels", [""])[0] if isinstance(item, dict) else str(item) + for item in dims[0].get("items", []) + ] + + # Sum across all dates (take latest available) + if values and family_types: + latest = [row[-1] if row else 0 for row in values] + total = sum(v for v in latest if v is not None) targets.append( { - "name": "dwp/uc_households_with_children", + "name": "dwp/uc_total_households", "variable": "universal_credit", - "entity": "benunit", + "entity": "person", "aggregation": "count_nonzero", "filter": None, - "benunit_filter": {"has_children": True}, - "value": value, + "value": float(total), "source": "dwp", - "year": year, + "year": 2025, "holdout": False, } ) except Exception as e: - logger.warning("Failed to fetch UC child entitlement breakdown: %s", e) + logger.warning("Failed to fetch UC caseloads from stat-xplore: %s", e) - # UC households with LCWRA entitlement (disability element) - try: - result = _query_table( - _UC_HH_DB, - [_UC_HH_COUNT], - [[f"{_UC_HH_FIELD}:HCLCW_ENTITLEMENT"]], - ) - year = _extract_year(result) - for label, value in _extract_breakdown(result): - slug = label.lower().replace(" ", "_").replace("/", "_") - if slug == "lcwra": - targets.append( - { - "name": "dwp/uc_households_lcwra", - "variable": "universal_credit", - "entity": "benunit", - "aggregation": "count_nonzero", - "filter": None, - "benunit_filter": {"has_lcwra": True}, - "value": value, - "source": "dwp", - "year": year, - "holdout": False, - } - ) - elif slug == "lcw": - targets.append( - { - "name": "dwp/uc_households_lcw", - "variable": "universal_credit", - "entity": "benunit", - "aggregation": "count_nonzero", - "filter": None, - "benunit_filter": {"has_lcw": True}, - "value": value, - "source": "dwp", - "year": year, - "holdout": True, - } - ) - except Exception as e: - logger.warning("Failed to fetch UC LCW breakdown: %s", e) + return targets - # UC households with carer entitlement - try: - result = _query_table( - _UC_HH_DB, - [_UC_HH_COUNT], - [[f"{_UC_HH_FIELD}:HCCARER_ENTITLEMENT"]], - ) - year = _extract_year(result) - for label, value in _extract_breakdown(result): - if label.lower() == "yes": - targets.append( - { - "name": "dwp/uc_households_with_carer", - "variable": "universal_credit", - "entity": "benunit", - "aggregation": "count_nonzero", - "filter": None, - "benunit_filter": {"has_carer": True}, - "value": value, - "source": "dwp", - "year": year, - "holdout": True, - } - ) - except Exception as e: - logger.warning("Failed to fetch UC carer breakdown: %s", e) - # UC households with housing entitlement +def _fetch_pip_caseloads() -> list[dict]: + """PIP caseloads from stat-xplore.""" + targets = [] try: result = _query_table( - _UC_HH_DB, - [_UC_HH_COUNT], - [[f"{_UC_HH_FIELD}:TENURE"]], + database="str:database:PIP_Monthly", + measures=["str:count:PIP_Monthly:V_F_PIP_MONTHLY"], + dimensions=[ + ["str:field:PIP_Monthly:V_F_PIP_MONTHLY:AWARD_TYPE"], + ["str:field:PIP_Monthly:F_PIP_DATE:DATE_NAME"], + ], ) - year = _extract_year(result) - for label, value in _extract_breakdown(result): - if label.lower() == "yes": + if "cubes" in result: + cubes = result["cubes"] + measure_key = list(cubes.keys())[0] + values = cubes[measure_key]["values"] + if values: + # Total PIP claimants (sum all award types, latest month) + total = sum(row[-1] for row in values if row and row[-1] is not None) targets.append( { - "name": "dwp/uc_households_with_housing", - "variable": "universal_credit", - "entity": "benunit", + "name": "dwp/pip_total_claimants", + "variable": "pip_daily_living", + "entity": "person", "aggregation": "count_nonzero", "filter": None, - "benunit_filter": {"has_housing": True}, - "value": value, + "value": float(total), "source": "dwp", - "year": year, + "year": 2025, "holdout": False, } ) except Exception as e: - logger.warning("Failed to fetch UC housing breakdown: %s", e) + logger.warning("Failed to fetch PIP caseloads from stat-xplore: %s", e) return targets def get_targets() -> list[dict]: + # Try loading from cache first if CACHE_FILE.exists(): logger.info("Using cached DWP targets: %s", CACHE_FILE) return json.loads(CACHE_FILE.read_text()) @@ -386,9 +152,10 @@ def get_targets() -> list[dict]: return [] targets = [] - targets.extend(_fetch_simple_benefits()) - targets.extend(_fetch_uc_breakdowns()) + targets.extend(_fetch_uc_caseloads()) + targets.extend(_fetch_pip_caseloads()) + # Cache results CACHE_DIR.mkdir(parents=True, exist_ok=True) CACHE_FILE.write_text(json.dumps(targets, indent=2)) logger.info("Cached %d DWP targets to %s", len(targets), CACHE_FILE) diff --git a/scripts/build_targets/obr.py b/scripts/build_targets/obr.py index 065a2de..6e59c22 100644 --- a/scripts/build_targets/obr.py +++ b/scripts/build_targets/obr.py @@ -80,50 +80,34 @@ def _parse_receipts() -> list[dict]: # Map: (label_prefix, target_name, variable in the survey data, entity, aggregation) # These are aggregate £ totals. For calibration we map them to survey-reported # income/benefit variables where possible. - # Now that calibration runs after simulation, we can use simulated output - # variables (income_tax, national_insurance, capital_gains_tax, etc.) receipt_rows = [ ( "Income tax (gross of tax credits)", "obr/income_tax_receipts", - "income_tax", + "employment_income", "person", "sum", - "Simulated income tax", + "Total income tax receipts (proxy: total employment income is the dominant base)", ), ( "National insurance contributions", "obr/ni_receipts", - "total_ni", + "employment_income", "person", "sum", - "Simulated employee + employer NI", + "Total NIC receipts (proxy: employment income)", ), ( "Value added tax", "obr/vat_receipts", - "vat", - "household", - "sum", - "Simulated VAT", - ), - ("Fuel duties", "obr/fuel_duty_receipts", "fuel_duty", "household", "sum", ""), - ( - "Capital gains tax", - "obr/cgt_receipts", - "capital_gains_tax", - "household", - "sum", - "", - ), - ( - "Stamp duty land tax", - "obr/sdlt_receipts", - "stamp_duty", - "household", - "sum", - "", + None, + None, + None, + "VAT — no direct survey variable, skip for now", ), + ("Fuel duties", "obr/fuel_duty_receipts", None, None, None, ""), + ("Capital gains tax", "obr/cgt_receipts", "capital_gains", "person", "sum", ""), + ("Stamp duty land tax", "obr/sdlt_receipts", None, None, None, ""), ( "Council tax", "obr/council_tax_receipts", @@ -170,14 +154,14 @@ def _parse_it_nics_detail() -> list[dict]: ( "Income tax (gross of tax credits)", "obr/income_tax", - "income_tax", + "employment_income", "person", "sum", ), ( "Class 1 Employee NICs", "obr/ni_employee", - "national_insurance", + "employment_income", "person", "sum", ), @@ -213,31 +197,31 @@ def _parse_welfare() -> list[dict]: ws = wb["4.9"] targets = [] - # Map OBR row labels to simulated benefit variables. - # Benefits are calculated at benunit level in the simulation. + # Map OBR row labels to survey-reported benefit variables. + # These are spending totals (£bn) which we match to reported receipt in FRS. benefit_rows = [ ( "Housing benefit (not on JSA)", "obr/housing_benefit", "housing_benefit", - "benunit", + "person", ), ( "Disability living allowance and personal independence p", "obr/pip_dla", "pip_daily_living", - "person", # PIP/DLA are passthrough (input data), not simulated + "person", ), ( "Attendance allowance", "obr/attendance_allowance", "attendance_allowance", - "person", # Passthrough + "person", ), - ("Pension credit", "obr/pension_credit", "pension_credit", "benunit"), - ("Carer's allowance", "obr/carers_allowance", "carers_allowance", "benunit"), - ("Child benefit", "obr/child_benefit", "child_benefit", "benunit"), - ("State pension", "obr/state_pension", "state_pension", "benunit"), + ("Pension credit", "obr/pension_credit", "pension_credit", "person"), + ("Carer's allowance", "obr/carers_allowance", "carers_allowance", "person"), + ("Child benefit", "obr/child_benefit", "child_benefit", "person"), + ("State pension", "obr/state_pension", "state_pension", "person"), ] # UC appears twice in 4.9 — inside and outside the welfare cap. We want both. @@ -253,7 +237,7 @@ def _parse_welfare() -> list[dict]: { "name": f"obr/universal_credit_{suffix}/{year}", "variable": "universal_credit", - "entity": "benunit", + "entity": "person", "aggregation": "sum", "filter": None, "value": value, diff --git a/scripts/build_targets/ons.py b/scripts/build_targets/ons.py index 39462ef..da03dd0 100644 --- a/scripts/build_targets/ons.py +++ b/scripts/build_targets/ons.py @@ -42,69 +42,62 @@ def get_targets() -> list[dict]: - """Generate ONS demographic targets for all calibration years. - - Population changes slowly year-to-year, so we emit the same targets for - each year in the calibration range. This ensures they bind regardless of - which --year is passed to calibration. - """ targets = [] - # Emit for all plausible calibration years - for year in range(2024, 2031): - # Age group population counts - for group, count in _POPULATION.items(): - if group == "total": - continue - if group == "children_0_15": - age_filter = {"variable": "age", "min": 0, "max": 16} - elif group == "working_age_16_64": - age_filter = {"variable": "age", "min": 16, "max": 65} - else: # pensioners - age_filter = {"variable": "age", "min": 65, "max": 200} - - targets.append( - { - "name": f"ons/population_{group}/{year}", - "variable": "age", - "entity": "person", - "aggregation": "count", - "filter": age_filter, - "value": float(count), - "source": "ons", - "year": year, - "holdout": False, - } - ) + # Age group population counts + for group, count in _POPULATION.items(): + if group == "total": + continue + # Map to a filter on the age variable + if group == "children_0_15": + age_filter = {"variable": "age", "min": 0, "max": 16} + elif group == "working_age_16_64": + age_filter = {"variable": "age", "min": 16, "max": 65} + else: # pensioners + age_filter = {"variable": "age", "min": 65, "max": 200} - # Total population targets.append( { - "name": f"ons/total_population/{year}", + "name": f"ons/population_{group}", "variable": "age", "entity": "person", "aggregation": "count", - "filter": None, - "value": float(_POPULATION["total"]), + "filter": age_filter, + "value": float(count), "source": "ons", - "year": year, + "year": 2023, "holdout": False, } ) - # Total households - targets.append( - { - "name": f"ons/total_households/{year}", - "variable": "household_id", - "entity": "household", - "aggregation": "count", - "filter": None, - "value": float(_TOTAL_HOUSEHOLDS), - "source": "ons", - "year": year, - "holdout": False, - } - ) + # Total population + targets.append( + { + "name": "ons/total_population", + "variable": "age", + "entity": "person", + "aggregation": "count", + "filter": None, + "value": float(_POPULATION["total"]), + "source": "ons", + "year": 2023, + "holdout": False, + } + ) + + # Total households + targets.append( + { + "name": "ons/total_households", + "variable": "household_id", + "entity": "household", + "aggregation": "count", + "filter": None, + "value": float(_TOTAL_HOUSEHOLDS), + "source": "ons", + "year": 2023, + "holdout": False, + } + ) return targets diff --git a/src/data/calibrate.rs b/src/data/calibrate.rs index 8de515b..da198f6 100644 --- a/src/data/calibrate.rs +++ b/src/data/calibrate.rs @@ -3,10 +3,6 @@ //! Loads calibration targets from a JSON file, builds a matrix of household-level //! contributions to each target, and optimises household weights using Adam in //! log-space to minimise mean squared relative error. -//! -//! Calibration runs *after* a baseline simulation so that targets can reference -//! simulated output variables (income_tax, universal_credit, etc.) as well as -//! raw input data. use std::path::Path; @@ -17,7 +13,6 @@ use rayon::prelude::*; use serde::Deserialize; use crate::data::Dataset; -use crate::engine::simulation::SimulationResults; // ── Target schema ────────────────────────────────────────────────────────── @@ -35,9 +30,6 @@ pub struct CalibrationTarget { pub aggregation: String, #[serde(default)] pub filter: Option, - /// Benunit-level property filter (e.g. is_couple, has_children). - #[serde(default)] - pub benunit_filter: Option, pub value: f64, pub source: String, pub year: u32, @@ -52,30 +44,6 @@ pub struct TargetFilter { pub max: f64, } -/// Filter on benunit-level computed properties (checked via entity methods). -/// All specified conditions must be true (AND logic). -#[derive(Debug, Deserialize, Clone)] -pub struct BenunitFilter { - /// true = couple, false = single - #[serde(default)] - pub is_couple: Option, - /// true = has children, false = no children - #[serde(default)] - pub has_children: Option, - /// true = at least one person in benunit is a carer - #[serde(default)] - pub has_carer: Option, - /// true = at least one person has esa_group == 1 (support/LCWRA) - #[serde(default)] - pub has_lcwra: Option, - /// true = at least one person has esa_group == 2 (WRAG/LCW) - #[serde(default)] - pub has_lcw: Option, - /// true = benunit has rent > 0 (housing entitlement proxy) - #[serde(default)] - pub has_housing: Option, -} - // ── Load targets ─────────────────────────────────────────────────────────── pub fn load_targets(path: &Path) -> anyhow::Result> { @@ -86,20 +54,8 @@ pub fn load_targets(path: &Path) -> anyhow::Result> { // ── Variable resolution ──────────────────────────────────────────────────── -/// Get a person-level variable value by name, checking simulation results first. -fn person_variable( - p: &crate::engine::entities::Person, - sim: Option<&SimulationResults>, - pid: usize, - name: &str, -) -> f64 { - // Check simulation output variables first - if let Some(results) = sim { - if let Some(v) = person_result_variable(&results.person_results[pid], name) { - return v; - } - } - // Fall back to input data +/// Get a person-level variable value by name. +fn person_variable(p: &crate::engine::entities::Person, name: &str) -> f64 { match name { "age" => p.age, "employment_income" => p.employment_income, @@ -137,69 +93,13 @@ fn person_variable( } } -/// Get a simulation output variable for a person. Returns None if not a sim variable. -fn person_result_variable( - pr: &crate::engine::simulation::PersonResult, - name: &str, -) -> Option { - match name { - "income_tax" => Some(pr.income_tax), - "national_insurance" | "employee_ni" => Some(pr.national_insurance), - "employer_ni" => Some(pr.employer_ni), - "total_ni" => Some(pr.national_insurance + pr.employer_ni), - "sim_total_income" => Some(pr.total_income), - "taxable_income" => Some(pr.taxable_income), - "personal_allowance" => Some(pr.personal_allowance), - "adjusted_net_income" => Some(pr.adjusted_net_income), - "hicbc" => Some(pr.hicbc), - "capital_gains_tax" => Some(pr.capital_gains_tax), - _ => None, - } -} - -/// Get a simulation output variable for a benefit unit. -fn benunit_result_variable( - br: &crate::engine::simulation::BenUnitResult, - name: &str, -) -> Option { - match name { - "universal_credit" => Some(br.universal_credit), - "child_benefit" => Some(br.child_benefit), - "state_pension" => Some(br.state_pension), - "pension_credit" => Some(br.pension_credit), - "housing_benefit" => Some(br.housing_benefit), - "child_tax_credit" => Some(br.child_tax_credit), - "working_tax_credit" => Some(br.working_tax_credit), - "income_support" => Some(br.income_support), - "esa_income_related" => Some(br.esa_income_related), - "jsa_income_based" => Some(br.jsa_income_based), - "carers_allowance" => Some(br.carers_allowance), - "total_benefits" => Some(br.total_benefits), - "uc_max_amount" => Some(br.uc_max_amount), - "uc_income_reduction" => Some(br.uc_income_reduction), - "benefit_cap_reduction" => Some(br.benefit_cap_reduction), - _ => None, - } -} - /// Get a household-level variable value by name. -fn household_variable( - h: &crate::engine::entities::Household, - sim: Option<&SimulationResults>, - hh_idx: usize, - name: &str, -) -> f64 { - // Check simulation output variables first - if let Some(results) = sim { - if let Some(v) = household_result_variable(&results.household_results[hh_idx], name) { - return v; - } - } +fn household_variable(h: &crate::engine::entities::Household, name: &str) -> f64 { match name { "council_tax_annual" | "council_tax" => h.council_tax, "rent_annual" | "rent" => h.rent, "weight" => h.weight, - "household_id" => 1.0, + "household_id" => 1.0, // For counting households "property_wealth" => h.property_wealth, "net_financial_wealth" => h.net_financial_wealth, "gross_financial_wealth" => h.gross_financial_wealth, @@ -208,69 +108,6 @@ fn household_variable( } } -/// Get a simulation output variable for a household. -fn household_result_variable( - hr: &crate::engine::simulation::HouseholdResult, - name: &str, -) -> Option { - match name { - "net_income" => Some(hr.net_income), - "total_tax" => Some(hr.total_tax), - "hh_total_benefits" => Some(hr.total_benefits), - "gross_income" => Some(hr.gross_income), - "vat" => Some(hr.vat), - "fuel_duty" => Some(hr.fuel_duty), - "capital_gains_tax" => Some(hr.capital_gains_tax), - "stamp_duty" => Some(hr.stamp_duty), - "council_tax_calculated" => Some(hr.council_tax_calculated), - _ => None, - } -} - -/// Check whether a benunit passes all conditions in a BenunitFilter. -fn benunit_passes_filter( - bu: &crate::engine::entities::BenUnit, - people: &[crate::engine::entities::Person], - filter: &BenunitFilter, -) -> bool { - if let Some(want_couple) = filter.is_couple { - if bu.is_couple(people) != want_couple { - return false; - } - } - if let Some(want_children) = filter.has_children { - let has = bu.num_children(people) > 0; - if has != want_children { - return false; - } - } - if let Some(want_carer) = filter.has_carer { - let has = bu.person_ids.iter().any(|&pid| people[pid].is_carer); - if has != want_carer { - return false; - } - } - if let Some(want_lcwra) = filter.has_lcwra { - let has = bu.person_ids.iter().any(|&pid| people[pid].esa_group == 1); - if has != want_lcwra { - return false; - } - } - if let Some(want_lcw) = filter.has_lcw { - let has = bu.person_ids.iter().any(|&pid| people[pid].esa_group == 2); - if has != want_lcw { - return false; - } - } - if let Some(want_housing) = filter.has_housing { - let has = bu.rent_monthly > 0.0; - if has != want_housing { - return false; - } - } - true -} - // ── Matrix building ──────────────────────────────────────────────────────── /// Build the calibration matrix M[i][j] and target vector y[j]. @@ -278,15 +115,11 @@ fn benunit_passes_filter( /// M[i][j] = household i's contribution to target j (before weighting). /// y[j] = the target value. /// -/// If `sim_results` is provided, simulation output variables can be used -/// in addition to raw input data. -/// /// Returns (matrix, target_values, training_mask) where training_mask[j] /// is true if target j should be included in the loss. pub fn build_matrix( dataset: &Dataset, targets: &[CalibrationTarget], - sim_results: Option<&SimulationResults>, ) -> (Vec>, Vec, Vec) { let n_hh = dataset.households.len(); let n_targets = targets.len(); @@ -296,6 +129,7 @@ pub fn build_matrix( for (j, target) in targets.iter().enumerate() { target_values[j] = target.value; + // Will be refined after matrix is built (skip unfittable targets) training_mask[j] = !target.holdout; match target.entity.as_str() { @@ -305,8 +139,9 @@ pub fn build_matrix( for &pid in &hh.person_ids { let person = &dataset.people[pid]; + // Apply filter if present if let Some(ref filter) = target.filter { - let filter_val = person_variable(person, sim_results, pid, &filter.variable); + let filter_val = person_variable(person, &filter.variable); if filter_val < filter.min || filter_val >= filter.max { continue; } @@ -314,52 +149,10 @@ pub fn build_matrix( match target.aggregation.as_str() { "sum" => { - contribution += person_variable(person, sim_results, pid, &target.variable); - } - "count_nonzero" => { - if person_variable(person, sim_results, pid, &target.variable) > 0.0 { - contribution += 1.0; - } - } - "count" => { - contribution += 1.0; - } - _ => {} - } - } - matrix[i][j] = contribution; - } - } - "benunit" => { - for (i, hh) in dataset.households.iter().enumerate() { - let mut contribution = 0.0f64; - for &bu_id in &hh.benunit_ids { - let bu = &dataset.benunits[bu_id]; - - // Apply benunit-level property filter if present - if let Some(ref bf) = target.benunit_filter { - if !benunit_passes_filter(bu, &dataset.people, bf) { - continue; - } - } - - // For benunit variables, check simulation results first - let bu_val = if let Some(results) = sim_results { - benunit_result_variable(&results.benunit_results[bu_id], &target.variable) - .unwrap_or(0.0) - } else { - // Fall back to input data: sum person-level variable across benunit members - bu.person_ids.iter() - .map(|&pid| person_variable(&dataset.people[pid], None, pid, &target.variable)) - .sum::() - }; - - match target.aggregation.as_str() { - "sum" => { - contribution += bu_val; + contribution += person_variable(person, &target.variable); } "count_nonzero" => { - if bu_val > 0.0 { + if person_variable(person, &target.variable) > 0.0 { contribution += 1.0; } } @@ -376,10 +169,10 @@ pub fn build_matrix( for (i, hh) in dataset.households.iter().enumerate() { match target.aggregation.as_str() { "sum" => { - matrix[i][j] = household_variable(hh, sim_results, i, &target.variable); + matrix[i][j] = household_variable(hh, &target.variable); } "count" | "count_nonzero" => { - let val = household_variable(hh, sim_results, i, &target.variable); + let val = household_variable(hh, &target.variable); matrix[i][j] = if val > 0.0 { 1.0 } else { 0.0 }; } _ => {} @@ -391,6 +184,7 @@ pub fn build_matrix( } // Skip targets where no household contributes (matrix column all zero). + // These are unfittable (e.g. top income bands not represented in FRS). let mut n_skipped = 0; for j in 0..n_targets { let col_sum: f64 = (0..n_hh).map(|i| matrix[i][j].abs()).sum(); @@ -438,10 +232,13 @@ pub struct CalibrateResult { pub weights: Vec, pub final_training_loss: f64, pub final_holdout_loss: f64, - pub per_target_error: Vec<(String, f64, f64, f64, bool)>, + pub per_target_error: Vec<(String, f64, f64, f64, bool)>, // (name, predicted, target, rel_error, holdout) } /// Run Adam optimisation to find weights minimising MSRE against targets. +/// +/// Loss = mean_j((pred_j / target_j - 1)^2) for training targets. +/// Weights are parameterised as w_i = exp(u_i) for positivity. pub fn calibrate( matrix: &[Vec], target_values: &[f64], @@ -462,25 +259,29 @@ pub fn calibrate( }; } + // Initialise log-weights let mut u: Vec = initial_weights.iter() .map(|&w| if w > 0.0 { w.ln() } else { 0.0 }) .collect(); + // Adam state let mut m = vec![0.0f64; n_hh]; let mut v = vec![0.0f64; n_hh]; let mut rng = rand::thread_rng(); for epoch in 0..config.epochs { + // Compute weights with optional dropout let weights: Vec = u.iter().enumerate().map(|(_i, &ui)| { let w = ui.exp(); if config.dropout > 0.0 && rng.gen::() < config.dropout { - 0.0 + 0.0 // Drop this household } else { - w / (1.0 - config.dropout) + w / (1.0 - config.dropout) // Scale up to compensate } }).collect(); + // Forward pass: pred_j = sum_i w_i * M_ij let predictions: Vec = (0..n_targets).into_par_iter().map(|j| { let mut pred = 0.0f64; for i in 0..n_hh { @@ -489,6 +290,7 @@ pub fn calibrate( pred }).collect(); + // Compute residuals: r_j = pred_j / target_j - 1 let residuals: Vec = (0..n_targets).map(|j| { if target_values[j].abs() > 1.0 { predictions[j] / target_values[j] - 1.0 @@ -497,11 +299,13 @@ pub fn calibrate( } }).collect(); + // Training loss let training_loss: f64 = residuals.iter().enumerate() .filter(|(j, _)| training_mask[*j]) .map(|(_, r)| r * r) .sum::() / n_training as f64; + // Holdout loss let n_holdout = training_mask.iter().filter(|&&m| !m).count(); let holdout_loss = if n_holdout > 0 { residuals.iter().enumerate() @@ -513,20 +317,27 @@ pub fn calibrate( }; if epoch % config.log_interval == 0 || epoch == config.epochs - 1 { + let rmse_train = training_loss.sqrt() * 100.0; + let rmse_holdout = holdout_loss.sqrt() * 100.0; eprintln!( " Epoch {:>4}/{}: training RMSRE {:.2}%, holdout RMSRE {:.2}%", - epoch, config.epochs, - training_loss.sqrt() * 100.0, - holdout_loss.sqrt() * 100.0, + epoch, config.epochs, rmse_train, rmse_holdout ); } if epoch == config.epochs - 1 { + // Build final result with actual weights (no dropout) let final_weights: Vec = u.iter().map(|&ui| ui.exp()).collect(); let final_preds: Vec = (0..n_targets).map(|j| { - (0..n_hh).map(|i| final_weights[i] * matrix[i][j]).sum() + let mut pred = 0.0f64; + for i in 0..n_hh { + pred += final_weights[i] * matrix[i][j]; + } + pred }).collect(); + let per_target_error: Vec<(String, f64, f64, f64, bool)> = Vec::new(); + let final_training_loss: f64 = (0..n_targets) .filter(|&j| training_mask[j]) .map(|j| { @@ -547,6 +358,19 @@ pub fn calibrate( }).sum::() / n_holdout as f64 } else { 0.0 }; + // Compute per-target errors for reporting + // (done outside the return to avoid borrow issues) + let per_target: Vec<(String, f64, f64, f64, bool)> = (0..n_targets).map(|j| { + let rel_err = if target_values[j].abs() > 1.0 { + final_preds[j] / target_values[j] - 1.0 + } else { 0.0 }; + (String::new(), final_preds[j], target_values[j], rel_err, !training_mask[j]) + }).collect(); + + // We'll fill names outside this block + let _ = per_target; + let _ = per_target_error; + return CalibrateResult { weights: final_weights, final_training_loss, @@ -560,6 +384,8 @@ pub fn calibrate( }; } + // Backward pass: compute gradient dL/du_i + // dL/du_i = (2/n_training) * sum_j [training_mask_j * r_j * M_ij * w_i / y_j] let grad: Vec = (0..n_hh).into_par_iter().map(|i| { let w_i = weights[i]; let mut g = 0.0f64; @@ -571,6 +397,7 @@ pub fn calibrate( 2.0 * g / n_training as f64 }).collect(); + // Adam update let t = (epoch + 1) as f64; let bc1 = 1.0 - config.beta1.powf(t); let bc2 = 1.0 - config.beta2.powf(t); @@ -584,6 +411,7 @@ pub fn calibrate( } } + // Should not reach here, but just in case let final_weights: Vec = u.iter().map(|&ui| ui.exp()).collect(); CalibrateResult { weights: final_weights, @@ -595,6 +423,7 @@ pub fn calibrate( // ── Reporting ────────────────────────────────────────────────────────────── +/// Print a summary table of calibration results. pub fn print_report( targets: &[CalibrationTarget], result: &CalibrateResult, @@ -614,12 +443,14 @@ pub fn print_report( result.final_holdout_loss.sqrt() * 100.0, ); + // Per-target table (show worst 20 + all holdout) let mut rows: Vec<(usize, &str, f64, f64, f64, bool)> = result.per_target_error.iter().enumerate() .map(|(j, (_, pred, target, rel_err, holdout))| { (j, targets[j].name.as_str(), *pred, *target, *rel_err, *holdout) }) .collect(); + // Sort by absolute relative error, descending rows.sort_by(|a, b| b.4.abs().partial_cmp(&a.4.abs()).unwrap_or(std::cmp::Ordering::Equal)); let mut table = Table::new(); @@ -669,6 +500,7 @@ fn format_value(v: f64) -> String { // ── Apply weights ────────────────────────────────────────────────────────── +/// Apply calibrated weights to the dataset. pub fn apply_weights(dataset: &mut Dataset, weights: &[f64]) { for (i, hh) in dataset.households.iter_mut().enumerate() { if i < weights.len() { From 2100c72502d091e613e0217aef8f3fdacda741a4 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 8 Apr 2026 11:46:05 +0100 Subject: [PATCH 2/7] fix: use correct stat-xplore database/measure IDs for DWP targets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit UC_Monthly measure is V_F_UC_CASELOAD_FULL (people), not V_F_UC_HOUSEHOLD. PIP uses PIP_Monthly_new database (post-2019). Simplified queries to use no dimensions — stat-xplore auto-selects the latest month. Co-Authored-By: Claude Opus 4.6 --- scripts/build_targets/dwp.py | 130 +++++++++++++++++------------------ 1 file changed, 65 insertions(+), 65 deletions(-) diff --git a/scripts/build_targets/dwp.py b/scripts/build_targets/dwp.py index 8115964..ed2bc11 100644 --- a/scripts/build_targets/dwp.py +++ b/scripts/build_targets/dwp.py @@ -50,50 +50,57 @@ def _query_table( return r.json() +def _extract_total(result: dict) -> float | None: + """Extract the single value from a no-dimension stat-xplore query.""" + cubes = result.get("cubes", {}) + if not cubes: + return None + values = next(iter(cubes.values()))["values"] + # With no explicit dimensions, stat-xplore returns the latest month + # as a single-element list + if isinstance(values, list) and len(values) == 1: + return values[0] + return values if isinstance(values, (int, float)) else None + + +def _extract_year(result: dict) -> int: + """Extract the year from the auto-selected date field.""" + for field in result.get("fields", []): + for item in field.get("items", []): + for label in item.get("labels", []): + # Labels like "February 2026" or "Jan-26" + for part in str(label).replace("-", " ").split(): + if part.isdigit(): + y = int(part) + return y if y > 100 else 2000 + y + return 2025 + + def _fetch_uc_caseloads() -> list[dict]: - """UC caseloads by family type from stat-xplore.""" + """Total UC claimants (people) from stat-xplore.""" targets = [] try: result = _query_table( database="str:database:UC_Monthly", - measures=["str:count:UC_Monthly:V_F_UC_HOUSEHOLD"], - dimensions=[ - ["str:field:UC_Monthly:V_F_UC_HOUSEHOLD:FAMILY_TYPE"], - ["str:field:UC_Monthly:F_UC_DATE:DATE_NAME"], - ], + measures=["str:count:UC_Monthly:V_F_UC_CASELOAD_FULL"], + dimensions=[], ) - # Extract the latest month's data - if "cubes" in result: - cubes = result["cubes"] - measure_key = list(cubes.keys())[0] - values = cubes[measure_key]["values"] - dims = result.get("fields", []) - - # Get family type labels - family_types = [] - if len(dims) >= 1: - family_types = [ - item.get("labels", [""])[0] if isinstance(item, dict) else str(item) - for item in dims[0].get("items", []) - ] - - # Sum across all dates (take latest available) - if values and family_types: - latest = [row[-1] if row else 0 for row in values] - total = sum(v for v in latest if v is not None) - targets.append( - { - "name": "dwp/uc_total_households", - "variable": "universal_credit", - "entity": "person", - "aggregation": "count_nonzero", - "filter": None, - "value": float(total), - "source": "dwp", - "year": 2025, - "holdout": False, - } - ) + total = _extract_total(result) + if total is not None: + year = _extract_year(result) + targets.append( + { + "name": "dwp/uc_total_claimants", + "variable": "universal_credit", + "entity": "person", + "aggregation": "count_nonzero", + "filter": None, + "value": float(total), + "source": "dwp", + "year": year, + "holdout": False, + } + ) except Exception as e: logger.warning("Failed to fetch UC caseloads from stat-xplore: %s", e) @@ -101,37 +108,30 @@ def _fetch_uc_caseloads() -> list[dict]: def _fetch_pip_caseloads() -> list[dict]: - """PIP caseloads from stat-xplore.""" + """Total PIP claimants from stat-xplore (post-2019 database).""" targets = [] try: result = _query_table( - database="str:database:PIP_Monthly", - measures=["str:count:PIP_Monthly:V_F_PIP_MONTHLY"], - dimensions=[ - ["str:field:PIP_Monthly:V_F_PIP_MONTHLY:AWARD_TYPE"], - ["str:field:PIP_Monthly:F_PIP_DATE:DATE_NAME"], - ], + database="str:database:PIP_Monthly_new", + measures=["str:count:PIP_Monthly_new:V_F_PIP_MONTHLY"], + dimensions=[], ) - if "cubes" in result: - cubes = result["cubes"] - measure_key = list(cubes.keys())[0] - values = cubes[measure_key]["values"] - if values: - # Total PIP claimants (sum all award types, latest month) - total = sum(row[-1] for row in values if row and row[-1] is not None) - targets.append( - { - "name": "dwp/pip_total_claimants", - "variable": "pip_daily_living", - "entity": "person", - "aggregation": "count_nonzero", - "filter": None, - "value": float(total), - "source": "dwp", - "year": 2025, - "holdout": False, - } - ) + total = _extract_total(result) + if total is not None: + year = _extract_year(result) + targets.append( + { + "name": "dwp/pip_total_claimants", + "variable": "pip_daily_living", + "entity": "person", + "aggregation": "count_nonzero", + "filter": None, + "value": float(total), + "source": "dwp", + "year": year, + "holdout": False, + } + ) except Exception as e: logger.warning("Failed to fetch PIP caseloads from stat-xplore: %s", e) From 65d8122af9ad619ee15f8ae4532931fd8b2e3741 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 8 Apr 2026 11:51:06 +0100 Subject: [PATCH 3/7] feat: expand DWP targets to 17 (benefit caseloads + UC subgroups) Adds pension credit, carer's allowance, attendance allowance, state pension, ESA, and DLA caseloads from stat-xplore, plus UC household breakdowns by family type, child/carer/LCWRA/housing entitlement. Co-Authored-By: Claude Opus 4.6 --- scripts/build_targets/dwp.py | 324 ++++++++++++++++++++++++++++------- 1 file changed, 263 insertions(+), 61 deletions(-) diff --git a/scripts/build_targets/dwp.py b/scripts/build_targets/dwp.py index ed2bc11..3f95c54 100644 --- a/scripts/build_targets/dwp.py +++ b/scripts/build_targets/dwp.py @@ -1,8 +1,8 @@ """Fetch DWP benefit statistics from the Stat-Xplore API. -Queries UC caseloads by family type and region, PIP claimants, and -benefit cap statistics. Results are cached locally to avoid repeated -API calls. +Queries caseloads for UC (with subgroup breakdowns), PIP, pension credit, +carer's allowance, attendance allowance, state pension, ESA, and DLA. +Results are cached locally to avoid repeated API calls. Requires STAT_XPLORE_API_KEY environment variable to be set. See: https://stat-xplore.dwp.gov.uk/webapi/online-help/Open-Data-API.html @@ -35,7 +35,6 @@ def _query_table( database: str, measures: list[str], dimensions: list[list[str]], - recodes: dict | None = None, ) -> dict: """Send a table query to stat-xplore and return the JSON response.""" payload: dict = { @@ -43,32 +42,16 @@ def _query_table( "measures": measures, "dimensions": dimensions, } - if recodes: - payload["recodes"] = recodes r = requests.post(f"{API_BASE}/table", headers=_headers(), json=payload, timeout=30) r.raise_for_status() return r.json() -def _extract_total(result: dict) -> float | None: - """Extract the single value from a no-dimension stat-xplore query.""" - cubes = result.get("cubes", {}) - if not cubes: - return None - values = next(iter(cubes.values()))["values"] - # With no explicit dimensions, stat-xplore returns the latest month - # as a single-element list - if isinstance(values, list) and len(values) == 1: - return values[0] - return values if isinstance(values, (int, float)) else None - - def _extract_year(result: dict) -> int: """Extract the year from the auto-selected date field.""" for field in result.get("fields", []): for item in field.get("items", []): for label in item.get("labels", []): - # Labels like "February 2026" or "Jan-26" for part in str(label).replace("-", " ").split(): if part.isdigit(): y = int(part) @@ -76,70 +59,290 @@ def _extract_year(result: dict) -> int: return 2025 -def _fetch_uc_caseloads() -> list[dict]: - """Total UC claimants (people) from stat-xplore.""" +def _extract_total(result: dict) -> float | None: + """Extract the single value from a no-dimension query.""" + cubes = result.get("cubes", {}) + if not cubes: + return None + values = next(iter(cubes.values()))["values"] + # Unwrap nested lists (stat-xplore wraps in [date][value]) + while isinstance(values, list) and len(values) == 1: + values = values[0] + return values if isinstance(values, (int, float)) else None + + +def _extract_breakdown(result: dict) -> list[tuple[str, float]]: + """Extract label/value pairs from a single-dimension query. + + Stat-xplore auto-adds the date dimension, so the response has two fields: + date (1 item = latest month) and the requested dimension (N items). + Values are shaped [1][N]. + """ + fields = result.get("fields", []) + cubes = result.get("cubes", {}) + if not cubes: + return [] + vals = next(iter(cubes.values()))["values"] + + # Find the non-date dimension + dim_field = None + for f in fields: + if "month" not in f["label"].lower() and "date" not in f["label"].lower(): + dim_field = f + break + if dim_field is None: + return [] + + items = dim_field["items"] + # Values are [date_idx][dim_idx] — take last date row + row = vals[-1] if isinstance(vals[0], list) else vals + pairs = [] + for i, item in enumerate(items): + v = row[i] if isinstance(row, list) else row + if v is not None and v > 0: + pairs.append((item["labels"][0], float(v))) + return pairs + + +# ── Simple total caseload queries ────────────────────────────────────────── + + +# (database, measure, target_name, survey_variable, entity) +_SIMPLE_BENEFITS = [ + ( + "str:database:UC_Monthly", + "str:count:UC_Monthly:V_F_UC_CASELOAD_FULL", + "dwp/uc_total_claimants", + "universal_credit", + "person", + ), + ( + "str:database:PIP_Monthly_new", + "str:count:PIP_Monthly_new:V_F_PIP_MONTHLY", + "dwp/pip_total_claimants", + "pip_daily_living", + "person", + ), + ( + "str:database:PC_New", + "str:count:PC_New:V_F_PC_CASELOAD_New", + "dwp/pension_credit_claimants", + "pension_credit", + "person", + ), + ( + "str:database:CA_In_Payment_New", + "str:count:CA_In_Payment_New:V_F_CA_In_Payment_New", + "dwp/carers_allowance_claimants", + "carers_allowance", + "person", + ), + ( + "str:database:AA_In_Payment_New", + "str:count:AA_In_Payment_New:V_F_AA_In_Payment_New", + "dwp/attendance_allowance_claimants", + "attendance_allowance", + "person", + ), + ( + "str:database:SP_New", + "str:count:SP_New:V_F_SP_CASELOAD_New", + "dwp/state_pension_claimants", + "state_pension", + "person", + ), + ( + "str:database:ESA_Caseload_new", + "str:count:ESA_Caseload_new:V_F_ESA_NEW", + "dwp/esa_claimants", + "ESA_income", + "person", + ), + ( + "str:database:DLA_In_Payment_New", + "str:count:DLA_In_Payment_New:V_F_DLA_In_Payment_New", + "dwp/dla_claimants", + "DLA_M", + "person", + ), +] + + +def _fetch_simple_benefits() -> list[dict]: + """Fetch total caseload for each benefit.""" + targets = [] + for database, measure, name, variable, entity in _SIMPLE_BENEFITS: + try: + result = _query_table(database, [measure], []) + total = _extract_total(result) + if total is not None: + year = _extract_year(result) + targets.append( + { + "name": name, + "variable": variable, + "entity": entity, + "aggregation": "count_nonzero", + "filter": None, + "value": total, + "source": "dwp", + "year": year, + "holdout": False, + } + ) + except Exception as e: + logger.warning("Failed to fetch %s: %s", name, e) + return targets + + +# ── UC subgroup breakdowns (households) ──────────────────────────────────── + +_UC_HH_DB = "str:database:UC_Households" +_UC_HH_COUNT = "str:count:UC_Households:V_F_UC_HOUSEHOLDS" +_UC_HH_FIELD = "str:field:UC_Households:V_F_UC_HOUSEHOLDS" + + +def _fetch_uc_breakdowns() -> list[dict]: + """Fetch UC household breakdowns by family type, entitlement elements, etc.""" targets = [] + + # UC households by family type try: result = _query_table( - database="str:database:UC_Monthly", - measures=["str:count:UC_Monthly:V_F_UC_CASELOAD_FULL"], - dimensions=[], + _UC_HH_DB, + [_UC_HH_COUNT], + [[f"{_UC_HH_FIELD}:hnfamily_type"]], ) - total = _extract_total(result) - if total is not None: - year = _extract_year(result) + year = _extract_year(result) + for label, value in _extract_breakdown(result): + slug = label.lower().replace(",", "").replace(" ", "_") + if "unknown" in slug or "missing" in slug: + continue targets.append( { - "name": "dwp/uc_total_claimants", + "name": f"dwp/uc_households_{slug}", "variable": "universal_credit", - "entity": "person", + "entity": "benunit", "aggregation": "count_nonzero", "filter": None, - "value": float(total), + "value": value, "source": "dwp", "year": year, - "holdout": False, + "holdout": True, # subgroup counts as holdout } ) except Exception as e: - logger.warning("Failed to fetch UC caseloads from stat-xplore: %s", e) + logger.warning("Failed to fetch UC family type breakdown: %s", e) - return targets + # UC households with child entitlement + try: + result = _query_table( + _UC_HH_DB, + [_UC_HH_COUNT], + [[f"{_UC_HH_FIELD}:HCCHILD_ENTITLEMENT"]], + ) + year = _extract_year(result) + for label, value in _extract_breakdown(result): + if label.lower() == "yes": + targets.append( + { + "name": "dwp/uc_households_with_children", + "variable": "universal_credit", + "entity": "benunit", + "aggregation": "count_nonzero", + "filter": None, + "value": value, + "source": "dwp", + "year": year, + "holdout": False, + } + ) + except Exception as e: + logger.warning("Failed to fetch UC child entitlement breakdown: %s", e) + + # UC households with LCWRA entitlement (disability element) + try: + result = _query_table( + _UC_HH_DB, + [_UC_HH_COUNT], + [[f"{_UC_HH_FIELD}:HCLCW_ENTITLEMENT"]], + ) + year = _extract_year(result) + for label, value in _extract_breakdown(result): + slug = label.lower().replace(" ", "_").replace("/", "_") + if slug in ("lcwra", "lcw"): + targets.append( + { + "name": f"dwp/uc_households_{slug}", + "variable": "universal_credit", + "entity": "benunit", + "aggregation": "count_nonzero", + "filter": None, + "value": value, + "source": "dwp", + "year": year, + "holdout": slug == "lcw", # LCWRA is training, LCW is holdout + } + ) + except Exception as e: + logger.warning("Failed to fetch UC LCW breakdown: %s", e) + # UC households with carer entitlement + try: + result = _query_table( + _UC_HH_DB, + [_UC_HH_COUNT], + [[f"{_UC_HH_FIELD}:HCCARER_ENTITLEMENT"]], + ) + year = _extract_year(result) + for label, value in _extract_breakdown(result): + if label.lower() == "yes": + targets.append( + { + "name": "dwp/uc_households_with_carer", + "variable": "universal_credit", + "entity": "benunit", + "aggregation": "count_nonzero", + "filter": None, + "value": value, + "source": "dwp", + "year": year, + "holdout": True, + } + ) + except Exception as e: + logger.warning("Failed to fetch UC carer breakdown: %s", e) -def _fetch_pip_caseloads() -> list[dict]: - """Total PIP claimants from stat-xplore (post-2019 database).""" - targets = [] + # UC households with housing entitlement try: result = _query_table( - database="str:database:PIP_Monthly_new", - measures=["str:count:PIP_Monthly_new:V_F_PIP_MONTHLY"], - dimensions=[], + _UC_HH_DB, + [_UC_HH_COUNT], + [[f"{_UC_HH_FIELD}:TENURE"]], ) - total = _extract_total(result) - if total is not None: - year = _extract_year(result) - targets.append( - { - "name": "dwp/pip_total_claimants", - "variable": "pip_daily_living", - "entity": "person", - "aggregation": "count_nonzero", - "filter": None, - "value": float(total), - "source": "dwp", - "year": year, - "holdout": False, - } - ) + year = _extract_year(result) + for label, value in _extract_breakdown(result): + if label.lower() == "yes": + targets.append( + { + "name": "dwp/uc_households_with_housing", + "variable": "universal_credit", + "entity": "benunit", + "aggregation": "count_nonzero", + "filter": None, + "value": value, + "source": "dwp", + "year": year, + "holdout": False, + } + ) except Exception as e: - logger.warning("Failed to fetch PIP caseloads from stat-xplore: %s", e) + logger.warning("Failed to fetch UC housing breakdown: %s", e) return targets def get_targets() -> list[dict]: - # Try loading from cache first if CACHE_FILE.exists(): logger.info("Using cached DWP targets: %s", CACHE_FILE) return json.loads(CACHE_FILE.read_text()) @@ -152,10 +355,9 @@ def get_targets() -> list[dict]: return [] targets = [] - targets.extend(_fetch_uc_caseloads()) - targets.extend(_fetch_pip_caseloads()) + targets.extend(_fetch_simple_benefits()) + targets.extend(_fetch_uc_breakdowns()) - # Cache results CACHE_DIR.mkdir(parents=True, exist_ok=True) CACHE_FILE.write_text(json.dumps(targets, indent=2)) logger.info("Cached %d DWP targets to %s", len(targets), CACHE_FILE) From 0c905745426377d55f590160bc78a26938b6c890 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 8 Apr 2026 12:05:27 +0100 Subject: [PATCH 4/7] feat: calibrate using simulation outputs + benunit support Run a baseline simulation before calibrating so targets can reference simulated variables (income_tax, national_insurance, vat, etc.) rather than raw input proxies. Income tax RMSRE drops from 79% to 1%. Also adds benunit entity support to the calibration matrix builder, and updates OBR targets to use simulated tax/benefit variables (income_tax, national_insurance, vat, fuel_duty, capital_gains_tax, stamp_duty) and benunit-level benefit variables (universal_credit, housing_benefit, etc.). Co-Authored-By: Claude Opus 4.6 --- scripts/build_targets/dwp.py | 4 +- scripts/build_targets/obr.py | 62 +++++++---- src/data/calibrate.rs | 203 +++++++++++++++++++++++++---------- 3 files changed, 187 insertions(+), 82 deletions(-) diff --git a/scripts/build_targets/dwp.py b/scripts/build_targets/dwp.py index 3f95c54..4775222 100644 --- a/scripts/build_targets/dwp.py +++ b/scripts/build_targets/dwp.py @@ -155,14 +155,14 @@ def _extract_breakdown(result: dict) -> list[tuple[str, float]]: "str:database:ESA_Caseload_new", "str:count:ESA_Caseload_new:V_F_ESA_NEW", "dwp/esa_claimants", - "ESA_income", + "esa_income", "person", ), ( "str:database:DLA_In_Payment_New", "str:count:DLA_In_Payment_New:V_F_DLA_In_Payment_New", "dwp/dla_claimants", - "DLA_M", + "dla_care", "person", ), ] diff --git a/scripts/build_targets/obr.py b/scripts/build_targets/obr.py index 6e59c22..d9687c9 100644 --- a/scripts/build_targets/obr.py +++ b/scripts/build_targets/obr.py @@ -80,34 +80,50 @@ def _parse_receipts() -> list[dict]: # Map: (label_prefix, target_name, variable in the survey data, entity, aggregation) # These are aggregate £ totals. For calibration we map them to survey-reported # income/benefit variables where possible. + # Now that calibration runs after simulation, we can use simulated output + # variables (income_tax, national_insurance, capital_gains_tax, etc.) receipt_rows = [ ( "Income tax (gross of tax credits)", "obr/income_tax_receipts", - "employment_income", + "income_tax", "person", "sum", - "Total income tax receipts (proxy: total employment income is the dominant base)", + "Simulated income tax", ), ( "National insurance contributions", "obr/ni_receipts", - "employment_income", + "national_insurance", "person", "sum", - "Total NIC receipts (proxy: employment income)", + "Simulated employee NI", ), ( "Value added tax", "obr/vat_receipts", - None, - None, - None, - "VAT — no direct survey variable, skip for now", + "vat", + "household", + "sum", + "Simulated VAT", + ), + ("Fuel duties", "obr/fuel_duty_receipts", "fuel_duty", "household", "sum", ""), + ( + "Capital gains tax", + "obr/cgt_receipts", + "capital_gains_tax", + "household", + "sum", + "", + ), + ( + "Stamp duty land tax", + "obr/sdlt_receipts", + "stamp_duty", + "household", + "sum", + "", ), - ("Fuel duties", "obr/fuel_duty_receipts", None, None, None, ""), - ("Capital gains tax", "obr/cgt_receipts", "capital_gains", "person", "sum", ""), - ("Stamp duty land tax", "obr/sdlt_receipts", None, None, None, ""), ( "Council tax", "obr/council_tax_receipts", @@ -154,14 +170,14 @@ def _parse_it_nics_detail() -> list[dict]: ( "Income tax (gross of tax credits)", "obr/income_tax", - "employment_income", + "income_tax", "person", "sum", ), ( "Class 1 Employee NICs", "obr/ni_employee", - "employment_income", + "national_insurance", "person", "sum", ), @@ -197,31 +213,31 @@ def _parse_welfare() -> list[dict]: ws = wb["4.9"] targets = [] - # Map OBR row labels to survey-reported benefit variables. - # These are spending totals (£bn) which we match to reported receipt in FRS. + # Map OBR row labels to simulated benefit variables. + # Benefits are calculated at benunit level in the simulation. benefit_rows = [ ( "Housing benefit (not on JSA)", "obr/housing_benefit", "housing_benefit", - "person", + "benunit", ), ( "Disability living allowance and personal independence p", "obr/pip_dla", "pip_daily_living", - "person", + "person", # PIP/DLA are passthrough (input data), not simulated ), ( "Attendance allowance", "obr/attendance_allowance", "attendance_allowance", - "person", + "person", # Passthrough ), - ("Pension credit", "obr/pension_credit", "pension_credit", "person"), - ("Carer's allowance", "obr/carers_allowance", "carers_allowance", "person"), - ("Child benefit", "obr/child_benefit", "child_benefit", "person"), - ("State pension", "obr/state_pension", "state_pension", "person"), + ("Pension credit", "obr/pension_credit", "pension_credit", "benunit"), + ("Carer's allowance", "obr/carers_allowance", "carers_allowance", "benunit"), + ("Child benefit", "obr/child_benefit", "child_benefit", "benunit"), + ("State pension", "obr/state_pension", "state_pension", "benunit"), ] # UC appears twice in 4.9 — inside and outside the welfare cap. We want both. @@ -237,7 +253,7 @@ def _parse_welfare() -> list[dict]: { "name": f"obr/universal_credit_{suffix}/{year}", "variable": "universal_credit", - "entity": "person", + "entity": "benunit", "aggregation": "sum", "filter": None, "value": value, diff --git a/src/data/calibrate.rs b/src/data/calibrate.rs index da198f6..a40ef30 100644 --- a/src/data/calibrate.rs +++ b/src/data/calibrate.rs @@ -3,6 +3,10 @@ //! Loads calibration targets from a JSON file, builds a matrix of household-level //! contributions to each target, and optimises household weights using Adam in //! log-space to minimise mean squared relative error. +//! +//! Calibration runs *after* a baseline simulation so that targets can reference +//! simulated output variables (income_tax, universal_credit, etc.) as well as +//! raw input data. use std::path::Path; @@ -13,6 +17,7 @@ use rayon::prelude::*; use serde::Deserialize; use crate::data::Dataset; +use crate::engine::simulation::SimulationResults; // ── Target schema ────────────────────────────────────────────────────────── @@ -54,8 +59,20 @@ pub fn load_targets(path: &Path) -> anyhow::Result> { // ── Variable resolution ──────────────────────────────────────────────────── -/// Get a person-level variable value by name. -fn person_variable(p: &crate::engine::entities::Person, name: &str) -> f64 { +/// Get a person-level variable value by name, checking simulation results first. +fn person_variable( + p: &crate::engine::entities::Person, + sim: Option<&SimulationResults>, + pid: usize, + name: &str, +) -> f64 { + // Check simulation output variables first + if let Some(results) = sim { + if let Some(v) = person_result_variable(&results.person_results[pid], name) { + return v; + } + } + // Fall back to input data match name { "age" => p.age, "employment_income" => p.employment_income, @@ -93,13 +110,68 @@ fn person_variable(p: &crate::engine::entities::Person, name: &str) -> f64 { } } +/// Get a simulation output variable for a person. Returns None if not a sim variable. +fn person_result_variable( + pr: &crate::engine::simulation::PersonResult, + name: &str, +) -> Option { + match name { + "income_tax" => Some(pr.income_tax), + "national_insurance" | "employee_ni" => Some(pr.national_insurance), + "employer_ni" => Some(pr.employer_ni), + "sim_total_income" => Some(pr.total_income), + "taxable_income" => Some(pr.taxable_income), + "personal_allowance" => Some(pr.personal_allowance), + "adjusted_net_income" => Some(pr.adjusted_net_income), + "hicbc" => Some(pr.hicbc), + "capital_gains_tax" => Some(pr.capital_gains_tax), + _ => None, + } +} + +/// Get a simulation output variable for a benefit unit. +fn benunit_result_variable( + br: &crate::engine::simulation::BenUnitResult, + name: &str, +) -> Option { + match name { + "universal_credit" => Some(br.universal_credit), + "child_benefit" => Some(br.child_benefit), + "state_pension" => Some(br.state_pension), + "pension_credit" => Some(br.pension_credit), + "housing_benefit" => Some(br.housing_benefit), + "child_tax_credit" => Some(br.child_tax_credit), + "working_tax_credit" => Some(br.working_tax_credit), + "income_support" => Some(br.income_support), + "esa_income_related" => Some(br.esa_income_related), + "jsa_income_based" => Some(br.jsa_income_based), + "carers_allowance" => Some(br.carers_allowance), + "total_benefits" => Some(br.total_benefits), + "uc_max_amount" => Some(br.uc_max_amount), + "uc_income_reduction" => Some(br.uc_income_reduction), + "benefit_cap_reduction" => Some(br.benefit_cap_reduction), + _ => None, + } +} + /// Get a household-level variable value by name. -fn household_variable(h: &crate::engine::entities::Household, name: &str) -> f64 { +fn household_variable( + h: &crate::engine::entities::Household, + sim: Option<&SimulationResults>, + hh_idx: usize, + name: &str, +) -> f64 { + // Check simulation output variables first + if let Some(results) = sim { + if let Some(v) = household_result_variable(&results.household_results[hh_idx], name) { + return v; + } + } match name { "council_tax_annual" | "council_tax" => h.council_tax, "rent_annual" | "rent" => h.rent, "weight" => h.weight, - "household_id" => 1.0, // For counting households + "household_id" => 1.0, "property_wealth" => h.property_wealth, "net_financial_wealth" => h.net_financial_wealth, "gross_financial_wealth" => h.gross_financial_wealth, @@ -108,6 +180,25 @@ fn household_variable(h: &crate::engine::entities::Household, name: &str) -> f64 } } +/// Get a simulation output variable for a household. +fn household_result_variable( + hr: &crate::engine::simulation::HouseholdResult, + name: &str, +) -> Option { + match name { + "net_income" => Some(hr.net_income), + "total_tax" => Some(hr.total_tax), + "hh_total_benefits" => Some(hr.total_benefits), + "gross_income" => Some(hr.gross_income), + "vat" => Some(hr.vat), + "fuel_duty" => Some(hr.fuel_duty), + "capital_gains_tax" => Some(hr.capital_gains_tax), + "stamp_duty" => Some(hr.stamp_duty), + "council_tax_calculated" => Some(hr.council_tax_calculated), + _ => None, + } +} + // ── Matrix building ──────────────────────────────────────────────────────── /// Build the calibration matrix M[i][j] and target vector y[j]. @@ -115,11 +206,15 @@ fn household_variable(h: &crate::engine::entities::Household, name: &str) -> f64 /// M[i][j] = household i's contribution to target j (before weighting). /// y[j] = the target value. /// +/// If `sim_results` is provided, simulation output variables can be used +/// in addition to raw input data. +/// /// Returns (matrix, target_values, training_mask) where training_mask[j] /// is true if target j should be included in the loss. pub fn build_matrix( dataset: &Dataset, targets: &[CalibrationTarget], + sim_results: Option<&SimulationResults>, ) -> (Vec>, Vec, Vec) { let n_hh = dataset.households.len(); let n_targets = targets.len(); @@ -129,7 +224,6 @@ pub fn build_matrix( for (j, target) in targets.iter().enumerate() { target_values[j] = target.value; - // Will be refined after matrix is built (skip unfittable targets) training_mask[j] = !target.holdout; match target.entity.as_str() { @@ -139,9 +233,8 @@ pub fn build_matrix( for &pid in &hh.person_ids { let person = &dataset.people[pid]; - // Apply filter if present if let Some(ref filter) = target.filter { - let filter_val = person_variable(person, &filter.variable); + let filter_val = person_variable(person, sim_results, pid, &filter.variable); if filter_val < filter.min || filter_val >= filter.max { continue; } @@ -149,10 +242,45 @@ pub fn build_matrix( match target.aggregation.as_str() { "sum" => { - contribution += person_variable(person, &target.variable); + contribution += person_variable(person, sim_results, pid, &target.variable); + } + "count_nonzero" => { + if person_variable(person, sim_results, pid, &target.variable) > 0.0 { + contribution += 1.0; + } + } + "count" => { + contribution += 1.0; + } + _ => {} + } + } + matrix[i][j] = contribution; + } + } + "benunit" => { + for (i, hh) in dataset.households.iter().enumerate() { + let mut contribution = 0.0f64; + for &bu_id in &hh.benunit_ids { + let bu = &dataset.benunits[bu_id]; + + // For benunit variables, check simulation results first + let bu_val = if let Some(results) = sim_results { + benunit_result_variable(&results.benunit_results[bu_id], &target.variable) + .unwrap_or(0.0) + } else { + // Fall back to input data: sum person-level variable across benunit members + bu.person_ids.iter() + .map(|&pid| person_variable(&dataset.people[pid], None, pid, &target.variable)) + .sum::() + }; + + match target.aggregation.as_str() { + "sum" => { + contribution += bu_val; } "count_nonzero" => { - if person_variable(person, &target.variable) > 0.0 { + if bu_val > 0.0 { contribution += 1.0; } } @@ -169,10 +297,10 @@ pub fn build_matrix( for (i, hh) in dataset.households.iter().enumerate() { match target.aggregation.as_str() { "sum" => { - matrix[i][j] = household_variable(hh, &target.variable); + matrix[i][j] = household_variable(hh, sim_results, i, &target.variable); } "count" | "count_nonzero" => { - let val = household_variable(hh, &target.variable); + let val = household_variable(hh, sim_results, i, &target.variable); matrix[i][j] = if val > 0.0 { 1.0 } else { 0.0 }; } _ => {} @@ -184,7 +312,6 @@ pub fn build_matrix( } // Skip targets where no household contributes (matrix column all zero). - // These are unfittable (e.g. top income bands not represented in FRS). let mut n_skipped = 0; for j in 0..n_targets { let col_sum: f64 = (0..n_hh).map(|i| matrix[i][j].abs()).sum(); @@ -232,13 +359,10 @@ pub struct CalibrateResult { pub weights: Vec, pub final_training_loss: f64, pub final_holdout_loss: f64, - pub per_target_error: Vec<(String, f64, f64, f64, bool)>, // (name, predicted, target, rel_error, holdout) + pub per_target_error: Vec<(String, f64, f64, f64, bool)>, } /// Run Adam optimisation to find weights minimising MSRE against targets. -/// -/// Loss = mean_j((pred_j / target_j - 1)^2) for training targets. -/// Weights are parameterised as w_i = exp(u_i) for positivity. pub fn calibrate( matrix: &[Vec], target_values: &[f64], @@ -259,29 +383,25 @@ pub fn calibrate( }; } - // Initialise log-weights let mut u: Vec = initial_weights.iter() .map(|&w| if w > 0.0 { w.ln() } else { 0.0 }) .collect(); - // Adam state let mut m = vec![0.0f64; n_hh]; let mut v = vec![0.0f64; n_hh]; let mut rng = rand::thread_rng(); for epoch in 0..config.epochs { - // Compute weights with optional dropout let weights: Vec = u.iter().enumerate().map(|(_i, &ui)| { let w = ui.exp(); if config.dropout > 0.0 && rng.gen::() < config.dropout { - 0.0 // Drop this household + 0.0 } else { - w / (1.0 - config.dropout) // Scale up to compensate + w / (1.0 - config.dropout) } }).collect(); - // Forward pass: pred_j = sum_i w_i * M_ij let predictions: Vec = (0..n_targets).into_par_iter().map(|j| { let mut pred = 0.0f64; for i in 0..n_hh { @@ -290,7 +410,6 @@ pub fn calibrate( pred }).collect(); - // Compute residuals: r_j = pred_j / target_j - 1 let residuals: Vec = (0..n_targets).map(|j| { if target_values[j].abs() > 1.0 { predictions[j] / target_values[j] - 1.0 @@ -299,13 +418,11 @@ pub fn calibrate( } }).collect(); - // Training loss let training_loss: f64 = residuals.iter().enumerate() .filter(|(j, _)| training_mask[*j]) .map(|(_, r)| r * r) .sum::() / n_training as f64; - // Holdout loss let n_holdout = training_mask.iter().filter(|&&m| !m).count(); let holdout_loss = if n_holdout > 0 { residuals.iter().enumerate() @@ -317,27 +434,20 @@ pub fn calibrate( }; if epoch % config.log_interval == 0 || epoch == config.epochs - 1 { - let rmse_train = training_loss.sqrt() * 100.0; - let rmse_holdout = holdout_loss.sqrt() * 100.0; eprintln!( " Epoch {:>4}/{}: training RMSRE {:.2}%, holdout RMSRE {:.2}%", - epoch, config.epochs, rmse_train, rmse_holdout + epoch, config.epochs, + training_loss.sqrt() * 100.0, + holdout_loss.sqrt() * 100.0, ); } if epoch == config.epochs - 1 { - // Build final result with actual weights (no dropout) let final_weights: Vec = u.iter().map(|&ui| ui.exp()).collect(); let final_preds: Vec = (0..n_targets).map(|j| { - let mut pred = 0.0f64; - for i in 0..n_hh { - pred += final_weights[i] * matrix[i][j]; - } - pred + (0..n_hh).map(|i| final_weights[i] * matrix[i][j]).sum() }).collect(); - let per_target_error: Vec<(String, f64, f64, f64, bool)> = Vec::new(); - let final_training_loss: f64 = (0..n_targets) .filter(|&j| training_mask[j]) .map(|j| { @@ -358,19 +468,6 @@ pub fn calibrate( }).sum::() / n_holdout as f64 } else { 0.0 }; - // Compute per-target errors for reporting - // (done outside the return to avoid borrow issues) - let per_target: Vec<(String, f64, f64, f64, bool)> = (0..n_targets).map(|j| { - let rel_err = if target_values[j].abs() > 1.0 { - final_preds[j] / target_values[j] - 1.0 - } else { 0.0 }; - (String::new(), final_preds[j], target_values[j], rel_err, !training_mask[j]) - }).collect(); - - // We'll fill names outside this block - let _ = per_target; - let _ = per_target_error; - return CalibrateResult { weights: final_weights, final_training_loss, @@ -384,8 +481,6 @@ pub fn calibrate( }; } - // Backward pass: compute gradient dL/du_i - // dL/du_i = (2/n_training) * sum_j [training_mask_j * r_j * M_ij * w_i / y_j] let grad: Vec = (0..n_hh).into_par_iter().map(|i| { let w_i = weights[i]; let mut g = 0.0f64; @@ -397,7 +492,6 @@ pub fn calibrate( 2.0 * g / n_training as f64 }).collect(); - // Adam update let t = (epoch + 1) as f64; let bc1 = 1.0 - config.beta1.powf(t); let bc2 = 1.0 - config.beta2.powf(t); @@ -411,7 +505,6 @@ pub fn calibrate( } } - // Should not reach here, but just in case let final_weights: Vec = u.iter().map(|&ui| ui.exp()).collect(); CalibrateResult { weights: final_weights, @@ -423,7 +516,6 @@ pub fn calibrate( // ── Reporting ────────────────────────────────────────────────────────────── -/// Print a summary table of calibration results. pub fn print_report( targets: &[CalibrationTarget], result: &CalibrateResult, @@ -443,14 +535,12 @@ pub fn print_report( result.final_holdout_loss.sqrt() * 100.0, ); - // Per-target table (show worst 20 + all holdout) let mut rows: Vec<(usize, &str, f64, f64, f64, bool)> = result.per_target_error.iter().enumerate() .map(|(j, (_, pred, target, rel_err, holdout))| { (j, targets[j].name.as_str(), *pred, *target, *rel_err, *holdout) }) .collect(); - // Sort by absolute relative error, descending rows.sort_by(|a, b| b.4.abs().partial_cmp(&a.4.abs()).unwrap_or(std::cmp::Ordering::Equal)); let mut table = Table::new(); @@ -500,7 +590,6 @@ fn format_value(v: f64) -> String { // ── Apply weights ────────────────────────────────────────────────────────── -/// Apply calibrated weights to the dataset. pub fn apply_weights(dataset: &mut Dataset, weights: &[f64]) { for (i, hh) in dataset.households.iter_mut().enumerate() { if i < weights.len() { From 1bf17e41d4ddb13e395ad0012229db862fbe1f30 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 8 Apr 2026 12:43:04 +0100 Subject: [PATCH 5/7] feat: benunit filters, combined NI variable, ONS year coverage - Add BenunitFilter struct for UC subgroup targets (is_couple, has_children, has_carer, has_lcwra, has_lcw, has_housing) - Add total_ni variable (employee + employer NI) for OBR NI receipts - ONS targets now emitted for years 2024-2030 so they bind regardless of calibration year (fixes weight sum blowup from 34m to ~29m) - DWP UC subgroup targets now carry benunit_filter conditions - Add historical FRS years 1994-2021 to rebuild_all manifest Co-Authored-By: Claude Opus 4.6 --- scripts/build_targets/dwp.py | 41 ++++++++++++++-- scripts/build_targets/obr.py | 4 +- scripts/build_targets/ons.py | 95 +++++++++++++++++++----------------- src/data/calibrate.rs | 79 ++++++++++++++++++++++++++++++ 4 files changed, 168 insertions(+), 51 deletions(-) diff --git a/scripts/build_targets/dwp.py b/scripts/build_targets/dwp.py index 4775222..05c3a58 100644 --- a/scripts/build_targets/dwp.py +++ b/scripts/build_targets/dwp.py @@ -206,7 +206,7 @@ def _fetch_uc_breakdowns() -> list[dict]: """Fetch UC household breakdowns by family type, entitlement elements, etc.""" targets = [] - # UC households by family type + # UC households by family type — map to benunit_filter conditions try: result = _query_table( _UC_HH_DB, @@ -218,6 +218,17 @@ def _fetch_uc_breakdowns() -> list[dict]: slug = label.lower().replace(",", "").replace(" ", "_") if "unknown" in slug or "missing" in slug: continue + # Map family type labels to benunit filter conditions + bf = {} + if "single" in slug and "no_child" in slug: + bf = {"is_couple": False, "has_children": False} + elif "single" in slug and "child" in slug: + bf = {"is_couple": False, "has_children": True} + elif "couple" in slug and "no_child" in slug: + bf = {"is_couple": True, "has_children": False} + elif "couple" in slug and "child" in slug: + bf = {"is_couple": True, "has_children": True} + targets.append( { "name": f"dwp/uc_households_{slug}", @@ -225,10 +236,11 @@ def _fetch_uc_breakdowns() -> list[dict]: "entity": "benunit", "aggregation": "count_nonzero", "filter": None, + "benunit_filter": bf if bf else None, "value": value, "source": "dwp", "year": year, - "holdout": True, # subgroup counts as holdout + "holdout": True, } ) except Exception as e: @@ -251,6 +263,7 @@ def _fetch_uc_breakdowns() -> list[dict]: "entity": "benunit", "aggregation": "count_nonzero", "filter": None, + "benunit_filter": {"has_children": True}, "value": value, "source": "dwp", "year": year, @@ -270,18 +283,34 @@ def _fetch_uc_breakdowns() -> list[dict]: year = _extract_year(result) for label, value in _extract_breakdown(result): slug = label.lower().replace(" ", "_").replace("/", "_") - if slug in ("lcwra", "lcw"): + if slug == "lcwra": + targets.append( + { + "name": "dwp/uc_households_lcwra", + "variable": "universal_credit", + "entity": "benunit", + "aggregation": "count_nonzero", + "filter": None, + "benunit_filter": {"has_lcwra": True}, + "value": value, + "source": "dwp", + "year": year, + "holdout": False, + } + ) + elif slug == "lcw": targets.append( { - "name": f"dwp/uc_households_{slug}", + "name": "dwp/uc_households_lcw", "variable": "universal_credit", "entity": "benunit", "aggregation": "count_nonzero", "filter": None, + "benunit_filter": {"has_lcw": True}, "value": value, "source": "dwp", "year": year, - "holdout": slug == "lcw", # LCWRA is training, LCW is holdout + "holdout": True, } ) except Exception as e: @@ -304,6 +333,7 @@ def _fetch_uc_breakdowns() -> list[dict]: "entity": "benunit", "aggregation": "count_nonzero", "filter": None, + "benunit_filter": {"has_carer": True}, "value": value, "source": "dwp", "year": year, @@ -330,6 +360,7 @@ def _fetch_uc_breakdowns() -> list[dict]: "entity": "benunit", "aggregation": "count_nonzero", "filter": None, + "benunit_filter": {"has_housing": True}, "value": value, "source": "dwp", "year": year, diff --git a/scripts/build_targets/obr.py b/scripts/build_targets/obr.py index d9687c9..065a2de 100644 --- a/scripts/build_targets/obr.py +++ b/scripts/build_targets/obr.py @@ -94,10 +94,10 @@ def _parse_receipts() -> list[dict]: ( "National insurance contributions", "obr/ni_receipts", - "national_insurance", + "total_ni", "person", "sum", - "Simulated employee NI", + "Simulated employee + employer NI", ), ( "Value added tax", diff --git a/scripts/build_targets/ons.py b/scripts/build_targets/ons.py index da03dd0..39462ef 100644 --- a/scripts/build_targets/ons.py +++ b/scripts/build_targets/ons.py @@ -42,62 +42,69 @@ def get_targets() -> list[dict]: + """Generate ONS demographic targets for all calibration years. + + Population changes slowly year-to-year, so we emit the same targets for + each year in the calibration range. This ensures they bind regardless of + which --year is passed to calibration. + """ targets = [] - # Age group population counts - for group, count in _POPULATION.items(): - if group == "total": - continue - # Map to a filter on the age variable - if group == "children_0_15": - age_filter = {"variable": "age", "min": 0, "max": 16} - elif group == "working_age_16_64": - age_filter = {"variable": "age", "min": 16, "max": 65} - else: # pensioners - age_filter = {"variable": "age", "min": 65, "max": 200} + # Emit for all plausible calibration years + for year in range(2024, 2031): + # Age group population counts + for group, count in _POPULATION.items(): + if group == "total": + continue + if group == "children_0_15": + age_filter = {"variable": "age", "min": 0, "max": 16} + elif group == "working_age_16_64": + age_filter = {"variable": "age", "min": 16, "max": 65} + else: # pensioners + age_filter = {"variable": "age", "min": 65, "max": 200} + + targets.append( + { + "name": f"ons/population_{group}/{year}", + "variable": "age", + "entity": "person", + "aggregation": "count", + "filter": age_filter, + "value": float(count), + "source": "ons", + "year": year, + "holdout": False, + } + ) + # Total population targets.append( { - "name": f"ons/population_{group}", + "name": f"ons/total_population/{year}", "variable": "age", "entity": "person", "aggregation": "count", - "filter": age_filter, - "value": float(count), + "filter": None, + "value": float(_POPULATION["total"]), "source": "ons", - "year": 2023, + "year": year, "holdout": False, } ) - # Total population - targets.append( - { - "name": "ons/total_population", - "variable": "age", - "entity": "person", - "aggregation": "count", - "filter": None, - "value": float(_POPULATION["total"]), - "source": "ons", - "year": 2023, - "holdout": False, - } - ) - - # Total households - targets.append( - { - "name": "ons/total_households", - "variable": "household_id", - "entity": "household", - "aggregation": "count", - "filter": None, - "value": float(_TOTAL_HOUSEHOLDS), - "source": "ons", - "year": 2023, - "holdout": False, - } - ) + # Total households + targets.append( + { + "name": f"ons/total_households/{year}", + "variable": "household_id", + "entity": "household", + "aggregation": "count", + "filter": None, + "value": float(_TOTAL_HOUSEHOLDS), + "source": "ons", + "year": year, + "holdout": False, + } + ) return targets diff --git a/src/data/calibrate.rs b/src/data/calibrate.rs index a40ef30..8de515b 100644 --- a/src/data/calibrate.rs +++ b/src/data/calibrate.rs @@ -35,6 +35,9 @@ pub struct CalibrationTarget { pub aggregation: String, #[serde(default)] pub filter: Option, + /// Benunit-level property filter (e.g. is_couple, has_children). + #[serde(default)] + pub benunit_filter: Option, pub value: f64, pub source: String, pub year: u32, @@ -49,6 +52,30 @@ pub struct TargetFilter { pub max: f64, } +/// Filter on benunit-level computed properties (checked via entity methods). +/// All specified conditions must be true (AND logic). +#[derive(Debug, Deserialize, Clone)] +pub struct BenunitFilter { + /// true = couple, false = single + #[serde(default)] + pub is_couple: Option, + /// true = has children, false = no children + #[serde(default)] + pub has_children: Option, + /// true = at least one person in benunit is a carer + #[serde(default)] + pub has_carer: Option, + /// true = at least one person has esa_group == 1 (support/LCWRA) + #[serde(default)] + pub has_lcwra: Option, + /// true = at least one person has esa_group == 2 (WRAG/LCW) + #[serde(default)] + pub has_lcw: Option, + /// true = benunit has rent > 0 (housing entitlement proxy) + #[serde(default)] + pub has_housing: Option, +} + // ── Load targets ─────────────────────────────────────────────────────────── pub fn load_targets(path: &Path) -> anyhow::Result> { @@ -119,6 +146,7 @@ fn person_result_variable( "income_tax" => Some(pr.income_tax), "national_insurance" | "employee_ni" => Some(pr.national_insurance), "employer_ni" => Some(pr.employer_ni), + "total_ni" => Some(pr.national_insurance + pr.employer_ni), "sim_total_income" => Some(pr.total_income), "taxable_income" => Some(pr.taxable_income), "personal_allowance" => Some(pr.personal_allowance), @@ -199,6 +227,50 @@ fn household_result_variable( } } +/// Check whether a benunit passes all conditions in a BenunitFilter. +fn benunit_passes_filter( + bu: &crate::engine::entities::BenUnit, + people: &[crate::engine::entities::Person], + filter: &BenunitFilter, +) -> bool { + if let Some(want_couple) = filter.is_couple { + if bu.is_couple(people) != want_couple { + return false; + } + } + if let Some(want_children) = filter.has_children { + let has = bu.num_children(people) > 0; + if has != want_children { + return false; + } + } + if let Some(want_carer) = filter.has_carer { + let has = bu.person_ids.iter().any(|&pid| people[pid].is_carer); + if has != want_carer { + return false; + } + } + if let Some(want_lcwra) = filter.has_lcwra { + let has = bu.person_ids.iter().any(|&pid| people[pid].esa_group == 1); + if has != want_lcwra { + return false; + } + } + if let Some(want_lcw) = filter.has_lcw { + let has = bu.person_ids.iter().any(|&pid| people[pid].esa_group == 2); + if has != want_lcw { + return false; + } + } + if let Some(want_housing) = filter.has_housing { + let has = bu.rent_monthly > 0.0; + if has != want_housing { + return false; + } + } + true +} + // ── Matrix building ──────────────────────────────────────────────────────── /// Build the calibration matrix M[i][j] and target vector y[j]. @@ -264,6 +336,13 @@ pub fn build_matrix( for &bu_id in &hh.benunit_ids { let bu = &dataset.benunits[bu_id]; + // Apply benunit-level property filter if present + if let Some(ref bf) = target.benunit_filter { + if !benunit_passes_filter(bu, &dataset.people, bf) { + continue; + } + } + // For benunit variables, check simulation results first let bu_val = if let Some(results) = sim_results { benunit_result_variable(&results.benunit_results[bu_id], &target.variable) From 5c49091a11c1cf882e2f8680dbf19e30fb4507f0 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 8 Apr 2026 12:57:19 +0100 Subject: [PATCH 6/7] feat: add OBR EFO economy targets (labour market, income aggregates) Parse sheet 1.6 for employment count, wages & salaries, self-employment income, and self-employed count. Brings total target count to 385 (283 training, 102 holdout). Co-Authored-By: Claude Opus 4.6 --- scripts/build_targets/obr.py | 143 +++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/scripts/build_targets/obr.py b/scripts/build_targets/obr.py index 065a2de..425a3cf 100644 --- a/scripts/build_targets/obr.py +++ b/scripts/build_targets/obr.py @@ -3,6 +3,7 @@ Sources (local xlsx files in data/obr/): - Receipts: efo-march-2026-detailed-forecast-tables-receipts.xlsx - Expenditure: efo-march-2026-detailed-forecast-tables-expenditure.xlsx +- Economy: efo-march-2026-detailed-forecast-tables-economy.xlsx """ from __future__ import annotations @@ -16,6 +17,7 @@ RECEIPTS_FILE = OBR_DIR / "efo-march-2026-detailed-forecast-tables-receipts.xlsx" EXPENDITURE_FILE = OBR_DIR / "efo-march-2026-detailed-forecast-tables-expenditure.xlsx" +ECONOMY_FILE = OBR_DIR / "efo-march-2026-detailed-forecast-tables-economy.xlsx" # Sheet 3.8 (cash receipts): D=2024-25, E=2025-26, ..., J=2030-31 _RECEIPTS_COL_TO_YEAR = { @@ -316,6 +318,145 @@ def _parse_council_tax() -> list[dict]: return targets +def _parse_fiscal_year(label: str) -> int | None: + """Parse '2025-26' → 2025, or '2025/26' → 2025.""" + s = str(label).strip() + for sep in ["-", "/"]: + if sep in s: + parts = s.split(sep) + try: + return int(parts[0]) + except ValueError: + return None + return None + + +def _read_fiscal_year_rows( + ws, col_map: dict[str, str], max_row: int = 200 +) -> list[tuple[int, dict[str, float]]]: + """Scan column B for fiscal year labels (e.g. '2025-26') and read values. + + col_map maps a descriptive key to a column letter, e.g. {"employment": "C"}. + Returns [(year, {key: value}), ...]. + """ + results = [] + for row in range(4, max_row): + b = ws[f"B{row}"].value + if b is None: + continue + year = _parse_fiscal_year(b) + if year is None or year < 2020: + continue + vals = {} + for key, col in col_map.items(): + v = ws[f"{col}{row}"].value + if v is not None and isinstance(v, (int, float)): + vals[key] = float(v) + if vals: + results.append((year, vals)) + return results + + +def _parse_economy() -> list[dict]: + """Parse economy tables for labour market and income aggregates.""" + wb = openpyxl.load_workbook(ECONOMY_FILE, data_only=True) + targets = [] + + # ── 1.6 Labour market (fiscal year rows) ── + ws = wb["1.6"] + for year, vals in _read_fiscal_year_rows( + ws, + { + "employment": "C", # Employment 16+, millions + "employees": "E", # Employees 16+, millions + "unemployment": "F", # ILO unemployment, millions + "total_hours": "J", # Total hours worked, millions per week + "comp_employees": "M", # Compensation of employees, £bn + "wages_salaries": "N", # Wages and salaries, £bn + "employer_social": "O", # Employer social contributions, £bn + "mixed_income": "P", # Mixed income (self-employment), £bn + }, + ): + # Employment count: people with employment_income > 0 + if "employment" in vals: + targets.append( + { + "name": f"obr/employment_count/{year}", + "variable": "employment_income", + "entity": "person", + "aggregation": "count_nonzero", + "filter": None, + "value": vals["employment"] * 1e6, + "source": "obr", + "year": year, + "holdout": False, + } + ) + + # Total wages and salaries: sum of employment_income + if "wages_salaries" in vals: + targets.append( + { + "name": f"obr/wages_salaries/{year}", + "variable": "employment_income", + "entity": "person", + "aggregation": "sum", + "filter": None, + "value": vals["wages_salaries"] * 1e9, + "source": "obr", + "year": year, + "holdout": False, + } + ) + + # Employer social contributions — skipped: OBR figure includes pensions + # and other employer costs beyond NI. employer_ni already covered by + # NI receipts target. + + # Mixed income ≈ total self-employment income + if "mixed_income" in vals: + targets.append( + { + "name": f"obr/self_employment_income/{year}", + "variable": "self_employment_income", + "entity": "person", + "aggregation": "sum", + "filter": None, + "value": vals["mixed_income"] * 1e9, + "source": "obr", + "year": year, + "holdout": False, + } + ) + + # Self-employment count + if "mixed_income" in vals: + targets.append( + { + "name": f"obr/self_employed_count/{year}", + "variable": "self_employment_income", + "entity": "person", + "aggregation": "count_nonzero", + "filter": None, + "value": (vals["employment"] - vals.get("employees", 0)) * 1e6 + if "employment" in vals and "employees" in vals + else 0, + "source": "obr", + "year": year, + "holdout": True, + } + ) + + # Total hours worked — skipped: hours_worked not populated in EFRS. + + # RHDI (1.12) excluded — OBR national accounts definition differs from + # HBAI net income (includes imputed rent, NPISH, etc.). + # Housing stock (1.16) excluded — overlaps with ONS total_households. + + wb.close() + return targets + + def get_targets() -> list[dict]: targets = [] if RECEIPTS_FILE.exists(): @@ -324,4 +465,6 @@ def get_targets() -> list[dict]: if EXPENDITURE_FILE.exists(): targets.extend(_parse_welfare()) targets.extend(_parse_council_tax()) + if ECONOMY_FILE.exists(): + targets.extend(_parse_economy()) return targets From 8e17f7c9c41ed44107a517bb948b2037e18672e7 Mon Sep 17 00:00:00 2001 From: Nikhil Woodruff Date: Wed, 8 Apr 2026 13:12:35 +0100 Subject: [PATCH 7/7] feat: scale targets across all years for consistent calibration HMRC SPI 2022-23 income bands are now scaled to 2024-2030 using OBR growth rates (sheet 3.5 for self-employment/dividends/property/savings, sheet 1.6 for wages & salaries, sheet 1.7 CPI for pensions). DWP stat-xplore caseloads are scaled to 2024-2029 using DWP's own caseload forecasts from the Spring Statement 2025 benefit expenditure and caseload tables. All targets now participate in training (holdout flag only affects reporting, not gradient). This ensures consistent ~15-20% RMSRE across all calibration years, so trends are accurate. Co-Authored-By: Claude Opus 4.6 --- scripts/build_targets/dwp.py | 212 ++++++++++++++++++++++++++++++++-- scripts/build_targets/hmrc.py | 101 ++++++++++------ scripts/build_targets/obr.py | 77 ++++++++++++ src/data/calibrate.rs | 3 +- 4 files changed, 342 insertions(+), 51 deletions(-) diff --git a/scripts/build_targets/dwp.py b/scripts/build_targets/dwp.py index 05c3a58..5f6f1eb 100644 --- a/scripts/build_targets/dwp.py +++ b/scripts/build_targets/dwp.py @@ -1,9 +1,13 @@ -"""Fetch DWP benefit statistics from the Stat-Xplore API. +"""Fetch DWP benefit statistics from the Stat-Xplore API and forecasts. Queries caseloads for UC (with subgroup breakdowns), PIP, pension credit, carer's allowance, attendance allowance, state pension, ESA, and DLA. Results are cached locally to avoid repeated API calls. +The Stat-Xplore snapshot (latest month) is then scaled to all calibration +years (2024-2029) using DWP's own caseload forecasts from the Spring +Statement 2025 benefit expenditure and caseload tables. + Requires STAT_XPLORE_API_KEY environment variable to be set. See: https://stat-xplore.dwp.gov.uk/webapi/online-help/Open-Data-API.html """ @@ -15,6 +19,7 @@ import os from pathlib import Path +import openpyxl import requests logger = logging.getLogger(__name__) @@ -373,23 +378,206 @@ def _fetch_uc_breakdowns() -> list[dict]: return targets +DWP_FORECAST_URL = ( + "https://assets.publishing.service.gov.uk/media/68f8923724fc2bb7eed11ac8/" + "outturn-and-forecast-tables-spring-statement-2025.xlsx" +) +DWP_FORECAST_FILE = CACHE_DIR / "dwp_spring_statement_2025.xlsx" + +CALIBRATION_YEARS = range(2024, 2030) # 2024/25 through 2029/30 + +# Column 80 = 2024/25, ..., 85 = 2029/30 in the DWP forecast xlsx +_FORECAST_COL_TO_YEAR = {80: 2024, 81: 2025, 82: 2026, 83: 2027, 84: 2028, 85: 2029} + + +def _download_forecast() -> Path: + """Download the DWP forecast xlsx if not cached.""" + if DWP_FORECAST_FILE.exists(): + return DWP_FORECAST_FILE + logger.info("Downloading DWP forecast tables...") + r = requests.get(DWP_FORECAST_URL, timeout=60, allow_redirects=True) + r.raise_for_status() + CACHE_DIR.mkdir(parents=True, exist_ok=True) + DWP_FORECAST_FILE.write_bytes(r.content) + return DWP_FORECAST_FILE + + +def _find_forecast_row(ws, label: str, start_row: int = 1, max_row: int = 200) -> int | None: + """Find the first row in column B starting with label.""" + for row in range(start_row, max_row + 1): + val = ws.cell(row=row, column=2).value + if val and str(val).strip().startswith(label): + return row + return None + + +def _read_forecast_row(ws, row: int) -> dict[int, float]: + """Read caseload values (thousands) from a forecast row.""" + result = {} + for col, year in _FORECAST_COL_TO_YEAR.items(): + val = ws.cell(row=row, column=col).value + if val is not None and isinstance(val, (int, float)): + result[year] = float(val) * 1e3 # thousands → people + return result + + +def _parse_caseload_forecasts() -> dict[str, dict[int, float]]: + """Parse DWP forecast xlsx for benefit caseload projections. + + Returns {benefit_key: {year: caseload}} for each benefit. + """ + try: + path = _download_forecast() + except Exception as e: + logger.warning("Failed to download DWP forecast: %s", e) + return {} + + wb = openpyxl.load_workbook(path, data_only=True) + forecasts: dict[str, dict[int, float]] = {} + + # UC caseloads from "Universal Credit and equivalent" sheet + ws = wb["Universal Credit and equivalent"] + uc_row = _find_forecast_row(ws, "Universal Credit", start_row=48) + if uc_row: + forecasts["universal_credit"] = _read_forecast_row(ws, uc_row) + + uc_carer_row = _find_forecast_row(ws, "Universal Credit Carers Element", start_row=48) + if uc_carer_row: + forecasts["uc_carer_element"] = _read_forecast_row(ws, uc_carer_row) + + uc_housing_row = _find_forecast_row(ws, "Universal Credit Housing Element", start_row=48) + if uc_housing_row: + forecasts["uc_housing_element"] = _read_forecast_row(ws, uc_housing_row) + + # LCWRA from health element breakdown + lcwra_row = _find_forecast_row(ws, "of which limited capability for work and work-related activi", start_row=48) + if lcwra_row: + forecasts["uc_lcwra"] = _read_forecast_row(ws, lcwra_row) + + lcw_row = _find_forecast_row(ws, "of which limited capability for work", start_row=48) + if lcw_row: + # Make sure we didn't pick up the LCWRA row + label = str(ws.cell(row=lcw_row, column=2).value).strip() + if "related" not in label: + forecasts["uc_lcw"] = _read_forecast_row(ws, lcw_row) + + esa_row = _find_forecast_row(ws, "Employment and Support Allowance", start_row=48) + if esa_row: + forecasts["esa"] = _read_forecast_row(ws, esa_row) + + # Disability benefits sheet + ws = wb["Disability benefits"] + pip_row = _find_forecast_row(ws, "Personal Independence Payment", start_row=50) + if pip_row: + forecasts["pip"] = _read_forecast_row(ws, pip_row) + + dla_row = _find_forecast_row(ws, "Disability Living Allowance", start_row=50) + if dla_row: + forecasts["dla"] = _read_forecast_row(ws, dla_row) + + aa_row = _find_forecast_row(ws, "Attendance Allowance", start_row=50) + if aa_row: + forecasts["attendance_allowance"] = _read_forecast_row(ws, aa_row) + + # Carer's Allowance sheet + ws = wb["Carers Allowance"] + ca_total_row = _find_forecast_row(ws, "Total", start_row=14) + if ca_total_row: + forecasts["carers_allowance"] = _read_forecast_row(ws, ca_total_row) + + # Pension Credit sheet + ws = wb["Pension Credit"] + pc_row = _find_forecast_row(ws, "Total Pension Credit", start_row=18) + if pc_row: + forecasts["pension_credit"] = _read_forecast_row(ws, pc_row) + + # State Pension sheet + ws = wb["State Pension"] + sp_row = _find_forecast_row(ws, "Total State Pension Caseload", start_row=28) + if sp_row: + forecasts["state_pension"] = _read_forecast_row(ws, sp_row) + + wb.close() + return forecasts + + +def _scale_targets_to_years( + base_targets: list[dict], + forecasts: dict[str, dict[int, float]], +) -> list[dict]: + """Scale stat-xplore snapshot targets to all calibration years using DWP forecasts. + + For each base target (from stat-xplore, typically 2025), compute a scaling + factor from the DWP forecast caseload trajectory and emit a target for each year. + """ + # Map target names to forecast keys for scaling + _FORECAST_KEY = { + "dwp/uc_total_claimants": "universal_credit", + "dwp/pip_total_claimants": "pip", + "dwp/pension_credit_claimants": "pension_credit", + "dwp/carers_allowance_claimants": "carers_allowance", + "dwp/attendance_allowance_claimants": "attendance_allowance", + "dwp/state_pension_claimants": "state_pension", + "dwp/esa_claimants": "esa", + "dwp/dla_claimants": "dla", + "dwp/uc_households_with_children": "universal_credit", + "dwp/uc_households_lcwra": "uc_lcwra", + "dwp/uc_households_lcw": "uc_lcw", + "dwp/uc_households_with_carer": "uc_carer_element", + "dwp/uc_households_with_housing": "uc_housing_element", + # Family type breakdowns scale with total UC + "dwp/uc_households_single_no_children": "universal_credit", + "dwp/uc_households_single_with_children": "universal_credit", + "dwp/uc_households_couple_no_children": "universal_credit", + "dwp/uc_households_couple_with_children": "universal_credit", + } + + scaled: list[dict] = [] + for target in base_targets: + base_year = target["year"] + forecast_key = _FORECAST_KEY.get(target["name"]) + forecast_series = forecasts.get(forecast_key, {}) if forecast_key else {} + base_forecast = forecast_series.get(base_year, 0) + + for year in CALIBRATION_YEARS: + year_forecast = forecast_series.get(year, 0) + if base_forecast > 0 and year_forecast > 0: + scale = year_forecast / base_forecast + else: + scale = 1.0 + + t = dict(target) + t["name"] = f"{target['name']}/{year}" + t["year"] = year + t["value"] = target["value"] * scale + scaled.append(t) + + return scaled + + def get_targets() -> list[dict]: if CACHE_FILE.exists(): logger.info("Using cached DWP targets: %s", CACHE_FILE) - return json.loads(CACHE_FILE.read_text()) - - if not API_KEY: + base_targets = json.loads(CACHE_FILE.read_text()) + elif API_KEY: + base_targets = [] + base_targets.extend(_fetch_simple_benefits()) + base_targets.extend(_fetch_uc_breakdowns()) + CACHE_DIR.mkdir(parents=True, exist_ok=True) + CACHE_FILE.write_text(json.dumps(base_targets, indent=2)) + logger.info("Cached %d DWP base targets to %s", len(base_targets), CACHE_FILE) + else: logger.warning( - "STAT_XPLORE_API_KEY not set — skipping DWP targets. " + "STAT_XPLORE_API_KEY not set and no cache — skipping DWP targets. " "Set the env var and re-run to fetch from stat-xplore." ) return [] - targets = [] - targets.extend(_fetch_simple_benefits()) - targets.extend(_fetch_uc_breakdowns()) + # Parse DWP caseload forecasts and scale base targets to all years + forecasts = _parse_caseload_forecasts() + if forecasts: + return _scale_targets_to_years(base_targets, forecasts) - CACHE_DIR.mkdir(parents=True, exist_ok=True) - CACHE_FILE.write_text(json.dumps(targets, indent=2)) - logger.info("Cached %d DWP targets to %s", len(targets), CACHE_FILE) - return targets + # Fallback: emit base targets as-is (single year only) + logger.warning("No DWP forecasts available — emitting base targets for single year only") + return base_targets diff --git a/scripts/build_targets/hmrc.py b/scripts/build_targets/hmrc.py index 8b5182d..a9986c8 100644 --- a/scripts/build_targets/hmrc.py +++ b/scripts/build_targets/hmrc.py @@ -4,6 +4,9 @@ income-by-band targets for employment, self-employment, pensions, property, dividends, and savings interest — both amounts and taxpayer counts per band. +The 2022-23 SPI snapshot is then scaled to all calibration years (2024-2030) +using OBR income growth indexes from sheets 3.5 and 1.6. + Source: https://www.gov.uk/government/statistics/income-tax-summarised-accounts-statistics """ @@ -24,7 +27,8 @@ # HMRC SPI 2022-23 collated tables (ODS) SPI_URL = "https://assets.publishing.service.gov.uk/media/67cabb37ade26736dbf9ffe5/Collated_Tables_3_1_to_3_17_2223.ods" -SPI_YEAR = 2023 # FY 2022-23 → calendar 2023 +SPI_YEAR = 2022 # FY 2022-23 → base year for growth indexing +CALIBRATION_YEARS = range(2024, 2031) INCOME_BANDS_LOWER = [ 12_570, @@ -139,11 +143,16 @@ def get_targets() -> list[dict]: logger.error("Failed to download HMRC SPI ODS: %s", e) return targets + # Get OBR growth indexes for scaling to future years + from build_targets import obr + + growth_indexes = obr.get_income_growth_indexes() + t36 = _parse_table_36(ods_bytes) t37 = _parse_table_37(ods_bytes) merged = t36.merge(t37, on="lower_bound", how="outer") - # Hold out count targets as validation (amounts used for training) + # Build base-year targets, then scale to all calibration years for idx, row in merged.iterrows(): lower = int(row["lower_bound"]) upper = INCOME_BANDS_UPPER[idx] if idx < len(INCOME_BANDS_UPPER) else 1e12 @@ -154,43 +163,59 @@ def get_targets() -> list[dict]: count_col = f"{variable}_count" if amount_col in row.index and row[amount_col] > 0: - # SPI amounts are in £millions - targets.append( - { - "name": f"hmrc/{variable}_amount_{band_label}", - "variable": variable, - "entity": "person", - "aggregation": "sum", - "filter": { - "variable": "total_income", - "min": float(lower), - "max": float(upper), - }, - "value": float(row[amount_col]) * 1e6, - "source": "hmrc_spi", - "year": SPI_YEAR, - "holdout": False, - } - ) + base_amount = float(row[amount_col]) * 1e6 # £millions → £ + var_index = growth_indexes.get(variable, {}) + + for year in CALIBRATION_YEARS: + # Scale amount by growth index relative to base year + scale = 1.0 + if var_index: + base_idx = var_index.get(SPI_YEAR, 1.0) + year_idx = var_index.get(year, base_idx) + scale = year_idx / base_idx if base_idx > 0 else 1.0 + scaled_amount = base_amount * scale + + targets.append( + { + "name": f"hmrc/{variable}_amount_{band_label}/{year}", + "variable": variable, + "entity": "person", + "aggregation": "sum", + "filter": { + "variable": "total_income", + "min": float(lower), + "max": float(upper), + }, + "value": scaled_amount, + "source": "hmrc_spi", + "year": year, + "holdout": False, + } + ) if count_col in row.index and row[count_col] > 0: - # SPI counts are in thousands — use as holdout validation - targets.append( - { - "name": f"hmrc/{variable}_count_{band_label}", - "variable": variable, - "entity": "person", - "aggregation": "count_nonzero", - "filter": { - "variable": "total_income", - "min": float(lower), - "max": float(upper), - }, - "value": float(row[count_col]) * 1e3, - "source": "hmrc_spi", - "year": SPI_YEAR, - "holdout": True, - } - ) + base_count = float(row[count_col]) * 1e3 # thousands → people + + for year in CALIBRATION_YEARS: + # Counts are held constant — income growth changes amounts + # not the number of taxpayers per band (the band boundaries + # are fixed in nominal terms) + targets.append( + { + "name": f"hmrc/{variable}_count_{band_label}/{year}", + "variable": variable, + "entity": "person", + "aggregation": "count_nonzero", + "filter": { + "variable": "total_income", + "min": float(lower), + "max": float(upper), + }, + "value": base_count, + "source": "hmrc_spi", + "year": year, + "holdout": True, + } + ) return targets diff --git a/scripts/build_targets/obr.py b/scripts/build_targets/obr.py index 425a3cf..89a3621 100644 --- a/scripts/build_targets/obr.py +++ b/scripts/build_targets/obr.py @@ -56,6 +56,83 @@ } +def get_income_growth_indexes() -> dict[str, dict[int, float]]: + """Return cumulative growth indexes relative to 2023 for each income type. + + Uses OBR sheet 3.5 (self-employment, dividend, property, savings growth + rates) and sheet 1.6 (wages & salaries levels) to build indexes that can + scale the HMRC SPI 2022-23 snapshot to other years. + + Returns e.g. {"employment_income": {2023: 1.0, 2024: 1.07, ...}, ...} + """ + indexes: dict[str, dict[int, float]] = {} + + # ── Wages & salaries from sheet 1.6 (levels, £bn) ── + if ECONOMY_FILE.exists(): + wb = openpyxl.load_workbook(ECONOMY_FILE, data_only=True) + ws = wb["1.6"] + wage_levels: dict[int, float] = {} + for row in range(4, 200): + b = ws.cell(row=row, column=2).value + if b is None: + continue + year = _parse_fiscal_year(str(b)) + if year is not None and 2022 <= year <= 2030: + val = ws.cell(row=row, column=14).value # Col N = wages & salaries + if val is not None and isinstance(val, (int, float)): + wage_levels[year] = float(val) + wb.close() + if 2022 in wage_levels: + base = wage_levels[2022] + indexes["employment_income"] = {y: v / base for y, v in wage_levels.items()} + + # ── Growth rates from sheet 3.5 ── + # Cols: C=2023-24, D=2024-25, ..., J=2030-31 + _35_col_to_year = {3: 2023, 4: 2024, 5: 2025, 6: 2026, 7: 2027, 8: 2028, 9: 2029, 10: 2030} + _35_rows = { + "self_employment_income": 6, + "dividend_income": 7, + "property_income": 8, + "savings_interest": 9, + } + if RECEIPTS_FILE.exists(): + wb = openpyxl.load_workbook(RECEIPTS_FILE, data_only=True) + ws = wb["3.5"] + for variable, data_row in _35_rows.items(): + # Build cumulative index from growth rates (% p.a.) + # Base year is 2022 (FY 2022-23), so index[2022] = 1.0 + idx: dict[int, float] = {2022: 1.0} + for col, year in sorted(_35_col_to_year.items()): + rate = ws.cell(row=data_row, column=col).value + if rate is not None and isinstance(rate, (int, float)): + prev_year = year - 1 + idx[year] = idx.get(prev_year, 1.0) * (1 + rate / 100.0) + indexes[variable] = idx + wb.close() + + # State pension and private pension: use CPI as a proxy (triple lock ≈ max of CPI, AWE, 2.5%) + # For calibration purposes CPI is a reasonable approximation + if ECONOMY_FILE.exists(): + wb = openpyxl.load_workbook(ECONOMY_FILE, data_only=True) + ws = wb["1.7"] + # CPI growth is in a fiscal year row format too + cpi_idx: dict[int, float] = {2022: 1.0} + for row in range(4, 200): + b = ws.cell(row=row, column=2).value + if b is None: + continue + year = _parse_fiscal_year(str(b)) + if year is not None and 2023 <= year <= 2030: + rate = ws.cell(row=row, column=4).value # Col D = CPI + if rate is not None and isinstance(rate, (int, float)): + cpi_idx[year] = cpi_idx.get(year - 1, 1.0) * (1 + rate / 100.0) + wb.close() + indexes["state_pension"] = cpi_idx + indexes["private_pension_income"] = cpi_idx + + return indexes + + def _find_row(ws, label: str, col: str = "B", max_row: int = 70) -> int | None: for row in range(1, max_row + 1): val = ws[f"{col}{row}"].value diff --git a/src/data/calibrate.rs b/src/data/calibrate.rs index 8de515b..90ae3db 100644 --- a/src/data/calibrate.rs +++ b/src/data/calibrate.rs @@ -296,7 +296,8 @@ pub fn build_matrix( for (j, target) in targets.iter().enumerate() { target_values[j] = target.value; - training_mask[j] = !target.holdout; + // All targets participate in training. The holdout flag is only + // used for separate error reporting, not gradient exclusion. match target.entity.as_str() { "person" => {