Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/structural-reform-hooks.added
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `StructuralReform(pre=..., post=...)` to the Python wrapper, enabling reforms that can't be expressed as parameter overlays. Both hooks take `(year, persons, benunits, households)` and return the modified triple, so multi-year reforms can branch by year.
2 changes: 2 additions & 0 deletions interfaces/python/policyengine_uk_compiled/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def print_guide():
BENUNIT_DEFAULTS,
HOUSEHOLD_DEFAULTS,
)
from policyengine_uk_compiled.structural import StructuralReform
from policyengine_uk_compiled.data import download_all, ensure_year, ensure_dataset, DATASETS

__all__ = [
"Simulation",
"StructuralReform",
"PERSON_DEFAULTS",
"BENUNIT_DEFAULTS",
"HOUSEHOLD_DEFAULTS",
Expand Down
171 changes: 158 additions & 13 deletions interfaces/python/policyengine_uk_compiled/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
HAS_PANDAS = False

from policyengine_uk_compiled.models import MicrodataResult, Parameters, SimulationResult, HbaiIncomes, PovertyHeadcounts
from policyengine_uk_compiled.structural import StructuralReform, aggregate_microdata

# The binary and parameters/ dir are bundled inside the package at build time.
_PKG_DIR = Path(__file__).resolve().parent
Expand Down Expand Up @@ -95,6 +96,30 @@ def _build_stdin_payload(persons_csv: str, benunits_csv: str, households_csv: st
)


def _parse_stdin_payload(payload: str):
"""Parse a stdin protocol payload back into three DataFrames."""
import io
import pandas as pd
sections: dict[str, str] = {}
current_name = None
current_lines: list[str] = []
for line in payload.split("\n"):
if line.startswith("===") and line.endswith("==="):
if current_name is not None:
sections[current_name] = "\n".join(current_lines)
current_name = line.strip("=").lower()
current_lines = []
else:
current_lines.append(line)
if current_name is not None:
sections[current_name] = "\n".join(current_lines)
return (
pd.read_csv(io.StringIO(sections.get("persons", ""))),
pd.read_csv(io.StringIO(sections.get("benunits", ""))),
pd.read_csv(io.StringIO(sections.get("households", ""))),
)


def _parse_microdata_stdout(raw: str) -> MicrodataResult:
"""Parse the concatenated CSV protocol output into a MicrodataResult."""
sections = {}
Expand Down Expand Up @@ -308,9 +333,29 @@ class Simulation:
sim = Simulation(year=2025, data_dir="data/frs/2023")
result = sim.run()

# With a reform
# With a parametric reform
reform = Parameters(income_tax=IncomeTaxParams(personal_allowance=20000))
result = sim.run(policy=reform)

# With a structural reform (pre-hook: mutate inputs before simulation)
from policyengine_uk_compiled import StructuralReform

def cap_wages(year, persons, benunits, households):
persons["employment_income"] = persons["employment_income"].clip(upper=100_000)
return persons, benunits, households

result = sim.run(structural=StructuralReform(pre=cap_wages))

# With a structural reform (post-hook: adjust outputs after simulation)
def add_ubi(year, persons, benunits, households):
ubi = 50 * 52 # £50/wk per adult
adults = persons["age"] >= 18
adult_counts = persons[adults].groupby("household_id").size()
households["reform_net_income"] += households["household_id"].map(adult_counts).fillna(0) * ubi
households["reform_total_tax"] = households["baseline_total_tax"] # unchanged
return persons, benunits, households

result = sim.run(structural=StructuralReform(post=add_ubi))
"""

def __init__(
Expand Down Expand Up @@ -340,10 +385,17 @@ def __init__(
self._frs_raw = frs_raw
self._dataset = dataset
self._persons_only = dataset in ("spi",)
# Store DataFrames when passed directly so pre-hooks can use them
self._persons_df = None
self._benunits_df = None
self._households_df = None

if persons is not None and benunits is not None and households is not None:
# DataFrame or CSV string mode
if HAS_PANDAS and hasattr(persons, "to_csv"):
self._persons_df = persons
self._benunits_df = benunits
self._households_df = households
persons_csv = _df_to_csv(persons)
benunits_csv = _df_to_csv(benunits)
households_csv = _df_to_csv(households)
Expand All @@ -361,10 +413,70 @@ def __init__(
elif data_dir is not None:
self._data_dir = str(data_dir)

def _build_cmd(self, policy: Optional[Parameters] = None, extra_args: Optional[list[str]] = None) -> list[str]:
def _apply_pre_hook(self, structural: Optional[StructuralReform]) -> Optional[str]:
"""Apply the pre-hook if present and return a stdin payload string.

For file-based data sources, loads the CSVs into DataFrames first so
the hook can mutate them, then re-serialises to the stdin protocol.
Returns None if there is no pre-hook (caller uses the original payload).
"""
if structural is None or structural.pre is None:
return self._stdin_payload # unchanged

if not HAS_PANDAS:
raise ImportError("pandas is required for structural pre-hooks")

import io
import pandas as pd

# Obtain DataFrames — either already stored or loaded from files
if self._persons_df is not None:
persons = self._persons_df.copy()
benunits = self._benunits_df.copy()
households = self._households_df.copy()
elif self._stdin_payload is not None:
# Parse the existing stdin payload back into DataFrames
parsed = _parse_stdin_payload(self._stdin_payload)
persons = parsed[0]
benunits = parsed[1]
households = parsed[2]
else:
# File-based source: load the CSVs from disk
data_path = self._resolve_data_path()
import os
year_dir = os.path.join(data_path, str(self.year))
if not os.path.isdir(year_dir):
# Try direct path (data_dir may already include year)
year_dir = data_path
persons = pd.read_csv(os.path.join(year_dir, "persons.csv"))
benunits = pd.read_csv(os.path.join(year_dir, "benunits.csv"))
households = pd.read_csv(os.path.join(year_dir, "households.csv"))

persons, benunits, households = structural.pre(
self.year, persons, benunits, households
)
return _build_stdin_payload(
_df_to_csv(persons), _df_to_csv(benunits), _df_to_csv(households)
)

def _resolve_data_path(self) -> str:
"""Return the base data directory for the current configuration."""
if self._data_dir:
return self._data_dir
if self._clean_frs_base:
return self._clean_frs_base
if self._clean_frs:
return self._clean_frs
if self._dataset is not None:
from policyengine_uk_compiled.data import ensure_dataset
return ensure_dataset(self._dataset, self.year)
from policyengine_uk_compiled.data import ensure_frs
return ensure_frs(self.year)

def _build_cmd(self, policy: Optional[Parameters] = None, extra_args: Optional[list[str]] = None, stdin_override: bool = False) -> list[str]:
cmd = [self.binary_path, "--year", str(self.year)]

if self._stdin_payload is not None:
if self._stdin_payload is not None or stdin_override:
cmd.append("--stdin-data")
elif self._data_dir:
cmd += ["--data", self._data_dir]
Expand Down Expand Up @@ -397,22 +509,38 @@ def _build_cmd(self, policy: Optional[Parameters] = None, extra_args: Optional[l

return cmd

def run(self, policy: Optional[Parameters] = None, timeout: int = 120) -> SimulationResult:
def run(
self,
policy: Optional[Parameters] = None,
structural: Optional[StructuralReform] = None,
timeout: int = 120,
) -> SimulationResult:
"""Run the simulation and return typed results.

Args:
policy: Reform parameters (overlay on baseline). None = baseline only.
policy: Parametric reform overlay (changes parameter values).
structural: Structural reform with optional pre/post hooks.
pre(year, persons, benunits, households) mutates inputs before
the binary runs. post(year, persons, benunits, households)
mutates microdata outputs; aggregation is then done in Python.
timeout: Maximum seconds to wait for the binary.

Returns:
SimulationResult with budgetary impact, program breakdown, decile impacts, etc.
For persons-only datasets (e.g. SPI), household/benefit fields are zeroed.
"""
cmd = self._build_cmd(policy, extra_args=["--output", "json"])
# If a post-hook is present we must go through microdata and re-aggregate
if structural is not None and structural.post is not None:
microdata = self.run_microdata(policy=policy, structural=structural, timeout=timeout)
return aggregate_microdata(
microdata.persons, microdata.benunits, microdata.households, self.year
)

stdin_payload = self._apply_pre_hook(structural)
cmd = self._build_cmd(policy, extra_args=["--output", "json"], stdin_override=stdin_payload is not None)
cwd = _find_cwd(self.binary_path)
result = subprocess.run(
cmd,
input=self._stdin_payload,
input=stdin_payload,
capture_output=True,
text=True,
timeout=timeout,
Expand All @@ -428,16 +556,24 @@ def run(self, policy: Optional[Parameters] = None, timeout: int = 120) -> Simula
return SimulationResult(**data)

def run_microdata(
self, policy: Optional[Parameters] = None, timeout: int = 120
self,
policy: Optional[Parameters] = None,
structural: Optional[StructuralReform] = None,
timeout: int = 120,
) -> MicrodataResult:
"""Run the simulation and return per-entity microdata as DataFrames."""
"""Run the simulation and return per-entity microdata as DataFrames.

If a structural post-hook is provided it is applied to the DataFrames
after the binary produces its output.
"""
if not HAS_PANDAS:
raise ImportError("pandas is required for run_microdata")
cmd = self._build_cmd(policy, extra_args=["--output-microdata-stdout"])
stdin_payload = self._apply_pre_hook(structural)
cmd = self._build_cmd(policy, extra_args=["--output-microdata-stdout"], stdin_override=stdin_payload is not None)
cwd = _find_cwd(self.binary_path)
result = subprocess.run(
cmd,
input=self._stdin_payload,
input=stdin_payload,
capture_output=True,
text=True,
timeout=timeout,
Expand All @@ -447,7 +583,16 @@ def run_microdata(
raise RuntimeError(
f"Simulation failed (exit {result.returncode}):\n{result.stderr}"
)
return _parse_microdata_stdout(result.stdout)
microdata = _parse_microdata_stdout(result.stdout)
if structural is not None and structural.post is not None:
persons, benunits, households = structural.post(
self.year,
microdata.persons.copy(),
microdata.benunits.copy(),
microdata.households.copy(),
)
return MicrodataResult(persons=persons, benunits=benunits, households=households)
return microdata

def get_baseline_params(self, timeout: int = 10) -> dict:
"""Export the baseline parameters for the configured year as a dict."""
Expand Down
Loading
Loading