Skip to content

Commit 9303ae8

Browse files
committed
Improve parity harness diagnostics
1 parent 9d3c109 commit 9303ae8

1 file changed

Lines changed: 179 additions & 39 deletions

File tree

scripts/validate_against_policyengine_uk.py

Lines changed: 179 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ class Metric:
2525
tolerance: float = 2.0
2626

2727

28+
@dataclass(frozen=True)
29+
class CaseMetric:
30+
name: str
31+
reducer: str = "sum"
32+
33+
2834
METRICS: dict[str, Metric] = {
2935
"income_tax": Metric("person_results", "income_tax", "income_tax"),
3036
"national_insurance": Metric(
@@ -60,7 +66,10 @@ def build_case(
6066
benunit_flags: dict[str, Any] | None = None,
6167
housing_costs: float = 0.0,
6268
country: str | None = None,
63-
metrics: list[str],
69+
metrics: list[str | CaseMetric],
70+
rust_reform: dict[str, Any] | None = None,
71+
policyengine_scenario: dict[str, Any] | None = None,
72+
known_failure: str | None = None,
6473
) -> dict[str, Any]:
6574
benunit_flags = benunit_flags or {}
6675

@@ -179,17 +188,28 @@ def build_case(
179188
return {
180189
"name": name,
181190
"year": year,
182-
"metrics": metrics,
191+
"known_failure": known_failure,
192+
"metrics": [
193+
metric if isinstance(metric, CaseMetric) else CaseMetric(metric)
194+
for metric in metrics
195+
],
196+
"entity_labels": {
197+
"person_results": person_names,
198+
"benunit_results": ["benunit"],
199+
"household_results": ["household"],
200+
},
183201
"rust_input": {
184202
"people": rust_people,
185203
"benunits": [rust_benunit],
186204
"households": [rust_household],
187205
},
206+
"rust_reform": rust_reform,
188207
"policyengine_situation": {
189208
"people": pe_people,
190209
"benunits": {"benunit": pe_benunit},
191210
"households": {"household": pe_household},
192211
},
212+
"policyengine_scenario": policyengine_scenario,
193213
}
194214

195215

@@ -217,7 +237,10 @@ def build_case(
217237
"would_claim_marriage_allowance": True,
218238
},
219239
],
220-
metrics=["income_tax", "national_insurance"],
240+
metrics=[
241+
CaseMetric("income_tax", reducer="sequence"),
242+
CaseMetric("national_insurance", reducer="sequence"),
243+
],
221244
),
222245
build_case(
223246
name="universal_credit_single_2025",
@@ -230,6 +253,22 @@ def build_case(
230253
},
231254
metrics=["universal_credit"],
232255
),
256+
build_case(
257+
name="income_tax_basic_rate_reform_2025",
258+
year=2025,
259+
people=[{"name": "person", "age": 30, "employment_income": 30_000}],
260+
rust_reform={
261+
"income_tax": {
262+
"uk_brackets": [
263+
{"rate": 0.25, "threshold": 0.0},
264+
{"rate": 0.40, "threshold": 37700.0},
265+
{"rate": 0.45, "threshold": 125140.0},
266+
]
267+
}
268+
},
269+
policyengine_scenario={"gov.hmrc.income_tax.rates.uk[0].rate": 0.25},
270+
metrics=["income_tax"],
271+
),
233272
build_case(
234273
name="child_benefit_two_children_2025",
235274
year=2025,
@@ -296,6 +335,22 @@ def build_case(
296335
],
297336
metrics=["state_pension"],
298337
),
338+
build_case(
339+
name="pension_credit_single_2025",
340+
year=2025,
341+
people=[
342+
{"name": "adult", "age": 75, "state_pension_reported": 5_000},
343+
],
344+
benunit_flags={
345+
"reported_pc": True,
346+
"would_claim_pc": True,
347+
},
348+
metrics=["pension_credit"],
349+
known_failure=(
350+
"Rust pension credit remains above policyengine-uk for this low-income "
351+
"single pensioner scenario."
352+
),
353+
),
299354
]
300355

301356

@@ -305,6 +360,18 @@ def _add_policyengine_uk_to_path(explicit_path: str | None) -> None:
305360
sys.path.insert(0, candidate)
306361

307362

