Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pipeline/bundle_pdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def build_client_lookup(
clients_obj = artifact.get("clients", [])
clients = clients_obj if isinstance(clients_obj, list) else []
lookup: Dict[tuple[str, str], dict] = {}
for client in clients: # type: ignore[var-annotated]
for client in clients:
sequence = client.get("sequence") # type: ignore[attr-defined]
client_id = client.get("client_id") # type: ignore[attr-defined]
lookup[(sequence, client_id)] = client # type: ignore[typeddict-item]
Expand Down
41 changes: 27 additions & 14 deletions pipeline/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from hashlib import sha1
from pathlib import Path
from string import Formatter
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, overload
import pandas as pd
import yaml
from babel.dates import format_date
Expand Down Expand Up @@ -95,6 +95,7 @@

THRESHOLD = 80


def convert_date_string(
date_str: str | datetime | pd.Timestamp, locale: str = "en"
) -> str | None:
Expand Down Expand Up @@ -172,6 +173,7 @@ def format_iso_date_for_language(iso_date: str, language: str) -> str:

return format_date(date_obj, format="long", locale=locale)


def check_addresses_complete(df: pd.DataFrame) -> pd.DataFrame:
"""
Check if address fields are complete in the DataFrame.
Expand All @@ -192,17 +194,13 @@ def check_addresses_complete(df: pd.DataFrame) -> pd.DataFrame:
]

for col in address_cols:
df[col] = (
df[col]
.astype(str)
.str.strip()
.replace({"": pd.NA, "nan": pd.NA})
)
df[col] = df[col].astype(str).str.strip().replace({"": pd.NA, "nan": pd.NA})

# Build combined address line
df["ADDRESS"] = (
df["STREET_ADDRESS_LINE_1"].fillna("") + " " +
df["STREET_ADDRESS_LINE_2"].fillna("")
df["STREET_ADDRESS_LINE_1"].fillna("")
+ " "
+ df["STREET_ADDRESS_LINE_2"].fillna("")
).str.strip()

df["ADDRESS"] = df["ADDRESS"].replace({"": pd.NA})
Expand Down Expand Up @@ -282,6 +280,7 @@ def over_16_check(date_of_birth, date_notice_delivery):

return age >= 16


def configure_logging(output_dir: Path, run_id: str) -> Path:
"""Configure file logging for the preprocessing step.

Expand Down Expand Up @@ -387,6 +386,7 @@ def read_input(file_path: Path) -> pd.DataFrame:
LOG.error("Failed to read %s: %s", file_path, exc)
raise


def normalize(col: str) -> str:
"""Normalize formatting prior to matching."""

Expand Down Expand Up @@ -443,11 +443,10 @@ def map_columns(df: pd.DataFrame, required_columns=REQUIRED_COLUMNS):

# Check each input column against required columns
for input_col in normalized_input_cols:

col_name, score, index = process.extractOne(
query=input_col,
choices=[normalize(req) for req in required_columns],
scorer=fuzz.partial_ratio
scorer=fuzz.partial_ratio,
)

# Remove column if it has a score of 0
Expand All @@ -460,18 +459,32 @@ def map_columns(df: pd.DataFrame, required_columns=REQUIRED_COLUMNS):

# print colname and score for debugging
print(f"Matching '{input_col}' to '{best_match}' with score {score}")

return df.rename(columns=col_map), col_map


@overload
def filter_columns(
df: pd.DataFrame, required_columns: list[str] = REQUIRED_COLUMNS
) -> pd.DataFrame:
) -> pd.DataFrame: ...


@overload
def filter_columns(
df: None, required_columns: list[str] = REQUIRED_COLUMNS
) -> None: ...


def filter_columns(
df: pd.DataFrame | None, required_columns: list[str] = REQUIRED_COLUMNS
) -> pd.DataFrame | None:
"""Filter dataframe to only include required columns."""
if df is None or df.empty:
return df

return df[[col for col in df.columns if col in required_columns]]


def ensure_required_columns(df: pd.DataFrame) -> pd.DataFrame:
"""Normalize column names and validate that all required columns are present.

Expand Down Expand Up @@ -767,7 +780,7 @@ def build_preprocess_result(
sorted_df["SEQUENCE"] = [f"{idx + 1:05d}" for idx in range(len(sorted_df))]

clients: List[ClientRecord] = []
for row in sorted_df.itertuples(index=False): # type: ignore[attr-defined]
for row in sorted_df.itertuples(index=False):
client_id = str(row.CLIENT_ID) # type: ignore[attr-defined]
sequence = row.SEQUENCE # type: ignore[attr-defined]
dob_iso = (
Expand Down
11 changes: 7 additions & 4 deletions tests/unit/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_handles_non_alphabetic_characters(self):
"""Verify that non-letter characters are preserved."""
assert preprocess.normalize("123 Name!") == "123 name!"


@pytest.mark.unit
class TestFilterColumns:
"""Unit tests for filter_columns() column filtering utility."""
Expand Down Expand Up @@ -143,7 +144,7 @@ def test_returns_empty_dataframe_when_no_required_columns_present(self):

def test_handles_empty_dataframe(self):
"""Verify that an empty DataFrame is returned unchanged."""
df = pd.DataFrame(columns=["child_first_name", "child_last_name"])
df = pd.DataFrame(columns=pd.Index(["child_first_name", "child_last_name"]))
result = preprocess.filter_columns(df, ["child_first_name"])
assert result.empty

Expand All @@ -164,7 +165,10 @@ def test_order_of_columns_is_preserved(self):
required = ["dob", "child_first_name"]
result = preprocess.filter_columns(df, required)

assert list(result.columns) == ["child_first_name", "dob"] or list(result.columns) == required
assert (
list(result.columns) == ["child_first_name", "dob"]
or list(result.columns) == required
)
# Either column order can appear depending on implementation; both are acceptable

def test_ignores_required_columns_not_in_df(self):
Expand All @@ -176,6 +180,7 @@ def test_ignores_required_columns_not_in_df(self):
assert "child_first_name" in result.columns
assert "missing_column" not in result.columns


@pytest.mark.unit
class TestReadInput:
"""Unit tests for read_input function."""
Expand Down Expand Up @@ -785,5 +790,3 @@ def test_build_result_no_warning_for_unique_client_ids(
# Should have NO warnings about duplicates
duplicate_warnings = [w for w in result.warnings if "Duplicate client ID" in w]
assert len(duplicate_warnings) == 0