Skip to content

Commit 86c9ba0

Browse files
authored
Merge pull request #474 from PSLmodels/pr4-cd-pipeline
Add Congressional District weight solving pipeline
2 parents 31de12a + 89ba786 commit 86c9ba0

11 files changed

Lines changed: 2906 additions & 92 deletions

tests/test_fingerprint.py

Lines changed: 99 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
NOT run by `make test` (excluded in Makefile).
55
Run manually after a full solve:
66
7-
pytest tests/test_fingerprint.py -v
87
pytest tests/test_fingerprint.py -v --update-fingerprint
8+
pytest tests/test_fingerprint.py -v
9+
pytest tests/test_fingerprint.py -v -k states
10+
pytest tests/test_fingerprint.py -v -k cds
911
10-
The first run saves a reference fingerprint.
12+
The first run saves a reference fingerprint per scope.
1113
Subsequent runs compare against it.
1214
1315
Fingerprint method: for each area, round weights to nearest integer,
@@ -26,6 +28,7 @@
2628

2729
FINGERPRINT_DIR = AREAS_FOLDER / "fingerprints"
2830
STATE_WEIGHT_DIR = AREAS_FOLDER / "weights" / "states"
31+
CD_WEIGHT_DIR = AREAS_FOLDER / "weights" / "cds"
2932

3033
ALL_STATES = [
3134
"AL",
@@ -82,6 +85,17 @@
8285
]
8386

8487

