@@ -25,6 +25,12 @@ class Metric:
2525 tolerance : float = 2.0
2626
2727
28+ @dataclass (frozen = True )
29+ class CaseMetric :
30+ name : str
31+ reducer : str = "sum"
32+
33+
2834METRICS : dict [str , Metric ] = {
2935 "income_tax" : Metric ("person_results" , "income_tax" , "income_tax" ),
3036 "national_insurance" : Metric (
@@ -60,7 +66,10 @@ def build_case(
6066 benunit_flags : dict [str , Any ] | None = None ,
6167 housing_costs : float = 0.0 ,
6268 country : str | None = None ,
63- metrics : list [str ],
69+ metrics : list [str | CaseMetric ],
70+ rust_reform : dict [str , Any ] | None = None ,
71+ policyengine_scenario : dict [str , Any ] | None = None ,
72+ known_failure : str | None = None ,
6473) -> dict [str , Any ]:
6574 benunit_flags = benunit_flags or {}
6675
@@ -179,17 +188,28 @@ def build_case(
179188 return {
180189 "name" : name ,
181190 "year" : year ,
182- "metrics" : metrics ,
191+ "known_failure" : known_failure ,
192+ "metrics" : [
193+ metric if isinstance (metric , CaseMetric ) else CaseMetric (metric )
194+ for metric in metrics
195+ ],
196+ "entity_labels" : {
197+ "person_results" : person_names ,
198+ "benunit_results" : ["benunit" ],
199+ "household_results" : ["household" ],
200+ },
183201 "rust_input" : {
184202 "people" : rust_people ,
185203 "benunits" : [rust_benunit ],
186204 "households" : [rust_household ],
187205 },
206+ "rust_reform" : rust_reform ,
188207 "policyengine_situation" : {
189208 "people" : pe_people ,
190209 "benunits" : {"benunit" : pe_benunit },
191210 "households" : {"household" : pe_household },
192211 },
212+ "policyengine_scenario" : policyengine_scenario ,
193213 }
194214
195215
@@ -217,7 +237,10 @@ def build_case(
217237 "would_claim_marriage_allowance" : True ,
218238 },
219239 ],
220- metrics = ["income_tax" , "national_insurance" ],
240+ metrics = [
241+ CaseMetric ("income_tax" , reducer = "sequence" ),
242+ CaseMetric ("national_insurance" , reducer = "sequence" ),
243+ ],
221244 ),
222245 build_case (
223246 name = "universal_credit_single_2025" ,
@@ -230,6 +253,22 @@ def build_case(
230253 },
231254 metrics = ["universal_credit" ],
232255 ),
256+ build_case (
257+ name = "income_tax_basic_rate_reform_2025" ,
258+ year = 2025 ,
259+ people = [{"name" : "person" , "age" : 30 , "employment_income" : 30_000 }],
260+ rust_reform = {
261+ "income_tax" : {
262+ "uk_brackets" : [
263+ {"rate" : 0.25 , "threshold" : 0.0 },
264+ {"rate" : 0.40 , "threshold" : 37700.0 },
265+ {"rate" : 0.45 , "threshold" : 125140.0 },
266+ ]
267+ }
268+ },
269+ policyengine_scenario = {"gov.hmrc.income_tax.rates.uk[0].rate" : 0.25 },
270+ metrics = ["income_tax" ],
271+ ),
233272 build_case (
234273 name = "child_benefit_two_children_2025" ,
235274 year = 2025 ,
@@ -296,6 +335,22 @@ def build_case(
296335 ],
297336 metrics = ["state_pension" ],
298337 ),
338+ build_case (
339+ name = "pension_credit_single_2025" ,
340+ year = 2025 ,
341+ people = [
342+ {"name" : "adult" , "age" : 75 , "state_pension_reported" : 5_000 },
343+ ],
344+ benunit_flags = {
345+ "reported_pc" : True ,
346+ "would_claim_pc" : True ,
347+ },
348+ metrics = ["pension_credit" ],
349+ known_failure = (
350+ "Rust pension credit remains above policyengine-uk for this low-income "
351+ "single pensioner scenario."
352+ ),
353+ ),
299354]
300355
301356
@@ -305,6 +360,18 @@ def _add_policyengine_uk_to_path(explicit_path: str | None) -> None:
305360 sys .path .insert (0 , candidate )
306361
307362
363+ def _sequence_labels (case : dict [str , Any ], metric : Metric , length : int ) -> list [str ]:
364+ labels = case ["entity_labels" ].get (metric .rust_collection , [])
365+ if len (labels ) == length :
366+ return labels
367+ return [f"item_{ index } " for index in range (length )]
368+
369+
370+ def _format_sequence (values : list [float ], labels : list [str ]) -> str :
371+ pairs = [f"{ label } ={ value :.2f} " for label , value in zip (labels , values )]
372+ return "[" + ", " .join (pairs ) + "]"
373+
374+
308375def run_rust_case (case : dict [str , Any ], rust_binary : Path ) -> dict [str , float ]:
309376 with tempfile .NamedTemporaryFile (
310377 mode = "w" , suffix = ".json" , delete = False , encoding = "utf-8"
@@ -313,16 +380,20 @@ def run_rust_case(case: dict[str, Any], rust_binary: Path) -> dict[str, float]:
313380 scenario_path = Path (handle .name )
314381
315382 try :
383+ cmd = [
384+ str (rust_binary ),
385+ "--year" ,
386+ str (case ["year" ]),
387+ "--scenario-json" ,
388+ str (scenario_path ),
389+ "--output" ,
390+ "json" ,
391+ ]
392+ if case .get ("rust_reform" ) is not None :
393+ cmd .extend (["--reform-json" , json .dumps (case ["rust_reform" ])])
394+
316395 result = subprocess .run (
317- [
318- str (rust_binary ),
319- "--year" ,
320- str (case ["year" ]),
321- "--scenario-json" ,
322- str (scenario_path ),
323- "--output" ,
324- "json" ,
325- ],
396+ cmd ,
326397 check = True ,
327398 capture_output = True ,
328399 text = True ,
@@ -332,21 +403,35 @@ def run_rust_case(case: dict[str, Any], rust_binary: Path) -> dict[str, float]:
332403 scenario_path .unlink (missing_ok = True )
333404
334405 payload = json .loads (result .stdout )
335- values : dict [str , float ] = {}
336- for metric_name in case ["metrics" ]:
337- metric = METRICS [metric_name ]
338- values [metric_name ] = float (
339- sum (item [metric .rust_field ] for item in payload [metric .rust_collection ])
340- )
406+ values : dict [str , float | list [float ]] = {}
407+ for case_metric in case ["metrics" ]:
408+ metric = METRICS [case_metric .name ]
409+ raw_values = [
410+ float (item [metric .rust_field ]) for item in payload [metric .rust_collection ]
411+ ]
412+ if case_metric .reducer == "sequence" :
413+ values [case_metric .name ] = raw_values
414+ else :
415+ values [case_metric .name ] = float (sum (raw_values ))
341416 return values
342417
343418
344- def run_policyengine_case (case : dict [str , Any ], simulation_cls : Any ) -> dict [str , float ]:
345- sim = simulation_cls (situation = case ["policyengine_situation" ])
346- values : dict [str , float ] = {}
347- for metric_name in case ["metrics" ]:
348- metric = METRICS [metric_name ]
349- values [metric_name ] = float (sim .calculate (metric .policyengine_variable , case ["year" ]).sum ())
419+ def run_policyengine_case (
420+ case : dict [str , Any ], simulation_cls : Any , scenario_cls : Any
421+ ) -> dict [str , float | list [float ]]:
422+ scenario = None
423+ if case .get ("policyengine_scenario" ) is not None :
424+ scenario = scenario_cls (parameter_changes = case ["policyengine_scenario" ])
425+
426+ sim = simulation_cls (situation = case ["policyengine_situation" ], scenario = scenario )
427+ values : dict [str , float | list [float ]] = {}
428+ for case_metric in case ["metrics" ]:
429+ metric = METRICS [case_metric .name ]
430+ result = sim .calculate (metric .policyengine_variable , case ["year" ])
431+ if case_metric .reducer == "sequence" :
432+ values [case_metric .name ] = [float (value ) for value in result ]
433+ else :
434+ values [case_metric .name ] = float (result .sum ())
350435 return values
351436
352437
@@ -368,6 +453,11 @@ def parse_args() -> argparse.Namespace:
368453 dest = "cases" ,
369454 help = "Run only the named validation case. Can be supplied multiple times." ,
370455 )
456+ parser .add_argument (
457+ "--list-cases" ,
458+ action = "store_true" ,
459+ help = "List available validation cases and exit." ,
460+ )
371461 return parser .parse_args ()
372462
373463
@@ -377,8 +467,14 @@ def main() -> int:
377467 print (f"Rust binary not found at { args .rust_binary } " , file = sys .stderr )
378468 return 1
379469
470+ if args .list_cases :
471+ for case in CASES :
472+ suffix = " [expected-failure]" if case .get ("known_failure" ) else ""
473+ print (f"{ case ['name' ]} { suffix } " )
474+ return 0
475+
380476 _add_policyengine_uk_to_path (args .policyengine_uk_path )
381- from policyengine_uk import Simulation # pylint: disable=import-error
477+ from policyengine_uk import Scenario , Simulation # pylint: disable=import-error
382478
383479 selected_cases = CASES
384480 if args .cases :
@@ -390,30 +486,74 @@ def main() -> int:
390486 return 1
391487
392488 failures : list [str ] = []
489+ expected_failures : list [str ] = []
490+ unexpected_passes : list [str ] = []
393491 for case in selected_cases :
394492 rust_values = run_rust_case (case , args .rust_binary )
395- policyengine_values = run_policyengine_case (case , Simulation )
493+ policyengine_values = run_policyengine_case (case , Simulation , Scenario )
396494 print (f"[{ case ['name' ]} ]" )
397- for metric_name in case ["metrics" ]:
398- metric = METRICS [metric_name ]
399- rust_value = rust_values [metric_name ]
400- policy_value = policyengine_values [metric_name ]
401- diff = abs (rust_value - policy_value )
402- print (
403- f" { metric_name } : rust={ rust_value :.2f} policyengine={ policy_value :.2f} diff={ diff :.2f} "
404- )
405- if diff > metric .tolerance :
406- failures .append (
407- f"{ case ['name' ]} { metric_name } diff { diff :.2f} exceeds tolerance { metric .tolerance :.2f} "
495+ case_failures : list [str ] = []
496+ for case_metric in case ["metrics" ]:
497+ metric = METRICS [case_metric .name ]
498+ rust_value = rust_values [case_metric .name ]
499+ policy_value = policyengine_values [case_metric .name ]
500+ if isinstance (rust_value , list ):
501+ labels = _sequence_labels (case , metric , len (rust_value ))
502+ if len (rust_value ) != len (policy_value ):
503+ case_failures .append (
504+ f"{ case ['name' ]} { case_metric .name } length mismatch { len (rust_value )} != { len (policy_value )} "
505+ )
506+ print (
507+ f" { case_metric .name } : rust={ rust_value } policyengine={ policy_value } diff=length-mismatch"
508+ )
509+ continue
510+ diffs = [abs (a - b ) for a , b in zip (rust_value , policy_value )]
511+ print (
512+ " "
513+ f"{ case_metric .name } : "
514+ f"rust={ _format_sequence (rust_value , labels )} "
515+ f"policyengine={ _format_sequence (policy_value , labels )} "
516+ f"diffs={ [round (diff , 2 ) for diff in diffs ]} "
408517 )
518+ for index , diff in enumerate (diffs ):
519+ if diff > metric .tolerance :
520+ case_failures .append (
521+ f"{ case ['name' ]} { case_metric .name } [{ labels [index ]} ] diff { diff :.2f} exceeds tolerance { metric .tolerance :.2f} "
522+ )
523+ else :
524+ diff = abs (rust_value - policy_value )
525+ print (
526+ f" { case_metric .name } : rust={ rust_value :.2f} policyengine={ policy_value :.2f} diff={ diff :.2f} "
527+ )
528+ if diff > metric .tolerance :
529+ case_failures .append (
530+ f"{ case ['name' ]} { case_metric .name } diff { diff :.2f} exceeds tolerance { metric .tolerance :.2f} "
531+ )
532+
533+ if case .get ("known_failure" ):
534+ if case_failures :
535+ expected_failures .append (case ["name" ])
536+ print (f" expected failure: { case ['known_failure' ]} " )
537+ else :
538+ unexpected_passes .append (
539+ f"{ case ['name' ]} unexpectedly passed; review and remove known_failure"
540+ )
541+ else :
542+ failures .extend (case_failures )
409543
410- if failures :
544+ if failures or unexpected_passes :
411545 print ("\n Validation failed:" , file = sys .stderr )
412546 for failure in failures :
413547 print (f" - { failure } " , file = sys .stderr )
548+ for unexpected_pass in unexpected_passes :
549+ print (f" - { unexpected_pass } " , file = sys .stderr )
414550 return 1
415551
416- print (f"\n Validated { len (selected_cases )} case(s) against policyengine-uk." )
552+ print (
553+ "\n Validated "
554+ f"{ len (selected_cases ) - len (expected_failures )} passing case(s) "
555+ f"against policyengine-uk with { len (expected_failures )} expected failure(s)."
556+ )
417557 return 0
418558
419559
0 commit comments