Skip to content

Commit 1192ffd

Browse files
authored
Merge pull request #470 from PSLmodels/pr1-solver-robustness
Improve area weight solver: robustness, memory, and testing
2 parents 3a671e3 + 13025ec commit 1192ffd

10 files changed

Lines changed: 1062 additions & 61 deletions

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ tmd_files: tmd/storage/output/tmd.csv.gz \
3232

3333
.PHONY=test
3434
test: tmd_files
35-
pytest . -v -n4 --ignore=tests/national_targets_pipeline
35+
pytest . -v -n4 --ignore=tests/national_targets_pipeline --ignore=tests/test_fingerprint.py
3636

3737
.PHONY=data
3838
data: install tmd_files test

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
np.seterr(all="raise")
1515

1616

17+
def pytest_addoption(parser):
18+
parser.addoption(
19+
"--update-fingerprint",
20+
action="store_true",
21+
default=False,
22+
help="Save current results as reference fingerprint",
23+
)
24+
25+
1726
def create_tmd_records(
1827
data_path, weights_path, growfactors_path, exact_calculations=True
1928
):

tests/test_fingerprint.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
"""
2+
On-demand fingerprint test for area weight results.
3+
4+
NOT run by `make test` (excluded in Makefile).
5+
Run manually after a full solve:
6+
7+
pytest tests/test_fingerprint.py -v
8+
pytest tests/test_fingerprint.py -v --update-fingerprint
9+
10+
The first run saves a reference fingerprint.
11+
Subsequent runs compare against it.
12+
13+
Fingerprint method: for each area, round weights to nearest integer,
14+
sum them, and hash the per-area sums. This is simple, fast, and
15+
catches any meaningful change in results.
16+
"""
17+
18+
import hashlib
19+
import json
20+
21+
import numpy as np
22+
import pandas as pd
23+
import pytest
24+
25+
from tmd.areas import AREAS_FOLDER
26+
27+
FINGERPRINT_DIR = AREAS_FOLDER / "fingerprints"
28+
STATE_WEIGHT_DIR = AREAS_FOLDER / "weights" / "states"
29+
30+
ALL_STATES = [
31+
"AL",
32+
"AK",
33+
"AZ",
34+
"AR",
35+
"CA",
36+
"CO",
37+
"CT",
38+
"DC",
39+
"DE",
40+
"FL",
41+
"GA",
42+
"HI",
43+
"ID",
44+
"IL",
45+
"IN",
46+
"IA",
47+
"KS",
48+
"KY",
49+
"LA",
50+
"ME",
51+
"MD",
52+
"MA",
53+
"MI",
54+
"MN",
55+
"MS",
56+
"MO",
57+
"MT",
58+
"NE",
59+
"NV",
60+
"NH",
61+
"NJ",
62+
"NM",
63+
"NY",
64+
"NC",
65+
"ND",
66+
"OH",
67+
"OK",
68+
"OR",
69+
"PA",
70+
"RI",
71+
"SC",
72+
"SD",
73+
"TN",
74+
"TX",
75+
"UT",
76+
"VT",
77+
"VA",
78+
"WA",
79+
"WV",
80+
"WI",
81+
"WY",
82+
]
83+
84+
85+
def _compute_fingerprint(areas, weight_dir):
86+
"""Compute fingerprint from weight files.
87+
88+
For each area, reads the first WT column, rounds weights to
89+
nearest integer, and records the sum. The collection of integer
90+
sums is hashed for a single comparison value.
91+
"""
92+
per_area = {}
93+
for area in areas:
94+
code = area.lower()
95+
wpath = weight_dir / f"{code}_tmd_weights.csv.gz"
96+
if not wpath.exists():
97+
continue
98+
wdf = pd.read_csv(wpath)
99+
wt_cols = [c for c in wdf.columns if c.startswith("WT")]
100+
wt = wdf[wt_cols[0]].values
101+
int_sum = int(np.round(wt).sum())
102+
per_area[area] = int_sum
103+
104+
# Hash of all per-area integer sums
105+
hash_str = "|".join(f"{a}:{per_area[a]}" for a in sorted(per_area.keys()))
106+
hash_val = hashlib.sha256(hash_str.encode()).hexdigest()[:16]
107+
108+
return {
109+
"n_areas": len(per_area),
110+
"weight_hash": hash_val,
111+
"per_area_int_sums": per_area,
112+
}
113+
114+
115+
def _fingerprint_path(scope):
116+
return FINGERPRINT_DIR / f"{scope}_fingerprint.json"
117+
118+
119+
def _save_fingerprint(scope, fp):
120+
FINGERPRINT_DIR.mkdir(parents=True, exist_ok=True)
121+
path = _fingerprint_path(scope)
122+
with open(path, "w", encoding="utf-8") as f:
123+
json.dump(fp, f, indent=2, sort_keys=True)
124+
return path
125+
126+
127+
def _load_fingerprint(scope):
128+
path = _fingerprint_path(scope)
129+
if not path.exists():
130+
return None
131+
with open(path, "r", encoding="utf-8") as f:
132+
return json.load(f)
133+
134+
135+
@pytest.fixture
136+
def update_mode(request):
137+
return request.config.getoption("--update-fingerprint")
138+
139+
140+
def _has_weight_files(weight_dir, areas):
141+
for a in areas:
142+
wpath = weight_dir / f"{a.lower()}_tmd_weights.csv.gz"
143+
if wpath.exists():
144+
return True
145+
return False
146+
147+
148+
@pytest.mark.skipif(
149+
not _has_weight_files(STATE_WEIGHT_DIR, ALL_STATES),
150+
reason="No state weight files — run solve_weights first",
151+
)
152+
class TestStateFingerprint:
153+
"""Fingerprint tests for state weights."""
154+
155+
# pylint: disable=redefined-outer-name
156+
def test_state_weights_match_reference(self, update_mode):
157+
"""Compare weight integer sums against saved reference."""
158+
current = _compute_fingerprint(ALL_STATES, STATE_WEIGHT_DIR)
159+
160+
if update_mode:
161+
path = _save_fingerprint("states", current)
162+
pytest.skip(f"Saved to {path} — re-run to test")
163+
164+
reference = _load_fingerprint("states")
165+
if reference is None:
166+
path = _save_fingerprint("states", current)
167+
pytest.skip(f"No reference found. Saved to {path} — re-run")
168+
169+
ref_n = reference["n_areas"]
170+
cur_n = current["n_areas"]
171+
assert cur_n == ref_n, f"Area count: {ref_n} -> {cur_n}"
172+
173+
assert (
174+
current["weight_hash"] == reference["weight_hash"]
175+
), "Weight hash mismatch — results changed"
176+
177+
def test_per_area_sums_match(self, update_mode):
178+
"""Identify which areas changed."""
179+
if update_mode:
180+
pytest.skip("Update mode")
181+
182+
reference = _load_fingerprint("states")
183+
if reference is None:
184+
pytest.skip("No reference fingerprint")
185+
186+
current = _compute_fingerprint(ALL_STATES, STATE_WEIGHT_DIR)
187+
ref_sums = reference.get("per_area_int_sums", {})
188+
cur_sums = current.get("per_area_int_sums", {})
189+
190+
mismatches = []
191+
for area in sorted(ref_sums.keys()):
192+
if area not in cur_sums:
193+
mismatches.append(f"{area}: missing")
194+
continue
195+
if ref_sums[area] != cur_sums[area]:
196+
mismatches.append(
197+
f"{area}: {ref_sums[area]}" f" -> {cur_sums[area]}"
198+
)
199+
200+
assert (
201+
not mismatches
202+
), f"{len(mismatches)} areas changed:\n" + "\n".join(mismatches)