88+
def _discover_cd_areas():
89+
"""Discover CD area codes from existing weight files."""
90+
if not CD_WEIGHT_DIR.exists():
91+
return []
92+
codes = sorted(
93+
p.name.replace("_tmd_weights.csv.gz", "").upper()
94+
for p in CD_WEIGHT_DIR.glob("*_tmd_weights.csv.gz")
95+
)
96+
return codes
97+
98+
8599
def _compute_fingerprint(areas, weight_dir):
86100
"""Compute fingerprint from weight files.
87101
@@ -145,58 +159,95 @@ def _has_weight_files(weight_dir, areas):
145159
return False
146160

147161

162+
def _run_fingerprint_test(scope, areas, weight_dir, update):
163+
"""Shared logic for fingerprint comparison."""
164+
current = _compute_fingerprint(areas, weight_dir)
165+
166+
if update:
167+
path = _save_fingerprint(scope, current)
168+
pytest.skip(f"Saved to {path} — re-run to test")
169+
170+
reference = _load_fingerprint(scope)
171+
if reference is None:
172+
path = _save_fingerprint(scope, current)
173+
pytest.skip(f"No reference found. Saved to {path} — re-run")
174+
175+
ref_n = reference["n_areas"]
176+
cur_n = current["n_areas"]
177+
assert cur_n == ref_n, f"Area count: {ref_n} -> {cur_n}"
178+
179+
assert (
180+
current["weight_hash"] == reference["weight_hash"]
181+
), "Weight hash mismatch — results changed"
182+
183+
184+
def _run_detail_test(scope, areas, weight_dir, update):
185+
"""Shared logic for per-area sum comparison."""
186+
if update:
187+
pytest.skip("Update mode")
188+
189+
reference = _load_fingerprint(scope)
190+
if reference is None:
191+
pytest.skip("No reference fingerprint")
192+
193+
current = _compute_fingerprint(areas, weight_dir)
194+
ref_sums = reference.get("per_area_int_sums", {})
195+
cur_sums = current.get("per_area_int_sums", {})
196+
197+
mismatches = []
198+
for area in sorted(ref_sums.keys()):
199+
if area not in cur_sums:
200+
mismatches.append(f"{area}: missing")
201+
continue
202+
if ref_sums[area] != cur_sums[area]:
203+
mismatches.append(
204+
f"{area}: {ref_sums[area]}" f" -> {cur_sums[area]}"
205+
)
206+
207+
assert not mismatches, f"{len(mismatches)} areas changed:\n" + "\n".join(
208+
mismatches
209+
)
210+
211+
212+
# --- State tests ---
213+
214+
148215
@pytest.mark.skipif(
149216
not _has_weight_files(STATE_WEIGHT_DIR, ALL_STATES),
150217
reason="No state weight files — run solve_weights first",
151218
)
152-
class TestStateFingerprint:
219+
class TestStatesFingerprint:
153220
"""Fingerprint tests for state weights."""
154221

155222
# pylint: disable=redefined-outer-name
156223
def test_state_weights_match_reference(self, update_mode):
157224
"""Compare weight integer sums against saved reference."""
158-
current = _compute_fingerprint(ALL_STATES, STATE_WEIGHT_DIR)
159-
160-
if update_mode:
161-
path = _save_fingerprint("states", current)
162-
pytest.skip(f"Saved to {path} — re-run to test")
163-
164-
reference = _load_fingerprint("states")
165-
if reference is None:
166-
path = _save_fingerprint("states", current)
167-
pytest.skip(f"No reference found. Saved to {path} — re-run")
168-
169-
ref_n = reference["n_areas"]
170-
cur_n = current["n_areas"]
171-
assert cur_n == ref_n, f"Area count: {ref_n} -> {cur_n}"
172-
173-
assert (
174-
current["weight_hash"] == reference["weight_hash"]
175-
), "Weight hash mismatch — results changed"
176-
177-
def test_per_area_sums_match(self, update_mode):
178-
"""Identify which areas changed."""
179-
if update_mode:
180-
pytest.skip("Update mode")
181-
182-
reference = _load_fingerprint("states")
183-
if reference is None:
184-
pytest.skip("No reference fingerprint")
185-
186-
current = _compute_fingerprint(ALL_STATES, STATE_WEIGHT_DIR)
187-
ref_sums = reference.get("per_area_int_sums", {})
188-
cur_sums = current.get("per_area_int_sums", {})
189-
190-
mismatches = []
191-
for area in sorted(ref_sums.keys()):
192-
if area not in cur_sums:
193-
mismatches.append(f"{area}: missing")
194-
continue
195-
if ref_sums[area] != cur_sums[area]:
196-
mismatches.append(
197-
f"{area}: {ref_sums[area]}" f" -> {cur_sums[area]}"
198-
)
199-
200-
assert (
201-
not mismatches
202-
), f"{len(mismatches)} areas changed:\n" + "\n".join(mismatches)
225+
_run_fingerprint_test(
226+
"states", ALL_STATES, STATE_WEIGHT_DIR, update_mode
227+
)
228+
229+
def test_state_per_area_sums_match(self, update_mode):
230+
"""Identify which states changed."""
231+
_run_detail_test("states", ALL_STATES, STATE_WEIGHT_DIR, update_mode)
232+
233+
234+
# --- CD tests ---
235+
236+
_CD_AREAS = _discover_cd_areas()
237+
238+
239+
@pytest.mark.skipif(
240+
not _has_weight_files(CD_WEIGHT_DIR, _CD_AREAS),
241+
reason="No CD weight files — run solve_weights --scope cds first",
242+
)
243+
class TestCdsFingerprint:
244+
"""Fingerprint tests for congressional district weights."""
245+
246+
# pylint: disable=redefined-outer-name
247+
def test_cds_weights_match_reference(self, update_mode):
248+
"""Compare weight integer sums against saved reference."""
249+
_run_fingerprint_test("cds", _CD_AREAS, CD_WEIGHT_DIR, update_mode)
250+
251+
def test_cds_per_area_sums_match(self, update_mode):
252+
"""Identify which CDs changed."""
253+
_run_detail_test("cds", _CD_AREAS, CD_WEIGHT_DIR, update_mode)

tests/test_prepare_targets.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,23 @@ def test_mn_agi_share_reasonable(self, shares_data):
9898
share = mn_agi["soi_share"].values[0]
9999
assert 0.01 < share < 0.03
100100

101+
def test_no_duplicate_shares(self, shares_data):
102+
"""Each (area, var, count, fstatus, agistub) has one share."""
103+
state_shares = shares_data[~shares_data["stabbr"].isin(_EXCLUDE)]
104+
group_cols = [
105+
"stabbr",
106+
"basesoivname",
107+
"count",
108+
"fstatus",
109+
"agistub",
110+
]
111+
counts = state_shares.groupby(group_cols).size()
112+
dupes = counts[counts > 1]
113+
assert len(dupes) == 0, (
114+
f"Found {len(dupes)} duplicate share groups. "
115+
f"First few: {dupes.head(5).to_dict()}"
116+
)
117+
101118
def test_xtot_equals_us_population(self):
102119
"""XTOT 51-state sum equals US Census population."""
103120
pop_df = get_state_population(2022)

0 commit comments

Comments
 (0)