Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 99 additions & 48 deletions tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
NOT run by `make test` (excluded in Makefile).
Run manually after a full solve:

pytest tests/test_fingerprint.py -v
pytest tests/test_fingerprint.py -v --update-fingerprint
pytest tests/test_fingerprint.py -v
pytest tests/test_fingerprint.py -v -k states
pytest tests/test_fingerprint.py -v -k cds

The first run saves a reference fingerprint.
The first run saves a reference fingerprint per scope.
Subsequent runs compare against it.

Fingerprint method: for each area, round weights to nearest integer,
Expand All @@ -26,6 +28,7 @@

FINGERPRINT_DIR = AREAS_FOLDER / "fingerprints"
STATE_WEIGHT_DIR = AREAS_FOLDER / "weights" / "states"
CD_WEIGHT_DIR = AREAS_FOLDER / "weights" / "cds"

ALL_STATES = [
"AL",
Expand Down Expand Up @@ -82,6 +85,17 @@
]


def _discover_cd_areas():
"""Discover CD area codes from existing weight files."""
if not CD_WEIGHT_DIR.exists():
return []
codes = sorted(
p.name.replace("_tmd_weights.csv.gz", "").upper()
for p in CD_WEIGHT_DIR.glob("*_tmd_weights.csv.gz")
)
return codes


def _compute_fingerprint(areas, weight_dir):
"""Compute fingerprint from weight files.

Expand Down Expand Up @@ -145,58 +159,95 @@ def _has_weight_files(weight_dir, areas):
return False


def _run_fingerprint_test(scope, areas, weight_dir, update):
"""Shared logic for fingerprint comparison."""
current = _compute_fingerprint(areas, weight_dir)

if update:
path = _save_fingerprint(scope, current)
pytest.skip(f"Saved to {path} — re-run to test")

reference = _load_fingerprint(scope)
if reference is None:
path = _save_fingerprint(scope, current)
pytest.skip(f"No reference found. Saved to {path} — re-run")

ref_n = reference["n_areas"]
cur_n = current["n_areas"]
assert cur_n == ref_n, f"Area count: {ref_n} -> {cur_n}"

assert (
current["weight_hash"] == reference["weight_hash"]
), "Weight hash mismatch — results changed"


def _run_detail_test(scope, areas, weight_dir, update):
"""Shared logic for per-area sum comparison."""
if update:
pytest.skip("Update mode")

reference = _load_fingerprint(scope)
if reference is None:
pytest.skip("No reference fingerprint")

current = _compute_fingerprint(areas, weight_dir)
ref_sums = reference.get("per_area_int_sums", {})
cur_sums = current.get("per_area_int_sums", {})

mismatches = []
for area in sorted(ref_sums.keys()):
if area not in cur_sums:
mismatches.append(f"{area}: missing")
continue
if ref_sums[area] != cur_sums[area]:
mismatches.append(
f"{area}: {ref_sums[area]}" f" -> {cur_sums[area]}"
)

assert not mismatches, f"{len(mismatches)} areas changed:\n" + "\n".join(
mismatches
)


# --- State tests ---


@pytest.mark.skipif(
not _has_weight_files(STATE_WEIGHT_DIR, ALL_STATES),
reason="No state weight files — run solve_weights first",
)
class TestStateFingerprint:
class TestStatesFingerprint:
"""Fingerprint tests for state weights."""

# pylint: disable=redefined-outer-name
def test_state_weights_match_reference(self, update_mode):
"""Compare weight integer sums against saved reference."""
current = _compute_fingerprint(ALL_STATES, STATE_WEIGHT_DIR)

if update_mode:
path = _save_fingerprint("states", current)
pytest.skip(f"Saved to {path} — re-run to test")

reference = _load_fingerprint("states")
if reference is None:
path = _save_fingerprint("states", current)
pytest.skip(f"No reference found. Saved to {path} — re-run")

ref_n = reference["n_areas"]
cur_n = current["n_areas"]
assert cur_n == ref_n, f"Area count: {ref_n} -> {cur_n}"

assert (
current["weight_hash"] == reference["weight_hash"]
), "Weight hash mismatch — results changed"

def test_per_area_sums_match(self, update_mode):
"""Identify which areas changed."""
if update_mode:
pytest.skip("Update mode")

reference = _load_fingerprint("states")
if reference is None:
pytest.skip("No reference fingerprint")

current = _compute_fingerprint(ALL_STATES, STATE_WEIGHT_DIR)
ref_sums = reference.get("per_area_int_sums", {})
cur_sums = current.get("per_area_int_sums", {})

mismatches = []
for area in sorted(ref_sums.keys()):
if area not in cur_sums:
mismatches.append(f"{area}: missing")
continue
if ref_sums[area] != cur_sums[area]:
mismatches.append(
f"{area}: {ref_sums[area]}" f" -> {cur_sums[area]}"
)

assert (
not mismatches
), f"{len(mismatches)} areas changed:\n" + "\n".join(mismatches)
_run_fingerprint_test(
"states", ALL_STATES, STATE_WEIGHT_DIR, update_mode
)

def test_state_per_area_sums_match(self, update_mode):
"""Identify which states changed."""
_run_detail_test("states", ALL_STATES, STATE_WEIGHT_DIR, update_mode)


# --- CD tests ---

_CD_AREAS = _discover_cd_areas()


@pytest.mark.skipif(
not _has_weight_files(CD_WEIGHT_DIR, _CD_AREAS),
reason="No CD weight files — run solve_weights --scope cds first",
)
class TestCdsFingerprint:
"""Fingerprint tests for congressional district weights."""

# pylint: disable=redefined-outer-name
def test_cds_weights_match_reference(self, update_mode):
"""Compare weight integer sums against saved reference."""
_run_fingerprint_test("cds", _CD_AREAS, CD_WEIGHT_DIR, update_mode)

def test_cds_per_area_sums_match(self, update_mode):
"""Identify which CDs changed."""
_run_detail_test("cds", _CD_AREAS, CD_WEIGHT_DIR, update_mode)
17 changes: 17 additions & 0 deletions tests/test_prepare_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,23 @@ def test_mn_agi_share_reasonable(self, shares_data):
share = mn_agi["soi_share"].values[0]
assert 0.01 < share < 0.03

def test_no_duplicate_shares(self, shares_data):
"""Each (area, var, count, fstatus, agistub) has one share."""
state_shares = shares_data[~shares_data["stabbr"].isin(_EXCLUDE)]
group_cols = [
"stabbr",
"basesoivname",
"count",
"fstatus",
"agistub",
]
counts = state_shares.groupby(group_cols).size()
dupes = counts[counts > 1]
assert len(dupes) == 0, (
f"Found {len(dupes)} duplicate share groups. "
f"First few: {dupes.head(5).to_dict()}"
)

def test_xtot_equals_us_population(self):
"""XTOT 51-state sum equals US Census population."""
pop_df = get_state_population(2022)
Expand Down
Loading
Loading