Skip to content

Commit f4ab939

Browse files
authored
Merge pull request #160 from WDGPH/fix/ty-check
type checker fixes
2 parents cc7af66 + 5715d09 commit f4ab939

3 files changed

Lines changed: 35 additions & 19 deletions

File tree

pipeline/bundle_pdfs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def build_client_lookup(
316316
clients_obj = artifact.get("clients", [])
317317
clients = clients_obj if isinstance(clients_obj, list) else []
318318
lookup: Dict[tuple[str, str], dict] = {}
319-
for client in clients: # type: ignore[var-annotated]
319+
for client in clients:
320320
sequence = client.get("sequence") # type: ignore[attr-defined]
321321
client_id = client.get("client_id") # type: ignore[attr-defined]
322322
lookup[(sequence, client_id)] = client # type: ignore[typeddict-item]

pipeline/preprocess.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from hashlib import sha1
4848
from pathlib import Path
4949
from string import Formatter
50-
from typing import Any, Dict, List, Optional
50+
from typing import Any, Dict, List, Optional, overload
5151
import pandas as pd
5252
import yaml
5353
from babel.dates import format_date
@@ -95,6 +95,7 @@
9595

9696
THRESHOLD = 80
9797

98+
9899
def convert_date_string(
99100
date_str: str | datetime | pd.Timestamp, locale: str = "en"
100101
) -> str | None:
@@ -172,6 +173,7 @@ def format_iso_date_for_language(iso_date: str, language: str) -> str:
172173

173174
return format_date(date_obj, format="long", locale=locale)
174175

176+
175177
def check_addresses_complete(df: pd.DataFrame) -> pd.DataFrame:
176178
"""
177179
Check if address fields are complete in the DataFrame.
@@ -192,17 +194,13 @@ def check_addresses_complete(df: pd.DataFrame) -> pd.DataFrame:
192194
]
193195

194196
for col in address_cols:
195-
df[col] = (
196-
df[col]
197-
.astype(str)
198-
.str.strip()
199-
.replace({"": pd.NA, "nan": pd.NA})
200-
)
197+
df[col] = df[col].astype(str).str.strip().replace({"": pd.NA, "nan": pd.NA})
201198

202199
# Build combined address line
203200
df["ADDRESS"] = (
204-
df["STREET_ADDRESS_LINE_1"].fillna("") + " " +
205-
df["STREET_ADDRESS_LINE_2"].fillna("")
201+
df["STREET_ADDRESS_LINE_1"].fillna("")
202+
+ " "
203+
+ df["STREET_ADDRESS_LINE_2"].fillna("")
206204
).str.strip()
207205

208206
df["ADDRESS"] = df["ADDRESS"].replace({"": pd.NA})
@@ -282,6 +280,7 @@ def over_16_check(date_of_birth, date_notice_delivery):
282280

283281
return age >= 16
284282

283+
285284
def configure_logging(output_dir: Path, run_id: str) -> Path:
286285
"""Configure file logging for the preprocessing step.
287286
@@ -387,6 +386,7 @@ def read_input(file_path: Path) -> pd.DataFrame:
387386
LOG.error("Failed to read %s: %s", file_path, exc)
388387
raise
389388

389+
390390
def normalize(col: str) -> str:
391391
"""Normalize formatting prior to matching."""
392392

@@ -443,11 +443,10 @@ def map_columns(df: pd.DataFrame, required_columns=REQUIRED_COLUMNS):
443443

444444
# Check each input column against required columns
445445
for input_col in normalized_input_cols:
446-
447446
col_name, score, index = process.extractOne(
448447
query=input_col,
449448
choices=[normalize(req) for req in required_columns],
450-
scorer=fuzz.partial_ratio
449+
scorer=fuzz.partial_ratio,
451450
)
452451

453452
# Remove column if it has a score of 0
@@ -460,18 +459,32 @@ def map_columns(df: pd.DataFrame, required_columns=REQUIRED_COLUMNS):
460459

461460
# print colname and score for debugging
462461
print(f"Matching '{input_col}' to '{best_match}' with score {score}")
463-
462+
464463
return df.rename(columns=col_map), col_map
465464

465+
466+
@overload
466467
def filter_columns(
467468
df: pd.DataFrame, required_columns: list[str] = REQUIRED_COLUMNS
468-
) -> pd.DataFrame:
469+
) -> pd.DataFrame: ...
470+
471+
472+
@overload
473+
def filter_columns(
474+
df: None, required_columns: list[str] = REQUIRED_COLUMNS
475+
) -> None: ...
476+
477+
478+
def filter_columns(
479+
df: pd.DataFrame | None, required_columns: list[str] = REQUIRED_COLUMNS
480+
) -> pd.DataFrame | None:
469481
"""Filter dataframe to only include required columns."""
470482
if df is None or df.empty:
471483
return df
472484