tests/test_state_weight_results.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
Post-solve validation of state weight files.
3+
4+
These tests verify that the actual state weight outputs are valid.
5+
They are skipped if weight files have not been generated yet
6+
(i.e., solve_weights --scope states has not been run).
7+
8+
Run after:
9+
python -m tmd.areas.prepare_targets --scope states
10+
python -m tmd.areas.solve_weights --scope states --workers 8
11+
"""
12+
13+
import io
14+
from pathlib import Path
15+
16+
import numpy as np
17+
import pandas as pd
18+
import pytest
19+
20+
from tmd.areas.create_area_weights import (
21+
AREA_CONSTRAINT_TOL,
22+
STATE_TARGET_DIR,
23+
STATE_WEIGHT_DIR,
24+
_build_constraint_matrix,
25+
_drop_impossible_targets,
26+
_load_taxcalc_data,
27+
)
28+
from tmd.areas.prepare.constants import ALL_STATES
29+
from tmd.imputation_assumptions import TAXYEAR
30+
31+
# Skip entire module if weight files haven't been generated
32+
_WEIGHT_FILES = list(STATE_WEIGHT_DIR.glob("*_tmd_weights.csv.gz"))
33+
pytestmark = pytest.mark.skipif(
34+
len(_WEIGHT_FILES) < 51,
35+
reason="State weight files not generated yet",
36+
)
37+
38+
# Also need cached data files for target accuracy checks
39+
_CACHED = Path(__file__).parent.parent / "tmd" / "storage" / "output"
40+
_HAS_CACHED = (_CACHED / "tmd.csv.gz").exists() and (
41+
_CACHED / "cached_c00100.npy"
42+
).exists()
43+
44+
45+
class TestStateWeightFiles:
46+
"""Basic validity checks on all 51 state weight files."""
47+
48+
def test_all_states_have_weight_files(self):
49+
"""Every state has a weight file."""
50+
for st in ALL_STATES:
51+
wpath = STATE_WEIGHT_DIR / f"{st.lower()}_tmd_weights.csv.gz"
52+
assert wpath.exists(), f"Missing weight file for {st}"
53+
54+
def test_all_states_have_log_files(self):
55+
"""Every state has a solver log."""
56+
for st in ALL_STATES:
57+
logpath = STATE_WEIGHT_DIR / f"{st.lower()}.log"
58+
assert logpath.exists(), f"Missing log file for {st}"
59+
60+
def test_weight_columns(self):
61+
"""Weight files have expected year columns."""
62+
wpath = STATE_WEIGHT_DIR / "mn_tmd_weights.csv.gz"
63+
wdf = pd.read_csv(wpath)
64+
expected = [f"WT{yr}" for yr in range(TAXYEAR, 2035)]
65+
assert list(wdf.columns) == expected
66+
67+
def test_weight_row_count(self):
68+
"""Weight files have one row per TMD record."""
69+
wpath = STATE_WEIGHT_DIR / "mn_tmd_weights.csv.gz"
70+
wdf = pd.read_csv(wpath)
71+
# Should match TMD record count (215,494 for 2022)
72+
assert len(wdf) > 200_000
73+
74+
@pytest.mark.parametrize(
75+
"state",
76+
[s.lower() for s in ALL_STATES],
77+
)
78+
def test_weights_nonnegative(self, state):
79+
"""All weights are non-negative."""
80+
wpath = STATE_WEIGHT_DIR / f"{state}_tmd_weights.csv.gz"
81+
wdf = pd.read_csv(wpath)
82+
assert (wdf >= 0).all().all(), f"{state}: negative weights found"
83+
84+
@pytest.mark.parametrize(
85+
"state",
86+
[s.lower() for s in ALL_STATES],
87+
)
88+
def test_weights_no_nan(self, state):
89+
"""No NaN or inf values in weights."""
90+
wpath = STATE_WEIGHT_DIR / f"{state}_tmd_weights.csv.gz"
91+
wdf = pd.read_csv(wpath)
92+
assert not wdf.isna().any().any(), f"{state}: NaN values found"
93+
assert np.isfinite(wdf.values).all(), f"{state}: inf values found"
94+
95+
@pytest.mark.parametrize(
96+
"state",
97+
[s.lower() for s in ALL_STATES],
98+
)
99+
def test_solver_status_solved(self, state):
100+
"""Solver log reports Solved status."""
101+
logpath = STATE_WEIGHT_DIR / f"{state}.log"
102+
log_text = logpath.read_text()
103+
assert (
104+
"Solver status: Solved" in log_text
105+
), f"{state}: solver did not report Solved"
106+
107+
108+
@pytest.mark.skipif(
109+
not _HAS_CACHED,
110+
reason="Cached TMD data files not available",
111+
)
112+
class TestStateTargetAccuracy:
113+
"""Verify weighted sums hit targets within tolerance."""
114+
115+
@pytest.fixture(scope="class")
116+
def vdf(self):
117+
"""Load TMD data once for all accuracy tests."""
118+
return _load_taxcalc_data()
119+
120+
@pytest.mark.parametrize(
121+
"state",
122+
["al", "ca", "mn", "ny", "tx"],
123+
)
124+
def test_targets_hit(self, vdf, state):
125+
"""Weighted sums match targets within constraint tolerance."""
126+
out = io.StringIO()
127+
B_csc, targets, labels, pop_share = _build_constraint_matrix(
128+
state,
129+
vdf,
130+
out,
131+
target_dir=STATE_TARGET_DIR,
132+
)
133+
B_csc, targets, labels = _drop_impossible_targets(
134+
B_csc,
135+
targets,
136+
labels,
137+
out,
138+
)
139+
140+
# Load weights and compute multipliers
141+
wpath = STATE_WEIGHT_DIR / f"{state}_tmd_weights.csv.gz"
142+
wdf = pd.read_csv(wpath)
143+
area_weights = wdf[f"WT{TAXYEAR}"].values
144+
w0 = pop_share * vdf["s006"].values
145+
# Avoid division by zero for zero-weight records
146+
safe_w0 = np.where(w0 > 0, w0, 1.0)
147+
x = area_weights / safe_w0
148+
x = np.where(w0 > 0, x, 0.0)
149+
150+
# Check target accuracy
151+
achieved = np.asarray(B_csc @ x).ravel()
152+
rel_errors = np.abs(achieved - targets) / np.maximum(
153+
np.abs(targets), 1.0
154+
)
155+
# Allow small margin above solver tolerance for floating-point
156+
# differences between solver internals and weight-file roundtrip
157+
eps = 1e-4
158+
n_violated = int((rel_errors > AREA_CONSTRAINT_TOL + eps).sum())
159+
max_err = rel_errors.max()
160+
assert n_violated == 0, (
161+
f"{state}: {n_violated} targets violated, "
162+
f"max error = {max_err * 100:.3f}%"
163+
)

0 commit comments

Comments
 (0)