363+
def _sequence_labels(case: dict[str, Any], metric: Metric, length: int) -> list[str]:
364+
labels = case["entity_labels"].get(metric.rust_collection, [])
365+
if len(labels) == length:
366+
return labels
367+
return [f"item_{index}" for index in range(length)]
368+
369+
370+
def _format_sequence(values: list[float], labels: list[str]) -> str:
371+
pairs = [f"{label}={value:.2f}" for label, value in zip(labels, values)]
372+
return "[" + ", ".join(pairs) + "]"
373+
374+
308375
def run_rust_case(case: dict[str, Any], rust_binary: Path) -> dict[str, float]:
309376
with tempfile.NamedTemporaryFile(
310377
mode="w", suffix=".json", delete=False, encoding="utf-8"
@@ -313,16 +380,20 @@ def run_rust_case(case: dict[str, Any], rust_binary: Path) -> dict[str, float]:
313380
scenario_path = Path(handle.name)
314381

315382
try:
383+
cmd = [
384+
str(rust_binary),
385+
"--year",
386+
str(case["year"]),
387+
"--scenario-json",
388+
str(scenario_path),
389+
"--output",
390+
"json",
391+
]
392+
if case.get("rust_reform") is not None:
393+
cmd.extend(["--reform-json", json.dumps(case["rust_reform"])])
394+
316395
result = subprocess.run(
317-
[
318-
str(rust_binary),
319-
"--year",
320-
str(case["year"]),
321-
"--scenario-json",
322-
str(scenario_path),
323-
"--output",
324-
"json",
325-
],
396+
cmd,
326397
check=True,
327398
capture_output=True,
328399
text=True,
@@ -332,21 +403,35 @@ def run_rust_case(case: dict[str, Any], rust_binary: Path) -> dict[str, float]:
332403
scenario_path.unlink(missing_ok=True)
333404

334405
payload = json.loads(result.stdout)
335-
values: dict[str, float] = {}
336-
for metric_name in case["metrics"]:
337-
metric = METRICS[metric_name]
338-
values[metric_name] = float(
339-
sum(item[metric.rust_field] for item in payload[metric.rust_collection])
340-
)
406+
values: dict[str, float | list[float]] = {}
407+
for case_metric in case["metrics"]:
408+
metric = METRICS[case_metric.name]
409+
raw_values = [
410+
float(item[metric.rust_field]) for item in payload[metric.rust_collection]
411+
]
412+
if case_metric.reducer == "sequence":
413+
values[case_metric.name] = raw_values
414+
else:
415+
values[case_metric.name] = float(sum(raw_values))
341416
return values
342417

343418

344-
def run_policyengine_case(case: dict[str, Any], simulation_cls: Any) -> dict[str, float]:
345-
sim = simulation_cls(situation=case["policyengine_situation"])
346-
values: dict[str, float] = {}
347-
for metric_name in case["metrics"]:
348-
metric = METRICS[metric_name]
349-
values[metric_name] = float(sim.calculate(metric.policyengine_variable, case["year"]).sum())
419+
def run_policyengine_case(
420+
case: dict[str, Any], simulation_cls: Any, scenario_cls: Any
421+
) -> dict[str, float | list[float]]:
422+
scenario = None
423+
if case.get("policyengine_scenario") is not None:
424+
scenario = scenario_cls(parameter_changes=case["policyengine_scenario"])
425+
426+
sim = simulation_cls(situation=case["policyengine_situation"], scenario=scenario)
427+
values: dict[str, float | list[float]] = {}
428+
for case_metric in case["metrics"]:
429+
metric = METRICS[case_metric.name]
430+
result = sim.calculate(metric.policyengine_variable, case["year"])
431+
if case_metric.reducer == "sequence":
432+
values[case_metric.name] = [float(value) for value in result]
433+
else:
434+
values[case_metric.name] = float(result.sum())
350435
return values
351436

352437

@@ -368,6 +453,11 @@ def parse_args() -> argparse.Namespace:
368453
dest="cases",
369454
help="Run only the named validation case. Can be supplied multiple times.",
370455
)
456+
parser.add_argument(
457+
"--list-cases",
458+
action="store_true",
459+
help="List available validation cases and exit.",
460+
)
371461
return parser.parse_args()
372462

373463

@@ -377,8 +467,14 @@ def main() -> int:
377467
print(f"Rust binary not found at {args.rust_binary}", file=sys.stderr)
378468
return 1
379469

470+
if args.list_cases:
471+
for case in CASES:
472+
suffix = " [expected-failure]" if case.get("known_failure") else ""
473+
print(f"{case['name']}{suffix}")
474+
return 0
475+
380476
_add_policyengine_uk_to_path(args.policyengine_uk_path)
381-
from policyengine_uk import Simulation # pylint: disable=import-error
477+
from policyengine_uk import Scenario, Simulation # pylint: disable=import-error
382478

383479
selected_cases = CASES
384480
if args.cases:
@@ -390,30 +486,74 @@ def main() -> int:
390486
return 1
391487

392488
failures: list[str] = []
489+
expected_failures: list[str] = []
490+
unexpected_passes: list[str] = []
393491
for case in selected_cases:
394492
rust_values = run_rust_case(case, args.rust_binary)
395-
policyengine_values = run_policyengine_case(case, Simulation)
493+
policyengine_values = run_policyengine_case(case, Simulation, Scenario)
396494
print(f"[{case['name']}]")
397-
for metric_name in case["metrics"]:
398-
metric = METRICS[metric_name]
399-
rust_value = rust_values[metric_name]
400-
policy_value = policyengine_values[metric_name]
401-
diff = abs(rust_value - policy_value)
402-
print(
403-
f" {metric_name}: rust={rust_value:.2f} policyengine={policy_value:.2f} diff={diff:.2f}"
404-
)
405-
if diff > metric.tolerance:
406-
failures.append(
407-
f"{case['name']} {metric_name} diff {diff:.2f} exceeds tolerance {metric.tolerance:.2f}"
495+
case_failures: list[str] = []
496+
for case_metric in case["metrics"]:
497+
metric = METRICS[case_metric.name]
498+
rust_value = rust_values[case_metric.name]
499+
policy_value = policyengine_values[case_metric.name]
500+
if isinstance(rust_value, list):
501+
labels = _sequence_labels(case, metric, len(rust_value))
502+
if len(rust_value) != len(policy_value):
503+
case_failures.append(
504+
f"{case['name']} {case_metric.name} length mismatch {len(rust_value)} != {len(policy_value)}"
505+
)
506+
print(
507+
f" {case_metric.name}: rust={rust_value} policyengine={policy_value} diff=length-mismatch"
508+
)
509+
continue
510+
diffs = [abs(a - b) for a, b in zip(rust_value, policy_value)]
511+
print(
512+
" "
513+
f"{case_metric.name}: "
514+
f"rust={_format_sequence(rust_value, labels)} "
515+
f"policyengine={_format_sequence(policy_value, labels)} "
516+
f"diffs={[round(diff, 2) for diff in diffs]}"
408517
)
518+
for index, diff in enumerate(diffs):
519+
if diff > metric.tolerance:
520+
case_failures.append(
521+
f"{case['name']} {case_metric.name}[{labels[index]}] diff {diff:.2f} exceeds tolerance {metric.tolerance:.2f}"
522+
)
523+
else:
524+
diff = abs(rust_value - policy_value)
525+
print(
526+
f" {case_metric.name}: rust={rust_value:.2f} policyengine={policy_value:.2f} diff={diff:.2f}"
527+
)
528+
if diff > metric.tolerance:
529+
case_failures.append(
530+
f"{case['name']} {case_metric.name} diff {diff:.2f} exceeds tolerance {metric.tolerance:.2f}"
531+
)
532+
533+
if case.get("known_failure"):
534+
if case_failures:
535+
expected_failures.append(case["name"])
536+
print(f" expected failure: {case['known_failure']}")
537+
else:
538+
unexpected_passes.append(
539+
f"{case['name']} unexpectedly passed; review and remove known_failure"
540+
)
541+
else:
542+
failures.extend(case_failures)
409543

410-
if failures:
544+
if failures or unexpected_passes:
411545
print("\nValidation failed:", file=sys.stderr)
412546
for failure in failures:
413547
print(f" - {failure}", file=sys.stderr)
548+
for unexpected_pass in unexpected_passes:
549+
print(f" - {unexpected_pass}", file=sys.stderr)
414550
return 1
415551

416-
print(f"\nValidated {len(selected_cases)} case(s) against policyengine-uk.")
552+
print(
553+
"\nValidated "
554+
f"{len(selected_cases) - len(expected_failures)} passing case(s) "
555+
f"against policyengine-uk with {len(expected_failures)} expected failure(s)."
556+
)
417557
return 0
418558

419559

0 commit comments

Comments
 (0)