473485
return df[[col for col in df.columns if col in required_columns]]
474486

487+
475488
def ensure_required_columns(df: pd.DataFrame) -> pd.DataFrame:
476489
"""Normalize column names and validate that all required columns are present.
477490
@@ -767,7 +780,7 @@ def build_preprocess_result(
767780
sorted_df["SEQUENCE"] = [f"{idx + 1:05d}" for idx in range(len(sorted_df))]
768781

769782
clients: List[ClientRecord] = []
770-
for row in sorted_df.itertuples(index=False): # type: ignore[attr-defined]
783+
for row in sorted_df.itertuples(index=False):
771784
client_id = str(row.CLIENT_ID) # type: ignore[attr-defined]
772785
sequence = row.SEQUENCE # type: ignore[attr-defined]
773786
dob_iso = (

tests/unit/test_preprocess.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def test_handles_non_alphabetic_characters(self):
112112
"""Verify that non-letter characters are preserved."""
113113
assert preprocess.normalize("123 Name!") == "123 name!"
114114

115+
115116
@pytest.mark.unit
116117
class TestFilterColumns:
117118
"""Unit tests for filter_columns() column filtering utility."""
@@ -143,7 +144,7 @@ def test_returns_empty_dataframe_when_no_required_columns_present(self):
143144

144145
def test_handles_empty_dataframe(self):
145146
"""Verify that an empty DataFrame is returned unchanged."""
146-
df = pd.DataFrame(columns=["child_first_name", "child_last_name"])
147+
df = pd.DataFrame(columns=pd.Index(["child_first_name", "child_last_name"]))
147148
result = preprocess.filter_columns(df, ["child_first_name"])
148149
assert result.empty
149150

@@ -164,7 +165,10 @@ def test_order_of_columns_is_preserved(self):
164165
required = ["dob", "child_first_name"]
165166
result = preprocess.filter_columns(df, required)
166167

167-
assert list(result.columns) == ["child_first_name", "dob"] or list(result.columns) == required
168+
assert (
169+
list(result.columns) == ["child_first_name", "dob"]
170+
or list(result.columns) == required
171+
)
168172
# Either column order can appear depending on implementation; both are acceptable
169173

170174
def test_ignores_required_columns_not_in_df(self):
@@ -176,6 +180,7 @@ def test_ignores_required_columns_not_in_df(self):
176180
assert "child_first_name" in result.columns
177181
assert "missing_column" not in result.columns
178182

183+
179184
@pytest.mark.unit
180185
class TestReadInput:
181186
"""Unit tests for read_input function."""
@@ -785,5 +790,3 @@ def test_build_result_no_warning_for_unique_client_ids(
785790
# Should have NO warnings about duplicates
786791
duplicate_warnings = [w for w in result.warnings if "Duplicate client ID" in w]
787792
assert len(duplicate_warnings) == 0
788-
789-

0 commit comments

Comments
 (0)