From 5715d09946a700fb7d85182eb281257dbb4a26d0 Mon Sep 17 00:00:00 2001 From: Justin Angevaare Date: Tue, 20 Jan 2026 20:15:39 +0000 Subject: [PATCH] type checker fixes --- pipeline/bundle_pdfs.py | 2 +- pipeline/preprocess.py | 41 +++++++++++++++++++++++------------ tests/unit/test_preprocess.py | 11 ++++++---- 3 files changed, 35 insertions(+), 19 deletions(-) diff --git a/pipeline/bundle_pdfs.py b/pipeline/bundle_pdfs.py index 011eec6..b7add16 100644 --- a/pipeline/bundle_pdfs.py +++ b/pipeline/bundle_pdfs.py @@ -316,7 +316,7 @@ def build_client_lookup( clients_obj = artifact.get("clients", []) clients = clients_obj if isinstance(clients_obj, list) else [] lookup: Dict[tuple[str, str], dict] = {} - for client in clients: # type: ignore[var-annotated] + for client in clients: sequence = client.get("sequence") # type: ignore[attr-defined] client_id = client.get("client_id") # type: ignore[attr-defined] lookup[(sequence, client_id)] = client # type: ignore[typeddict-item] diff --git a/pipeline/preprocess.py b/pipeline/preprocess.py index 6209d54..cdad5c4 100644 --- a/pipeline/preprocess.py +++ b/pipeline/preprocess.py @@ -47,7 +47,7 @@ from hashlib import sha1 from pathlib import Path from string import Formatter -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, overload import pandas as pd import yaml from babel.dates import format_date @@ -95,6 +95,7 @@ THRESHOLD = 80 + def convert_date_string( date_str: str | datetime | pd.Timestamp, locale: str = "en" ) -> str | None: @@ -172,6 +173,7 @@ def format_iso_date_for_language(iso_date: str, language: str) -> str: return format_date(date_obj, format="long", locale=locale) + def check_addresses_complete(df: pd.DataFrame) -> pd.DataFrame: """ Check if address fields are complete in the DataFrame. @@ -192,17 +194,13 @@ def check_addresses_complete(df: pd.DataFrame) -> pd.DataFrame: ] for col in address_cols: - df[col] = ( - df[col] - .astype(str) - .str.strip() - .replace({"": pd.NA, "nan": pd.NA}) - ) + df[col] = df[col].astype(str).str.strip().replace({"": pd.NA, "nan": pd.NA}) # Build combined address line df["ADDRESS"] = ( - df["STREET_ADDRESS_LINE_1"].fillna("") + " " + - df["STREET_ADDRESS_LINE_2"].fillna("") + df["STREET_ADDRESS_LINE_1"].fillna("") + + " " + + df["STREET_ADDRESS_LINE_2"].fillna("") ).str.strip() df["ADDRESS"] = df["ADDRESS"].replace({"": pd.NA}) @@ -282,6 +280,7 @@ def over_16_check(date_of_birth, date_notice_delivery): return age >= 16 + def configure_logging(output_dir: Path, run_id: str) -> Path: """Configure file logging for the preprocessing step. @@ -387,6 +386,7 @@ def read_input(file_path: Path) -> pd.DataFrame: LOG.error("Failed to read %s: %s", file_path, exc) raise + def normalize(col: str) -> str: """Normalize formatting prior to matching.""" @@ -443,11 +443,10 @@ def map_columns(df: pd.DataFrame, required_columns=REQUIRED_COLUMNS): # Check each input column against required columns for input_col in normalized_input_cols: - col_name, score, index = process.extractOne( query=input_col, choices=[normalize(req) for req in required_columns], - scorer=fuzz.partial_ratio + scorer=fuzz.partial_ratio, ) # Remove column if it has a score of 0 @@ -460,18 +459,32 @@ def map_columns(df: pd.DataFrame, required_columns=REQUIRED_COLUMNS): # print colname and score for debugging print(f"Matching '{input_col}' to '{best_match}' with score {score}") - + return df.rename(columns=col_map), col_map + +@overload def filter_columns( df: pd.DataFrame, required_columns: list[str] = REQUIRED_COLUMNS -) -> pd.DataFrame: +) -> pd.DataFrame: ... + + +@overload +def filter_columns( + df: None, required_columns: list[str] = REQUIRED_COLUMNS +) -> None: ... + + +def filter_columns( + df: pd.DataFrame | None, required_columns: list[str] = REQUIRED_COLUMNS +) -> pd.DataFrame | None: """Filter dataframe to only include required columns.""" if df is None or df.empty: return df return df[[col for col in df.columns if col in required_columns]] + def ensure_required_columns(df: pd.DataFrame) -> pd.DataFrame: """Normalize column names and validate that all required columns are present. @@ -767,7 +780,7 @@ def build_preprocess_result( sorted_df["SEQUENCE"] = [f"{idx + 1:05d}" for idx in range(len(sorted_df))] clients: List[ClientRecord] = [] - for row in sorted_df.itertuples(index=False): # type: ignore[attr-defined] + for row in sorted_df.itertuples(index=False): client_id = str(row.CLIENT_ID) # type: ignore[attr-defined] sequence = row.SEQUENCE # type: ignore[attr-defined] dob_iso = ( diff --git a/tests/unit/test_preprocess.py b/tests/unit/test_preprocess.py index 471ff23..2d7cc77 100644 --- a/tests/unit/test_preprocess.py +++ b/tests/unit/test_preprocess.py @@ -112,6 +112,7 @@ def test_handles_non_alphabetic_characters(self): """Verify that non-letter characters are preserved.""" assert preprocess.normalize("123 Name!") == "123 name!" + @pytest.mark.unit class TestFilterColumns: """Unit tests for filter_columns() column filtering utility.""" @@ -143,7 +144,7 @@ def test_returns_empty_dataframe_when_no_required_columns_present(self): def test_handles_empty_dataframe(self): """Verify that an empty DataFrame is returned unchanged.""" - df = pd.DataFrame(columns=["child_first_name", "child_last_name"]) + df = pd.DataFrame(columns=pd.Index(["child_first_name", "child_last_name"])) result = preprocess.filter_columns(df, ["child_first_name"]) assert result.empty @@ -164,7 +165,10 @@ def test_order_of_columns_is_preserved(self): required = ["dob", "child_first_name"] result = preprocess.filter_columns(df, required) - assert list(result.columns) == ["child_first_name", "dob"] or list(result.columns) == required + assert ( + list(result.columns) == ["child_first_name", "dob"] + or list(result.columns) == required + ) # Either column order can appear depending on implementation; both are acceptable def test_ignores_required_columns_not_in_df(self): @@ -176,6 +180,7 @@ def test_ignores_required_columns_not_in_df(self): assert "child_first_name" in result.columns assert "missing_column" not in result.columns + @pytest.mark.unit class TestReadInput: """Unit tests for read_input function.""" @@ -785,5 +790,3 @@ def test_build_result_no_warning_for_unique_client_ids( # Should have NO warnings about duplicates duplicate_warnings = [w for w in result.warnings if "Duplicate client ID" in w] assert len(duplicate_warnings) == 0 - -