diff --git a/Dockerfile b/Dockerfile index f2617a2..2258ffc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,3 +19,8 @@ COPY . /app RUN uv sync --frozen CMD [ "python", "smart_meter_analysis/foo.py" ] + +# R installation +RUN apt-get update \ + && apt-get install -y r-base r-base-dev \ + && rm -rf /var/lib/apt/lists/* diff --git a/analysis/clustering/clustering_validation.py b/analysis/clustering/clustering_validation.py deleted file mode 100644 index c45cbd2..0000000 --- a/analysis/clustering/clustering_validation.py +++ /dev/null @@ -1,428 +0,0 @@ -""" -Validation utilities for clustering pipeline. - -Validates data quality at each stage of the DTW clustering analysis. Designed -to work with both enriched data (with census variables) and raw processed data -(without census variables) to support flexible pipeline configurations. -""" - -from __future__ import annotations - -import logging -from typing import Any - -import polars as pl - -logger = logging.getLogger(__name__) - - -class ClusteringDataValidator: - """Validates data at each stage of clustering pipeline.""" - - def __init__(self): - self.errors: list[str] = [] - self.warnings: list[str] = [] - - def reset(self): - """Reset errors and warnings for a new validation run.""" - self.errors = [] - self.warnings = [] - - def validate_enriched_data(self, df: pl.DataFrame) -> dict[str, Any]: - """ - Validate interval-level data before aggregation. - - Works with both enriched data (with census variables) and raw processed - data (without census variables). Geographic validation is skipped if - census columns are not present. - - Args: - df: Interval-level energy data - - Returns: - Validation results with status, errors, warnings, and statistics - """ - self.reset() - logger.info("Validating interval data...") - - # Required energy columns (must be present) - required_energy_cols = ["zip_code", "account_identifier", "datetime", "kwh"] - - # Required time columns (must be present) - required_time_cols = ["date", "hour", "weekday", "is_weekend"] - - # Geographic columns (optional - only validated if present) - geo_cols = ["block_group_geoid"] - - missing_energy = [c for c in required_energy_cols if c not in df.columns] - missing_time = [c for c in required_time_cols if c not in df.columns] - - if missing_energy: - self.errors.append(f"Missing energy columns: {missing_energy}") - if missing_time: - self.errors.append(f"Missing time columns: {missing_time}") - - # Check data completeness for critical columns - critical_cols = ["zip_code", "account_identifier", "datetime", "kwh"] - for col in critical_cols: - if col not in df.columns: - continue - null_count = df[col].null_count() - null_pct = (null_count / len(df)) * 100 - if null_pct > 5: - self.errors.append(f"{col}: {null_pct:.1f}% null values (>5%)") - elif null_pct > 0: - self.warnings.append(f"{col}: {null_pct:.1f}% null values") - - # Geographic coverage - only check if column exists - match_rate = None - if "block_group_geoid" in df.columns: - total_rows = len(df) - matched_rows = df.filter(pl.col("block_group_geoid").is_not_null()).height - match_rate = (matched_rows / total_rows) * 100 - - if match_rate < 90: - self.errors.append(f"Low geographic match rate: {match_rate:.1f}%") - elif match_rate < 95: - self.warnings.append(f"Geographic match rate below 95%: {match_rate:.1f}%") - else: - self.warnings.append("No geographic columns - running without census enrichment") - - # Check time features - if "hour" in df.columns: - hour_min, hour_max = df["hour"].min(), df["hour"].max() - if hour_min < 0 or hour_max > 23: - self.errors.append(f"Invalid hour values: {hour_min} to {hour_max}") - - if "weekday" in df.columns: - weekday_min, weekday_max = df["weekday"].min(), df["weekday"].max() - if weekday_min < 1 or weekday_max > 7: - self.errors.append(f"Invalid weekday values: {weekday_min} to {weekday_max}") - - # Check energy values - if "kwh" in df.columns: - kwh_stats = df.select([ - pl.col("kwh").min().alias("min"), - pl.col("kwh").max().alias("max"), - pl.col("kwh").mean().alias("mean"), - ]).to_dicts()[0] - - if kwh_stats["min"] is not None and kwh_stats["min"] < 0: - self.warnings.append(f"Negative kWh values: min={kwh_stats['min']:.4f}") - - if kwh_stats["max"] is not None and kwh_stats["max"] > 100: - self.warnings.append(f"Very high kWh values: max={kwh_stats['max']:.2f}") - - self._print_results("INTERVAL DATA VALIDATION") - - return { - "status": "PASS" if not self.errors else "FAIL", - "errors": self.errors, - "warnings": self.warnings, - "stats": { - "n_rows": len(df), - "n_accounts": df["account_identifier"].n_unique() if "account_identifier" in df.columns else None, - "n_zip4s": df["zip_code"].n_unique() if "zip_code" in df.columns else None, - "geographic_match_rate": match_rate, - "has_census_data": "block_group_geoid" in df.columns, - }, - } - - def validate_daily_profiles(self, df: pl.DataFrame, expected_intervals: int = 48) -> dict[str, Any]: - """ - Validate daily load profiles after aggregation. - - Ensures profiles have the expected 48-point structure and reasonable values. - - Args: - df: Daily profiles with 'profile' list column - expected_intervals: Expected intervals per profile (default: 48) - - Returns: - Validation results with status, errors, warnings, and statistics - """ - self.reset() - logger.info("Validating daily profiles...") - - # Check required columns - required_cols = ["zip_code", "date", "profile"] - missing = [c for c in required_cols if c not in df.columns] - if missing: - self.errors.append(f"Missing required columns: {missing}") - self._print_results("PROFILE VALIDATION") - return {"status": "FAIL", "errors": self.errors, "warnings": self.warnings} - - # Check profile completeness - if "num_intervals" in df.columns: - incomplete = df.filter( - (pl.col("num_intervals") < expected_intervals - 1) | (pl.col("num_intervals") > expected_intervals) - ) - if incomplete.height > 0: - pct = (incomplete.height / len(df)) * 100 - self.warnings.append(f"{incomplete.height} profiles ({pct:.1f}%) have incorrect interval count") - - # Check profile array lengths - profile_lengths = df.select(pl.col("profile").list.len().alias("len")).unique() - unique_lengths = profile_lengths["len"].to_list() - - if len(unique_lengths) > 1: - self.warnings.append(f"Inconsistent profile lengths: {unique_lengths}") - elif unique_lengths[0] != expected_intervals: - self.errors.append(f"Profile length {unique_lengths[0]} != expected {expected_intervals}") - - # Check for null profiles - null_profiles = df.filter(pl.col("profile").is_null()).height - if null_profiles > 0: - self.errors.append(f"{null_profiles} null profiles found") - - # Check date coverage per ZIP+4 - dates_per_zip = df.group_by("zip_code").agg([ - pl.col("date").n_unique().alias("n_dates"), - pl.col("date").min().alias("min_date"), - pl.col("date").max().alias("max_date"), - ]) - - # Flag ZIP+4s with very few days - sparse_zips = dates_per_zip.filter(pl.col("n_dates") < 5) - if sparse_zips.height > 0: - self.warnings.append(f"{sparse_zips.height} ZIP+4s have fewer than 5 days of data") - - # Check for reasonable values in profiles - if df.height > 0: - value_stats = ( - df.select(pl.col("profile").list.explode().alias("value")) - .select([ - pl.col("value").min().alias("min"), - pl.col("value").max().alias("max"), - ]) - .to_dicts()[0] - ) - - if value_stats["min"] is not None and value_stats["min"] < 0: - self.warnings.append(f"Negative values in profiles: min={value_stats['min']:.2f}") - - if value_stats["max"] is not None and value_stats["max"] > 10000: - self.warnings.append(f"Very high values in profiles: max={value_stats['max']:.2f}") - - self._print_results("DAILY PROFILES VALIDATION") - - return { - "status": "PASS" if not self.errors else "FAIL", - "errors": self.errors, - "warnings": self.warnings, - "stats": { - "n_profiles": len(df), - "n_zip4s": df["zip_code"].n_unique(), - "n_dates": df["date"].n_unique(), - "profile_length": unique_lengths[0] if unique_lengths else None, - }, - } - - def validate_demographics(self, df: pl.DataFrame, required_zip4s: set[str] | None = None) -> dict[str, Any]: - """ - Validate census demographics data. - - Args: - df: Demographics data with ZIP+4 codes - required_zip4s: Set of ZIP+4 codes that must have demographics - - Returns: - Validation results with status, errors, warnings, and statistics - """ - self.reset() - logger.info("Validating demographics data...") - - # Handle empty demographics (valid when running without census data) - if df is None or df.height == 0: - self.warnings.append("No demographics data - running without census enrichment") - self._print_results("DEMOGRAPHICS VALIDATION") - return { - "status": "PASS", - "errors": [], - "warnings": self.warnings, - "stats": {"n_zip4s": 0, "n_demo_vars": 0}, - } - - # Check required columns - required_cols = ["zip_code"] - missing = [c for c in required_cols if c not in df.columns] - if missing: - self.errors.append(f"Missing required columns: {missing}") - - # Check GEOID format if present - if "block_group_geoid" in df.columns: - non_null_geoids = df.filter(pl.col("block_group_geoid").is_not_null()) - - if non_null_geoids.height > 0: - geoid_lengths = non_null_geoids["block_group_geoid"].str.len_chars().unique().to_list() - if geoid_lengths and 12 not in geoid_lengths: - self.errors.append(f"Block Group GEOIDs should be 12 digits, got: {geoid_lengths}") - - # Check Illinois state FIPS - non_il = non_null_geoids.filter(~pl.col("block_group_geoid").str.starts_with("17")).height - if non_il > 0: - self.warnings.append(f"{non_il} GEOIDs don't start with '17' (Illinois)") - - # Check coverage of required ZIP+4s - if required_zip4s: - present_zips = set(df["zip_code"].to_list()) - missing_zips = required_zip4s - present_zips - if missing_zips: - self.warnings.append(f"{len(missing_zips)} required ZIP+4s missing demographics") - - # Count demographic columns - demo_cols = [c for c in df.columns if c not in ["zip_code", "block_group_geoid", "Urban_Rural_Classification"]] - - # Check for excessive nulls in demographic columns - high_null_cols = [] - for col in demo_cols: - null_pct = (df[col].null_count() / len(df)) * 100 - if null_pct > 50: - high_null_cols.append(f"{col} ({null_pct:.1f}%)") - - if high_null_cols: - self.warnings.append(f"{len(high_null_cols)} columns with >50% nulls") - - self._print_results("DEMOGRAPHICS VALIDATION") - - return { - "status": "PASS" if not self.errors else "FAIL", - "errors": self.errors, - "warnings": self.warnings, - "stats": { - "n_zip4s": len(df), - "n_demo_vars": len(demo_cols), - "high_null_vars": len(high_null_cols), - }, - } - - def _print_results(self, title: str): - """Print validation results summary.""" - print(f"\n{'=' * 80}") - print(f"{title}") - print("=" * 80) - - if self.errors: - print(f"\n❌ FAILED with {len(self.errors)} error(s):") - for err in self.errors: - print(f" - {err}") - else: - print("\n✅ PASSED all critical checks") - - if self.warnings: - print(f"\n⚠️ {len(self.warnings)} warning(s):") - for warn in self.warnings: - print(f" - {warn}") - - -def validate_interval_completeness( - df: pl.DataFrame, account_col: str = "account_identifier", date_col: str = "date", expected_intervals: int = 48 -) -> pl.DataFrame: - """ - Check interval completeness for each account-date combination. - - Args: - df: Interval-level data - account_col: Account identifier column name - date_col: Date column name - expected_intervals: Expected intervals per day (default: 48) - - Returns: - DataFrame with completeness statistics per account-date - """ - completeness = ( - df.group_by([account_col, date_col]) - .agg([ - pl.len().alias("n_intervals"), - pl.col("kwh").is_null().sum().alias("n_null_kwh"), - ]) - .with_columns([ - (pl.col("n_intervals") == expected_intervals).alias("is_complete"), - ((pl.col("n_intervals") - pl.col("n_null_kwh")) / expected_intervals * 100).alias("completeness_pct"), - ]) - ) - - return completeness - - -def check_for_duplicates(df: pl.DataFrame, key_cols: list[str]) -> tuple[int, pl.DataFrame | None]: - """ - Check for duplicate records based on key columns. - - Args: - df: DataFrame to check - key_cols: Columns that should be unique together - - Returns: - Tuple of (duplicate_count, duplicate_records_df) - """ - duplicates = df.group_by(key_cols).agg(pl.len().alias("count")).filter(pl.col("count") > 1) - - n_dups = duplicates.height - - if n_dups > 0: - dup_keys = duplicates.select(key_cols) - dup_records = df.join(dup_keys, on=key_cols, how="inner") - return n_dups, dup_records - - return 0, None - - -def validate_time_series_array( - profiles: list[list[float]], expected_length: int = 48, max_profiles: int = 5000 -) -> dict[str, Any]: - """ - Validate time series arrays for clustering. - - Args: - profiles: List of time series profiles - expected_length: Expected length of each profile - max_profiles: Maximum profiles to validate (samples if exceeded) - - Returns: - Validation results dictionary - """ - import numpy as np - - issues = [] - warnings = [] - - if len(profiles) > max_profiles: - logger.info(f"Sampling {max_profiles} of {len(profiles)} profiles for validation") - profiles = profiles[:max_profiles] - - # Check lengths - lengths = [len(p) for p in profiles] - unique_lengths = set(lengths) - - if len(unique_lengths) > 1: - issues.append(f"Inconsistent lengths: {unique_lengths}") - elif list(unique_lengths)[0] != expected_length: - issues.append(f"Expected length {expected_length}, got {list(unique_lengths)[0]}") - - # Check for NaN/inf values - arr = np.array(profiles, dtype=np.float32) - if np.any(np.isnan(arr)): - issues.append("NaN values detected in profiles") - if np.any(np.isinf(arr)): - issues.append("Infinite values detected in profiles") - - # Check value ranges - if np.any(arr < 0): - warnings.append(f"Negative values detected: min={arr.min():.2f}") - - if arr.max() > 10000: - warnings.append(f"Very high values detected: max={arr.max():.2f}") - - return { - "status": "PASS" if not issues else "FAIL", - "issues": issues, - "warnings": warnings, - "stats": { - "n_profiles": len(profiles), - "profile_length": list(unique_lengths)[0] if len(unique_lengths) == 1 else None, - "min_value": float(arr.min()), - "max_value": float(arr.max()), - "mean_value": float(arr.mean()), - }, - } diff --git a/analysis/clustering/dtw_clustering.py b/analysis/clustering/dtw_clustering.py deleted file mode 100644 index fcc9ba2..0000000 --- a/analysis/clustering/dtw_clustering.py +++ /dev/null @@ -1,661 +0,0 @@ -""" -Phase 2: DTW K-Means Clustering for Load Profile Analysis. - -Clusters daily electricity usage profiles using Dynamic Time Warping (DTW) -distance metric to identify consumption patterns. - -Performance Optimization: - - K evaluation uses a subsample (default 2000 profiles) for speed - - Final clustering runs on full dataset with optimal k - - Configurable max_iter and n_init for speed vs quality tradeoff - -Pipeline: - 1. Load daily profiles from Phase 1 - 2. Normalize profiles (optional) - 3. Evaluate k values on subsample to find optimal k - 4. Run final clustering on full dataset with optimal k - 5. Output assignments, centroids, and visualizations - -Usage: - # Standard run (evaluates k=3-6, uses subsample for evaluation) - python dtw_clustering.py \\ - --input data/clustering/sampled_profiles.parquet \\ - --output-dir data/clustering/results \\ - --k-range 3 6 \\ - --find-optimal-k \\ - --normalize - - # Fast validation run - python dtw_clustering.py \\ - --input data/clustering/sampled_profiles.parquet \\ - --output-dir data/clustering/results \\ - --k-range 3 4 \\ - --max-eval-samples 1000 \\ - --eval-max-iter 5 -""" - -from __future__ import annotations - -import argparse -import json -import logging -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import polars as pl - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) - - -def load_profiles(path: Path) -> tuple[np.ndarray, pl.DataFrame]: - """ - Load profiles from parquet file. - - Args: - path: Path to sampled_profiles.parquet - - Returns: - Tuple of (profile_array, metadata_df) - """ - logger.info(f"Loading profiles from {path}") - - df = pl.read_parquet(path) - - # Extract profiles as numpy array - profiles = np.array(df["profile"].to_list(), dtype=np.float64) - - logger.info(f" Loaded {len(profiles):,} profiles with {profiles.shape[1]} time points each") - logger.info(f" Data shape: {profiles.shape}") - logger.info(f" Data range: [{profiles.min():.2f}, {profiles.max():.2f}]") - - return profiles, df - - -def normalize_profiles( - X: np.ndarray, - method: str = "zscore", -) -> np.ndarray: - """ - Normalize profiles for clustering. - - Args: - X: Profile array of shape (n_samples, n_timepoints) - method: Normalization method ('zscore', 'minmax', 'none') - - Returns: - Normalized array - """ - if method == "none": - return X - - logger.info(f"Normalizing profiles using {method} method...") - - if method == "zscore": - # Per-profile z-score normalization - means = X.mean(axis=1, keepdims=True) - stds = X.std(axis=1, keepdims=True) - stds[stds == 0] = 1 # Avoid division by zero - X_norm = (X - means) / stds - elif method == "minmax": - # Per-profile min-max normalization - mins = X.min(axis=1, keepdims=True) - maxs = X.max(axis=1, keepdims=True) - ranges = maxs - mins - ranges[ranges == 0] = 1 - X_norm = (X - mins) / ranges - else: - raise ValueError(f"Unknown normalization method: {method}") - - logger.info(f" Normalized data range: [{X_norm.min():.2f}, {X_norm.max():.2f}]") - - return X_norm - - -def evaluate_clustering( - X: np.ndarray, - k_range: range, - max_iter: int = 10, - n_init: int = 3, - random_state: int = 42, - max_eval_samples: int = 2000, -) -> dict: - """ - Evaluate clustering for different values of k using a subsample. - - Uses a random subsample for evaluation to reduce runtime while still - providing reliable estimates of optimal k. - - Args: - X: Profile array of shape (n_samples, n_timepoints) - k_range: Range of k values to test - max_iter: Maximum iterations per k-means run - n_init: Number of random initializations - random_state: Random seed for reproducibility - max_eval_samples: Maximum profiles to use for evaluation - - Returns: - Dictionary with k_values, inertia, and silhouette scores - """ - from sklearn.metrics import silhouette_score - from tslearn.clustering import TimeSeriesKMeans - - logger.info(f"Evaluating clustering for k in {list(k_range)}...") - - # Subsample for evaluation if dataset is large - if X.shape[0] > max_eval_samples: - rng = np.random.default_rng(random_state) - idx = rng.choice(X.shape[0], size=max_eval_samples, replace=False) - X_eval = X[idx] - logger.info(f" Using subsample of {max_eval_samples:,} profiles for k evaluation") - logger.info(f" (Full dataset: {X.shape[0]:,} profiles will be used for final clustering)") - else: - X_eval = X - logger.info(f" Using all {X_eval.shape[0]:,} profiles for evaluation") - - results = { - "k_values": [], - "inertia": [], - "silhouette": [], - } - - # Reshape for tslearn (n_samples, n_timepoints, n_features) - X_reshaped = X_eval.reshape(X_eval.shape[0], X_eval.shape[1], 1) - - for k in k_range: - logger.info(f"\n Testing k={k}...") - - model = TimeSeriesKMeans( - n_clusters=k, - metric="dtw", - max_iter=max_iter, - n_init=n_init, - random_state=random_state, - n_jobs=-1, - verbose=0, - ) - - labels = model.fit_predict(X_reshaped) - - # Use Euclidean distance for silhouette (faster, still informative) - sil_score = silhouette_score(X_eval, labels, metric="euclidean") - - results["k_values"].append(k) - results["inertia"].append(float(model.inertia_)) - results["silhouette"].append(float(sil_score)) - - logger.info(f" Inertia: {model.inertia_:.2f}") - logger.info(f" Silhouette: {sil_score:.3f}") - - return results - - -def find_optimal_k(eval_results: dict) -> int: - """ - Find optimal k based on silhouette score. - - Args: - eval_results: Results from evaluate_clustering - - Returns: - Optimal k value - """ - k_values = eval_results["k_values"] - silhouettes = eval_results["silhouette"] - - best_idx = np.argmax(silhouettes) - best_k = k_values[best_idx] - - logger.info(f"\nOptimal k={best_k} (silhouette={silhouettes[best_idx]:.3f})") - - return best_k - - -def run_final_clustering( - X: np.ndarray, - k: int, - max_iter: int = 10, - n_init: int = 3, - random_state: int = 42, -) -> tuple[np.ndarray, np.ndarray, float]: - """ - Run final clustering on full dataset with chosen k. - - Args: - X: Full profile array - k: Number of clusters - max_iter: Maximum iterations - n_init: Number of random initializations - random_state: Random seed - - Returns: - Tuple of (labels, centroids, inertia) - """ - from tslearn.clustering import TimeSeriesKMeans - - logger.info(f"\nRunning final clustering with k={k} on {X.shape[0]:,} profiles...") - - X_reshaped = X.reshape(X.shape[0], X.shape[1], 1) - - model = TimeSeriesKMeans( - n_clusters=k, - metric="dtw", - max_iter=max_iter, - n_init=n_init, - random_state=random_state, - n_jobs=-1, - verbose=1, - ) - - labels = model.fit_predict(X_reshaped) - centroids = model.cluster_centers_.squeeze() # Remove extra dimension - - logger.info(f" Final inertia: {model.inertia_:.2f}") - - # Log cluster distribution - unique, counts = np.unique(labels, return_counts=True) - for cluster, count in zip(unique, counts): - pct = count / len(labels) * 100 - logger.info(f" Cluster {cluster}: {count:,} profiles ({pct:.1f}%)") - - return labels, centroids, float(model.inertia_) - - -def plot_centroids( - centroids: np.ndarray, - output_path: Path, - title: str = "Cluster Centroids (Average Load Profiles)", -) -> None: - """ - Plot cluster centroids showing typical daily patterns. - - Args: - centroids: Centroid array of shape (k, n_timepoints) - output_path: Path to save plot - title: Plot title - """ - fig, ax = plt.subplots(figsize=(12, 6)) - - hours = np.arange(0, 24, 0.5) # 48 half-hour intervals - - colors = plt.cm.tab10(np.linspace(0, 1, len(centroids))) - - for i, (centroid, color) in enumerate(zip(centroids, colors)): - ax.plot(hours, centroid, label=f"Cluster {i}", color=color, linewidth=2) - - ax.set_xlabel("Hour of Day", fontsize=12) - ax.set_ylabel("Normalized Energy Usage", fontsize=12) - ax.set_title(title, fontsize=14) - ax.legend(loc="upper right") - ax.set_xticks(range(0, 25, 3)) - ax.grid(True, alpha=0.3) - - plt.tight_layout() - plt.savefig(output_path, dpi=150) - plt.close() - - logger.info(f" Saved centroid plot: {output_path}") - - -def plot_cluster_samples( - X: np.ndarray, - labels: np.ndarray, - centroids: np.ndarray, - output_path: Path, - n_samples: int = 50, -) -> None: - """ - Plot sample profiles from each cluster with centroid overlay. - - Args: - X: Profile array - labels: Cluster assignments - centroids: Cluster centroids - output_path: Path to save plot - n_samples: Number of sample profiles per cluster - """ - k = len(centroids) - fig, axes = plt.subplots(1, k, figsize=(5 * k, 4), sharey=True) - - if k == 1: - axes = [axes] - - hours = np.arange(0, 24, 0.5) - - for i, ax in enumerate(axes): - cluster_mask = labels == i - cluster_profiles = X[cluster_mask] - - # Sample profiles to plot - n_to_plot = min(n_samples, len(cluster_profiles)) - if n_to_plot > 0: - idx = np.random.choice(len(cluster_profiles), n_to_plot, replace=False) - for profile in cluster_profiles[idx]: - ax.plot(hours, profile, alpha=0.2, color="gray", linewidth=0.5) - - # Plot centroid - ax.plot(hours, centroids[i], color="red", linewidth=2, label="Centroid") - - ax.set_title(f"Cluster {i} (n={cluster_mask.sum():,})") - ax.set_xlabel("Hour of Day") - if i == 0: - ax.set_ylabel("Normalized Usage") - ax.set_xticks(range(0, 25, 6)) - ax.grid(True, alpha=0.3) - ax.legend(loc="upper right") - - plt.tight_layout() - plt.savefig(output_path, dpi=150) - plt.close() - - logger.info(f" Saved cluster samples plot: {output_path}") - - -def plot_elbow_curve( - eval_results: dict, - output_path: Path, -) -> None: - """ - Plot elbow curve showing inertia and silhouette vs k. - - Args: - eval_results: Results from evaluate_clustering - output_path: Path to save plot - """ - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) - - k_values = eval_results["k_values"] - - # Inertia plot - ax1.plot(k_values, eval_results["inertia"], "bo-", linewidth=2, markersize=8) - ax1.set_xlabel("Number of Clusters (k)") - ax1.set_ylabel("Inertia") - ax1.set_title("Elbow Curve") - ax1.grid(True, alpha=0.3) - - # Silhouette plot - ax2.plot(k_values, eval_results["silhouette"], "go-", linewidth=2, markersize=8) - ax2.set_xlabel("Number of Clusters (k)") - ax2.set_ylabel("Silhouette Score") - ax2.set_title("Silhouette Score vs k") - ax2.grid(True, alpha=0.3) - - # Mark best k - best_idx = np.argmax(eval_results["silhouette"]) - ax2.axvline(x=k_values[best_idx], color="red", linestyle="--", alpha=0.7) - ax2.annotate( - f"Best k={k_values[best_idx]}", - xy=(k_values[best_idx], eval_results["silhouette"][best_idx]), - xytext=(10, 10), - textcoords="offset points", - fontsize=10, - color="red", - ) - - plt.tight_layout() - plt.savefig(output_path, dpi=150) - plt.close() - - logger.info(f" Saved elbow curve: {output_path}") - - -def save_results( - df: pl.DataFrame, - labels: np.ndarray, - centroids: np.ndarray, - eval_results: dict | None, - metadata: dict, - output_dir: Path, -) -> None: - """ - Save all clustering results to output directory. - - Args: - df: Original profile DataFrame with metadata - labels: Cluster assignments - centroids: Cluster centroids - eval_results: K evaluation results (if any) - metadata: Clustering metadata - output_dir: Output directory - """ - output_dir.mkdir(parents=True, exist_ok=True) - - # Save cluster assignments - assignments = df.select(["zip_code", "date", "is_weekend", "weekday"]).with_columns(pl.Series("cluster", labels)) - assignments_path = output_dir / "cluster_assignments.parquet" - assignments.write_parquet(assignments_path) - logger.info(f" Saved assignments: {assignments_path}") - - # Save centroids as parquet - centroids_df = pl.DataFrame({ - "cluster": list(range(len(centroids))), - "centroid": [c.tolist() for c in centroids], - }) - centroids_path = output_dir / "cluster_centroids.parquet" - centroids_df.write_parquet(centroids_path) - logger.info(f" Saved centroids: {centroids_path}") - - # Save k evaluation results - if eval_results: - eval_path = output_dir / "k_evaluation.json" - with open(eval_path, "w") as f: - json.dump(eval_results, f, indent=2) - logger.info(f" Saved k evaluation: {eval_path}") - - # Save metadata - metadata_path = output_dir / "clustering_metadata.json" - with open(metadata_path, "w") as f: - json.dump(metadata, f, indent=2) - logger.info(f" Saved metadata: {metadata_path}") - - -def main() -> None: - parser = argparse.ArgumentParser( - description="DTW K-Means Clustering for Load Profiles", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Standard run with k evaluation - python dtw_clustering.py \\ - --input data/clustering/sampled_profiles.parquet \\ - --output-dir data/clustering/results \\ - --k-range 3 6 --find-optimal-k --normalize - - # Fast run for testing - python dtw_clustering.py \\ - --input data/clustering/sampled_profiles.parquet \\ - --output-dir data/clustering/results \\ - --k-range 3 4 --max-eval-samples 1000 --eval-max-iter 5 - - # Fixed k (no evaluation) - python dtw_clustering.py \\ - --input data/clustering/sampled_profiles.parquet \\ - --output-dir data/clustering/results \\ - --k 4 --normalize - """, - ) - - parser.add_argument( - "--input", - type=Path, - required=True, - help="Path to sampled_profiles.parquet", - ) - parser.add_argument( - "--output-dir", - type=Path, - default=Path("data/clustering/results"), - help="Output directory for results", - ) - - # K selection - k_group = parser.add_argument_group("Cluster Selection") - k_group.add_argument( - "--k", - type=int, - default=None, - help="Fixed number of clusters (skip evaluation)", - ) - k_group.add_argument( - "--k-range", - type=int, - nargs=2, - metavar=("MIN", "MAX"), - default=[3, 6], - help="Range of k values to evaluate (default: 3 6)", - ) - k_group.add_argument( - "--find-optimal-k", - action="store_true", - help="Evaluate k range and use optimal k", - ) - - # Performance tuning - perf_group = parser.add_argument_group("Performance Tuning") - perf_group.add_argument( - "--max-eval-samples", - type=int, - default=2000, - help="Max profiles for k evaluation subsample (default: 2000)", - ) - perf_group.add_argument( - "--eval-max-iter", - type=int, - default=10, - help="Max iterations for k evaluation runs (default: 10)", - ) - perf_group.add_argument( - "--eval-n-init", - type=int, - default=3, - help="Number of initializations for k evaluation (default: 3)", - ) - perf_group.add_argument( - "--final-max-iter", - type=int, - default=10, - help="Max iterations for final clustering (default: 10)", - ) - perf_group.add_argument( - "--final-n-init", - type=int, - default=3, - help="Number of initializations for final clustering (default: 3)", - ) - - # Preprocessing - parser.add_argument( - "--normalize", - action="store_true", - help="Apply z-score normalization to profiles", - ) - parser.add_argument( - "--normalize-method", - choices=["zscore", "minmax", "none"], - default="zscore", - help="Normalization method (default: zscore)", - ) - - parser.add_argument( - "--random-state", - type=int, - default=42, - help="Random seed for reproducibility", - ) - - args = parser.parse_args() - - print("=" * 70) - print("PHASE 2: DTW K-MEANS CLUSTERING") - print("=" * 70) - - # Load profiles - X, df = load_profiles(args.input) - - # Normalize if requested - if args.normalize: - X = normalize_profiles(X, method=args.normalize_method) - - # Determine k - eval_results = None - - if args.k is not None: - # Fixed k - k = args.k - logger.info(f"\nUsing fixed k={k}") - elif args.find_optimal_k: - # Evaluate k range on subsample - k_range = range(args.k_range[0], args.k_range[1] + 1) - - eval_results = evaluate_clustering( - X, - k_range=k_range, - max_iter=args.eval_max_iter, - n_init=args.eval_n_init, - random_state=args.random_state, - max_eval_samples=args.max_eval_samples, - ) - - # Save elbow curve - args.output_dir.mkdir(parents=True, exist_ok=True) - plot_elbow_curve(eval_results, args.output_dir / "elbow_curve.png") - - k = find_optimal_k(eval_results) - else: - # Default to min of k_range - k = args.k_range[0] - logger.info(f"\nUsing default k={k}") - - # Run final clustering on full dataset - labels, centroids, inertia = run_final_clustering( - X, - k=k, - max_iter=args.final_max_iter, - n_init=args.final_n_init, - random_state=args.random_state, - ) - - # Create visualizations - logger.info("\nGenerating visualizations...") - args.output_dir.mkdir(parents=True, exist_ok=True) - - plot_centroids(centroids, args.output_dir / "cluster_centroids.png") - plot_cluster_samples(X, labels, centroids, args.output_dir / "cluster_samples.png") - - # Save results - logger.info("\nSaving results...") - - metadata = { - "k": k, - "n_profiles": len(X), - "n_timepoints": X.shape[1], - "normalized": args.normalize, - "normalize_method": args.normalize_method if args.normalize else None, - "max_iter": args.final_max_iter, - "n_init": args.final_n_init, - "random_state": args.random_state, - "inertia": inertia, - "eval_max_samples": args.max_eval_samples if args.find_optimal_k else None, - } - - save_results(df, labels, centroids, eval_results, metadata, args.output_dir) - - # Summary - print("\n" + "=" * 70) - print("CLUSTERING COMPLETE") - print("=" * 70) - print(f"\nResults saved to: {args.output_dir}") - print(f" • {len(X):,} profiles clustered into {k} groups") - print(f" • Inertia: {inertia:.2f}") - if eval_results: - best_sil = max(eval_results["silhouette"]) - print(f" • Best silhouette score: {best_sil:.3f}") - print("=" * 70) - - -if __name__ == "__main__": - main() diff --git a/analysis/clustering/euclidean_clustering.py b/analysis/clustering/euclidean_clustering.py deleted file mode 100644 index 9ce1e2a..0000000 --- a/analysis/clustering/euclidean_clustering.py +++ /dev/null @@ -1,823 +0,0 @@ -#!/usr/bin/env python3 -""" -Phase 2: K-Means Clustering for Load Profile Analysis. - -Clusters daily electricity usage profiles using standard Euclidean distance -to identify distinct consumption patterns. - -Pipeline: - 1. Load daily profiles from Phase 1 - 2. Normalize profiles (optional) - 3. Evaluate k values to find optimal k (via silhouette score on a sample) - 4. Run final clustering with optimal k - 5. Output assignments, centroids, and visualizations - -Usage: - # Standard run (evaluates k=3-6 using silhouette on up to 10k samples) - python euclidean_clustering_fixed.py \\ - --input data/clustering/sampled_profiles.parquet \\ - --output-dir data/clustering/results \\ - --k-range 3 6 \\ - --find-optimal-k \\ - --normalize \\ - --silhouette-sample-size 10000 - - # Fixed k (no evaluation) - python euclidean_clustering.py \\ - --input data/clustering/sampled_profiles.parquet \\ - --output-dir data/clustering/results \\ - --k 4 --normalize -""" - -from __future__ import annotations - -import argparse -import json -import logging -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import polars as pl -from sklearn.cluster import KMeans -from sklearn.metrics import silhouette_score - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) - -DEFAULT_NORMALIZATION: str = "minmax" -DEFAULT_NORMALIZE: bool = True - - -def load_profiles(path: Path) -> tuple[np.ndarray, pl.DataFrame]: - """ - Load profiles from parquet file. - - Args: - path: Path to sampled_profiles.parquet - - Returns: - Tuple of (profile_array, metadata_df) - """ - logger.info(f"Loading profiles from {path}") - - df = pl.read_parquet(path) - - # Extract profiles as numpy array - profiles = np.array(df["profile"].to_list(), dtype=np.float64) - - logger.info( - " Loaded %s profiles with %s time points each", - f"{len(profiles):,}", - profiles.shape[1], - ) - logger.info(" Data shape: %s", (profiles.shape[0], profiles.shape[1])) - logger.info(" Data range: [%.2f, %.2f]", profiles.min(), profiles.max()) - - return profiles, df - - -def normalize_profiles( - df: pl.DataFrame, - method: str = "minmax", - profile_col: str = "profile", - out_col: str | None = None, -) -> pl.DataFrame: - """ - Normalize per-household-day profiles for clustering. - - Parameters - ---------- - df : pl.DataFrame - Must contain a list column with the profile values (e.g. 48-dim vector). - method : {"none", "zscore", "minmax"} - - "none": return df unchanged - - "zscore": per-profile z-score: (x - mean) / std - - "minmax": per-profile min-max: (x - min) / (max - min) - profile_col : str - Name of the list column holding the raw profile. - out_col : str | None - If provided, write normalized profile to this column; otherwise overwrite - `profile_col` in-place. - - Notes - ----- - - Normalization is done per profile (per row), not globally. - - For degenerate profiles where max == min, we fall back to all zeros. - """ - - if method == "none": - return df - - if profile_col not in df.columns: - raise ValueError(f"normalize_profiles: column '{profile_col}' not found in DataFrame") - - target_col = out_col or profile_col - - expr = pl.col(profile_col) - - if method == "zscore": - mean_expr = expr.list.mean() - std_expr = expr.list.std(ddof=0) - - normalized = (expr - mean_expr) / std_expr - - # If std == 0 (flat profile), fall back to zeros - normalized = pl.when(std_expr != 0).then(normalized).otherwise(expr * 0.0) - - elif method == "minmax": - min_expr = expr.list.min() - max_expr = expr.list.max() - range_expr = max_expr - min_expr - - normalized = (expr - min_expr) / range_expr - - # If range == 0 (flat profile), fall back to zeros - normalized = pl.when(range_expr != 0).then(normalized).otherwise(expr * 0.0) - - else: - raise ValueError(f"Unknown normalization method: {method!r}") - - return df.with_columns(normalized.alias(target_col)) - - -def evaluate_clustering( - X: np.ndarray, - k_range: range, - n_init: int = 10, - random_state: int = 42, - silhouette_sample_size: int | None = 10_000, -) -> dict: - """ - Evaluate clustering for different values of k. - - Uses inertia on the full dataset and silhouette score computed on a - subsample (to avoid O(n^2) cost when n is large). - - Args: - X: Profile array of shape (n_samples, n_timepoints) - k_range: Range of k values to test - n_init: Number of random initializations - random_state: Random seed for reproducibility - silhouette_sample_size: Max number of samples for silhouette. - If None, uses full dataset (NOT recommended for very large n). - - Returns: - Dictionary with k_values, inertia, and silhouette scores - """ - n_samples = X.shape[0] - logger.info("Evaluating clustering for k in %s...", list(k_range)) - logger.info(" Dataset size: %s profiles", f"{n_samples:,}") - - if silhouette_sample_size is None: - logger.info(" Silhouette: using FULL dataset (may be very slow).") - elif n_samples > silhouette_sample_size: - logger.info( - " Silhouette: using a random subsample of %s profiles.", - f"{silhouette_sample_size:,}", - ) - else: - logger.info( - " Silhouette: using all %s profiles (<= sample size).", - f"{n_samples:,}", - ) - - results = { - "k_values": [], - "inertia": [], - "silhouette": [], - } - - for k in k_range: - logger.info("") - logger.info(" Testing k=%d...", k) - - model = KMeans( - n_clusters=k, - n_init=n_init, - random_state=random_state, - ) - - labels = model.fit_predict(X) - - # Inertia on full data - inertia = float(model.inertia_) - - # Silhouette on sample (or full data if silhouette_sample_size is None) - sil_kwargs: dict = {"metric": "euclidean"} - if silhouette_sample_size is not None and n_samples > silhouette_sample_size: - sil_kwargs["sample_size"] = silhouette_sample_size - sil_kwargs["random_state"] = random_state - - sil_score = silhouette_score(X, labels, **sil_kwargs) - - results["k_values"].append(k) - results["inertia"].append(inertia) - results["silhouette"].append(float(sil_score)) - - logger.info(" Inertia: %s", f"{inertia:,.2f}") - logger.info(" Silhouette: %.3f", sil_score) - - return results - - -def find_optimal_k(eval_results: dict) -> int: - """ - Find optimal k based on silhouette score. - - Args: - eval_results: Results from evaluate_clustering - - Returns: - Optimal k value - """ - k_values = eval_results["k_values"] - silhouettes = eval_results["silhouette"] - - best_idx = int(np.argmax(silhouettes)) - best_k = int(k_values[best_idx]) - - logger.info("") - logger.info( - "Optimal k=%d (silhouette=%.3f)", - best_k, - silhouettes[best_idx], - ) - - return best_k - - -def run_clustering( - X: np.ndarray, - k: int, - n_init: int = 10, - random_state: int = 42, -) -> tuple[np.ndarray, np.ndarray, float]: - """ - Run k-means clustering. - - Args: - X: Profile array - k: Number of clusters - n_init: Number of random initializations - random_state: Random seed - - Returns: - Tuple of (labels, centroids, inertia) - """ - logger.info("") - logger.info( - "Running k-means with k=%d on %s profiles...", - k, - f"{X.shape[0]:,}", - ) - - model = KMeans( - n_clusters=k, - n_init=n_init, - random_state=random_state, - ) - - labels = model.fit_predict(X) - centroids = model.cluster_centers_ - inertia = float(model.inertia_) - - logger.info(" Inertia: %s", f"{inertia:,.2f}") - - # Log cluster distribution - unique, counts = np.unique(labels, return_counts=True) - for cluster, count in zip(unique, counts): - pct = count / len(labels) * 100 - logger.info( - " Cluster %d: %s profiles (%.1f%%)", - cluster, - f"{count:,}", - pct, - ) - - return labels, centroids, inertia - - -def plot_centroids( - centroids: np.ndarray, - output_path: Path, -) -> None: - """ - Plot cluster centroids (average load profiles). - - Args: - centroids: Array of shape (k, n_timepoints) - output_path: Path to save plot - """ - k = len(centroids) - n_timepoints = centroids.shape[1] - - # Create hour labels (assuming 48 half-hourly intervals) - if n_timepoints == 48: - hours = np.arange(0.5, 24.5, 0.5) - xlabel = "Hour of Day" - elif n_timepoints == 24: - hours = np.arange(1, 25) - xlabel = "Hour of Day" - else: - hours = np.arange(n_timepoints) - xlabel = "Time Interval" - - fig, ax = plt.subplots(figsize=(12, 6)) - - colors = plt.cm.tab10(np.linspace(0, 1, k)) - - for i, (centroid, color) in enumerate(zip(centroids, colors)): - ax.plot(hours, centroid, label=f"Cluster {i}", color=color, linewidth=2) - - ax.set_xlabel(xlabel, fontsize=12) - ax.set_ylabel("Normalized Usage", fontsize=12) - ax.set_title("Cluster Centroids (Average Load Profiles)", fontsize=14) - ax.legend(loc="upper right") - ax.grid(True, alpha=0.3) - - if n_timepoints == 48: - ax.set_xticks(range(0, 25, 4)) - ax.set_xlim(0, 24) - - plt.tight_layout() - plt.savefig(output_path, dpi=150) - plt.close() - - logger.info(" Saved centroids plot: %s", output_path) - - -def plot_cluster_samples( - X: np.ndarray, - labels: np.ndarray, - centroids: np.ndarray, - output_path: Path, - n_samples: int = 50, - random_state: int = 42, -) -> None: - """ - Plot sample profiles from each cluster with centroid overlay. - - Args: - X: Profile array - labels: Cluster assignments - centroids: Cluster centroids - output_path: Path to save plot - n_samples: Number of sample profiles per cluster - random_state: Random seed - """ - k = len(centroids) - n_timepoints = X.shape[1] - - # Create hour labels - if n_timepoints == 48: - hours = np.arange(0.5, 24.5, 0.5) - elif n_timepoints == 24: - hours = np.arange(1, 25) - else: - hours = np.arange(n_timepoints) - - fig, axes = plt.subplots(1, k, figsize=(5 * k, 4), sharey=True) - if k == 1: - axes = [axes] - - rng = np.random.default_rng(random_state) - colors = plt.cm.tab10(np.linspace(0, 1, k)) - - for i, (ax, color) in enumerate(zip(axes, colors)): - cluster_mask = labels == i - cluster_profiles = X[cluster_mask] - - n_available = len(cluster_profiles) - if n_available == 0: - ax.set_title(f"Cluster {i} (n=0)") - ax.grid(True, alpha=0.3) - continue - - n_plot = min(n_samples, n_available) - idx = rng.choice(n_available, size=n_plot, replace=False) - - # Plot samples with transparency - for profile in cluster_profiles[idx]: - ax.plot(hours, profile, color=color, alpha=0.1, linewidth=0.5) - - # Plot centroid - ax.plot(hours, centroids[i], color="black", linewidth=2, label="Centroid") - - ax.set_title(f"Cluster {i} (n={n_available:,})") - ax.set_xlabel("Hour") - if i == 0: - ax.set_ylabel("Normalized Usage") - ax.grid(True, alpha=0.3) - - if n_timepoints == 48: - ax.set_xticks(range(0, 25, 6)) - ax.set_xlim(0, 24) - - plt.tight_layout() - plt.savefig(output_path, dpi=150) - plt.close() - - logger.info(" Saved cluster samples plot: %s", output_path) - - -def plot_elbow_curve( - eval_results: dict, - output_path: Path, -) -> None: - """ - Plot elbow curve (inertia and silhouette vs k). - - Args: - eval_results: Results from evaluate_clustering - output_path: Path to save plot - """ - k_values = eval_results["k_values"] - inertia = eval_results["inertia"] - silhouette = eval_results["silhouette"] - - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) - - # Inertia (elbow curve) - ax1.plot(k_values, inertia, "b-o", linewidth=2, markersize=8) - ax1.set_xlabel("Number of Clusters (k)", fontsize=12) - ax1.set_ylabel("Inertia", fontsize=12) - ax1.set_title("Elbow Curve", fontsize=14) - ax1.grid(True, alpha=0.3) - ax1.set_xticks(k_values) - - # Silhouette score - ax2.plot(k_values, silhouette, "g-o", linewidth=2, markersize=8) - ax2.set_xlabel("Number of Clusters (k)", fontsize=12) - ax2.set_ylabel("Silhouette Score", fontsize=12) - ax2.set_title("Silhouette Score", fontsize=14) - ax2.grid(True, alpha=0.3) - ax2.set_xticks(k_values) - - # Mark optimal k - best_idx = int(np.argmax(silhouette)) - ax2.axvline(x=k_values[best_idx], color="red", linestyle="--", alpha=0.7) - ax2.scatter( - [k_values[best_idx]], - [silhouette[best_idx]], - s=200, - facecolors="none", - edgecolors="red", - linewidths=2, - zorder=5, - ) - - plt.tight_layout() - plt.savefig(output_path, dpi=150) - plt.close() - - logger.info(" Saved elbow curve: %s", output_path) - - -def analyze_weekday_weekend_distribution( - df: pl.DataFrame, - labels: np.ndarray, -) -> dict: - """ - Analyze weekday vs weekend distribution across clusters. - - This diagnostic checks if certain clusters are dominated by weekdays - or weekends, which would suggest usage patterns are day-type dependent. - - Args: - df: Original profile DataFrame with 'is_weekend' column - labels: Cluster assignments - - Returns: - Dictionary with distribution statistics - """ - if "is_weekend" not in df.columns: - logger.warning(" No 'is_weekend' column found - skipping weekday/weekend analysis") - return {} - - # Add cluster labels to dataframe - df_with_clusters = df.with_columns(pl.Series("cluster", labels)) - - # Calculate distribution - dist = ( - df_with_clusters.group_by(["cluster", "is_weekend"]) - .agg(pl.len().alias("count")) - .sort(["cluster", "is_weekend"]) - ) - - # Calculate percentages per cluster - dist = dist.with_columns([(pl.col("count") / pl.col("count").sum().over("cluster") * 100).alias("pct")]) - - # Summary: % weekend by cluster - summary = ( - df_with_clusters.group_by("cluster") - .agg([pl.len().alias("total"), (pl.col("is_weekend").sum() / pl.len() * 100).alias("pct_weekend")]) - .sort("cluster") - ) - - logger.info("") - logger.info("=" * 70) - logger.info("WEEKDAY/WEEKEND DISTRIBUTION BY CLUSTER") - logger.info("=" * 70) - - for row in summary.iter_rows(named=True): - cluster = row["cluster"] - total = row["total"] - pct_weekend = row["pct_weekend"] - pct_weekday = 100 - pct_weekend - - logger.info( - " Cluster %d: %.1f%% weekday, %.1f%% weekend (n=%s)", cluster, pct_weekday, pct_weekend, f"{total:,}" - ) - - # Flag significant imbalances (>70% one type) - if pct_weekend > 70: - logger.warning(" ⚠️ Weekend-dominated cluster") - elif pct_weekday > 70: - logger.warning(" ⚠️ Weekday-dominated cluster") - - # Overall distribution - overall_weekend_pct = float(df_with_clusters["is_weekend"].mean() * 100) - logger.info("") - logger.info(" Overall dataset: %.1f%% weekend, %.1f%% weekday", overall_weekend_pct, 100 - overall_weekend_pct) - - # Chi-square test would go here if needed for formal significance testing - logger.info("=" * 70) - - return { - "cluster_distribution": summary.to_dicts(), - "detailed_distribution": dist.to_dicts(), - "overall_weekend_pct": overall_weekend_pct, - } - - -def save_results( - df: pl.DataFrame, - labels: np.ndarray, - centroids: np.ndarray, - eval_results: dict | None, - metadata: dict, - output_dir: Path, -) -> None: - """ - Save all clustering results to output directory. - - Args: - df: Original profile DataFrame with metadata - labels: Cluster assignments - centroids: Cluster centroids - eval_results: K evaluation results (if any) - metadata: Clustering metadata - output_dir: Output directory - """ - output_dir.mkdir(parents=True, exist_ok=True) - - # Determine which ID columns are present (household vs ZIP+4 level) - id_cols: list[str] = [] - if "account_identifier" in df.columns: - id_cols.append("account_identifier") - if "zip_code" in df.columns: - id_cols.append("zip_code") - id_cols.extend(["date", "is_weekend", "weekday"]) - - # Only include columns that exist - available_cols = [c for c in id_cols if c in df.columns] - - # Save cluster assignments - assignments = df.select(available_cols).with_columns( - pl.Series("cluster", labels), - ) - assignments_path = output_dir / "cluster_assignments.parquet" - assignments.write_parquet(assignments_path) - logger.info(" Saved assignments: %s", assignments_path) - - # Save centroids as parquet - centroids_df = pl.DataFrame({ - "cluster": list(range(len(centroids))), - "centroid": [c.tolist() for c in centroids], - }) - centroids_path = output_dir / "cluster_centroids.parquet" - centroids_df.write_parquet(centroids_path) - logger.info(" Saved centroids: %s", centroids_path) - - # Save k evaluation results - if eval_results: - eval_path = output_dir / "k_evaluation.json" - with open(eval_path, "w") as f: - json.dump(eval_results, f, indent=2) - logger.info(" Saved k evaluation: %s", eval_path) - - # Save metadata - metadata_path = output_dir / "clustering_metadata.json" - with open(metadata_path, "w") as f: - json.dump(metadata, f, indent=2) - logger.info(" Saved metadata: %s", metadata_path) - - -def main() -> None: - parser = argparse.ArgumentParser( - description="K-Means Clustering for Load Profiles (Euclidean Distance)", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Standard run with k evaluation (silhouette on sample) - python euclidean_clustering.py \\ - --input data/clustering/sampled_profiles.parquet \\ - --output-dir data/clustering/results \\ - --k-range 3 6 --find-optimal-k --normalize \\ - --silhouette-sample-size 10000 - - # Fixed k (no evaluation) - python euclidean_clustering.py \\ - --input data/clustering/sampled_profiles.parquet \\ - --output-dir data/clustering/results \\ - --k 4 --normalize - """, - ) - - parser.add_argument( - "--input", - type=Path, - required=True, - help="Path to sampled_profiles.parquet", - ) - parser.add_argument( - "--output-dir", - type=Path, - default=Path("data/clustering/results"), - help="Output directory for results", - ) - - # K selection - k_group = parser.add_argument_group("Cluster Selection") - k_group.add_argument( - "--k", - type=int, - default=None, - help="Fixed number of clusters (skip evaluation)", - ) - k_group.add_argument( - "--k-range", - type=int, - nargs=2, - metavar=("MIN", "MAX"), - default=[3, 6], - help="Range of k values to evaluate (default: 3 6)", - ) - k_group.add_argument( - "--find-optimal-k", - action="store_true", - help="Evaluate k range and use optimal k", - ) - k_group.add_argument( - "--silhouette-sample-size", - type=int, - default=10_000, - help=( - "Max number of samples for silhouette evaluation " - "(default: 10000; use -1 to use full dataset, not recommended for large n)." - ), - ) - - # Clustering parameters - parser.add_argument( - "--n-init", - type=int, - default=10, - help="Number of k-means initializations (default: 10)", - ) - - # Preprocessing - parser.add_argument( - "--normalize", - action="store_true", - default=DEFAULT_NORMALIZE, - help="Apply normalization to profiles (default: True)", - ) - parser.add_argument( - "--normalize-method", - choices=["zscore", "minmax", "none"], - default=DEFAULT_NORMALIZATION, - help="Normalization method (default: minmax)", - ) - - parser.add_argument( - "--random-state", - type=int, - default=42, - help="Random seed for reproducibility", - ) - - args = parser.parse_args() - - print("=" * 70) - print("PHASE 2: K-MEANS CLUSTERING (EUCLIDEAN DISTANCE)") - print("=" * 70) - - # Load profiles - X, df = load_profiles(args.input) - - logger.info("Loaded %s sampled profiles", f"{len(df):,}") - - # Normalize if requested - if args.normalize: - logger.info("Normalizing profiles per household-day (method=%s)", args.normalize_method) - df = normalize_profiles( - df, - method=args.normalize_method, # ✅ FIXED: was args.normalization_method - profile_col="profile", - out_col=None, - ) - # ✅ CRITICAL FIX: Re-extract normalized profiles as numpy array - X = np.array(df["profile"].to_list(), dtype=np.float64) - logger.info(" Normalized data range: [%.2f, %.2f]", X.min(), X.max()) - else: - logger.info("Profile normalization disabled (using raw kWh values).") - - # Determine k - eval_results = None - - if args.k is not None: - # Fixed k - k = args.k - logger.info("") - logger.info("Using fixed k=%d", k) - elif args.find_optimal_k: - # Evaluate k range - k_range = range(args.k_range[0], args.k_range[1] + 1) - - silhouette_sample_size: int | None - silhouette_sample_size = None if args.silhouette_sample_size < 0 else args.silhouette_sample_size - - eval_results = evaluate_clustering( - X, - k_range=k_range, - n_init=args.n_init, - random_state=args.random_state, - silhouette_sample_size=silhouette_sample_size, - ) - - # Save elbow curve - args.output_dir.mkdir(parents=True, exist_ok=True) - plot_elbow_curve(eval_results, args.output_dir / "elbow_curve.png") - - k = find_optimal_k(eval_results) - else: - # Default to min of k_range - k = args.k_range[0] - logger.info("") - logger.info("Using default k=%d", k) - - # Run final clustering - labels, centroids, inertia = run_clustering( - X, - k=k, - n_init=args.n_init, - random_state=args.random_state, - ) - - # Create visualizations - logger.info("") - logger.info("Generating visualizations...") - args.output_dir.mkdir(parents=True, exist_ok=True) - - plot_centroids(centroids, args.output_dir / "cluster_centroids.png") - plot_cluster_samples(X, labels, centroids, args.output_dir / "cluster_samples.png") - - # Save results - logger.info("") - logger.info("Saving results...") - - metadata = { - "k": int(k), - "n_profiles": int(X.shape[0]), - "n_timepoints": int(X.shape[1]), - "normalized": bool(args.normalize), - "normalize_method": args.normalize_method if args.normalize else None, - "n_init": int(args.n_init), - "random_state": int(args.random_state), - "inertia": float(inertia), - "distance_metric": "euclidean", - "silhouette_sample_size": (None if args.silhouette_sample_size < 0 else int(args.silhouette_sample_size)), - } - - save_results(df, labels, centroids, eval_results, metadata, args.output_dir) - - # Summary - print("\n" + "=" * 70) - print("CLUSTERING COMPLETE") - print("=" * 70) - print(f"\nResults saved to: {args.output_dir}") - print(f" • {X.shape[0]:,} profiles clustered into {k} groups") - print(f" • Inertia: {inertia:,.2f}") - if eval_results: - best_sil = max(eval_results["silhouette"]) - print(f" • Best silhouette score: {best_sil:.3f}") - print("=" * 70) - - -if __name__ == "__main__": - main() diff --git a/analysis/clustering/euclidean_clustering_k_search.py b/analysis/clustering/euclidean_clustering_k_search.py new file mode 100644 index 0000000..b655065 --- /dev/null +++ b/analysis/clustering/euclidean_clustering_k_search.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python3 +""" +MiniBatch K-Means clustering for large load-profile datasets (PyArrow-batched). + +This script is the Stage 1 clustering step in the ComEd pipeline. It reads +`sampled_profiles.parquet` in streaming batches via PyArrow, optionally normalizes +profiles row-wise, fits MiniBatchKMeans with partial_fit, then predicts and writes +cluster assignments back to Parquet incrementally. + +Modes: +- k-range evaluation (default): evaluate k in [k_min, k_max], compute silhouette on a + deterministic sample, select best k, then write final artifacts for the selected k. + +Outputs (written to --output-dir): +- cluster_assignments.parquet (for selected/best k) +- cluster_centroids.parquet +- cluster_centroids.png +- clustering_metadata.json +- k_evaluation.json (per-k summary metrics) +""" + +from __future__ import annotations + +import argparse +import json +import logging +from collections.abc import Iterable +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import polars as pl +import pyarrow as pa +import pyarrow.parquet as pq +from sklearn.cluster import MiniBatchKMeans +from sklearn.metrics import silhouette_score + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +# ============================================================================= +# IO + batch utilities +# ============================================================================= + + +def parquet_num_rows(path: Path) -> int: + """Return Parquet row count from file metadata (no full scan).""" + pf = pq.ParquetFile(path) + md = pf.metadata + if md is None: + # Extremely rare; ParquetFile.metadata is normally present. + total = 0 + for i in range(pf.num_row_groups): + rg = pf.metadata.row_group(i) # type: ignore[union-attr] + total += int(rg.num_rows) + return total + return int(md.num_rows) + + +def iter_profile_batches( + path: Path, + batch_size: int, + columns: list[str] | None = None, +) -> Iterable[pa.RecordBatch]: + """Yield RecordBatches from a Parquet file.""" + pf = pq.ParquetFile(path) + yield from pf.iter_batches(batch_size=batch_size, columns=columns) + + +def recordbatch_profiles_to_numpy(rb: pa.RecordBatch, profile_col: str = "profile") -> np.ndarray: + """Convert RecordBatch `profile` list column into a 2D float64 NumPy array.""" + idx = rb.schema.get_field_index(profile_col) + if idx < 0: + raise ValueError(f"RecordBatch missing required column '{profile_col}'") + profiles = rb.column(idx).to_pylist() + X = np.asarray(profiles, dtype=np.float64) + if X.ndim != 2: + raise ValueError(f"Expected 2D profile array; got shape={X.shape}") + return X + + +# ============================================================================= +# Normalization +# ============================================================================= + + +def normalize_batch(X: np.ndarray, method: str) -> np.ndarray: + """Row-wise normalization: minmax, zscore, or none. Constant rows -> zeros.""" + if method in ("none", "", None): + Xn = X + elif method == "minmax": + mins = np.min(X, axis=1, keepdims=True) + maxs = np.max(X, axis=1, keepdims=True) + denom = maxs - mins + denom_safe = np.where(denom == 0, 1.0, denom) + Xn = (X - mins) / denom_safe + Xn = np.where(denom == 0, 0.0, Xn) + elif method == "zscore": + means = np.mean(X, axis=1, keepdims=True) + stds = np.std(X, axis=1, keepdims=True) + std_safe = np.where(stds == 0, 1.0, stds) + Xn = (X - means) / std_safe + Xn = np.where(stds == 0, 0.0, Xn) + else: + raise ValueError(f"Unknown normalize method: {method}") + + if not np.isfinite(Xn).all(): + Xn = np.nan_to_num(Xn, nan=0.0, posinf=0.0, neginf=0.0) + return Xn + + +# ============================================================================= +# Clustering primitives +# ============================================================================= + + +def fit_minibatch_kmeans( + input_path: Path, + k: int, + batch_size: int, + n_init: int, + random_state: int, + normalize: bool, + normalize_method: str, +) -> MiniBatchKMeans: + """Fit MiniBatchKMeans by streaming over batches and calling partial_fit().""" + logger.info("Fitting MiniBatchKMeans (k=%d, batch_size=%s, n_init=%d)...", k, f"{batch_size:,}", n_init) + + model = MiniBatchKMeans( + n_clusters=k, + batch_size=batch_size, + n_init=n_init, + random_state=random_state, + verbose=0, + ) + + total_rows = parquet_num_rows(input_path) + n_batches = (total_rows + batch_size - 1) // batch_size + logger.info(" Training on %s profiles in %d batches", f"{total_rows:,}", n_batches) + + seen = 0 + for bi, rb in enumerate(iter_profile_batches(input_path, batch_size=batch_size, columns=["profile"]), start=1): + X = recordbatch_profiles_to_numpy(rb, profile_col="profile") + if normalize and normalize_method != "none": + X = normalize_batch(X, normalize_method) + model.partial_fit(X) + seen += int(X.shape[0]) + if bi % 10 == 0 or bi == n_batches: + logger.info(" Trained batch %d/%d (seen=%s)", bi, n_batches, f"{seen:,}") + + logger.info(" Training complete. Inertia: %s", f"{float(model.inertia_):,.2f}") + return model + + +def compute_silhouette_on_sample( # noqa: C901 + model: MiniBatchKMeans, + input_path: Path, + batch_size: int, + normalize: bool, + normalize_method: str, + sample_idx: np.ndarray, +) -> float | None: + """ + Compute silhouette on a deterministic set of global row indices, streaming through the file. + + Returns None if sample size < 2 or if all sampled points end up in one cluster. + """ + if sample_idx.size < 2: + return None + + sample_X: list[np.ndarray] = [] + sample_y: list[int] = [] + + global_row = 0 + sample_pos = 0 + for rb in iter_profile_batches(input_path, batch_size=batch_size, columns=["profile"]): + X = recordbatch_profiles_to_numpy(rb, profile_col="profile") + if normalize and normalize_method != "none": + X = normalize_batch(X, normalize_method) + + labels = model.predict(X).astype(np.int32) + + n = int(X.shape[0]) + while sample_pos < sample_idx.size and int(sample_idx[sample_pos]) < global_row: + sample_pos += 1 + + start = sample_pos + while sample_pos < sample_idx.size and int(sample_idx[sample_pos]) < global_row + n: + sample_pos += 1 + + if sample_pos > start: + idx_in_batch = sample_idx[start:sample_pos] - global_row + for j in idx_in_batch: + jj = int(j) + sample_X.append(X[jj]) + sample_y.append(int(labels[jj])) + + global_row += n + if sample_pos >= sample_idx.size: + break + + if len(sample_y) < 2: + return None + + ys = np.asarray(sample_y, dtype=np.int32) + if np.unique(ys).size < 2: + return None + + Xs = np.asarray(sample_X, dtype=np.float64) + logger.info("Computing silhouette on sample_size=%s ...", f"{len(ys):,}") + return float(silhouette_score(Xs, ys, metric="euclidean")) + + +def predict_and_write_assignments_streaming( + model: MiniBatchKMeans, + input_path: Path, + output_path: Path, + batch_size: int, + normalize: bool, + normalize_method: str, +) -> np.ndarray: + """Predict labels in batches and write output Parquet incrementally (input columns + `cluster`).""" + logger.info("Predicting labels + writing assignments streaming: %s", output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + pf = pq.ParquetFile(input_path) + out_schema = pf.schema_arrow.append(pa.field("cluster", pa.int32())) + writer = pq.ParquetWriter(output_path, out_schema, compression="zstd") + + k = int(model.n_clusters) + counts = np.zeros(k, dtype=np.int64) + + try: + for rb in iter_profile_batches(input_path, batch_size=batch_size, columns=None): + X = recordbatch_profiles_to_numpy(rb, profile_col="profile") + if normalize and normalize_method != "none": + X = normalize_batch(X, normalize_method) + + labels = model.predict(X).astype(np.int32) + counts += np.bincount(labels, minlength=k) + + out_rb = rb.append_column("cluster", pa.array(labels, type=pa.int32())) + writer.write_batch(out_rb) + finally: + writer.close() + + total = int(counts.sum()) + logger.info(" Cluster distribution (total=%s):", f"{total:,}") + for c, n in enumerate(counts.tolist()): + pct = (n / total * 100.0) if total > 0 else 0.0 + logger.info(" Cluster %d: %s profiles (%.1f%%)", c, f"{n:,}", pct) + + return counts + + +# ============================================================================= +# Plotting + outputs +# ============================================================================= + + +def plot_centroids(centroids: np.ndarray, output_path: Path) -> None: + """Save a line plot of cluster centroids.""" + k = int(centroids.shape[0]) + n_timepoints = int(centroids.shape[1]) + + if n_timepoints == 48: + x = np.arange(0.5, 24.5, 0.5) + xlabel = "Hour of Day" + else: + x = np.arange(n_timepoints) + xlabel = "Time Interval" + + fig, ax = plt.subplots(figsize=(12, 6)) + for i in range(k): + ax.plot(x, centroids[i], label=f"Cluster {i}", linewidth=2) + + ax.set_xlabel(xlabel, fontsize=12) + ax.set_ylabel("Usage (normalized)" if n_timepoints else "Usage", fontsize=12) + ax.set_title("Cluster Centroids (MiniBatch K-Means)", fontsize=14) + ax.legend() + ax.grid(True, alpha=0.3) + + if n_timepoints == 48: + ax.set_xticks(range(0, 25, 4)) + ax.set_xlim(0, 24) + + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.tight_layout() + plt.savefig(output_path, dpi=150) + plt.close() + logger.info(" Saved centroids plot: %s", output_path) + + +def save_centroids_parquet(centroids: np.ndarray, output_path: Path) -> None: + """Write centroids to Parquet as (cluster, centroid[list[float]]).""" + centroids_df = pl.DataFrame({ + "cluster": list(range(int(centroids.shape[0]))), + "centroid": [c.tolist() for c in centroids], + }) + centroids_df.write_parquet(output_path) + logger.info(" Saved centroids parquet: %s", output_path) + + +def save_metadata(metadata: dict[str, Any], output_path: Path) -> None: + """Write run metadata and summary metrics to JSON.""" + with open(output_path, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2, sort_keys=True) + f.write("\n") + logger.info(" Saved metadata: %s", output_path) + + +# ============================================================================= +# k-range evaluation +# ============================================================================= + + +@dataclass(frozen=True) +class KEvalResult: + k: int + inertia: float + silhouette_score_sample: float | None + n_profiles: int + normalized: bool + normalize_method: str + batch_size: int + n_init: int + random_state: int + + +def choose_best_k(results: list[KEvalResult]) -> int: + """ + Choose best k. Primary criterion: max silhouette_score_sample when available. + Fallback: min inertia. + """ + with_sil = [r for r in results if r.silhouette_score_sample is not None] + if with_sil: + # Tie-breakers: higher silhouette, then higher k (to be deterministic). + with_sil_sorted = sorted(with_sil, key=lambda r: (r.silhouette_score_sample, r.k), reverse=True) + return int(with_sil_sorted[0].k) + + # If silhouette unavailable for all, pick minimum inertia (still deterministic). + by_inertia = sorted(results, key=lambda r: (r.inertia, r.k)) + return int(by_inertia[0].k) + + +def make_silhouette_sample_idx(n_profiles: int, sample_size: int, seed: int) -> np.ndarray: + """Deterministic global row index sample.""" + if sample_size <= 0 or n_profiles <= 0: + return np.array([], dtype=np.int64) + if sample_size >= n_profiles: + return np.arange(n_profiles, dtype=np.int64) + rng = np.random.default_rng(int(seed)) + idx = rng.choice(n_profiles, size=int(sample_size), replace=False) + idx.sort() + return idx.astype(np.int64) + + +# ============================================================================= +# Main +# ============================================================================= + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + description="MiniBatch K-Means Clustering (k-range, PyArrow-batched)", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("--input", type=Path, required=True, help="Path to sampled_profiles.parquet") + p.add_argument("--output-dir", type=Path, required=True, help="Output directory (run_dir/clustering)") + + # Orchestrator-required interface + p.add_argument("--k-min", type=int, default=3, help="Minimum k (inclusive)") + p.add_argument("--k-max", type=int, default=6, help="Maximum k (inclusive)") + p.add_argument("--seed", type=int, default=42, help="Random seed (model + sampling)") + + # Tunables + p.add_argument("--batch-size", type=int, default=50_000, help="Batch size (default: 50k)") + p.add_argument("--n-init", type=int, default=3, help="MiniBatchKMeans n_init (default: 3)") + p.add_argument("--normalize", action="store_true", help="Normalize profiles per row") + p.add_argument("--normalize-method", choices=["minmax", "zscore", "none"], default="minmax") + p.add_argument( + "--silhouette-sample-size", + type=int, + default=5_000, + help="Sample size for silhouette (default: 5000; set 0 to skip)", + ) + + return p + + +def main() -> int: + args = build_parser().parse_args() + + if args.k_min <= 1 or args.k_max <= 1: + raise ValueError("k-min and k-max must be >= 2") + if args.k_min > args.k_max: + raise ValueError(f"Invalid k range: k-min ({args.k_min}) > k-max ({args.k_max})") + if args.batch_size <= 0: + raise ValueError("--batch-size must be > 0") + + args.output_dir.mkdir(parents=True, exist_ok=True) + + logger.info("=" * 70) + logger.info("MINIBATCH K-MEANS CLUSTERING (K-RANGE, PYARROW-BATCHED)") + logger.info("=" * 70) + logger.info("Input: %s", args.input) + logger.info("Output: %s", args.output_dir) + logger.info("k range: %d..%d", args.k_min, args.k_max) + logger.info("batch_size: %s", f"{args.batch_size:,}") + logger.info("n_init: %d", args.n_init) + logger.info("seed: %d", args.seed) + + eff_norm = bool(args.normalize and args.normalize_method != "none") + logger.info("Normalize: %s (method=%s)", eff_norm, args.normalize_method if eff_norm else "none") + + n_profiles = parquet_num_rows(args.input) + logger.info("Profiles (from parquet metadata): %s", f"{n_profiles:,}") + + sample_idx = make_silhouette_sample_idx(int(n_profiles), int(args.silhouette_sample_size), int(args.seed)) + if sample_idx.size > 0: + logger.info("Silhouette sample size: %s", f"{sample_idx.size:,}") + else: + logger.info("Silhouette: disabled") + + results: list[KEvalResult] = [] + + # Evaluate each k + for k in range(int(args.k_min), int(args.k_max) + 1): + logger.info("-" * 70) + logger.info("EVALUATING k=%d", k) + + model = fit_minibatch_kmeans( + input_path=args.input, + k=int(k), + batch_size=int(args.batch_size), + n_init=int(args.n_init), + random_state=int(args.seed), + normalize=eff_norm, + normalize_method=str(args.normalize_method), + ) + + sil = None + if sample_idx.size > 0: + sil = compute_silhouette_on_sample( + model=model, + input_path=args.input, + batch_size=int(args.batch_size), + normalize=eff_norm, + normalize_method=str(args.normalize_method), + sample_idx=sample_idx, + ) + if sil is not None: + logger.info("k=%d silhouette(sample)=%.3f", k, sil) + else: + logger.info("k=%d silhouette(sample)=None (insufficient clusters or sample)", k) + + res = KEvalResult( + k=int(k), + inertia=float(model.inertia_), + silhouette_score_sample=float(sil) if sil is not None else None, + n_profiles=int(n_profiles), + normalized=eff_norm, + normalize_method=str(args.normalize_method) if eff_norm else "none", + batch_size=int(args.batch_size), + n_init=int(args.n_init), + random_state=int(args.seed), + ) + results.append(res) + + # Choose best k + best_k = choose_best_k(results) + logger.info("=" * 70) + logger.info("SELECTED k=%d", best_k) + logger.info("=" * 70) + + # Save k evaluation summary + k_eval_path = args.output_dir / "k_evaluation.json" + k_eval_payload: dict[str, Any] = { + "k_min": int(args.k_min), + "k_max": int(args.k_max), + "selected_k": int(best_k), + "selection_rule": "max silhouette (sample) if available else min inertia", + "results": [asdict(r) for r in results], + } + save_metadata(k_eval_payload, k_eval_path) + + # Refit best model (deterministic given same seed + same stream order) + best_model = fit_minibatch_kmeans( + input_path=args.input, + k=int(best_k), + batch_size=int(args.batch_size), + n_init=int(args.n_init), + random_state=int(args.seed), + normalize=eff_norm, + normalize_method=str(args.normalize_method), + ) + + # Write assignments + assignments_path = args.output_dir / "cluster_assignments.parquet" + counts = predict_and_write_assignments_streaming( + model=best_model, + input_path=args.input, + output_path=assignments_path, + batch_size=int(args.batch_size), + normalize=eff_norm, + normalize_method=str(args.normalize_method), + ) + + # Centroids + plot + metadata + centroids = best_model.cluster_centers_ + plot_centroids(centroids, args.output_dir / "cluster_centroids.png") + save_centroids_parquet(centroids, args.output_dir / "cluster_centroids.parquet") + + best_res = next(r for r in results if r.k == best_k) + metadata = { + "k_selected": int(best_k), + "n_profiles": int(n_profiles), + "n_timepoints": int(centroids.shape[1]), + "normalized": bool(eff_norm), + "normalize_method": str(args.normalize_method) if eff_norm else "none", + "batch_size": int(args.batch_size), + "n_init": int(args.n_init), + "seed": int(args.seed), + "algorithm": "MiniBatchKMeans", + "inertia": float(best_model.inertia_), + "silhouette_score_sample": best_res.silhouette_score_sample, + "cluster_counts": {str(i): int(c) for i, c in enumerate(counts.tolist())}, + "artifacts": { + "assignments": str(assignments_path), + "centroids_parquet": str(args.output_dir / "cluster_centroids.parquet"), + "centroids_plot": str(args.output_dir / "cluster_centroids.png"), + "k_evaluation": str(k_eval_path), + }, + } + save_metadata(metadata, args.output_dir / "clustering_metadata.json") + + logger.info("=" * 70) + logger.info("CLUSTERING COMPLETE") + logger.info("=" * 70) + logger.info("Selected k: %d", best_k) + if best_res.silhouette_score_sample is not None: + logger.info("Silhouette (sample): %.3f", float(best_res.silhouette_score_sample)) + logger.info("Output: %s", args.output_dir) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/analysis/clustering/euclidean_clustering_minibatch.py b/analysis/clustering/euclidean_clustering_minibatch.py index b036e4b..49a690d 100644 --- a/analysis/clustering/euclidean_clustering_minibatch.py +++ b/analysis/clustering/euclidean_clustering_minibatch.py @@ -4,9 +4,9 @@ Reads an input Parquet file in streaming batches via PyArrow, optionally normalizes each profile row on-the-fly, fits MiniBatchKMeans with partial_fit, then predicts and -writes cluster assignments back to Parquet incrementally (original columns + `cluster`). +writes cluster assignments back to Parquet incrementally (selected columns + `cluster`). -Outputs: +Outputs (written to --output-dir): - cluster_assignments.parquet - cluster_centroids.parquet - cluster_centroids.png @@ -33,9 +33,9 @@ logger = logging.getLogger(__name__) -# ---------------------------- +# ============================================================================= # IO + batch utilities -# ---------------------------- +# ============================================================================= def parquet_num_rows(path: Path) -> int: @@ -43,15 +43,15 @@ def parquet_num_rows(path: Path) -> int: pf = pq.ParquetFile(path) md = pf.metadata if md is None: - return sum(pf.metadata.row_group(i).num_rows for i in range(pf.num_row_groups)) + total = 0 + for i in range(pf.num_row_groups): + rg = pf.metadata.row_group(i) # type: ignore[union-attr] + total += int(rg.num_rows) + return total return int(md.num_rows) -def iter_profile_batches( - path: Path, - batch_size: int, - columns: list[str] | None = None, -) -> Iterable[pa.RecordBatch]: +def iter_profile_batches(path: Path, batch_size: int, columns: list[str] | None) -> Iterable[pa.RecordBatch]: """Yield record batches from a Parquet file.""" pf = pq.ParquetFile(path) yield from pf.iter_batches(batch_size=batch_size, columns=columns) @@ -69,9 +69,33 @@ def recordbatch_profiles_to_numpy(rb: pa.RecordBatch, profile_col: str = "profil return X -# ---------------------------- +def parse_output_columns(spec: str) -> list[str] | None: + """ + Parse --output-columns. + + Returns: + - None if spec is empty (meaning "all columns"). + - Otherwise a de-duplicated list of column names, in order. + """ + s = (spec or "").strip() + if not s: + return None + + cols: list[str] = [] + seen: set[str] = set() + for part in s.split(","): + c = part.strip() + if not c: + continue + if c not in seen: + cols.append(c) + seen.add(c) + return cols + + +# ============================================================================= # Normalization -# ---------------------------- +# ============================================================================= def normalize_batch(X: np.ndarray, method: str) -> np.ndarray: @@ -99,9 +123,9 @@ def normalize_batch(X: np.ndarray, method: str) -> np.ndarray: return Xn -# ---------------------------- +# ============================================================================= # Clustering -# ---------------------------- +# ============================================================================= def fit_minibatch_kmeans( @@ -134,7 +158,7 @@ def fit_minibatch_kmeans( if normalize and normalize_method != "none": X = normalize_batch(X, normalize_method) model.partial_fit(X) - seen += X.shape[0] + seen += int(X.shape[0]) if bi % 10 == 0 or bi == n_batches: logger.info(" Trained batch %d/%d (seen=%s)", bi, n_batches, f"{seen:,}") @@ -142,53 +166,74 @@ def fit_minibatch_kmeans( return model -# ruff: noqa: C901 -def predict_and_write_assignments_streaming( +def make_silhouette_sample_idx(n_profiles: int, sample_size: int, seed: int) -> np.ndarray: + """Deterministic global row index sample (sorted).""" + if sample_size <= 0 or n_profiles <= 0: + return np.array([], dtype=np.int64) + if sample_size >= n_profiles: + return np.arange(n_profiles, dtype=np.int64) + rng = np.random.default_rng(int(seed)) + idx = rng.choice(n_profiles, size=int(sample_size), replace=False) + idx.sort() + return idx.astype(np.int64) + + +def predict_write_and_optional_silhouette( # noqa: C901 model: MiniBatchKMeans, input_path: Path, output_path: Path, batch_size: int, normalize: bool, normalize_method: str, + read_columns: list[str] | None, + write_columns: list[str] | None, silhouette_sample_idx: np.ndarray, -) -> tuple[np.ndarray, float]: +) -> tuple[np.ndarray, float | None]: """ - Predict labels in batches, write output Parquet incrementally, and optionally - compute silhouette on a sampled set of global row indices. + Predict labels in batches, write output Parquet incrementally, and optionally compute + silhouette on a sampled set of global row indices. + + Notes: + - `read_columns` controls what we read from input (must include `profile`). + - `write_columns` controls what we keep in output (may exclude `profile`). + If None, we write all read columns (including `profile`). """ - logger.info("Predicting labels + writing assignments streaming...") + logger.info("Predicting labels + writing assignments streaming: %s", output_path) output_path.parent.mkdir(parents=True, exist_ok=True) - pf = pq.ParquetFile(input_path) - out_schema = pf.schema_arrow.append(pa.field("cluster", pa.int32())) - writer = pq.ParquetWriter(output_path, out_schema, compression="zstd") - k = int(model.n_clusters) counts = np.zeros(k, dtype=np.int64) - sample_X: list[np.ndarray] = None - sample_y: list[int] = None + use_sil = silhouette_sample_idx.size > 0 + sample_X: list[np.ndarray] = [] + sample_y: list[int] = [] sample_pos = 0 - if silhouette_sample_idx is not None and len(silhouette_sample_idx) > 0: - sample_X, sample_y = [], [] - global_row = 0 + + writer: pq.ParquetWriter | None = None + out_schema: pa.Schema | None = None + try: - for rb in iter_profile_batches(input_path, batch_size=batch_size, columns=None): + for rb in iter_profile_batches(input_path, batch_size=batch_size, columns=read_columns): + # Compute labels (requires profile) X = recordbatch_profiles_to_numpy(rb, profile_col="profile") if normalize and normalize_method != "none": X = normalize_batch(X, normalize_method) - labels = model.predict(X).astype(np.int32) counts += np.bincount(labels, minlength=k) - if sample_X is not None and sample_y is not None: - n = X.shape[0] - while sample_pos < len(silhouette_sample_idx) and silhouette_sample_idx[sample_pos] < global_row: + # Silhouette sampling (uses labels on this pass; avoids a second scan) + if use_sil: + n = int(X.shape[0]) + + while sample_pos < silhouette_sample_idx.size and int(silhouette_sample_idx[sample_pos]) < global_row: sample_pos += 1 start = sample_pos - while sample_pos < len(silhouette_sample_idx) and silhouette_sample_idx[sample_pos] < global_row + n: + while ( + sample_pos < silhouette_sample_idx.size and int(silhouette_sample_idx[sample_pos]) < global_row + n + ): sample_pos += 1 + if sample_pos > start: idx_in_batch = silhouette_sample_idx[start:sample_pos] - global_row for j in idx_in_batch: @@ -196,32 +241,57 @@ def predict_and_write_assignments_streaming( sample_X.append(X[jj]) sample_y.append(int(labels[jj])) - out_rb = rb.append_column("cluster", pa.array(labels, type=pa.int32())) + # Build output batch: select write columns (or all columns), then append cluster + if write_columns is None: + out_rb = rb + else: + indices: list[int] = [] + for name in write_columns: + i = rb.schema.get_field_index(name) + if i < 0: + raise ValueError(f"Input batch missing requested output column '{name}'") + indices.append(i) + out_rb = rb.select(indices) + + out_rb = out_rb.append_column("cluster", pa.array(labels, type=pa.int32())) + + # Initialize writer lazily from first out_rb schema + if writer is None: + out_schema = out_rb.schema + writer = pq.ParquetWriter(output_path, out_schema, compression="zstd") writer.write_batch(out_rb) - global_row += labels.shape[0] + + global_row += int(labels.shape[0]) + finally: - writer.close() + if writer is not None: + writer.close() - logger.info(" Cluster distribution:") total = int(counts.sum()) + logger.info(" Cluster distribution (total=%s):", f"{total:,}") for c, n in enumerate(counts.tolist()): pct = (n / total * 100.0) if total > 0 else 0.0 logger.info(" Cluster %d: %s profiles (%.1f%%)", c, f"{n:,}", pct) - sil = None - if sample_X is not None and sample_y is not None and len(sample_y) >= 2: - Xs = np.asarray(sample_X, dtype=np.float64) + sil: float | None = None + if use_sil and len(sample_y) >= 2: ys = np.asarray(sample_y, dtype=np.int32) - logger.info("Computing silhouette on sample_size=%s ...", f"{len(ys):,}") - sil = float(silhouette_score(Xs, ys, metric="euclidean")) - logger.info(" Silhouette score (sample): %.3f", sil) + if np.unique(ys).size >= 2: + Xs = np.asarray(sample_X, dtype=np.float64) + logger.info("Computing silhouette on sample_size=%s ...", f"{len(ys):,}") + sil = float(silhouette_score(Xs, ys, metric="euclidean")) + logger.info(" Silhouette score (sample): %.3f", sil) + else: + logger.info("Silhouette: skipped (sample fell into a single cluster)") + elif use_sil: + logger.info("Silhouette: skipped (insufficient sample)") return counts, sil -# ---------------------------- +# ============================================================================= # Plotting + outputs -# ---------------------------- +# ============================================================================= def plot_centroids(centroids: np.ndarray, output_path: Path) -> None: @@ -267,112 +337,171 @@ def save_centroids_parquet(centroids: np.ndarray, output_path: Path) -> None: logger.info(" Saved centroids parquet: %s", output_path) -def save_metadata(metadata: dict, output_path: Path) -> None: +def save_metadata(metadata: dict[str, object], output_path: Path) -> None: """Write run metadata and summary metrics to JSON.""" with open(output_path, "w", encoding="utf-8") as f: - json.dump(metadata, f, indent=2) + json.dump(metadata, f, indent=2, sort_keys=True) + f.write("\n") logger.info(" Saved metadata: %s", output_path) -# ---------------------------- +# ============================================================================= # Main -# ---------------------------- +# ============================================================================= -def main() -> int: - """CLI entrypoint.""" - parser = argparse.ArgumentParser( +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( description="Memory-Efficient K-Means Clustering (MiniBatch, PyArrow-batched)", formatter_class=argparse.RawDescriptionHelpFormatter, ) - parser.add_argument("--input", type=Path, required=True, help="sampled_profiles.parquet") - parser.add_argument("--output-dir", type=Path, required=True, help="Output directory") - parser.add_argument("--k", type=int, required=True, help="Number of clusters") - parser.add_argument("--normalize", action="store_true", help="Normalize profiles per row") - parser.add_argument("--normalize-method", choices=["minmax", "zscore", "none"], default="minmax") - parser.add_argument("--batch-size", type=int, default=50_000, help="Batch size (default: 50k)") - parser.add_argument("--n-init", type=int, default=3, help="Number of initializations (default: 3)") - parser.add_argument( + p.add_argument("--input", type=Path, required=True, help="Path to sampled_profiles.parquet") + p.add_argument("--output-dir", type=Path, required=True, help="Output directory") + p.add_argument("--k", type=int, required=True, help="Number of clusters") + + p.add_argument("--normalize", action="store_true", help="Normalize profiles per row") + p.add_argument("--normalize-method", choices=["minmax", "zscore", "none"], default="minmax") + + p.add_argument("--batch-size", type=int, default=50_000, help="Batch size (default: 50k)") + p.add_argument("--n-init", type=int, default=3, help="MiniBatchKMeans n_init (default: 3)") + p.add_argument("--seed", type=int, default=42, help="Random seed (model + sampling)") + + p.add_argument( "--silhouette-sample-size", type=int, default=5_000, help="Sample size for silhouette (default: 5000; set 0 to skip)", ) - parser.add_argument("--random-state", type=int, default=42, help="Random seed") - args = parser.parse_args() + p.add_argument( + "--output-columns", + type=str, + default="", + help=( + "Comma-separated columns to carry through to cluster_assignments.parquet " + "(default: all input columns). `cluster` is always added." + ), + ) + + return p + + +def main() -> int: + args = build_parser().parse_args() args.output_dir.mkdir(parents=True, exist_ok=True) + if args.k <= 1: + raise ValueError("--k must be >= 2") + if args.batch_size <= 0: + raise ValueError("--batch-size must be > 0") + if args.n_init <= 0: + raise ValueError("--n-init must be > 0") + logger.info("=" * 70) logger.info("MINIBATCH K-MEANS CLUSTERING (PYARROW-BATCHED)") logger.info("=" * 70) logger.info("Input: %s", args.input) + logger.info("Output: %s", args.output_dir) logger.info("k: %d", args.k) - logger.info("Batch size: %s", f"{args.batch_size:,}") + logger.info("batch_size: %s", f"{args.batch_size:,}") + logger.info("n_init: %d", args.n_init) + logger.info("seed: %d", args.seed) - eff_norm = args.normalize and args.normalize_method != "none" + eff_norm = bool(args.normalize and args.normalize_method != "none") logger.info("Normalize: %s (method=%s)", eff_norm, args.normalize_method if eff_norm else "none") n_profiles = parquet_num_rows(args.input) logger.info("Profiles (from parquet metadata): %s", f"{n_profiles:,}") + # Output column handling + requested_out_cols = parse_output_columns(args.output_columns) + + # Read columns must include profile for prediction; if requested_out_cols is None -> read all columns. + # If a subset was requested, read only: requested_out_cols + profile (deduped). + if requested_out_cols is None: + read_columns = None + write_columns = None # write everything we read (which is everything) + logger.info("Assignments output columns: ALL (default)") + else: + read_columns = [] + seen = set() + for c in [*requested_out_cols, "profile"]: + if c not in seen: + read_columns.append(c) + seen.add(c) + write_columns = requested_out_cols + logger.info("Assignments output columns: %s (+cluster)", ",".join(write_columns)) + + # Fit model = fit_minibatch_kmeans( input_path=args.input, - k=args.k, - batch_size=args.batch_size, - n_init=args.n_init, - random_state=args.random_state, - normalize=bool(eff_norm), - normalize_method=args.normalize_method, + k=int(args.k), + batch_size=int(args.batch_size), + n_init=int(args.n_init), + random_state=int(args.seed), + normalize=eff_norm, + normalize_method=str(args.normalize_method), ) centroids = model.cluster_centers_ - silhouette_sample_idx = None - if args.silhouette_sample_size and args.silhouette_sample_size > 0: - if args.silhouette_sample_size >= n_profiles: - silhouette_sample_idx = np.arange(n_profiles, dtype=np.int64) - else: - rng = np.random.default_rng(args.random_state) - silhouette_sample_idx = rng.choice(n_profiles, size=args.silhouette_sample_size, replace=False) - silhouette_sample_idx.sort() + # Deterministic silhouette sample indices (optional) + sample_idx = make_silhouette_sample_idx(int(n_profiles), int(args.silhouette_sample_size), int(args.seed)) + if sample_idx.size > 0: + logger.info("Silhouette sample size: %s", f"{sample_idx.size:,}") + else: + logger.info("Silhouette: disabled") + # Predict + write assignments (and optional silhouette) assignments_path = args.output_dir / "cluster_assignments.parquet" - counts, sil_score = predict_and_write_assignments_streaming( + counts, sil_score = predict_write_and_optional_silhouette( model=model, input_path=args.input, output_path=assignments_path, - batch_size=args.batch_size, - normalize=bool(eff_norm), - normalize_method=args.normalize_method, - silhouette_sample_idx=silhouette_sample_idx, + batch_size=int(args.batch_size), + normalize=eff_norm, + normalize_method=str(args.normalize_method), + read_columns=read_columns, + write_columns=write_columns, + silhouette_sample_idx=sample_idx, ) - plot_centroids(centroids, args.output_dir / "cluster_centroids.png") - save_centroids_parquet(centroids, args.output_dir / "cluster_centroids.parquet") + # Save artifacts + centroids_plot_path = args.output_dir / "cluster_centroids.png" + centroids_parquet_path = args.output_dir / "cluster_centroids.parquet" + metadata_path = args.output_dir / "clustering_metadata.json" + + plot_centroids(centroids, centroids_plot_path) + save_centroids_parquet(centroids, centroids_parquet_path) - metadata = { + metadata: dict[str, object] = { "k": int(args.k), "n_profiles": int(n_profiles), "n_timepoints": int(centroids.shape[1]), "normalized": bool(eff_norm), - "normalize_method": args.normalize_method if eff_norm else "none", + "normalize_method": str(args.normalize_method) if eff_norm else "none", "batch_size": int(args.batch_size), "n_init": int(args.n_init), - "random_state": int(args.random_state), + "seed": int(args.seed), "algorithm": "MiniBatchKMeans", "inertia": float(model.inertia_), "silhouette_score_sample": float(sil_score) if sil_score is not None else None, "cluster_counts": {str(i): int(c) for i, c in enumerate(counts.tolist())}, + "assignments_output_columns": write_columns if write_columns is not None else "ALL", + "artifacts": { + "assignments": str(assignments_path), + "centroids_parquet": str(centroids_parquet_path), + "centroids_plot": str(centroids_plot_path), + }, } - save_metadata(metadata, args.output_dir / "clustering_metadata.json") + save_metadata(metadata, metadata_path) logger.info("=" * 70) logger.info("CLUSTERING COMPLETE") logger.info("=" * 70) logger.info("Profiles: %s", f"{n_profiles:,}") - logger.info("Clusters: %d", args.k) + logger.info("Clusters: %d", int(args.k)) if sil_score is not None: - logger.info("Silhouette (sample): %.3f", sil_score) + logger.info("Silhouette (sample): %.3f", float(sil_score)) logger.info("Output: %s", args.output_dir) return 0 diff --git a/analysis/clustering/stage2_logratio_regression.py b/analysis/clustering/stage2_logratio_regression.py deleted file mode 100644 index 5319b1c..0000000 --- a/analysis/clustering/stage2_logratio_regression.py +++ /dev/null @@ -1,1049 +0,0 @@ -#!/usr/bin/env python3 -""" -Stage 2: Block-Group-Level Log-Ratio Regression of Cluster Composition (HOUSEHOLD-DAY UNITS) - -Goal ------ -Model how Census block-group demographics are associated with the *composition* -of household-day observations across load-profile clusters, using log-ratio -regression (ALR / additive log-ratio). - -Unit of Analysis ----------------- -One row per Census BLOCK GROUP (not one row per block_group x cluster). - -Data Flow ---------- -1. Load household-day cluster assignments from Stage 1 (one row per household-day) -2. Join to Census block groups via ZIP+4 → block group crosswalk (1-to-1 enforced) -3. Aggregate to block-group-level cluster composition (wide format) -4. Join block groups to Census demographics -5. Create smoothed proportions and log-ratios vs a baseline cluster: - y_k = log(p_k / p_base) -6. Fit separate WLS regressions for each non-baseline cluster: - y_k ~ demographics with weights = total_obs (household-day count) -7. Fit OLS models (robustness check, unweighted) - -Outputs -------- -- regression_data_blockgroups_wide.parquet -- regression_results_logratio_blockgroups.json -- statsmodels_summaries_wls.txt -- statsmodels_summaries_ols.txt -- regression_report_logratio_blockgroups.txt - -Usage ------ - python stage2_logratio_regression.py \ - --clusters data/clustering/results/cluster_assignments.parquet \ - --crosswalk data/reference/2023_comed_zip4_census_crosswalk.txt \ - --census-cache data/reference/census_17_2023.parquet \ - --output-dir data/clustering/results/stage2_blockgroups_logratio \ - --baseline-cluster 1 -""" - -from __future__ import annotations - -import argparse -import json -import logging -import sys -from pathlib import Path - -import numpy as np -import polars as pl -import statsmodels.api as sm -from sklearn.preprocessing import StandardScaler - -from smart_meter_analysis.census import fetch_census_data -from smart_meter_analysis.census_specs import STAGE2_PREDICTORS_47 -from smart_meter_analysis.run_manifest import write_stage2_manifest - -logger = logging.getLogger(__name__) -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", -) - - -def load_cluster_assignments_household_day(path: Path) -> tuple[pl.DataFrame, dict]: - """ - Load household-day cluster assignments. - - Returns the raw Stage 1 output: one row per (household, day) with cluster label. - - Returns - ------- - df : pl.DataFrame - One row per household-day with columns: - - account_identifier - - zip_code - - date (if present) - - cluster - - dominance_stats : dict - Summary statistics on how consistently households stay in one cluster - (for reporting/interpretation, not used in regression) - """ - logger.info("Loading cluster assignments from %s", path) - - # For large files, compute stats without loading everything - logger.info(" Computing statistics in streaming mode...") - lf = pl.scan_parquet(path) - - # Quick validation - schema = lf.collect_schema() - required = ["account_identifier", "zip_code", "cluster"] - missing = [c for c in required if c not in schema.names()] - if missing: - raise ValueError(f"cluster_assignments missing required columns: {missing}") - - # Compute basic stats without loading full data - stats_df = lf.select([ - pl.len().alias("n_household_days"), - pl.col("account_identifier").n_unique().alias("n_households"), - pl.col("cluster").n_unique().alias("n_clusters"), - ]).collect() - - stats = stats_df.to_dicts()[0] - - logger.info( - " Found: %s household-day observations, %s households, %s clusters", - f"{stats['n_household_days']:,}", - f"{stats['n_households']:,}", - stats["n_clusters"], - ) - - # Compute dominance stats with streaming - logger.info(" Computing dominance statistics...") - dominance_stats = _compute_dominance_stats_streaming(path) - - logger.info( - " Dominance stats: mean=%.1f%%, median=%.1f%%, >50%%: %.1f%% of households", - dominance_stats["dominance_mean"] * 100, - dominance_stats["dominance_median"] * 100, - dominance_stats["pct_above_50"], - ) - - # Now load only what we need for regression: just the columns - logger.info(" Loading data for regression (selecting needed columns only)...") - raw = lf.select(required).collect(streaming=True) - - return raw, dominance_stats - - -def _compute_dominance_stats_streaming(path: Path) -> dict: - """ - Compute dominance stats using streaming aggregation to avoid OOM. - """ - lf = pl.scan_parquet(path) - - # Compute per-household cluster counts - counts = ( - lf.group_by(["account_identifier", "cluster"]).agg(pl.len().alias("days_in_cluster")).collect(streaming=True) - ) - - # Get max days per household - max_days = counts.group_by("account_identifier").agg(pl.col("days_in_cluster").max().alias("max_days_in_cluster")) - - # Get total days per household - totals = counts.group_by("account_identifier").agg(pl.col("days_in_cluster").sum().alias("n_days")) - - # Join and compute dominance - dominance = max_days.join(totals, on="account_identifier") - dominance = dominance.with_columns((pl.col("max_days_in_cluster") / pl.col("n_days")).alias("dominance")) - - dom_series = dominance["dominance"] - - return { - "n_households": int(dominance.height), - "dominance_mean": float(dom_series.mean()), - "dominance_median": float(dom_series.median()), - "dominance_std": float(dom_series.std()), - "dominance_min": float(dom_series.min()), - "dominance_max": float(dom_series.max()), - "pct_above_50": float((dom_series > 0.5).mean() * 100), - "pct_above_67": float((dom_series > 0.67).mean() * 100), - "pct_above_80": float((dom_series > 0.80).mean() * 100), - } - - -def load_crosswalk_one_to_one(crosswalk_path: Path, zip_codes: list[str]) -> pl.DataFrame: - """ - Load ZIP+4 → Census block-group crosswalk with deterministic 1-to-1 mapping. - - When fan-out exists (ZIP+4 maps to multiple block groups), - choose smallest GEOID per ZIP+4 to avoid double-counting household-day observations. - - This is the only valid approach when crosswalk weights are unavailable. - """ - logger.info("Loading crosswalk from %s", crosswalk_path) - - crosswalk = ( - pl.scan_csv(crosswalk_path, separator="\t", infer_schema_length=10000) - .with_columns([ - (pl.col("Zip").cast(pl.Utf8).str.zfill(5) + "-" + pl.col("Zip4").cast(pl.Utf8).str.zfill(4)).alias( - "zip_code" - ), - pl.col("CensusKey2023").cast(pl.Utf8).str.zfill(15).str.slice(0, 12).alias("block_group_geoid"), - ]) - .filter(pl.col("zip_code").is_in(zip_codes)) - .select(["zip_code", "block_group_geoid"]) - .collect() - ) - - logger.info( - " Matched %s of %s ZIP+4 codes", - f"{crosswalk['zip_code'].n_unique():,}", - f"{len(set(zip_codes)):,}", - ) - - if crosswalk.is_empty(): - logger.warning(" Crosswalk is empty after filtering for sample ZIP+4s.") - return crosswalk - - # Check for fan-out - fanout = crosswalk.group_by("zip_code").agg(pl.n_unique("block_group_geoid").alias("n_block_groups")) - max_fanout = int(fanout["n_block_groups"].max()) - - if max_fanout > 1: - n_fanout = fanout.filter(pl.col("n_block_groups") > 1).height - pct_fanout = (n_fanout / len(fanout)) * 100 - - logger.warning( - " ZIP+4 fan-out detected: %s ZIP+4s (%.1f%%) map to multiple block groups (max=%d per ZIP+4)", - f"{n_fanout:,}", - pct_fanout, - max_fanout, - ) - logger.warning(" Applying deterministic 1-to-1 mapping: selecting smallest GEOID per ZIP+4") - logger.warning(" This prevents double-counting household-day observations") - - # Deterministic resolution: smallest GEOID per ZIP+4 - crosswalk = ( - crosswalk.sort(["zip_code", "block_group_geoid"]) - .group_by("zip_code") - .agg(pl.col("block_group_geoid").first()) - ) - - logger.info( - " After 1-to-1 resolution: %s ZIP+4 codes → %s unique mappings", - f"{len(crosswalk):,}", - f"{len(crosswalk):,}", - ) - else: - logger.info(" Crosswalk is already 1-to-1: each ZIP+4 maps to exactly one block group.") - - return crosswalk - - -def attach_block_groups_to_household_days( - household_days: pl.DataFrame, - crosswalk: pl.DataFrame, -) -> pl.DataFrame: - """ - Attach block-group GEOIDs to household-day observations via ZIP+4. - - Input: one row per household-day - Output: one row per household-day with block_group_geoid attached - """ - logger.info("Joining household-day observations to block groups...") - - df = household_days.join(crosswalk, on="zip_code", how="left") - - n_before = len(df) - missing = df.filter(pl.col("block_group_geoid").is_null()).height - - if missing > 0: - pct = missing / n_before * 100 - logger.warning(" %s (%.1f%%) observations missing block_group - dropping", f"{missing:,}", pct) - df = df.filter(pl.col("block_group_geoid").is_not_null()) - - logger.info( - " %s household-day observations across %s block groups", - f"{len(df):,}", - f"{df['block_group_geoid'].n_unique():,}", - ) - - return df - - -def aggregate_blockgroup_cluster_composition(df: pl.DataFrame) -> pl.DataFrame: - """ - Aggregate household-day observations to block-group-level cluster composition (wide). - - Output: one row per block_group_geoid with: - - total_obs - - total_households - - n_cluster_ - - p_cluster_ - """ - logger.info("Aggregating to block-group cluster composition (wide; household-day units)...") - - totals = df.group_by("block_group_geoid").agg([ - pl.len().alias("total_obs"), - pl.col("account_identifier").n_unique().alias("total_households"), - ]) - - counts_long = df.group_by(["block_group_geoid", "cluster"]).agg(pl.len().alias("n_obs")) - - counts_wide = ( - counts_long.with_columns(pl.col("cluster").cast(pl.Utf8)) - .pivot( - values="n_obs", - index="block_group_geoid", - columns="cluster", - aggregate_function="first", - ) - .fill_null(0) - ) - - cluster_cols = [c for c in counts_wide.columns if c != "block_group_geoid"] - counts_wide = counts_wide.rename({c: f"n_cluster_{c}" for c in cluster_cols}) - - out = totals.join(counts_wide, on="block_group_geoid", how="left").fill_null(0) - - n_cols = [c for c in out.columns if c.startswith("n_cluster_")] - out = out.with_columns([ - (pl.col(c) / pl.col("total_obs")).alias(c.replace("n_cluster_", "p_cluster_")) for c in n_cols - ]) - - logger.info( - " Created %s block-group rows; total obs=%s; total households=%s", - f"{len(out):,}", - f"{int(out['total_obs'].sum()):,}", - f"{int(out['total_households'].sum()):,}", - ) - return out - - -def fetch_or_load_census( - cache_path: Path, - state_fips: str = "17", - acs_year: int = 2023, - force_fetch: bool = False, -) -> pl.DataFrame: - """Fetch Census data from API or load from cache.""" - if cache_path.exists() and not force_fetch: - logger.info("Loading Census data from cache: %s", cache_path) - return pl.read_parquet(cache_path) - - logger.info("Fetching Census data from API (state=%s, year=%s)...", state_fips, acs_year) - - census_df = fetch_census_data(state_fips=state_fips, acs_year=acs_year) - - cache_path.parent.mkdir(parents=True, exist_ok=True) - census_df.write_parquet(cache_path) - logger.info(" Cached Census data to %s", cache_path) - - return census_df - - -def create_derived_variables(census_df: pl.DataFrame) -> pl.DataFrame: - """Create derived percentage variables from raw Census counts.""" - logger.info("Creating derived variables...") - - df = census_df.with_columns([ - (pl.col("Owner_Occupied") / pl.col("Occupied_Housing_Units") * 100).alias("Owner_Occupied_Pct"), - (pl.col("Heat_Electric") / pl.col("Total_Households") * 100).alias("Heat_Electric_Pct"), - ( - ( - pl.col("Built_1960_1969") - + pl.col("Built_1950_1959") - + pl.col("Built_1940_1949") - + pl.col("Built_1939_Earlier") - ) - / pl.col("Total_Housing_Units") - * 100 - ).alias("Old_Building_Pct"), - ]) - - df = df.with_columns([ - pl.when(pl.col("Owner_Occupied_Pct").is_nan()) - .then(None) - .otherwise(pl.col("Owner_Occupied_Pct")) - .alias("Owner_Occupied_Pct"), - pl.when(pl.col("Heat_Electric_Pct").is_nan()) - .then(None) - .otherwise(pl.col("Heat_Electric_Pct")) - .alias("Heat_Electric_Pct"), - pl.when(pl.col("Old_Building_Pct").is_nan()) - .then(None) - .otherwise(pl.col("Old_Building_Pct")) - .alias("Old_Building_Pct"), - ]) - - return df - - -def attach_census_to_blockgroups(bg_comp: pl.DataFrame, census_df: pl.DataFrame) -> pl.DataFrame: - """Attach Census demographics to block-group composition (wide).""" - logger.info("Joining Census data to block-group composition...") - - census_df = census_df.with_columns(pl.col("GEOID").cast(pl.Utf8).str.zfill(12).alias("block_group_geoid")) - - demo = bg_comp.join(census_df, on="block_group_geoid", how="left") - - n_before = len(demo) - missing = demo.filter(pl.col("GEOID").is_null()).height - - if missing > 0: - pct = missing / n_before * 100 - logger.warning(" %s (%.1f%%) rows missing Census data - dropping", f"{missing:,}", pct) - demo = demo.filter(pl.col("GEOID").is_not_null()) - - logger.info(" Demographics attached for %s block groups", f"{demo['block_group_geoid'].n_unique():,}") - - return demo - - -def detect_predictors_from_census(demo_df: pl.DataFrame) -> list[str]: - """ - Automatically detect predictor columns from census data. - - Excludes identifiers and non-predictors: - - GEOID, block_group_geoid - - total_obs, total_households - - Cluster counts (n_cluster_*) - - Cluster proportions (p_cluster_*) - - Log-ratio columns (log_ratio_*) - - NAME (if present) - - Returns all other columns as predictors (should be census-derived features). - """ - exclude_patterns = [ - "GEOID", - "block_group_geoid", - "total_obs", - "total_households", - "NAME", - ] - exclude_prefixes = [ - "n_cluster_", - "p_cluster_", - "log_ratio_", - ] - - all_cols = demo_df.columns - predictors = [] - - for col in all_cols: - # Skip if matches exact exclusion - if col in exclude_patterns: - continue - # Skip if starts with exclusion prefix - if any(col.startswith(prefix) for prefix in exclude_prefixes): - continue - predictors.append(col) - - return sorted(predictors) - - -def prepare_regression_dataset_wide( - demo_df: pl.DataFrame, - predictors: list[str], - min_obs_per_bg: int = 50, -) -> tuple[pl.DataFrame, list[str]]: - """ - Prepare block-group (wide) dataset for log-ratio regression. - - Filters: - - Block groups with fewer than min_obs_per_bg household-day observations - - Drops predictors with too many nulls - - Drops rows with any null predictor values (conservative / statsmodels-friendly) - """ - logger.info("Preparing regression dataset (wide)...") - - df = demo_df.filter(pl.col("total_obs") >= min_obs_per_bg) - logger.info( - " After min_obs filter (>=%d): %s block groups", - min_obs_per_bg, - f"{df['block_group_geoid'].n_unique():,}", - ) - - available_predictors: list[str] = [] - for p in predictors: - if p not in df.columns: - logger.warning(" Predictor not found: %s", p) - continue - null_rate = df[p].null_count() / len(df) - if null_rate > 0.5: - logger.warning(" Predictor %s has %.0f%% nulls - excluding", p, null_rate * 100) - continue - available_predictors.append(p) - - if not available_predictors: - raise ValueError("No valid predictors available") - - # Drop rows with any null predictor values - df = df.filter(~pl.any_horizontal(pl.col(available_predictors).is_null())) - bg_after_dropnull = df["block_group_geoid"].n_unique() - logger.info( - " After dropping rows with null predictors: %s block groups", - f"{bg_after_dropnull:,}", - ) - - logger.info(" Using %d predictors: %s", len(available_predictors), available_predictors) - - return df, available_predictors - - -def choose_baseline_cluster_from_household_days(household_days: pl.DataFrame) -> str: - """ - Choose baseline cluster as the most frequent cluster by household-day observations. - Returns as string (to match pivot-derived cluster column suffixes). - """ - dist = household_days.group_by("cluster").agg(pl.len().alias("n")).sort("n", descending=True) - baseline = dist["cluster"][0] - logger.info(" Auto-selected baseline: cluster %s (most frequent by household-days)", baseline) - return str(baseline) - - -def add_smoothed_proportions_and_logratios( - df: pl.DataFrame, - baseline_cluster: str, - alpha: float = 0.5, -) -> tuple[pl.DataFrame, list[str], list[str]]: - """ - Add smoothed proportions and log-ratios vs baseline. - - Smoothing is applied at the count level: - n_s_k = n_k + alpha - total_s = total_obs + alpha*K - p_s_k = n_s_k / total_s - - Then outcomes: - log_ratio_k = log(p_s_k / p_s_base) - """ - n_cols = sorted([c for c in df.columns if c.startswith("n_cluster_")]) - if not n_cols: - raise ValueError("No n_cluster_ columns found. Did you run wide aggregation?") - - clusters = [c.replace("n_cluster_", "") for c in n_cols] - if baseline_cluster not in clusters: - raise ValueError(f"Baseline cluster {baseline_cluster} not found in clusters={clusters}") - - K = len(clusters) - nonbase = [k for k in clusters if k != baseline_cluster] - - logger.info("Adding smoothed proportions and log-ratios (alpha=%.2f)...", alpha) - logger.info(" Clusters: %s (K=%d)", clusters, K) - logger.info(" Baseline: %s", baseline_cluster) - logger.info(" Non-baseline: %s", nonbase) - - df2 = df.with_columns([(pl.col(f"n_cluster_{k}") + alpha).alias(f"n_s_{k}") for k in clusters]).with_columns( - (pl.col("total_obs") + alpha * K).alias("total_obs_s") - ) - - df2 = df2.with_columns([(pl.col(f"n_s_{k}") / pl.col("total_obs_s")).alias(f"p_s_{k}") for k in clusters]) - - df2 = df2.with_columns([ - (pl.col(f"p_s_{k}") / pl.col(f"p_s_{baseline_cluster}")).log().alias(f"log_ratio_{k}") for k in nonbase - ]) - - # Diagnostic: check for extreme log-ratios - for k in nonbase: - extreme_pos = df2.filter(pl.col(f"log_ratio_{k}") > 5).height - extreme_neg = df2.filter(pl.col(f"log_ratio_{k}") < -5).height - - if extreme_pos > 0 or extreme_neg > 0: - logger.warning( - " Cluster %s: %d block groups with log_ratio > 5, %d with log_ratio < -5", - k, - extreme_pos, - extreme_neg, - ) - logger.warning(" This suggests very imbalanced cluster distributions in some block groups") - - return df2, clusters, nonbase - - -def run_logratio_regressions( - reg_df: pl.DataFrame, - predictors: list[str], - baseline_cluster: str, - weight_col: str = "total_obs", - standardize: bool = False, - include_ols: bool = True, -) -> dict[str, object]: - """ - Fit separate WLS models for each non-baseline cluster: - log(p_k / p_base) ~ predictors - with weights = total_obs (household-day count in block group). - - Also fits OLS models (unweighted) as robustness check. - - Interpretation: - exp(beta) = multiplicative effect on the proportion ratio p_k/p_base - for a 1-unit increase in the predictor. - """ - logger.info("Running log-ratio regressions...") - logger.info(" Final predictor count: %d", len(predictors)) - logger.info(" Baseline cluster: %s", baseline_cluster) - logger.info(" Weighting by: %s", weight_col) - logger.info(" OLS robustness check: %s", include_ols) - - logratio_cols = sorted([c for c in reg_df.columns if c.startswith("log_ratio_")]) - if not logratio_cols: - raise ValueError("No log_ratio_ columns found. Did you call add_smoothed_proportions_and_logratios()?") - - X = reg_df.select(predictors).to_numpy().astype(np.float64) - w = reg_df.get_column(weight_col).to_numpy().astype(np.float64) - - # Drop invalid rows (NaNs / infs / nonpositive weights) - valid = np.isfinite(X).all(axis=1) & np.isfinite(w) & (w > 0) - if valid.sum() == 0: - raise ValueError("No valid rows after filtering missing predictors / invalid weights.") - - X = X[valid] - w = w[valid] - - scaler = None - if standardize: - logger.info(" Standardizing predictors...") - scaler = StandardScaler() - X = scaler.fit_transform(X) - else: - logger.info(" Using raw predictor units (no standardization).") - - X = sm.add_constant(X) - param_names = ["const", *predictors] - - results: dict[str, object] = { - "n_block_groups": int(reg_df["block_group_geoid"].n_unique()), - "n_rows": len(reg_df), - "n_valid_rows": int(valid.sum()), - "weight_col": weight_col, - "baseline_cluster": baseline_cluster, - "predictors": predictors, - "standardize": bool(standardize), - "models_wls": {}, - "models_ols": {}, - } - - if scaler is not None: - results["standardization"] = { - "type": "zscore", - "means": {p: float(m) for p, m in zip(predictors, scaler.mean_)}, - "scales": {p: float(s) for p, s in zip(predictors, scaler.scale_)}, - } - - summaries_wls = [] - summaries_ols = [] - - for col in logratio_cols: - k = col.replace("log_ratio_", "") - y = reg_df.get_column(col).to_numpy().astype(np.float64)[valid] - - if not np.isfinite(y).all(): - raise ValueError(f"Non-finite values in outcome {col}. Check smoothing / inputs.") - - # WLS model - model_wls = sm.WLS(y, X, weights=w) - fit_wls = model_wls.fit() - - coef_wls = {name: float(fit_wls.params[i]) for i, name in enumerate(param_names)} - se_wls = {name: float(fit_wls.bse[i]) for i, name in enumerate(param_names)} - pvals_wls = {name: float(fit_wls.pvalues[i]) for i, name in enumerate(param_names)} - mult_wls = {name: float(np.exp(fit_wls.params[i])) for i, name in enumerate(param_names)} - - key = f"cluster_{k}_vs_{baseline_cluster}" - results["models_wls"][key] = { - "outcome": f"log(p_{k}/p_{baseline_cluster})", - "nobs": int(fit_wls.nobs), - "r2": float(fit_wls.rsquared), - "adj_r2": float(fit_wls.rsquared_adj), - "coefficients": coef_wls, - "std_errors": se_wls, - "p_values": pvals_wls, - "multiplicative_effects": mult_wls, - } - - summaries_wls.append(f"\n{'=' * 80}\nWLS: {key}\n{'=' * 80}\n{fit_wls.summary().as_text()}") - logger.info(" WLS %s: R²=%.4f", key, float(fit_wls.rsquared)) - - # OLS model (robustness check) - if include_ols: - model_ols = sm.OLS(y, X) - fit_ols = model_ols.fit() - - coef_ols = {name: float(fit_ols.params[i]) for i, name in enumerate(param_names)} - se_ols = {name: float(fit_ols.bse[i]) for i, name in enumerate(param_names)} - pvals_ols = {name: float(fit_ols.pvalues[i]) for i, name in enumerate(param_names)} - mult_ols = {name: float(np.exp(fit_ols.params[i])) for i, name in enumerate(param_names)} - - results["models_ols"][key] = { - "outcome": f"log(p_{k}/p_{baseline_cluster})", - "nobs": int(fit_ols.nobs), - "r2": float(fit_ols.rsquared), - "adj_r2": float(fit_ols.rsquared_adj), - "coefficients": coef_ols, - "std_errors": se_ols, - "p_values": pvals_ols, - "multiplicative_effects": mult_ols, - } - - summaries_ols.append(f"\n{'=' * 80}\nOLS: {key}\n{'=' * 80}\n{fit_ols.summary().as_text()}") - logger.info(" OLS %s: R²=%.4f", key, float(fit_ols.rsquared)) - - results["all_model_summaries_wls"] = "\n".join(summaries_wls) - if include_ols: - results["all_model_summaries_ols"] = "\n".join(summaries_ols) - - return results - - -def generate_report_logratio( - results: dict[str, object], - dominance_stats: dict, - cluster_dist: pl.DataFrame, - output_path: Path, -) -> None: - """Generate a human-readable report emphasizing log-ratio interpretation.""" - lines = [ - "=" * 80, - "STAGE 2: BLOCK-GROUP LOG-RATIO REGRESSION RESULTS", - "=" * 80, - "", - "ANALYSIS UNIT: BLOCK GROUPS (HOUSEHOLD-DAY COMPOSITION)", - "-" * 80, - "Each row is a block group, with household-day counts aggregated into", - "cluster composition proportions. Outcomes are log-ratios vs a baseline:", - " y_k = log(p_k / p_baseline)", - "", - "Models are separate WLS regressions per non-baseline cluster, weighted by total_obs.", - "OLS models (unweighted) are included as robustness checks.", - "", - "MODEL OVERVIEW", - "-" * 80, - f"Block groups (total): {results['n_block_groups']:,}", - f"Block groups (valid): {results['n_valid_rows']:,}", - f"Rows: {results['n_rows']:,}", - f"Predictors: {len(results['predictors'])}", - f"Weight column: {results['weight_col']}", - f"Baseline cluster: {results['baseline_cluster']}", - f"Standardized predictors: {results.get('standardize', False)}", - "", - "HOUSEHOLD CLUSTER CONSISTENCY (interpretation context)", - "-" * 80, - "How consistently do households stay in one cluster across sampled days?", - "(This doesn't affect the regression - just useful context.)", - "", - f" Households: {dominance_stats['n_households']:,}", - f" Mean dominance: {dominance_stats['dominance_mean'] * 100:.1f}%", - f" Median dominance: {dominance_stats['dominance_median'] * 100:.1f}%", - f" Households >50% in one cluster: {dominance_stats['pct_above_50']:.1f}%", - f" Households >67% in one cluster: {dominance_stats['pct_above_67']:.1f}%", - f" Households >80% in one cluster: {dominance_stats['pct_above_80']:.1f}%", - "", - "CLUSTER DISTRIBUTION (by household-day observations, overall)", - "-" * 80, - ] - - for row in cluster_dist.iter_rows(named=True): - lines.append(f" Cluster {row['cluster']}: {row['n_obs']:,} obs ({row['pct']:.1f}%)") - - lines.extend([ - "", - "TOP PREDICTORS BY MODEL (WLS; by |coef|; *=p<0.05)", - "-" * 80, - "Interpretation: exp(coef) multiplies the proportion ratio p_k/p_baseline", - "for a 1-unit increase in the predictor (holding others constant).", - "", - ]) - - models_wls = results["models_wls"] - models_ols = results.get("models_ols", {}) - predictors = results["predictors"] - - for model_key in sorted(models_wls.keys()): - m_wls = models_wls[model_key] - coefs_wls = m_wls["coefficients"] - pvals_wls = m_wls["p_values"] - mult_wls = m_wls["multiplicative_effects"] - - lines.append(f"\n{model_key}") - lines.append("-" * 80) - lines.append(f"WLS R²={m_wls['r2']:.4f}, Adj R²={m_wls['adj_r2']:.4f}") - - if model_key in models_ols: - m_ols = models_ols[model_key] - lines.append(f"OLS R²={m_ols['r2']:.4f}, Adj R²={m_ols['adj_r2']:.4f} (robustness check)") - - lines.append("") - - sorted_preds = sorted( - [(p, coefs_wls[p]) for p in predictors], - key=lambda x: abs(x[1]), - reverse=True, - )[:5] - - for pred, coef in sorted_preds: - star = "*" if pvals_wls[pred] < 0.05 else "" - direction = "↑" if coef > 0 else "↓" - - line = f" {direction} {pred:<30} WLS: mult={mult_wls[pred]:.3f}, coef={coef:>7.4f}, p={pvals_wls[pred]:.3g}{star}" - - # Show OLS comparison if available - if model_key in models_ols: - coef_ols = models_ols[model_key]["coefficients"][pred] - diff = coef - coef_ols - line += f" | OLS: coef={coef_ols:>7.4f} (Δ={diff:>6.3f})" - - lines.append(line) - - lines.append("\n" + "=" * 80) - lines.append("") - lines.append("NOTES:") - lines.append("- WLS models weight by total household-day observations per block group") - lines.append("- OLS models are unweighted (robustness check)") - lines.append("- Large WLS-OLS differences suggest results driven by large block groups") - lines.append("- Multiplicative effect > 1.0 means predictor increases p_k/p_baseline") - lines.append("- Multiplicative effect < 1.0 means predictor decreases p_k/p_baseline") - lines.append("=" * 80) - - text = "\n".join(lines) - output_path.write_text(text, encoding="utf-8") - logger.info("Report saved to %s", output_path) - print("\n" + text) - - -def main() -> int: - parser = argparse.ArgumentParser( - description="Stage 2: Block-group-level log-ratio regression using household-day units.", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - parser.add_argument("--clusters", type=Path, required=True, help="cluster_assignments.parquet") - parser.add_argument("--crosswalk", type=Path, required=True, help="ZIP+4 → block-group crosswalk") - parser.add_argument( - "--census-cache", - type=Path, - default=Path("data/reference/census_17_2023.parquet"), - ) - parser.add_argument("--fetch-census", action="store_true", help="Force re-fetch Census data") - parser.add_argument("--state-fips", default="17") - parser.add_argument("--acs-year", type=int, default=2023) - parser.add_argument( - "--min-obs-per-bg", - type=int, - default=50, - help="Minimum household-day observations per block group (default: 50)", - ) - parser.add_argument( - "--predictors", - nargs="+", - default=None, - help="Predictor columns (default: auto-detect from census data; this argument is deprecated)", - ) - parser.add_argument( - "--output-dir", - type=Path, - default=Path("data/clustering/results/stage2_blockgroups_logratio"), - ) - parser.add_argument( - "--standardize", - action="store_true", - help="Standardize predictors before regression (default: use raw units).", - ) - parser.add_argument( - "--alpha", - type=float, - default=0.5, - help="Pseudocount smoothing parameter for proportions (default: 0.5)", - ) - parser.add_argument( - "--baseline-cluster", - type=str, - default=None, - help="Baseline cluster label (default: most frequent cluster by household-day observations)", - ) - parser.add_argument( - "--no-ols", - action="store_true", - help="Skip OLS robustness check (only run WLS)", - ) - - parser.add_argument( - "--predictors-from", - type=str, - default=None, - help="Optional: path to a predictors_used.txt file to force an exact predictor list.", - ) - - args = parser.parse_args() - - if not args.clusters.exists(): - logger.error("Cluster assignments not found: %s", args.clusters) - return 1 - if not args.crosswalk.exists(): - logger.error("Crosswalk not found: %s", args.crosswalk) - return 1 - - args.output_dir.mkdir(parents=True, exist_ok=True) - - print("=" * 80) - print("STAGE 2: BLOCK-GROUP LOG-RATIO REGRESSION (HOUSEHOLD-DAY UNITS)") - print("=" * 80) - - household_days, dominance_stats = load_cluster_assignments_household_day(args.clusters) - - # Baseline cluster (string form to match wide column naming) - baseline_cluster = args.baseline_cluster or choose_baseline_cluster_from_household_days(household_days) - baseline_cluster = str(baseline_cluster) - logger.info("Using baseline cluster: %s", baseline_cluster) - - zip_codes = household_days["zip_code"].unique().to_list() - crosswalk = load_crosswalk_one_to_one(args.crosswalk, zip_codes) - household_days_bg = attach_block_groups_to_household_days(household_days, crosswalk) - - bg_comp = aggregate_blockgroup_cluster_composition(household_days_bg) - - census_df = fetch_or_load_census( - cache_path=args.census_cache, - state_fips=args.state_fips, - acs_year=args.acs_year, - force_fetch=args.fetch_census, - ) - logger.info(" Census: %s block groups, %s columns", f"{len(census_df):,}", len(census_df.columns)) - - demo_df = attach_census_to_blockgroups(bg_comp, census_df) - - # Track initial block group count - bg_total = demo_df["block_group_geoid"].n_unique() - - # If user supplied predictors-from, prefer that; else use stable 47 list. - if args.predictors_from is not None: - # Load predictors from a prior run's predictors_used.txt file - predictors_path = Path(args.predictors_from) - if not predictors_path.exists(): - logger.error("Predictors file not found: %s", predictors_path) - return 1 - predictors = [line.strip() for line in predictors_path.read_text().strip().split("\n") if line.strip()] - logger.info("Using predictors from prior run: %s (%d predictors)", args.predictors_from, len(predictors)) - else: - predictors = list(STAGE2_PREDICTORS_47) - logger.info("Using stable predictor list from census_specs: %d predictors", len(predictors)) - - # Track predictors before filtering - initial_predictors = set(predictors) - - reg_df, predictors = prepare_regression_dataset_wide( - demo_df=demo_df, - predictors=predictors, - min_obs_per_bg=args.min_obs_per_bg, - ) - - # Track block group counts and excluded predictors - bg_after_minobs = reg_df["block_group_geoid"].n_unique() if not reg_df.is_empty() else 0 - excluded_all_null_predictors = sorted(initial_predictors - set(predictors)) - predictors_detected = len(initial_predictors) - bg_after_dropnull = reg_df["block_group_geoid"].n_unique() if not reg_df.is_empty() else 0 - - if reg_df.is_empty(): - logger.error("No data after filtering") - return 1 - - # Add smoothed proportions + log-ratios - reg_df2, clusters, nonbase = add_smoothed_proportions_and_logratios( - reg_df, - baseline_cluster=baseline_cluster, - alpha=args.alpha, - ) - logger.info("Clusters detected: %s (baseline=%s, non-baseline=%s)", clusters, baseline_cluster, nonbase) - - # Save regression dataset - reg_df2.write_parquet(args.output_dir / "regression_data_blockgroups_wide.parquet") - logger.info("Saved regression data to %s", args.output_dir / "regression_data_blockgroups_wide.parquet") - - # Validate required predictor columns exist before modeling - missing = [c for c in predictors if c not in reg_df2.columns] - if missing: - available = sorted([ - c - for c in reg_df2.columns - if not c.startswith(("n_cluster_", "p_cluster_", "log_ratio_", "total_obs", "block_group_geoid", "GEOID")) - ]) - raise ValueError( - f"Missing required predictor columns: {missing}\n" - f"Available columns: {available}\n" - f"Please check that census.py returns the expected engineered features." - ) - - # Log final predictor count before modeling - logger.info("Final predictor count: %d", len(predictors)) - logger.info("Predictors: %s", predictors) - - # Fit models - results = run_logratio_regressions( - reg_df=reg_df2, - predictors=predictors, - baseline_cluster=baseline_cluster, - weight_col="total_obs", - standardize=args.standardize, - include_ols=not args.no_ols, - ) - - results["dominance_stats"] = dominance_stats - results["alpha"] = float(args.alpha) - results["k"] = len(clusters) - results["clusters"] = clusters - results["nonbaseline_clusters"] = nonbase - - # Write outputs - all_summaries_wls = results.pop("all_model_summaries_wls") - all_summaries_ols = results.pop("all_model_summaries_ols", None) - - with open(args.output_dir / "regression_results_logratio_blockgroups.json", "w") as f: - json.dump(results, f, indent=2) - - (args.output_dir / "statsmodels_summaries_wls.txt").write_text(all_summaries_wls, encoding="utf-8") - - if all_summaries_ols: - (args.output_dir / "statsmodels_summaries_ols.txt").write_text(all_summaries_ols, encoding="utf-8") - - # Cluster distribution overall (by household-day observations) - cluster_dist = ( - household_days.group_by("cluster") - .agg(pl.len().alias("n_obs")) - .sort("cluster") - .with_columns((pl.col("n_obs") / pl.col("n_obs").sum() * 100).alias("pct")) - ) - - generate_report_logratio( - results=results, - dominance_stats=dominance_stats, - cluster_dist=cluster_dist, - output_path=args.output_dir / "regression_report_logratio_blockgroups.txt", - ) - - print(f"\nOutputs saved to: {args.output_dir}") - - write_stage2_manifest( - output_dir=args.output_dir, - command=" ".join(sys.argv), - repo_root=".", - clusters_path=args.clusters, - crosswalk_path=args.crosswalk, - census_cache_path=args.census_cache, - baseline_cluster=baseline_cluster, - min_obs_per_bg=args.min_obs_per_bg, - alpha=args.alpha, - weight_column="total_obs", - predictors_detected=predictors_detected, - predictors_used=predictors, - predictors_excluded_all_null=excluded_all_null_predictors, - block_groups_total=int(bg_total), - block_groups_after_min_obs=int(bg_after_minobs), - block_groups_after_drop_null_predictors=int(bg_after_dropnull), - regression_data_path=args.output_dir / "regression_data_blockgroups_wide.parquet", - regression_report_path=args.output_dir / "regression_report_logratio_blockgroups.txt", - run_log_path=args.output_dir / "run.log", - ) - logger.info("Wrote Stage 2 manifest and predictor lists to %s", args.output_dir) - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/analysis/stage2/stage2_multinom_blockgroup_weighted.R b/analysis/stage2/stage2_multinom_blockgroup_weighted.R new file mode 100644 index 0000000..3efb410 --- /dev/null +++ b/analysis/stage2/stage2_multinom_blockgroup_weighted.R @@ -0,0 +1,1054 @@ +#!/usr/bin/env Rscript + +# ============================================================================= +# Stage 2 Multinomial Logit (Block Group): Cluster Composition ~ Census Predictors +# ============================================================================= +# +# What this script does (high-level) +# ---------------------------------- +# 1) Reads household-day cluster assignments (ZIP+4 resolution) from a Parquet file. +# 2) Aggregates these household-days into ZIP+4 × cluster counts using Arrow compute +# (i.e., without loading the full row-level data into memory). +# 3) Joins ZIP+4 to Census Block Group using a crosswalk, then aggregates to +# Block Group × cluster counts and total household-days per Block Group. +# 4) Joins Block Group-level census predictors (Parquet output from upstream census pipeline). +# 5) Fits a multinomial logit model with COUNT response: +# Y_bg = (n_bg,clusterA, n_bg,clusterB, ..., n_bg,clusterBaseline) +# where cluster probabilities are modeled as a function of BG predictors. +# +# Why this structure +# ------------------ +# - The regression is formulated at the Block Group level to avoid per-row modeling at the +# household-day level while still leveraging the full count information (multinomial likelihood). +# - Arrow Dataset aggregation avoids memory blowups when the input Parquet is very large. +# - Predictors are inferred from the census parquet to reduce coupling / hardcoding. +# - VGAM is used by default because it is typically more numerically stable at large scale than +# nnet::multinom, but VGAM requires a full-rank design matrix; rank-deficient terms are dropped +# deterministically to satisfy that constraint. +# +# Outputs +# ------- +# - regression_results.parquet: coefficient table with cluster, predictor, estimate, SE, z, p, q. +# - regression_diagnostics.json: model fit diagnostics (LL, deviance, pseudo-R2, AIC/BIC, etc.) +# - stage2_input_qc.json: data lineage + drop counts + inferred/used predictor lists +# - regression_data_blockgroups_wide.parquet: the modeled BG dataset (counts + predictors) +# - stage2_manifest.json: paths of all outputs +# - stage2_metadata.json: runtime + package versions (provenance) +# +# ============================================================================= + + +# ----------------------------- +# CLI args (no external deps) +# ----------------------------- +# Notes: +# - Avoids argparse-style dependencies in R, keeping the script self-contained. +# - Supports both "--flag value" and "--flag=value". +print_help_and_exit <- function(exit_code = 0) { + cat( + paste0( + "\nStage 2 Multinomial Logit (Block Group)\n\n", + "Usage:\n", + " Rscript stage2_multinom_blockgroup_weighted.R [options]\n\n", + "Required:\n", + " --clusters PATH Cluster assignments parquet (ZIP+4 household-day rows)\n", + " --crosswalk PATH ZIP+4 -> Block Group crosswalk (tab-delimited txt)\n", + " --census PATH Census predictors parquet (Block Group level)\n", + " --out-dir PATH Output directory\n\n", + "Optional:\n", + " --baseline-cluster K Baseline cluster label (default: choose most frequent)\n", + " --min-obs-per-bg N Drop BGs with total household-days < N (default: 50)\n", + " --allow-missing-predictors 0|1 If 0, abort if predictor NA would drop any BGs (default: 0)\n", + " --standardize 0|1 Z-score standardize predictors (default: 0)\n", + " --use-vgam 0|1 Use VGAM::vglm() (IRLS) instead of nnet::multinom (default: 1)\n", + " --verbose 0|1 Verbose logging (default: 1)\n", + " --no-emoji 0|1 Disable unicode icons (default: 0)\n", + " --help Print this help and exit\n\n", + "Notes:\n", + " - Predictors are inferred from the census parquet columns.\n", + " - Model uses COUNT response: cbind(count_clusterA, count_clusterB, ...)\n", + " - Zeros are handled naturally by the multinomial likelihood; no smoothing is applied.\n", + " - Standardization (--standardize=1) is STRONGLY RECOMMENDED for numerical stability.\n", + " - VGAM requires a full-rank design matrix; this script drops rank-deficient terms deterministically.\n", + " - Outputs written under out-dir:\n", + " regression_results.parquet\n", + " regression_diagnostics.json\n", + " stage2_input_qc.json\n", + " regression_data_blockgroups_wide.parquet\n", + " stage2_manifest.json\n", + " stage2_metadata.json\n\n" + ) + ) + quit(status = exit_code) +} + +args <- commandArgs(trailingOnly = TRUE) +if (any(args %in% c("--help", "-h"))) print_help_and_exit(0) + +get_arg <- function(flag, default = NULL) { + hit <- grep(paste0("^", flag, "="), args) + if (length(hit) > 0) return(sub(paste0("^", flag, "="), "", args[hit[1]])) + hit2 <- which(args == flag) + if (length(hit2) > 0 && hit2[1] < length(args)) return(args[hit2[1] + 1]) + default +} + +parse_bool01 <- function(x, default = 0L) { + if (is.null(x) || is.na(x) || x == "") return(as.integer(default)) + s <- tolower(trimws(as.character(x))) + if (s %in% c("1", "true", "t", "yes", "y")) return(1L) + if (s %in% c("0", "false", "f", "no", "n")) return(0L) + as.integer(default) +} + +parse_int <- function(x, default = NA_integer_) { + if (is.null(x) || is.na(x) || x == "") return(default) + suppressWarnings(v <- as.integer(x)) + if (is.na(v)) default else v +} + +stopf <- function(fmt, ...) stop(sprintf(fmt, ...), call. = FALSE) + +# Required I/O paths +CLUSTERS_PATH <- get_arg("--clusters", default = NULL) +CROSSWALK_PATH <- get_arg("--crosswalk", default = NULL) +CENSUS_PATH <- get_arg("--census", default = NULL) +OUT_DIR <- get_arg("--out-dir", default = NULL) + +# Model/config knobs +BASELINE_CLUSTER_ARG <- get_arg("--baseline-cluster", default = NULL) +MIN_OBS_PER_BG <- parse_int(get_arg("--min-obs-per-bg", default = "50"), default = 50L) + +ALLOW_MISSING_PREDICTORS <- parse_bool01(get_arg("--allow-missing-predictors", default = "0"), default = 0L) +STANDARDIZE <- parse_bool01(get_arg("--standardize", default = "0"), default = 0L) +USE_VGAM <- parse_bool01(get_arg("--use-vgam", default = "1"), default = 1L) + +VERBOSE <- parse_bool01(get_arg("--verbose", default = "1"), default = 1L) +NO_EMOJI <- parse_bool01(get_arg("--no-emoji", default = "0"), default = 0L) + +if (is.null(CLUSTERS_PATH) || is.null(CROSSWALK_PATH) || is.null(CENSUS_PATH) || is.null(OUT_DIR)) { + cat("Missing required argument(s). Use --help.\n") + quit(status = 2) +} + +# ----------------------------- +# Deps +# ----------------------------- +# The approach uses: +# - arrow: efficient parquet I/O and dataset aggregation +# - dplyr/tibble: data manipulation +# - jsonlite: writing QC/diagnostics +# - nnet: fallback multinomial logit solver +# - VGAM: default multinomial logit solver for stability (full-rank requirement) +require_pkg <- function(pkg) requireNamespace(pkg, quietly = TRUE) + +if (!require_pkg("arrow")) stopf("Missing R package 'arrow'. Install with: install.packages('arrow')") +if (!require_pkg("jsonlite")) stopf("Missing R package 'jsonlite'. Install with: install.packages('jsonlite')") +if (!require_pkg("dplyr")) stopf("Missing R package 'dplyr'. Install with: install.packages('dplyr')") +if (!require_pkg("tibble")) stopf("Missing R package 'tibble'. Install with: install.packages('tibble')") +if (!require_pkg("nnet")) stopf("Missing R package 'nnet'. Install with: install.packages('nnet')") +if (USE_VGAM == 1L && !require_pkg("VGAM")) stopf("Missing R package 'VGAM'. Install with: install.packages('VGAM')") + +suppressPackageStartupMessages({ + library(arrow) + library(jsonlite) + library(dplyr) + library(tibble) + library(nnet) + if (USE_VGAM == 1L) library(VGAM) +}) + +dir.create(OUT_DIR, recursive = TRUE, showWarnings = FALSE) + +icon_for <- function() { + if (NO_EMOJI == 1L) return(list(ok = "OK", warn = "[WARN]", crit = "[CRIT]")) + list(ok = "\u2705", warn = "\U0001F7E1", crit = "\U0001F534") +} +IC <- icon_for() + +logi <- function(...) if (VERBOSE == 1L) cat(sprintf(...), "\n") + +safe_write_json <- function(obj, path) { + jsonlite::write_json(obj, path, pretty = TRUE, auto_unbox = TRUE) +} + +logi("%s Config: standardize=%d use_vgam=%d", IC$ok, STANDARDIZE, USE_VGAM) +t_total_start <- Sys.time() + + +# ----------------------------- +# Helpers: keys + inference +# ----------------------------- +# ZIP+4 normalization: +# - Cluster parquet uses "zip_code" which can appear as: +# - "60601-1234" or "606011234" or other string forms. +# - Crosswalk expects Zip + Zip4 columns; we standardize to "#####-####". +normalize_zip4 <- function(x) { + s <- as.character(x) + s <- trimws(s) + s[s == ""] <- NA_character_ + out <- s + is9 <- !is.na(s) & grepl("^[0-9]{9}$", s) + out[is9] <- paste0(substr(s[is9], 1, 5), "-", substr(s[is9], 6, 9)) + out +} + +# Ensure Zip4 is exactly 4 digits, leading zeros preserved. +zfill4 <- function(x) { + s <- as.character(x) + s <- trimws(s) + s[s == ""] <- NA_character_ + s <- gsub("[^0-9]", "", s) + s <- ifelse(is.na(s), NA_character_, sprintf("%04d", as.integer(s))) + s +} + +# Census GEOID column inference: +# - Upstream data may name the key GEOID, CensusKey2023, CensusKey2020, etc. +infer_geoid_col <- function(df) { + nms <- names(df) + low <- tolower(nms) + if ("geoid" %in% low) return(nms[which(low == "geoid")[1]]) + if ("censuskey2023" %in% low) return(nms[which(low == "censuskey2023")[1]]) + if ("censuskey2020" %in% low) return(nms[which(low == "censuskey2020")[1]]) + NULL +} + +# Predictor inference: +# - Uses numeric/integer/logical columns as candidate predictors. +# - Excludes id-like columns (GEOID/NAME + the inferred geoid key). +# - Drops columns that are entirely NA. +infer_predictors <- function(census_df) { + geoid_col <- infer_geoid_col(census_df) + if (is.null(geoid_col)) { + stopf("Census predictors must include a GEOID-like column. Found: %s", paste(names(census_df), collapse = ", ")) + } + + id_like <- unique(c(geoid_col, "GEOID", "NAME")) + candidates <- setdiff(names(census_df), id_like) + + is_numish <- vapply( + census_df[candidates], + function(x) is.numeric(x) || is.integer(x) || is.logical(x), + logical(1) + ) + preds <- candidates[is_numish] + + all_na <- preds[vapply(census_df[preds], function(x) all(is.na(x)), logical(1))] + preds <- setdiff(preds, all_na) + + list(geoid_col = geoid_col, predictors = preds, dropped_all_na = all_na) +} + +# Rank-deficiency handling for VGAM: +# - VGAM::vglm(multinomial()) requires a full-rank model matrix. +# - We build a standard model.matrix with intercept and remove: +# (a) constant columns (excluding intercept) +# (b) columns beyond QR rank using pivot ordering +# +# Note: This is deterministic and reproducible, but may drop terms that are +# substantively meaningful if predictors are highly collinear. Treat as a +# pragmatic requirement for full-rank MLE and revisit for a more principled +# approach if needed. +drop_rank_deficient_terms <- function(model_df, predictors) { + ftmp <- stats::as.formula(paste0("~ ", paste(predictors, collapse = " + "))) + Xmm <- stats::model.matrix(ftmp, data = model_df) + + is_const <- apply(Xmm, 2, function(v) { + if (!is.numeric(v)) return(FALSE) + rng <- range(v, na.rm = TRUE) + is.finite(rng[1]) && is.finite(rng[2]) && abs(rng[2] - rng[1]) < 1e-12 + }) + is_const[colnames(Xmm) == "(Intercept)"] <- FALSE + const_cols <- colnames(Xmm)[is_const] + + if (length(const_cols) > 0) { + logi("%s Dropping %d constant design columns before rank check.", IC$warn, length(const_cols)) + keep <- setdiff(colnames(Xmm), const_cols) + Xmm <- Xmm[, keep, drop = FALSE] + } + + qrX <- qr(Xmm) + rk <- qrX$rank + full_cols <- colnames(Xmm) + dropped_predictors <- character(0) + + if (rk < ncol(Xmm)) { + piv <- qrX$pivot + keep_cols <- full_cols[piv[seq_len(rk)]] + drop_cols <- setdiff(full_cols, keep_cols) + + if (!"(Intercept)" %in% keep_cols) stopf("Internal error: intercept was dropped by rank procedure.") + + dropped_predictors <- intersect(drop_cols, predictors) + kept_predictors <- setdiff(predictors, dropped_predictors) + + logi( + "%s Rank-deficient design detected: rank=%d of %d columns. Dropping %d term(s).", + IC$warn, rk, ncol(Xmm), length(dropped_predictors) + ) + if (length(dropped_predictors) > 0) { + logi("%s Dropped predictors (rank-deficient): %s", IC$warn, paste(dropped_predictors, collapse = ", ")) + } + + return(list( + predictors = kept_predictors, + dropped_predictors = dropped_predictors, + rank = rk, + ncol_design = ncol(Xmm) + )) + } + + return(list( + predictors = predictors, + dropped_predictors = character(0), + rank = rk, + ncol_design = ncol(Xmm) + )) +} + + +# ----------------------------- +# Read + aggregate clusters (memory-safe) +# ----------------------------- +# Key point: +# - We do not read the full household-day dataset into memory. +# - Instead we open it as an Arrow dataset and aggregate "zip_code × cluster" counts. +if (!file.exists(CLUSTERS_PATH)) stopf("Clusters parquet not found: %s", CLUSTERS_PATH) +if (!file.exists(CROSSWALK_PATH)) stopf("Crosswalk file not found: %s", CROSSWALK_PATH) +if (!file.exists(CENSUS_PATH)) stopf("Census predictors parquet not found: %s", CENSUS_PATH) + +logi("%s Aggregating clusters from parquet (Arrow Dataset compute): %s", IC$ok, CLUSTERS_PATH) + +clusters_ds <- tryCatch( + arrow::open_dataset(sources = CLUSTERS_PATH, format = "parquet"), + error = function(e) NULL +) +if (is.null(clusters_ds)) stopf("Failed to open clusters parquet as Arrow Dataset: %s", CLUSTERS_PATH) + +zip_cluster_counts <- clusters_ds %>% + dplyr::select(zip_code, cluster) %>% + dplyr::filter(!is.na(zip_code), !is.na(cluster)) %>% + dplyr::group_by(zip_code, cluster) %>% + dplyr::summarise(n = dplyr::n(), .groups = "drop") %>% + dplyr::collect() + +if (!("zip_code" %in% names(zip_cluster_counts))) stopf("Clusters parquet must include column 'zip_code'") +if (!("cluster" %in% names(zip_cluster_counts))) stopf("Clusters parquet must include column 'cluster'") +if (!("n" %in% names(zip_cluster_counts))) stopf("Internal error: missing 'n' after aggregation") + +# Normalize and sanitize types +zip_cluster_counts <- zip_cluster_counts %>% + dplyr::mutate( + zip4 = normalize_zip4(zip_code), + cluster = suppressWarnings(as.integer(as.character(cluster))), + n = suppressWarnings(as.integer(n)) + ) %>% + dplyr::filter(!is.na(zip4), !is.na(cluster), !is.na(n), n > 0) + +if (nrow(zip_cluster_counts) == 0) stopf("No usable ZIP+4×cluster counts after basic filtering.") +household_day_rows_total <- sum(zip_cluster_counts$n, na.rm = TRUE) + + +# ----------------------------- +# Read crosswalk +# ----------------------------- +# Crosswalk requirements: +# - Must include Zip, Zip4, and CensusKey2023 (or equivalent). +# - We standardize to "zip4" = "#####-####" and "block_group_geoid" = 12-char BG GEOID. +# +# Design choice: +# - If multiple BGs map to the same ZIP+4, we deterministically pick the first after sorting. +# This enforces a single mapping and avoids row inflation, but it should be reviewed for whether +# a probabilistic or fractional allocation is preferable. +logi("%s Reading crosswalk: %s", IC$ok, CROSSWALK_PATH) + +cw_tbl <- tryCatch( + arrow::read_csv_arrow(CROSSWALK_PATH, delim = "\t"), + error = function(e) NULL +) + +cw <- NULL +if (!is.null(cw_tbl)) { + cw <- as.data.frame(cw_tbl) +} else { + cw <- tryCatch( + read.delim(CROSSWALK_PATH, header = TRUE, sep = "\t", stringsAsFactors = FALSE, check.names = FALSE), + error = function(e) NULL + ) +} +if (is.null(cw)) stopf("Failed to read crosswalk: %s", CROSSWALK_PATH) + +low_cw <- tolower(names(cw)) +zip_col <- if ("zip" %in% low_cw) names(cw)[which(low_cw == "zip")[1]] else NA_character_ +zip4_col <- if ("zip4" %in% low_cw) names(cw)[which(low_cw == "zip4")[1]] else NA_character_ +geoid_col_cw <- if ("censuskey2023" %in% low_cw) names(cw)[which(low_cw == "censuskey2023")[1]] else NA_character_ + +if (is.na(zip_col) || is.na(zip4_col)) stopf("Crosswalk must include columns Zip and Zip4. Found: %s", paste(names(cw), collapse = ", ")) +if (is.na(geoid_col_cw)) stopf("Crosswalk must include column CensusKey2023. Found: %s", paste(names(cw), collapse = ", ")) + +zip4_present <- unique(zip_cluster_counts$zip4) + +cw <- cw %>% + dplyr::transmute( + Zip = as.character(.data[[zip_col]]), + Zip4 = zfill4(.data[[zip4_col]]), + zip4 = ifelse(!is.na(Zip) & !is.na(Zip4), paste0(Zip, "-", Zip4), NA_character_), + block_group_geoid = as.character(.data[[geoid_col_cw]]) + ) %>% + dplyr::filter(!is.na(zip4), !is.na(block_group_geoid)) %>% + dplyr::filter(zip4 %in% zip4_present) %>% + dplyr::mutate(block_group_geoid = substr(block_group_geoid, 1, 12)) %>% + dplyr::filter(!is.na(block_group_geoid), nchar(block_group_geoid) == 12) + +if (nrow(cw) == 0) stopf("Crosswalk produced 0 usable rows after cleaning/filtering.") + +# Deterministic one-to-one ZIP+4 → BG mapping +cw <- cw %>% + dplyr::arrange(zip4, block_group_geoid) %>% + dplyr::group_by(zip4) %>% + dplyr::slice(1) %>% + dplyr::ungroup() + +dup_zip4 <- cw %>% dplyr::count(zip4) %>% dplyr::filter(n > 1) +if (nrow(dup_zip4) > 0) stopf("Crosswalk still has non-unique zip4 after deterministic resolution. Found %d duplicates.", nrow(dup_zip4)) + + +# ----------------------------- +# Join + aggregate to BG counts +# ----------------------------- +# After joining the crosswalk, we can compute: +# - BG×cluster counts +# - BG total household-days +# and optionally drop BGs with too few observations (min_obs_per_bg). +logi("%s Joining ZIP+4×cluster counts to crosswalk...", IC$ok) + +clusters2 <- zip_cluster_counts %>% + dplyr::select(zip4, cluster, n) %>% + dplyr::inner_join(cw, by = "zip4") + +household_day_rows_after_crosswalk <- sum(clusters2$n, na.rm = TRUE) +dropped_missing_crosswalk <- household_day_rows_total - household_day_rows_after_crosswalk + +if (nrow(clusters2) == 0) stopf("All cluster counts dropped after crosswalk join. Check zip4 normalization and crosswalk keying.") + +logi("%s Aggregating to block group counts...", IC$ok) + +bg_counts <- clusters2 %>% + dplyr::group_by(block_group_geoid, cluster) %>% + dplyr::summarize(n = sum(n, na.rm = TRUE), .groups = "drop") + +total_by_bg <- clusters2 %>% + dplyr::group_by(block_group_geoid) %>% + dplyr::summarize(total_household_days = sum(n, na.rm = TRUE), .groups = "drop") + +clusters_observed <- sort(unique(bg_counts$cluster)) +if (length(clusters_observed) < 2) stopf("Need at least 2 clusters observed after aggregation; found: %s", paste(clusters_observed, collapse = ",")) + +# Wide BG frame: GEOID + total + one column per cluster count +bg_wide <- total_by_bg %>% dplyr::rename(GEOID = block_group_geoid) + +for (k in clusters_observed) { + colname <- paste0("cluster_", k) + tmp <- bg_counts %>% + dplyr::filter(cluster == k) %>% + dplyr::transmute(GEOID = block_group_geoid, n = as.integer(n)) + + bg_wide <- bg_wide %>% + dplyr::left_join(tmp, by = "GEOID") %>% + dplyr::mutate(!!colname := ifelse(is.na(n), 0L, as.integer(n))) %>% + dplyr::select(-n) +} + +# Observation floor: avoids fragile inference for tiny BG totals. +bg_wide <- bg_wide %>% dplyr::filter(total_household_days >= as.integer(MIN_OBS_PER_BG)) +if (nrow(bg_wide) == 0) stopf("No block groups remain after --min-obs-per-bg filtering (N=%d).", MIN_OBS_PER_BG) + +# Zero count diagnostic: proportion of BGs with zero count in each cluster column. +zero_stats <- list() +for (k in clusters_observed) { + colname <- paste0("cluster_", k) + if (colname %in% names(bg_wide)) { + n_zeros <- sum(bg_wide[[colname]] == 0, na.rm = TRUE) + pct_zeros <- 100 * n_zeros / nrow(bg_wide) + zero_stats[[as.character(k)]] <- list(n_zeros = as.integer(n_zeros), pct_zeros = as.numeric(pct_zeros)) + } +} + + +# ----------------------------- +# Read census predictors + infer predictors +# ----------------------------- +# This step establishes the modeling covariate set at the BG level. +# Predictors are inferred (numeric/integer/logical columns) rather than hard-coded. +logi("%s Reading census predictors: %s", IC$ok, CENSUS_PATH) + +census <- arrow::read_parquet(CENSUS_PATH, as_data_frame = TRUE) +inf <- infer_predictors(census) +CENSUS_GEOID_COL <- inf$geoid_col +PREDICTORS <- inf$predictors +DROPPED_ALL_NA <- inf$dropped_all_na + +if (length(PREDICTORS) == 0) stopf("No usable numeric predictors inferred from census parquet.") + +# Deduplicate by GEOID defensively; keep only inferred predictors. +census <- census %>% + dplyr::mutate(GEOID = as.character(.data[[CENSUS_GEOID_COL]])) %>% + dplyr::select(GEOID, dplyr::any_of(PREDICTORS)) %>% + dplyr::distinct(GEOID, .keep_all = TRUE) + +logi("%s Joining census predictors to BG counts...", IC$ok) +bg_model <- bg_wide %>% dplyr::inner_join(census, by = "GEOID") + +# Missing predictor handling: +# - If allow_missing_predictors=0: fail fast if any BG would be dropped by NA. +# - If allow_missing_predictors=1: drop incomplete BGs (complete-case analysis). +pred_mat <- bg_model %>% dplyr::select(dplyr::any_of(PREDICTORS)) +any_na <- apply(is.na(pred_mat), 1, any) + +drop_missing_pred_bg <- sum(any_na) +drop_missing_pred_hhday <- if (drop_missing_pred_bg > 0) sum(bg_model$total_household_days[any_na], na.rm = TRUE) else 0 + +if (drop_missing_pred_bg > 0 && ALLOW_MISSING_PREDICTORS == 0L) { + stopf( + "Predictor missingness would drop %d block groups (%.0f household-days). Refuse to proceed. Set --allow-missing-predictors=1 to override.", + drop_missing_pred_bg, drop_missing_pred_hhday + ) +} +if (drop_missing_pred_bg > 0 && ALLOW_MISSING_PREDICTORS == 1L) { + bg_model <- bg_model[!any_na, , drop = FALSE] +} +if (nrow(bg_model) == 0) stopf("No block groups remain after predictor missingness filtering.") + + +# ----------------------------- +# Optional: standardize predictors (z-score) +# ----------------------------- +# Rationale: +# - Improves numerical conditioning for optimization / IRLS. +# - Makes coefficients more comparable (effect per 1 SD change). +# - Retains original scale parameters in scaling_info for provenance. +scaling_info <- NULL +zero_var_predictors <- character(0) + +if (STANDARDIZE == 1L) { + logi("%s Standardizing %d predictors (z-score: mean=0, sd=1)...", IC$ok, length(PREDICTORS)) + scaling_info <- list() + + for (col in PREDICTORS) { + if (!col %in% names(bg_model)) next + x <- bg_model[[col]] + if (is.logical(x)) x <- as.numeric(x) + + mu <- mean(x, na.rm = TRUE) + sigma <- stats::sd(x, na.rm = TRUE) + scaling_info[[col]] <- list(mean = as.numeric(mu), sd = as.numeric(sigma)) + + if (is.finite(sigma) && sigma > 1e-10) { + bg_model[[col]] <- (x - mu) / sigma + } else { + # Retain original scale if SD ~ 0; keep a list for QC outputs. + zero_var_predictors <- c(zero_var_predictors, col) + cat(sprintf("%s Predictor '%s' has ~zero variance; not standardized.\n", IC$warn, col)) + bg_model[[col]] <- x + } + } +} + + +# ----------------------------- +# Choose baseline + response matrix +# ----------------------------- +# The response is a matrix of counts per BG across clusters: +# - Columns are ordered so the baseline cluster is last, matching the refLevel config for VGAM. +# - Baseline selection: +# - If user specifies --baseline-cluster, use it. +# - Otherwise choose the most frequent cluster by total count. +resp_cols <- paste0("cluster_", clusters_observed) +resp_cols <- resp_cols[resp_cols %in% names(bg_model)] +if (length(resp_cols) < 2) stopf("Need >=2 cluster count columns; found: %s", paste(resp_cols, collapse = ",")) + +cluster_totals_for_baseline <- sapply(resp_cols, function(cn) sum(bg_model[[cn]], na.rm = TRUE)) +names(cluster_totals_for_baseline) <- resp_cols + +baseline_cluster <- NA_integer_ +if (!is.null(BASELINE_CLUSTER_ARG) && nchar(BASELINE_CLUSTER_ARG) > 0) { + baseline_cluster <- suppressWarnings(as.integer(BASELINE_CLUSTER_ARG)) +} else { + max_col <- names(cluster_totals_for_baseline)[which.max(cluster_totals_for_baseline)] + baseline_cluster <- as.integer(sub("^cluster_", "", max_col)) +} + +if (!paste0("cluster_", baseline_cluster) %in% resp_cols) { + stopf("Baseline cluster %d not present among observed clusters: %s", baseline_cluster, paste(resp_cols, collapse = ", ")) +} + +resp_cols_ordered <- c(setdiff(resp_cols, paste0("cluster_", baseline_cluster)), paste0("cluster_", baseline_cluster)) +Y <- as.matrix(bg_model[, resp_cols_ordered, drop = FALSE]) + +# Basic response validation +if (any(is.na(Y))) stopf("NA values detected in response matrix.") +if (any(Y < 0, na.rm = TRUE)) stopf("Negative counts detected in response matrix.") +rs <- rowSums(Y) +if (any(rs <= 0, na.rm = TRUE)) { + bad_rows <- which(rs <= 0) + stopf("Block groups with non-positive total counts. Example row indices: %s", paste(head(bad_rows, 10), collapse = ", ")) +} + +# Model frame for fitting +model_df <- bg_model[, c("GEOID", "total_household_days", PREDICTORS), drop = FALSE] +model_df$Y <- I(Y) + +# Map “equations” to cluster labels: +# - Multinomial logit has one linear predictor per non-baseline outcome. +# - Here, eq index i corresponds to nonbase_clusters[i]. +nonbase_cols <- resp_cols_ordered[resp_cols_ordered != paste0("cluster_", baseline_cluster)] +nonbase_clusters <- as.integer(sub("^cluster_", "", nonbase_cols)) + + +# ----------------------------- +# Rank check / drop terms (VGAM full-rank requirement) +# ----------------------------- +dropped_predictors_rank <- character(0) +design_rank <- NA_integer_ +design_ncol <- NA_integer_ + +if (USE_VGAM == 1L) { + rk <- drop_rank_deficient_terms(model_df, PREDICTORS) + PREDICTORS <- rk$predictors + dropped_predictors_rank <- rk$dropped_predictors + design_rank <- as.integer(rk$rank) + design_ncol <- as.integer(rk$ncol_design) + + # Rebuild with kept predictors only + model_df <- bg_model[, c("GEOID", "total_household_days", PREDICTORS), drop = FALSE] + model_df$Y <- I(Y) +} + + +# ----------------------------- +# Fit model +# ----------------------------- +# Model form: +# Y ~ X1 + X2 + ... + Xp +# +# For VGAM::vglm: +# - family = multinomial(refLevel = K) where K is the last column (baseline). +# - IRLS typically handles scale better than nnet::multinom in large-count settings. +logi( + "%s Fitting multinomial logit (counts) with %d BGs, %d predictors, %d clusters (baseline=%d)...", + IC$ok, nrow(model_df), length(PREDICTORS), length(resp_cols_ordered), baseline_cluster +) + +rhs <- paste(PREDICTORS, collapse = " + ") +form <- stats::as.formula(paste0("Y ~ ", rhs)) + +fit_start <- Sys.time() +fit_warnings <- character(0) +fit0_warnings <- character(0) + +if (USE_VGAM == 1L) { + logi("%s Using VGAM::vglm() (IRLS) for full-rank MLE...", IC$ok) + + fit <- tryCatch( + withCallingHandlers( + VGAM::vglm(form, family = VGAM::multinomial(refLevel = length(resp_cols_ordered)), data = model_df), + warning = function(w) { fit_warnings <<- c(fit_warnings, w$message); invokeRestart("muffleWarning") } + ), + error = function(e) stopf("Model fit failed (VGAM::vglm): %s", e$message) + ) + + # Null model (intercept-only): provides baseline for pseudo-R2 and deviance comparisons + fit0 <- tryCatch( + withCallingHandlers( + VGAM::vglm(Y ~ 1, family = VGAM::multinomial(refLevel = length(resp_cols_ordered)), data = model_df), + warning = function(w) { fit0_warnings <<- c(fit0_warnings, w$message); invokeRestart("muffleWarning") } + ), + error = function(e) stopf("Null model fit failed (VGAM::vglm): %s", e$message) + ) +} else { + fit <- tryCatch( + withCallingHandlers( + nnet::multinom(form, data = model_df, trace = FALSE, maxit = 500), + warning = function(w) { fit_warnings <<- c(fit_warnings, w$message); invokeRestart("muffleWarning") } + ), + error = function(e) stopf("Model fit failed (nnet::multinom): %s", e$message) + ) + + fit0 <- tryCatch( + withCallingHandlers( + nnet::multinom(Y ~ 1, data = model_df, trace = FALSE, maxit = 500), + warning = function(w) { fit0_warnings <<- c(fit0_warnings, w$message); invokeRestart("muffleWarning") } + ), + error = function(e) stopf("Null model fit failed (nnet::multinom): %s", e$message) + ) +} + +fit_end <- Sys.time() +fit_duration <- as.numeric(difftime(fit_end, fit_start, units = "secs")) +logi("%s Model fit completed in %.1f seconds", IC$ok, fit_duration) + + +# ----------------------------- +# Convergence heuristics +# ----------------------------- +# Approach: +# - Primary: check warnings that suggest iteration/step issues. +# - Secondary: compare deviance_full vs deviance_null; if too close, signal weak fit or instability. +# Note: VGAM is S4; we avoid `$iter` or similar direct slot assumptions here. +convergence_ok <- TRUE +convergence_message <- "Model converged successfully" + +all_warns <- c(fit_warnings, fit0_warnings) +if (length(all_warns) > 0) { + conv_warnings <- grep("converg|iteration|maxit|step", all_warns, ignore.case = TRUE, value = TRUE) + if (length(conv_warnings) > 0) { + convergence_ok <- FALSE + convergence_message <- paste0("CONVERGENCE WARNING: ", paste(unique(conv_warnings), collapse = "; ")) + cat(sprintf("\n%s %s\n", IC$warn, convergence_message)) + } +} + +iter_used <- NA_integer_ +if (USE_VGAM == 0L) { + if (!is.null(fit$iter)) iter_used <- suppressWarnings(as.integer(fit$iter)) + if (is.na(iter_used) && !is.null(fit$niter)) iter_used <- suppressWarnings(as.integer(fit$niter)) + if (!is.na(iter_used) && iter_used >= 500) { + convergence_ok <- FALSE + convergence_message <- paste0(convergence_message, "; Reached maxit=500 (iter_used=", iter_used, ")") + cat(sprintf("\n%s Reached maxit=500 (iter_used=%d). Treating as potential non-convergence.\n", IC$warn, iter_used)) + } +} + +deviance_full <- as.numeric(stats::deviance(fit)) +deviance_null <- as.numeric(stats::deviance(fit0)) +deviance_ratio <- deviance_full / deviance_null +if (is.finite(deviance_ratio) && deviance_ratio > 0.95) { + cat(sprintf( + "\n%s Model deviance (%.2f) is very close to null deviance (%.2f). Low explanatory power and/or convergence issues possible.\n", + IC$warn, deviance_full, deviance_null + )) + convergence_message <- paste0(convergence_message, "; High deviance_ratio=", sprintf("%.3f", deviance_ratio)) +} + + +# ----------------------------- +# Correlations (diagnostic) +# ----------------------------- +# This is a quick heuristic to identify predictors correlated with observed cluster proportions. +# It is not used for modeling; it is logged for reviewer sanity-checking and interpretation guidance. +if (all(c("cluster_0", "cluster_1", "cluster_3") %in% names(bg_model))) { + bg_tmp <- bg_model %>% + dplyr::mutate( + prop_cluster_0 = cluster_0 / total_household_days, + prop_cluster_1 = cluster_1 / total_household_days, + prop_cluster_3 = cluster_3 / total_household_days + ) + + cor_matrix <- tryCatch( + stats::cor( + bg_tmp[, PREDICTORS, drop = FALSE], + bg_tmp[, c("prop_cluster_0", "prop_cluster_1", "prop_cluster_3"), drop = FALSE], + use = "complete.obs" + ), + error = function(e) NULL + ) + + if (!is.null(cor_matrix)) { + strong_cors <- apply(abs(cor_matrix), 1, max) + cat("\nPredictors with |correlation| > 0.1 to any cluster proportion:\n") + print(sort(strong_cors[strong_cors > 0.1], decreasing = TRUE)) + } +} + + +# ----------------------------- +# Core stats +# ----------------------------- +# McFadden pseudo-R2: +# 1 - (LL_full / LL_null) +# Interpreted as improvement over intercept-only baseline on the log-likelihood scale. +# Not directly comparable to OLS R^2; treat as a relative fit measure. +ll_full <- as.numeric(stats::logLik(fit)) +ll_null <- as.numeric(stats::logLik(fit0)) +pseudo_r2 <- 1.0 - (ll_full / ll_null) + + +# ----------------------------- +# Coefficients table (nnet + VGAM) -> ALWAYS returns 'cluster' +# ----------------------------- +# This function normalizes output format across engines: +# - Each row: (cluster, predictor, coefficient, std_err, z_stat, p_value) +# - cluster indicates the non-baseline outcome corresponding to the equation. +# +# For VGAM: +# - We use vcov(fit) for SEs and align by exact coefficient names. +# - We robustly parse equation indices from coefficient naming conventions: +# A) "term:1" style +# B) "log(mu[,1]/mu[,K]):term" style +extract_coef_table <- function(fit_obj, nonbase_clusters) { + if (inherits(fit_obj, "multinom")) { + coefs <- summary(fit_obj)$coefficients + ses <- summary(fit_obj)$standard.errors + pred_names <- colnames(coefs) + + rows <- vector("list", length = nrow(coefs) * length(pred_names)) + idx <- 0L + for (i in seq_len(nrow(coefs))) { + for (j in seq_along(pred_names)) { + idx <- idx + 1L + b <- as.numeric(coefs[i, j]) + se <- as.numeric(ses[i, j]) + z <- b / se + p <- 2 * pnorm(-abs(z)) + rows[[idx]] <- list( + eq = as.integer(i), + cluster = as.integer(nonbase_clusters[i]), + predictor = ifelse(pred_names[j] == "(Intercept)", "Intercept", pred_names[j]), + coefficient = b, + std_err = se, + z_stat = z, + p_value = p + ) + } + } + return(tibble::as_tibble(do.call(rbind.data.frame, rows))) + } + + if (inherits(fit_obj, "vglm")) { + b <- stats::coef(fit_obj) + + V <- tryCatch(stats::vcov(fit_obj), error = function(e) NULL) + if (is.null(V)) stopf("Failed to compute vcov() for VGAM fit; cannot produce standard errors.") + if (is.null(colnames(V))) stopf("vcov(VGAM) returned no colnames; cannot align SEs to coefficients.") + + se_all <- sqrt(diag(V)) + names(se_all) <- colnames(V) + + se <- se_all[names(b)] + if (any(is.na(se))) { + missing <- names(b)[is.na(se)] + stopf( + "Internal error: missing SE for %d VGAM coefficient(s). Example: %s", + length(missing), + paste(head(missing, 5), collapse = ", ") + ) + } + + nm <- names(b) + + parse_vgam_eq_and_term <- function(nm_vec) { + eq <- rep(NA_integer_, length(nm_vec)) + term <- rep(NA_character_, length(nm_vec)) + + # A) "term:1" + m_a <- regexec("^(.*):([0-9]+)$", nm_vec) + r_a <- regmatches(nm_vec, m_a) + has_a <- lengths(r_a) == 3 + if (any(has_a)) { + term[has_a] <- vapply(r_a[has_a], function(x) x[[2]], character(1)) + eq[has_a] <- suppressWarnings(as.integer(vapply(r_a[has_a], function(x) x[[3]], character(1)))) + } + + # B) "log(mu[,k]/mu[,K]):term" + need_b <- !has_a + if (any(need_b)) { + nm_b <- nm_vec[need_b] + eq_b <- suppressWarnings(as.integer(sub("^.*mu\\[,([0-9]+)\\].*$", "\\1", nm_b))) + term_b <- sub("^.*\\):", "", nm_b) + eq[need_b] <- eq_b + term[need_b] <- term_b + } + + term <- ifelse(term == "(Intercept)", "Intercept", term) + list(eq = eq, term = term) + } + + parsed <- parse_vgam_eq_and_term(nm) + eq <- parsed$eq + term <- parsed$term + + if (any(is.na(eq))) { + stopf( + "Failed to parse VGAM equation indices from coefficient names. Example names: %s", + paste(head(nm, 10), collapse = " | ") + ) + } + if (max(eq, na.rm = TRUE) > length(nonbase_clusters)) { + stopf("Internal error: VGAM eq index exceeds nonbaseline cluster count.") + } + + cluster <- as.integer(nonbase_clusters[eq]) + z <- as.numeric(b) / as.numeric(se) + p <- 2 * pnorm(-abs(z)) + + return(tibble::tibble( + eq = as.integer(eq), + cluster = as.integer(cluster), + predictor = as.character(term), + coefficient = as.numeric(b), + std_err = as.numeric(se), + z_stat = as.numeric(z), + p_value = as.numeric(p) + )) + } + + stopf("Unknown fit object type for coefficient extraction.") +} + +res_tbl <- extract_coef_table(fit, nonbase_clusters) + +# Multiple testing control: +# - BH q-values computed within each cluster equation (i.e., within each non-baseline outcome), +# which is a sensible default for “per-equation” screening. +res_tbl <- res_tbl %>% + dplyr::group_by(cluster) %>% + dplyr::mutate(q_value = p.adjust(p_value, method = "BH")) %>% + dplyr::ungroup() %>% + dplyr::mutate( + r_squared = as.numeric(pseudo_r2), + nobs = as.integer(nrow(bg_model)), + baseline_cluster = as.integer(baseline_cluster) + ) %>% + dplyr::select( + cluster, predictor, coefficient, std_err, z_stat, p_value, q_value, r_squared, nobs, baseline_cluster + ) + + +# ----------------------------- +# QC + metadata outputs +# ----------------------------- +# This section writes: +# - input_qc: what was dropped and why; predictor inference outcomes; configuration +# - diag: fit diagnostics and cluster marginals +# - manifest: file path registry for downstream automation +household_day_rows_modeled <- sum(bg_model$total_household_days, na.rm = TRUE) + +input_qc <- list( + inputs = list( + clusters = CLUSTERS_PATH, + crosswalk = CROSSWALK_PATH, + census_predictors = CENSUS_PATH + ), + notes = list( + "Predictors are inferred from the census parquet columns (numeric/logical), excluding GEOID/NAME.", + "Counts are computed memory-safely using Arrow Dataset aggregation; the full household-day parquet is not read into RAM.", + "Zero counts are expected in BG×cluster composition; multinomial likelihood handles zeros naturally (no smoothing/alpha).", + "If VGAM is used, rank-deficient predictors are dropped to satisfy full-rank design requirement." + ), + counts = list( + household_day_rows_total_after_basic_filter = as.integer(household_day_rows_total), + household_day_rows_dropped_missing_crosswalk = as.integer(dropped_missing_crosswalk), + household_day_rows_after_crosswalk = as.integer(household_day_rows_after_crosswalk), + blockgroups_after_min_obs = as.integer(nrow(bg_wide)), + blockgroups_dropped_missing_predictors = as.integer(drop_missing_pred_bg), + household_day_rows_dropped_missing_predictors = as.integer(drop_missing_pred_hhday), + blockgroups_final_complete_case = as.integer(nrow(bg_model)), + household_days_modeled = as.integer(household_day_rows_modeled) + ), + inferred_predictors = list( + geoid_column = CENSUS_GEOID_COL, + predictors_inferred = inf$predictors, + predictors_used = PREDICTORS, + predictors_dropped_all_na = DROPPED_ALL_NA, + predictors_dropped_rank_deficient = dropped_predictors_rank, + predictors_zero_variance_not_standardized = zero_var_predictors + ), + model = list( + clusters_observed = as.integer(clusters_observed), + baseline_cluster = as.integer(baseline_cluster), + min_obs_per_bg = as.integer(MIN_OBS_PER_BG), + allow_missing_predictors = as.integer(ALLOW_MISSING_PREDICTORS), + standardized = as.logical(STANDARDIZE == 1L), + scaling_parameters = if (STANDARDIZE == 1L) scaling_info else NULL, + use_vgam = as.logical(USE_VGAM == 1L), + design_rank = if (USE_VGAM == 1L) as.integer(design_rank) else NULL, + design_ncol_checked = if (USE_VGAM == 1L) as.integer(design_ncol) else NULL, + zero_count_statistics = zero_stats + ) +) + +n_bg <- nrow(bg_model) +n_hhday <- sum(bg_model$total_household_days, na.rm = TRUE) +n_params <- length(stats::coef(fit)) +df_residual <- n_bg - n_params +aic_full <- tryCatch(as.numeric(stats::AIC(fit)), error = function(e) NA_real_) +bic_full <- tryCatch(as.numeric(stats::BIC(fit)), error = function(e) NA_real_) + +cluster_totals <- sapply(resp_cols_ordered, function(cn) sum(bg_model[[cn]], na.rm = TRUE)) +cluster_props <- cluster_totals / sum(cluster_totals) + +diag <- list( + fit = list( + converged = convergence_ok, + convergence_message = convergence_message, + iter_used = if (is.na(iter_used)) NULL else as.integer(iter_used), + deviance_full = as.numeric(deviance_full), + deviance_null = as.numeric(deviance_null), + deviance_ratio = as.numeric(deviance_ratio), + logLik_full = ll_full, + logLik_null = ll_null, + pseudo_r2_mcfadden = pseudo_r2, + aic = aic_full, + bic = bic_full, + n_parameters = as.integer(n_params), + df_residual = as.integer(df_residual), + nobs_blockgroups = as.integer(n_bg), + nobs_household_days = as.integer(n_hhday), + clusters_observed = as.integer(clusters_observed), + baseline_cluster = as.integer(baseline_cluster) + ), + cluster_marginal_distributions = lapply(names(cluster_totals), function(cn) { + k <- as.integer(sub("^cluster_", "", cn)) + list(cluster = k, household_days = as.integer(cluster_totals[[cn]]), proportion = as.numeric(cluster_props[[cn]])) + }) +) + +manifest <- list( + timestamp_utc = format(Sys.time(), tz = "UTC", usetz = TRUE), + outputs = list( + regression_results = file.path(OUT_DIR, "regression_results.parquet"), + regression_diagnostics = file.path(OUT_DIR, "regression_diagnostics.json"), + stage2_input_qc = file.path(OUT_DIR, "stage2_input_qc.json"), + regression_data_blockgroups_wide = file.path(OUT_DIR, "regression_data_blockgroups_wide.parquet"), + stage2_manifest = file.path(OUT_DIR, "stage2_manifest.json"), + stage2_metadata = file.path(OUT_DIR, "stage2_metadata.json") + ) +) + +logi("%s Writing outputs to: %s", IC$ok, OUT_DIR) + +arrow::write_parquet(res_tbl, file.path(OUT_DIR, "regression_results.parquet")) +arrow::write_parquet(bg_model, file.path(OUT_DIR, "regression_data_blockgroups_wide.parquet")) +safe_write_json(diag, file.path(OUT_DIR, "regression_diagnostics.json")) +safe_write_json(input_qc, file.path(OUT_DIR, "stage2_input_qc.json")) +safe_write_json(manifest, file.path(OUT_DIR, "stage2_manifest.json")) + +t_total_end <- Sys.time() +runtime_total_seconds <- as.numeric(difftime(t_total_end, t_total_start, units = "secs")) + +metadata <- list( + timestamp_utc = manifest$timestamp_utc, + runtime_seconds_total = as.numeric(runtime_total_seconds), + runtime_seconds_fit = as.numeric(fit_duration), + package_versions = list( + R = as.character(getRversion()), + arrow = as.character(utils::packageVersion("arrow")), + dplyr = as.character(utils::packageVersion("dplyr")), + nnet = as.character(utils::packageVersion("nnet")), + jsonlite = as.character(utils::packageVersion("jsonlite")), + VGAM = if (USE_VGAM == 1L) as.character(utils::packageVersion("VGAM")) else NULL + ) +) +safe_write_json(metadata, file.path(OUT_DIR, "stage2_metadata.json")) + +cat("\n") +cat(sprintf("%s Stage 2 multinomial logit complete.\n", IC$ok)) +cat(sprintf(" Block Groups (modeled): %d\n", nrow(bg_model))) +cat(sprintf(" Household-days (modeled): %s\n", format(household_day_rows_modeled, big.mark = ","))) +cat(sprintf(" Predictors (used): %d\n", length(PREDICTORS))) +cat(sprintf(" Pseudo R^2 (McFadden): %.4f\n", pseudo_r2)) +cat(sprintf(" Converged: %s\n", ifelse(convergence_ok, "true", "false"))) +cat("\n") + +quit(status = 0) diff --git a/docs/index.html b/docs/index.html index e7a3cf4..84fecc5 100644 --- a/docs/index.html +++ b/docs/index.html @@ -83,38 +83,101 @@

TK Title

Data

Load profiles

-

Our ComEd and Ameren Illinois data consists of kWh load profiles in 30-minute increments for the complete set of households across the entire Chicago metro area. To preserve anonymity, TK.

-

As a result, we only get consistent household identifiers for a given calendar month. Since Chicago’s grid is summer-peaking, we choose load profiles for the month of TK. We denote load profiles as \(L_i\) (for each \(i\) household). Each \(L_i\) is a TK point time series, for every 30 minutes in the month.

+

Our ComEd Illinois data consists of kWh load profiles in 30-minute increments for the complete set of households across the entire ComEd service area. To preserve anonymity per rules set by the Illinois Chamber of Commerce, customer data can only be included in a data release from a utility company if it passes a screening process. In this case, a customer’s individual data cannot be released if there are 15 or fewer customers in the given geographic area, or if they represent more than 15% of that area’s load. In our case, the geographic area of interest is the nine-digit Zip+4 postal code.

+
+
+
+ +
+
+NoteAmeren data was also grabbed for a hypothetical second project with CUB. It’s not part of the scope of this round of work–CUB explicitly asked for the ComEd data first and separately +
+
+
+ +
+
+

Furthermore, even when individual customers’ usage data is provided, the identification number associated with each customer is not retained month to month. In other words, while the usage data may feature the same customers in September as they did in August, ComEd assigns each customer a new account ID for each month. As a result, we only get consistent household identifiers for a given calendar month.

+

Since Illinois’s grid peaked in July 2023, for our initial test analysis we choose load profiles for the month of July 2023. We denote a household’s monthly 30-minute interval usage series as \(L_i\) (for household \(i\)). For July 2023, each \(L_i\) is a 1,488-point time series observed at 30-minute intervals (48 half-hour readings per day over 31 days).

+

For clustering, however, we work with household-day observations derived from these monthly series. Specifically, we partition each \(L_i\) into daily 48-point vectors \(L_{id}\), where \(d\in\{1,\ldots,31\}\) indexes days in July. Each household-day vector \(L_{id}\) is then normalized to a daily load-shape vector \(S_{id}\) prior to clustering. In our clustering implementation, this normalization is performed row-wise and can be specified as min–max scaling (primary specification), z-score scaling (robustness check), or no additional scaling.

+

These normalized household-day observations \(\{S_{id}\}\) are the inputs to k-means, and all subsequent aggregation and regression in Stage 2 is performed at the household-day level.

+
+
+
+ +
+
+NoteI am still unsure about the scale of the final analysis. Will we show a full year of data, or a particular month/set of months? Or 12 months as a series of sub-analyses. +
+
+
+ +
+

Demographic information

-

The highest spatial resolution demographic data available for our analysis is the 5-year TK (2024?) US Census Bureau American Community Survey (2024 ACS) at the block group level. From this, we derive a number of block-group level demographic features:

+

The highest spatial resolution demographic data available for our analysis is the 5-year 2023 US Census Bureau American Community Survey (2023 ACS) and the decennial 2020 Census Supplemental Demographic and Housing Characteristics File (2020 DHS) at the block group level (DHS 2020).

+

From this, we derive 47 block-group level demographic features across five categories (all sourced from ACS unless otherwise noted):

    -
  • TK
  • +
  • Spatial (1 variable): urban_percent.

  • +
  • Economic (7 variables): median_household_income, unemployment_rate, pct_in_civilian_labor_force, pct_not_in_labor_force, pct_income_under_25k, pct_income_25k_to_75k, pct_income_75k_plus.

  • +
  • Housing (24 variables): pct_owner_occupied, pct_renter_occupied, pct_heat_utility_gas, pct_heat_electric, pct_housing_built_2000_plus, pct_housing_built_1980_1999, old_building_pct, pct_structure_single_family_detached, pct_structure_single_family_attached, pct_structure_multifamily_2_to_4, pct_structure_multifamily_5_to_19, pct_structure_multifamily_10_plus, pct_structure_multifamily_20_plus, pct_structure_mobile_home, pct_vacant_housing_units, pct_home_value_under_150k, pct_home_value_150k_to_299k, pct_home_value_300k_plus, pct_rent_burden_30_plus, pct_rent_burden_50_plus, pct_owner_cost_burden_30_plus_mortgage, pct_owner_cost_burden_50_plus_mortgage, pct_owner_overcrowded_2plus_per_room, pct_renter_overcrowded_2plus_per_room.

  • +
  • Household (3 variables): avg_household_size, avg_family_size, pct_single_parent_households.

  • +
  • Demographic (12 variables): median_age, pct_white_alone, pct_black_alone, pct_asian_alone, pct_two_or_more_races, pct_population_under_5, pct_population_5_to_17, pct_population_18_to_24, pct_population_25_to_44, pct_population_45_to_64, pct_population_65_plus, pct_female.

+
+
+
+ +
+
+NoteVariable list to be cleaned up and made more readable. +
+
+
+ +
+
+

The full variable list, including Census ACS table references and calculation methods, is specified in census_specs.py.

Geographical crosswalk

-

Load profiles are identified by ZIP+4. Our demographic information from the TK (2024?) ACS is available at the Census block group level. To join the two, we use a crosswalk from Melissa. TK details of crosswalk.

-

TK ZIP+4 and Census Block group spaital overlap to show that it’s a reasonable mapping.

+

Load profiles are identified by ZIP+4. Our demographic information from the 2023ACS and 2020 DHS is available at the Census block group level. To join the two, we use a crosswalk from commercial data firm: Melissa. That crosswalk matches every Zip+4 postal code in Illinois to the Census Block that it was associated with in 2023. From here we aggregate the Census Blocks to their Block Groups–allowing them to be associated with our demographic information. Thus, we are able to characterize Block Groups both demographically and by the usage data of their residents.

+

When a ZIP+4 maps to multiple block groups and crosswalk weights are unavailable, we enforce a deterministic 1-to-1 linkage by assigning each ZIP+4 to a single block group (selecting the smallest GEOID) to avoid double-counting household-day observations; this introduces potential geographic misclassification that we treat as a limitation.

Clusters

-

We cluster the load profiles into clusters via k-means. We use Euclidean distance on normalized load profiles, to focus on load shape rather than overall levels. We are aiming for a small number of clusters (4-10) to aid in interpretation. We chose \(k=\textrm{TK}\) via the gap statistic [@tibshiraniEstimatingNumberClusters2001], using the GapStatistics package(TK citation).

+

We cluster the load profiles into clusters via k-means. We use Euclidean distance on normalized load profiles, to focus on load shape rather than overall levels. We are aiming for a small number of clusters (4–10) to aid in interpretation. We selected \(k=4\) using a combination of quantitative diagnostics and interpretability considerations. Quantitatively, we compared candidate values of \(k\) using (i) the silhouette score to assess separation and cohesion of clusters, and (ii) the within-cluster sum of squares (WCSS) “elbow” curve to identify diminishing returns in fit as \(k\) increases. Substantively, we also inspected average normalized load shapes by cluster to ensure that the resulting patterns were distinct and interpretable, and that clusters were not degenerate (e.g., extremely small or poorly separated).

TK details on each cluster, and brief description of each shape.

-

TK selection of cluster 1 as baseline.

+

We selected cluster 1 as the baseline (reference category) because it contains the largest share of household-day observations and exhibits the most stable and representative load-shape pattern among the clusters, providing a stable comparison point for the log-ratio regressions.

Demographic predictors of load profile mix

-

We begin by aggregating load profile clusters at the block group level. If \(j\) is the block group, and \(q\) is the cluster \((q\in\{1,\ldots,k\})\), then

-

\[C_{jq} = \sum_{i\in \textrm{bg(j)}}\textrm{I}(c_i = q)\]

-

is the count of households per cluster in each block group. We further normalize these to \(\pi_{jq} = \frac{C_{jq}}{\sum_qC_{jq}}\), the proportion of cluster assignments in each block group.

-

Our aim is to understand how block group level demographics predict the load cluster mix. To that end, we fit a series of multinomial logistic regression models, estimating the log probabilty of a cluster assignment relative to our baseline cluster 1,

+

We begin by aggregating load profile clusters at the block group level. If \(j\) is the block group, and \(q\) is the cluster (\(q\in\{1,\ldots,k\}\)), then

+

\[ +C_{jq} = \sum_{i\in \mathrm{bg}(j)} \mathrm{I}(c_i = q) +\]

+

is the count of household-day observations assigned to cluster \(q\) in block group \(j\), where \(c_i\) denotes the cluster assignment for household-day observation \(i\). Under this aggregation, a single household with multiple sampled days contributes multiple observations to the block group totals.

+

We further normalize these to \(\pi_{jq} = \frac{C_{jq}}{\sum_q C_{jq}}\), the proportion of cluster assignments in each block group.

+

Our aim is to understand how block group level demographics predict the load cluster mix. Because \((\pi_{j1},\ldots,\pi_{jk})\) are compositional proportions that must sum to 1, we model log-ratios of cluster proportions. To prevent numerical instability from zero proportions, we apply Laplace smoothing (\(\alpha = 0.5\) pseudocount) before computing log-ratios. Specifically, we define smoothed proportions

+

\[ +\tilde{\pi}_{jq} = \frac{C_{jq} + \alpha}{\sum_{q=1}^{k} C_{jq} + k\alpha}, +\qquad \alpha = 0.5, +\]

+

which ensures \(\tilde{\pi}_{jq} > 0\) for all \(q\) and prevents \(\log(0) = -\infty\).

+

We fit one model per cluster outcome (except the baseline), modeling the log-ratio of each cluster’s share relative to baseline cluster 1. Specifically, for each \(q\in\{2,\ldots,k\}\), we fit:

\[ -\log\left(\frac{\pi_{jq}}{\pi_{jK}}\right) = \beta_{0q} + \beta_{1q} \cdot \textrm{median\_income}_q + \beta_{2q} \textrm{pct\_hispanic}_q + \ldots +\log\left(\frac{\tilde{\pi}_{jq}}{\tilde{\pi}_{j1}}\right) += \beta_{0q} ++ \beta_{1q} X_{1j} ++ \beta_{2q} X_{2j} ++ \cdots ++ \varepsilon_{jq}. \]

-

for each \(q\in\{1,\ldots,k\}\). We use the MNLogit function in the Python package statsmodels to fit these models.

+

We estimate these as separate weighted least squares (WLS) regressions, weighting each block group \(j\) by \(\mathrm{total\_obs}_j\), the total number of household-day observations available for that block group. This weighting gives greater influence to block groups with more observed data. As a robustness check, we also estimate unweighted OLS versions of the same log-ratio regressions; large differences between WLS and OLS estimates would indicate results are disproportionately driven by high-observation block groups.

+

Coefficients are interpreted on the log-ratio scale: \(\beta_{pq}\) is the expected change in \(\log(\tilde{\pi}_{jq}/\tilde{\pi}_{j1})\) for a one-unit increase in predictor \(X_{pj}\), holding other predictors constant. Equivalently, \(\exp(\beta_{pq})\) is the multiplicative effect on the proportion ratio \(\tilde{\pi}_{jq}/\tilde{\pi}_{j1}\) for a one-unit increase in \(X_{pj}\).

diff --git a/pyproject.toml b/pyproject.toml index 902af39..c7297bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,8 +57,9 @@ DEP002 = [ "pyarrow", "memory-profiler", "snakeviz", + "tslearn", ] -DEP003 = ["botocore", "analysis"] +DEP003 = ["botocore", "analysis", "smart_meter_analysis", "pandas", "scipy"] DEP004 = ["botocore"] [dependency-groups] @@ -100,6 +101,14 @@ show_error_codes = true module = "boto3.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["statsmodels.*", "sklearn.*"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["pandas.*", "scipy.*"] +ignore_missing_imports = true + [tool.pytest.ini_options] testpaths = ["tests"] markers = [ @@ -125,13 +134,18 @@ ignore = ["E501", "E731"] extend-ignore = ["TRY003", "TRY300", "TRY400"] [tool.ruff.lint.per-file-ignores] -"scripts/run_comed_pipeline.py" = ["C901", "S603"] +"scripts/run_comed_pipeline.py" = ["C901", "S603", "S607"] "scripts/diagnostics/*.py" = ["C901"] "smart_meter_analysis/pipeline_validator.py" = ["C901", "PGH003"] +"smart_meter_analysis/run_manifest.py" = ["S603", "S607"] +"smart_meter_analysis/aws_loader.py" = ["C901", "TRY301"] +"smart_meter_analysis/census.py" = ["C901"] "tests/test_aws_transform.py" = ["E402"] "tests/test_census.py" = ["E402"] "scripts/data_collection/*" = ["C901"] "analysis/clustering/clustering_validation.py" = ["C901", "F841", "RUF015"] +"analysis/clustering/stage2_logratio_regression.py" = ["C901"] +"analysis/clustering/stage2_multinomial.py" = ["C901"] "tests/validate_total_comed_pipeline.py" = ["C901", "S603", "RUF001"] "scripts/testing/generate_sample_data.py" = ["UP035", "UP006", "UP007", "S311"] "tests/*" = ["S101", "RUF001"] diff --git a/scripts/process_csvs_batched_optimized.py b/scripts/process_csvs_batched_optimized.py index 6e589f7..1d89043 100644 --- a/scripts/process_csvs_batched_optimized.py +++ b/scripts/process_csvs_batched_optimized.py @@ -1,120 +1,283 @@ #!/usr/bin/env python3 +"""Memory-optimized CSV ingestion for large file counts. + +Responsibilities (and ONLY these): +- Read local ComEd CSV files from an input directory +- Convert wide-format interval columns to long format +- Add time columns (date/hour/weekday/is_weekend) eagerly +- Write a canonical, analysis-ready interval parquet used downstream +- Write a JSONL processing manifest with per-file outcomes """ -Memory-optimized CSV processing for large file counts. -Processes CSV files in batches and sub-batches to avoid OOM / huge lazy plans. - -Usage: - python process_csvs_batched_optimized.py \ - --input-dir data/validation_runs/202308_50k/samples \ - --output data/validation_runs/202308_50k/processed_combined.parquet \ - --batch-size 5000 \ - --sub-batch-size 250 -""" +from __future__ import annotations import argparse +import json import logging +import shutil +from collections.abc import Iterable +from datetime import datetime, timezone from pathlib import Path +from typing import Any import polars as pl -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +from smart_meter_analysis.transformation import COMED_INTERVAL_COLUMNS, add_time_columns, transform_wide_to_long_lf + logger = logging.getLogger(__name__) +# --- Canonical output columns required downstream (Stage 1 profile builder expects these) --- +REQUIRED_ENERGY_COLS = ["zip_code", "account_identifier", "datetime", "kwh"] +REQUIRED_TIME_COLS = ["date", "hour", "weekday", "is_weekend"] + +# --- ComEd input schema overrides (keep parsing strict; ignore_errors defaults False) --- +DType = Any + +COMED_SCHEMA_OVERRIDES: dict[str, DType] = { + "ZIP_CODE": pl.Utf8, + "DELIVERY_SERVICE_CLASS": pl.Utf8, + "DELIVERY_SERVICE_NAME": pl.Utf8, + "ACCOUNT_IDENTIFIER": pl.Utf8, + "INTERVAL_READING_DATE": pl.Utf8, + "INTERVAL_LENGTH": pl.Utf8, + "TOTAL_REGISTERED_ENERGY": pl.Float64, + "PLC_VALUE": pl.Utf8, + "NSPL_VALUE": pl.Utf8, +} + +_INTERVAL_SCHEMA: dict[str, DType] = dict.fromkeys(COMED_INTERVAL_COLUMNS, pl.Float64) +COMED_SCHEMA: dict[str, DType] = {**COMED_SCHEMA_OVERRIDES, **_INTERVAL_SCHEMA} + + +def _utc_now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() + + +def _chunked(seq: list[Path], size: int) -> Iterable[list[Path]]: + for i in range(0, len(seq), size): + yield seq[i : i + size] + + +def _write_manifest_line(fp: Any, record: dict[str, Any]) -> None: + fp.write(json.dumps(record, sort_keys=True) + "\n") + fp.flush() + + +def _ensure_required_cols(df: pl.DataFrame, *, context: str) -> None: + missing_energy = [c for c in REQUIRED_ENERGY_COLS if c not in df.columns] + missing_time = [c for c in REQUIRED_TIME_COLS if c not in df.columns] + if missing_energy or missing_time: + raise ValueError( + f"{context}: missing required columns. missing_energy={missing_energy} missing_time={missing_time}", + ) + + +def _canonicalize_df(df: pl.DataFrame, canonical_cols: list[str]) -> pl.DataFrame: + """Enforce deterministic column set + order across files/sub-batches. + If a column is missing, create it as null. + """ + exprs: list[pl.Expr] = [] + existing = set(df.columns) + for c in canonical_cols: + if c in existing: + exprs.append(pl.col(c)) + else: + exprs.append(pl.lit(None).alias(c)) + return df.select(exprs) + + +def _process_one_csv_to_df( + csv_path: Path, + *, + ignore_errors: bool, + day_mode: str, +) -> pl.DataFrame: + lf = pl.scan_csv( + str(csv_path), + schema_overrides=COMED_SCHEMA, + ignore_errors=bool(ignore_errors), + ) + + lf_long = transform_wide_to_long_lf(lf) + df_long = lf_long.collect(engine="streaming") + df_long = add_time_columns(df_long, day_mode=day_mode) + + _ensure_required_cols(df_long, context=f"{csv_path.name}") + + # Keep only what we promise downstream (canonical interval parquet) + out = df_long.select(REQUIRED_ENERGY_COLS + REQUIRED_TIME_COLS) + return out + def process_csv_subbatch_to_parquet( csv_files: list[Path], + *, batch_num: int, sub_num: int, temp_dir: Path, -) -> Path: - """ - Process a sub-batch of CSV files and write to a temporary parquet file. - """ - from smart_meter_analysis.aws_loader import ( - COMED_SCHEMA, - add_time_columns_lazy, - transform_wide_to_long_lazy, - ) - + canonical_cols: list[str] | None, + processing_manifest_fp: Any, + ignore_errors: bool, + max_errors: int, + day_mode: str, + log_every: int, + errors_so_far: int, +) -> tuple[Path, list[str], int]: logger.info(" Sub-batch %d.%d: %d files", batch_num, sub_num, len(csv_files)) - lazy_frames: list[pl.LazyFrame] = [] + dfs: list[pl.DataFrame] = [] + sub_errors = 0 + for i, csv_path in enumerate(csv_files, 1): - if i % 200 == 0: - logger.info(" Scanned %d/%d files in sub-batch %d.%d", i, len(csv_files), batch_num, sub_num) + if log_every > 0 and (i == 1 or i % log_every == 0 or i == len(csv_files)): + logger.info(" Processing %d/%d in sub-batch %d.%d", i, len(csv_files), batch_num, sub_num) + + ts = _utc_now_iso() try: - lf = pl.scan_csv( - str(csv_path), - schema_overrides=COMED_SCHEMA, - ignore_errors=True, + df = _process_one_csv_to_df( + csv_path, + ignore_errors=ignore_errors, + day_mode=day_mode, ) - lf = transform_wide_to_long_lazy(lf) - # IMPORTANT: updated signature (no day_mode) - lf = add_time_columns_lazy(lf) + if canonical_cols is None: + canonical_cols = df.columns + + df = _canonicalize_df(df, canonical_cols) + dfs.append(df) + + _write_manifest_line( + processing_manifest_fp, + {"file": csv_path.name, "status": "success", "rows": int(df.height), "timestamp": ts}, + ) - lazy_frames.append(lf) except Exception as exc: - logger.warning("Failed to scan %s: %s", csv_path.name, exc) + sub_errors += 1 + total_errors = errors_so_far + sub_errors + + _write_manifest_line( + processing_manifest_fp, + {"file": csv_path.name, "status": "error", "error": f"{type(exc).__name__}: {exc}", "timestamp": ts}, + ) - if not lazy_frames: - raise ValueError(f"No files successfully scanned in sub-batch {batch_num}.{sub_num}") + msg = f"Failed to process {csv_path.name}: {type(exc).__name__}: {exc}" + + # Fail-fast unless user explicitly opted into ignore-errors mode. + if not ignore_errors: + raise RuntimeError(msg) from exc + + logger.warning("%s", msg) + + if total_errors > max_errors: + raise RuntimeError( + f"Exceeded --max-errors={max_errors}. " + f"batch={batch_num} sub_batch={sub_num} total_errors={total_errors}. " + f"Last error: {type(exc).__name__}: {exc}", + ) from exc + + if not dfs: + raise RuntimeError(f"No files successfully processed in sub-batch {batch_num}.{sub_num}") sub_output = temp_dir / f"batch_{batch_num:04d}_sub_{sub_num:04d}.parquet" - # Combine this sub-batch and write immediately - pl.concat(lazy_frames, how="diagonal_relaxed").sink_parquet(sub_output) + # In-memory concat, then write. No diagonal_relaxed: schemas are already canonicalized. + pl.concat(dfs, how="vertical").write_parquet(sub_output) logger.info(" Sub-batch %d.%d complete: %s", batch_num, sub_num, sub_output) - return sub_output + return sub_output, (canonical_cols or []), sub_errors def process_csv_batch_to_parquet( csv_files: list[Path], + *, batch_num: int, temp_dir: Path, sub_batch_size: int, -) -> Path: - """ - Process a batch of CSV files by splitting into sub-batches and writing a single - batch parquet composed from the sub-batch parquets. - """ + canonical_cols: list[str] | None, + processing_manifest_fp: Any, + ignore_errors: bool, + max_errors: int, + day_mode: str, + log_every: int, + errors_so_far: int, +) -> tuple[Path, list[str], int]: logger.info("Processing batch %d: %d files", batch_num, len(csv_files)) sub_files: list[Path] = [] - for sub_num, i in enumerate(range(0, len(csv_files), sub_batch_size), 1): - sub = csv_files[i : i + sub_batch_size] - sub_file = process_csv_subbatch_to_parquet(sub, batch_num, sub_num, temp_dir) + batch_errors = 0 + + for sub_num, sub in enumerate(_chunked(csv_files, sub_batch_size), 1): + sub_file, canonical_cols, sub_errors = process_csv_subbatch_to_parquet( + csv_files=sub, + batch_num=batch_num, + sub_num=sub_num, + temp_dir=temp_dir, + canonical_cols=canonical_cols, + processing_manifest_fp=processing_manifest_fp, + ignore_errors=ignore_errors, + max_errors=max_errors, + day_mode=day_mode, + log_every=log_every, + errors_so_far=errors_so_far + batch_errors, + ) sub_files.append(sub_file) + batch_errors += sub_errors batch_output = temp_dir / f"batch_{batch_num:04d}.parquet" logger.info(" Concatenating %d sub-batches into %s", len(sub_files), batch_output) - pl.concat([pl.scan_parquet(str(f)) for f in sub_files], how="diagonal_relaxed").sink_parquet(batch_output) + pl.concat([pl.scan_parquet(str(f)) for f in sub_files], how="vertical").sink_parquet(batch_output) - # Clean up sub-batch files for f in sub_files: f.unlink(missing_ok=True) logger.info("Batch %d complete: %s", batch_num, batch_output) - return batch_output + return batch_output, (canonical_cols or []), batch_errors def main() -> int: - parser = argparse.ArgumentParser(description="Process CSV files in memory-safe batches") + parser = argparse.ArgumentParser(description="Process ComEd CSV files in memory-safe batches (local only).") + parser.add_argument("--input-dir", type=Path, required=True, help="Directory containing CSV files") - parser.add_argument("--output", type=Path, required=True, help="Output parquet file path") + parser.add_argument("--output", type=Path, required=True, help="Output parquet file path (interval-level long)") + parser.add_argument( + "--processing-manifest", + type=Path, + required=True, + help="Path to write processing_manifest.jsonl (required)", + ) + parser.add_argument("--batch-size", type=int, default=5000, help="Files per batch (default: 5000)") + parser.add_argument("--sub-batch-size", type=int, default=100, help="Files per sub-batch (default: 100)") + + parser.add_argument( + "--ignore-errors", + action="store_true", + help="Continue on malformed CSVs up to --max-errors (not recommended). Default: fail-fast.", + ) parser.add_argument( - "--sub-batch-size", + "--max-errors", type=int, - default=250, - help="Files per sub-batch within each batch (default: 250).", + default=10, + help="Maximum errors before aborting (only meaningful with --ignore-errors). Default: 10", ) + parser.add_argument( + "--day-mode", + type=str, + choices=["calendar", "billing"], + default="calendar", + help="Day attribution mode for time columns (default: calendar)", + ) + + parser.add_argument("--keep-temp", action="store_true", help="Keep temporary parquet batches for debugging") + parser.add_argument("--no-row-count", action="store_true", help="Skip final output row count scan") + parser.add_argument("--log-every", type=int, default=200, help="Log progress every N files (default: 200)") + args = parser.parse_args() + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + csv_files = sorted(args.input_dir.glob("*.csv")) logger.info("Found %d CSV files", len(csv_files)) @@ -122,34 +285,55 @@ def main() -> int: logger.error("No CSV files found in %s", args.input_dir) return 1 + args.output.parent.mkdir(parents=True, exist_ok=True) + args.processing_manifest.parent.mkdir(parents=True, exist_ok=True) + temp_dir = args.output.parent / "temp_batches" temp_dir.mkdir(parents=True, exist_ok=True) + canonical_cols: list[str] | None = None batch_files: list[Path] = [] - for batch_num, i in enumerate(range(0, len(csv_files), args.batch_size), 1): - batch = csv_files[i : i + args.batch_size] - batch_file = process_csv_batch_to_parquet( - csv_files=batch, - batch_num=batch_num, - temp_dir=temp_dir, - sub_batch_size=args.sub_batch_size, - ) - batch_files.append(batch_file) - - logger.info("Concatenating %d batch files into final output...", len(batch_files)) - args.output.parent.mkdir(parents=True, exist_ok=True) - - pl.concat([pl.scan_parquet(str(f)) for f in batch_files], how="diagonal_relaxed").sink_parquet(args.output) + total_errors = 0 - row_count = pl.scan_parquet(args.output).select(pl.len()).collect()[0, 0] - logger.info("Success! Wrote %s records to %s", f"{row_count:,}", args.output) - - logger.info("Cleaning up temporary batch files...") - for f in batch_files: - f.unlink(missing_ok=True) - temp_dir.rmdir() - - logger.info("Done!") + # Overwrite for determinism (byte-identical reruns given identical outcomes) + with args.processing_manifest.open("w", encoding="utf-8") as mf: + try: + for batch_num, batch in enumerate(_chunked(csv_files, args.batch_size), 1): + batch_file, canonical_cols, batch_errors = process_csv_batch_to_parquet( + csv_files=batch, + batch_num=batch_num, + temp_dir=temp_dir, + sub_batch_size=args.sub_batch_size, + canonical_cols=canonical_cols, + processing_manifest_fp=mf, + ignore_errors=bool(args.ignore_errors), + max_errors=int(args.max_errors), + day_mode=str(args.day_mode), + log_every=int(args.log_every), + errors_so_far=total_errors, + ) + batch_files.append(batch_file) + total_errors += batch_errors + + logger.info("Concatenating %d batch files into final output: %s", len(batch_files), args.output) + pl.concat([pl.scan_parquet(str(f)) for f in batch_files], how="vertical").sink_parquet(args.output) + + if not args.no_row_count: + row_count = pl.scan_parquet(args.output).select(pl.len()).collect(streaming=True)[0, 0] # type: ignore[call-overload] + logger.info("Success! Wrote %s records to %s", f"{row_count:,}", args.output) + else: + logger.info("Success! Wrote output to %s (row count skipped)", args.output) + + logger.info("File-level errors encountered: %d", total_errors) + + finally: + if not args.keep_temp: + for f in batch_files: + f.unlink(missing_ok=True) + shutil.rmtree(temp_dir, ignore_errors=True) + + logger.info("Wrote processing manifest: %s", args.processing_manifest) + logger.info("Done.") return 0 diff --git a/scripts/run_comed_pipeline.py b/scripts/run_comed_pipeline.py index f802b7e..9501ef7 100644 --- a/scripts/run_comed_pipeline.py +++ b/scripts/run_comed_pipeline.py @@ -1,1029 +1,482 @@ #!/usr/bin/env python3 -""" -ComEd Smart Meter Analysis Pipeline - -Main entry point for the ComEd smart meter clustering analysis. This script -handles the complete workflow from raw S3 data to clustered household-day -load profiles and (optionally) Stage 2 block-group regression. - -================================================================================ -PIPELINE OVERVIEW -================================================================================ - -Stage 1: Usage Pattern Clustering (this script) - 1. Download - Fetch CSV files from S3 - 2. Process - Transform wide→long, add time features, write parquet - 3. Prepare - Create daily load profiles per HOUSEHOLD-DAY - (one 48-point profile per (household, date)) - 4. Cluster - K-means (Euclidean) on household-day profiles - 5. Validate - Check data quality at each step - -Stage 2: Block-Group Demographic Regression (stage2_blockgroup_regression.py) - - Aggregate household-day observations to (block_group x cluster) counts - - Join with Census demographics at block-group level - - Run multinomial logistic regression (statsmodels for proper inference) - - Identify demographic predictors of the mix of clusters across household-days - - Unit of analysis: ONE ROW PER (block_group, cluster), with counts over - household-day observations. - -================================================================================ -USAGE MODES -================================================================================ - -Typical workflow from the project root (script lives in scripts/): - -FULL PIPELINE (download from S3, process, cluster, validate, run Stage 2): - python scripts/run_comed_pipeline.py \ - --from-s3 \ - --year-month 202308 \ - --num-files 10000 \ - --sample-households 20000 \ - --sample-days 31 \ - --k-range 3 6 \ - --run-stage2 - -QUICK LOCAL TEST (fewer files, fewer households/days): - python scripts/run_comed_pipeline.py \ - --from-s3 \ - --year-month 202308 \ - --num-files 100 \ - --sample-households 2000 \ - --sample-days 10 \ - --k-range 3 4 - -VALIDATE ONLY (check existing files for a given run): - python scripts/run_comed_pipeline.py \ - --validate-only \ - --run-name 202308_10000 - -SPECIFIC STAGE VALIDATION: - python scripts/run_comed_pipeline.py \ - --validate-only \ - --run-name 202308_10000 \ - --stage processed - - python scripts/run_comed_pipeline.py \ - --validate-only \ - --run-name 202308_10000 \ - --stage clustering - -================================================================================ -OUTPUT STRUCTURE -================================================================================ - -data/validation_runs/{run_name}/ -├── samples/ # Raw CSV files from S3 -├── processed/ -│ └── comed_{year_month}.parquet # Interval-level data (long format) -├── clustering/ -│ ├── sampled_profiles.parquet # Household-day profiles for clustering -│ ├── household_zip4_map.parquet # Map of households to ZIP+4s -│ └── results/ -│ ├── cluster_assignments.parquet -│ ├── cluster_centroids.parquet -│ ├── k_evaluation.json -│ ├── clustering_metadata.json -│ ├── elbow_curve.png -│ ├── cluster_centroids.png -│ ├── cluster_samples.png -│ └── stage2_blockgroups/ # (optional) Stage 2 outputs when run -│ ├── ... block-group cluster counts, regression results, etc. - +"""ComEd Smart Meter Analysis Pipeline Orchestrator (Phase 3) + +Directory structure (required): +data/runs/{run_name}/ +├── raw/ # Downloaded CSVs +├── processed/ # Canonical interval parquet +├── clustering/ # Stage 1 outputs + clustering outputs +├── stage2/ # Stage 2 outputs +├── logs/ # All logs +├── run_manifest.json # Pipeline metadata +├── download_manifest.jsonl # S3 download record +└── processing_manifest.jsonl # CSV processing record """ from __future__ import annotations import argparse +import json import logging +import os +import subprocess import sys +from datetime import datetime, timezone from pathlib import Path +from typing import Any import polars as pl -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", -) +from smart_meter_analysis.aws_loader import download_s3_batch, list_s3_files + logger = logging.getLogger(__name__) +DEFAULT_RUNS_DIR = Path("data/runs") +DEFAULT_CROSSWALK_PATH = Path("data/reference/2023_comed_zip4_census_crosswalk.txt") +DEFAULT_STATE_FIPS = "17" +DEFAULT_ACS_YEAR = 2023 -# ============================================================================= -# DEFAULT CONFIGURATION -# ============================================================================= - -DEFAULT_PATHS = { - "processed": Path("data/processed/comed_202308.parquet"), - "clustering_dir": Path("data/clustering"), - "crosswalk": Path("data/reference/2023_comed_zip4_census_crosswalk.txt"), -} - -DEFAULT_S3_CONFIG = { - "bucket": "smart-meter-data-sb", - "prefix": "sharepoint-files/Zip4/", -} - -DEFAULT_CLUSTERING_CONFIG = { - # Household/day sampling for clustering - "sample_days": 31, - "sample_households": 20_000, # None = all; 20k is the tested high-volume default - "day_strategy": "stratified", - # K-means hyperparameters - "k_min": 3, - "k_max": 6, - "n_init": 10, - # Profile construction (streaming is always on) - "chunk_size": 500, # households per chunk when building profiles -} - - -# ============================================================================= -# PIPELINE EXECUTOR CLASS -# ============================================================================= - - -class ComedPipeline: - """ - Orchestrates the ComEd smart meter analysis pipeline. - - This class manages the complete workflow from raw S3 data to clustered - household-day load profiles, with validation at each step and optional - Stage 2 regression. - - Attributes: - base_dir: Project root directory - run_name: Identifier for this pipeline run (e.g., "202308_10000") - run_dir: Output directory for this run - paths: Dictionary of file paths for this run - """ - - def __init__(self, base_dir: Path, run_name: str | None = None): - """ - Initialize pipeline with project directory and optional run name. - - Args: - base_dir: Root directory of the smart-meter-analysis project - run_name: Identifier for this run. If provided, outputs go to - data/validation_runs/{run_name}/. If None, uses default paths. - """ - self.base_dir = base_dir - self.run_name = run_name - self.results: dict[str, dict] = {} - - if run_name: - self.run_dir = base_dir / "data" / "validation_runs" / run_name - self.year_month = run_name.split("_")[0] if "_" in run_name else run_name - self.paths = { - "samples": self.run_dir / "samples", - "processed": self.run_dir / "processed" / f"comed_{self.year_month}.parquet", - "clustering_dir": self.run_dir / "clustering", - } - else: - self.run_dir = None - self.paths = DEFAULT_PATHS.copy() - - # ========================================================================= - # PIPELINE STEPS - # ========================================================================= - - def setup_directories(self) -> None: - """Create directory structure for pipeline outputs.""" - if not self.run_dir: - return - - for subdir in ["samples", "processed", "clustering/results"]: - (self.run_dir / subdir).mkdir(parents=True, exist_ok=True) - - logger.info("Created output directory: %s", self.run_dir) - - def download_from_s3( - self, - year_month: str, - num_files: int, - bucket: str = DEFAULT_S3_CONFIG["bucket"], - prefix: str = DEFAULT_S3_CONFIG["prefix"], - ) -> bool: - """ - Download CSV files from S3 for processing. - - Args: - year_month: Target month in YYYYMM format (e.g., "202308") - num_files: Number of CSV files to download - bucket: S3 bucket name - prefix: S3 key prefix for ComEd data - - Returns: - True if download successful, False otherwise. - """ - try: - import boto3 - except ImportError: - logger.error("boto3 not installed. Run: pip install boto3") - return False - - logger.info("Connecting to S3: s3://%s/%s%s/", bucket, prefix, year_month) - - try: - s3 = boto3.client("s3") - full_prefix = f"{prefix}{year_month}/" - - # List files - paginator = s3.get_paginator("list_objects_v2") - csv_keys: list[str] = [] - - for page in paginator.paginate(Bucket=bucket, Prefix=full_prefix): - if "Contents" not in page: - continue - for obj in page["Contents"]: - if obj["Key"].endswith(".csv"): - csv_keys.append(obj["Key"]) - if len(csv_keys) >= num_files: - break - if len(csv_keys) >= num_files: - break - - if not csv_keys: - logger.error("No CSV files found in s3://%s/%s", bucket, full_prefix) - return False - - logger.info("Downloading %d files to %s", len(csv_keys), self.paths["samples"]) - - for i, key in enumerate(csv_keys, 1): - filename = Path(key).name - local_path = self.paths["samples"] / filename - s3.download_file(bucket, key, str(local_path)) - - if i % 100 == 0 or i == len(csv_keys): - logger.info(" Downloaded %d/%d files", i, len(csv_keys)) - - logger.info("Download complete: %d files", len(csv_keys)) - return True - - except Exception as exc: - logger.error("S3 download failed: %s", exc) - return False - - def process_raw_data(self, year_month: str) -> bool: - """ - Process raw CSV files into analysis-ready parquet format. - - Transforms wide-format CSVs to long format with time features. - Uses lazy evaluation for memory efficiency. - - Args: - year_month: Month identifier for output file naming. - - Returns: - True if processing successful, False otherwise. - """ - csv_files = sorted(self.paths["samples"].glob("*.csv")) - if not csv_files: - logger.error("No CSV files found in %s", self.paths["samples"]) - return False - - logger.info("Processing %d CSV files", len(csv_files)) - - from smart_meter_analysis.aws_loader import ( - COMED_SCHEMA, - add_time_columns_lazy, - transform_wide_to_long_lazy, - ) - lazy_frames: list[pl.LazyFrame] = [] - for i, csv_path in enumerate(csv_files, 1): - if i % 200 == 0 or i == len(csv_files): - logger.info(" Scanned %d/%d files", i, len(csv_files)) - - try: - lf = pl.scan_csv(str(csv_path), schema_overrides=COMED_SCHEMA, ignore_errors=True) - lf = transform_wide_to_long_lazy(lf) - lf = add_time_columns_lazy(lf, day_mode="calendar") - lazy_frames.append(lf) - except Exception as exc: - logger.warning("Failed to scan %s: %s", csv_path.name, exc) - - if not lazy_frames: - logger.error("No files successfully scanned") - return False - - logger.info("Writing combined parquet file...") - self.paths["processed"].parent.mkdir(parents=True, exist_ok=True) - - lf_combined = pl.concat(lazy_frames, how="diagonal_relaxed") - lf_combined.sink_parquet(self.paths["processed"]) - - row_count = pl.scan_parquet(self.paths["processed"]).select(pl.len()).collect()[0, 0] - logger.info("Wrote %s records to %s", f"{row_count:,}", self.paths["processed"]) - - return True - - def prepare_clustering_data( - self, - sample_days: int = DEFAULT_CLUSTERING_CONFIG["sample_days"], - sample_households: int | None = DEFAULT_CLUSTERING_CONFIG.get("sample_households"), - day_strategy: str = DEFAULT_CLUSTERING_CONFIG["day_strategy"], - chunk_size: int = DEFAULT_CLUSTERING_CONFIG["chunk_size"], - ) -> bool: - """ - Prepare daily household-day load profiles for clustering. - - Creates 48-interval profiles for individual household-day combinations - using the manifest-based, chunked streaming pipeline. - - Args: - sample_days: Number of days to sample per run. - sample_households: Number of households to sample (None = all). - day_strategy: "stratified" (70/30 weekday/weekend) or "random". - chunk_size: Households per chunk for the streaming profile builder. - - Returns: - True if preparation successful, False otherwise. - """ - import subprocess - - input_path = self.paths["processed"] - output_dir = self.paths["clustering_dir"] - - if not input_path.exists(): - logger.error("Processed data not found: %s", input_path) - return False - - cmd = [ - sys.executable, - str(self.base_dir / "analysis" / "clustering" / "prepare_clustering_data_households.py"), - "--input", - str(input_path), - "--output-dir", - str(output_dir), - "--day-strategy", - day_strategy, - "--sample-days", - str(sample_days), - "--streaming", # streaming is always enabled - "--chunk-size", - str(chunk_size), - ] - - if sample_households: - cmd.extend(["--sample-households", str(sample_households)]) - - logger.info( - "Preparing household-day clustering data (%s households x %d days; streaming, chunk_size=%d)", - sample_households or "all", - sample_days, - chunk_size, - ) +def _utc_now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() - result = subprocess.run(cmd, capture_output=True, text=True) - - if result.returncode != 0: - logger.error("Clustering prep failed: %s", result.stderr) - return False - - logger.info("Clustering data prepared") - return True - - def run_clustering( - self, - k_min: int = DEFAULT_CLUSTERING_CONFIG["k_min"], - k_max: int = DEFAULT_CLUSTERING_CONFIG["k_max"], - n_init: int = DEFAULT_CLUSTERING_CONFIG["n_init"], - ) -> bool: - """ - Run k-means clustering on prepared household-day profiles. - - Uses Euclidean distance since all profiles are aligned to the same - time grid (no time warping needed). - - Args: - k_min: Minimum number of clusters to test. - k_max: Maximum number of clusters to test. - n_init: Number of k-means initializations. - - Returns: - True if clustering successful, False otherwise. - """ - import subprocess - - profiles_path = self.paths["clustering_dir"] / "sampled_profiles.parquet" - results_dir = self.paths["clustering_dir"] / "results" - results_dir.mkdir(parents=True, exist_ok=True) - - if not profiles_path.exists(): - logger.error("Profiles not found: %s", profiles_path) - return False - - cmd = [ - sys.executable, - str(self.base_dir / "analysis" / "clustering" / "euclidean_clustering.py"), - "--input", - str(profiles_path), - "--output-dir", - str(results_dir), - "--k-range", - str(k_min), - str(k_max), - "--find-optimal-k", - "--normalize", - "--n-init", - str(n_init), - ] - - logger.info("Running k-means clustering (k=%d-%d)...", k_min, k_max) - result = subprocess.run(cmd, capture_output=True, text=True) - - if result.returncode != 0: - logger.error("Clustering failed: %s", result.stderr) - if result.stdout: - logger.error("stdout: %s", result.stdout) - return False - - logger.info("Clustering complete") - return True - - def run_stage2_regression( - self, - crosswalk_path: Path | None = None, - census_cache_path: Path | None = None, - ) -> bool: - """ - Run Stage 2: Block-group-level regression of cluster composition. - - Models how Census block-group demographics are associated with the - composition of household-day profiles across clusters. - - Unit of analysis: ONE ROW PER (block_group, cluster), with counts over - household-day observations. - - stage2_blockgroup_regression.py handles census fetching internally - if cache does not exist. - - Args: - crosswalk_path: Path to ZIP+4 → block-group crosswalk file. - census_cache_path: Path to cached census data. - - Returns: - True if regression successful, False otherwise. - """ - import subprocess - - clusters_path = self.paths["clustering_dir"] / "results" / "cluster_assignments.parquet" - output_dir = self.paths["clustering_dir"] / "results" / "stage2_blockgroups" - - if not clusters_path.exists(): - logger.error("Cluster assignments not found: %s", clusters_path) - logger.error("Run Stage 1 clustering first") - return False - - # Default paths - if crosswalk_path is None: - crosswalk_path = self.base_dir / "data" / "reference" / "2023_comed_zip4_census_crosswalk.txt" - if census_cache_path is None: - census_cache_path = self.base_dir / "data" / "reference" / "census_17_2023.parquet" - - if not crosswalk_path.exists(): - logger.error("Crosswalk not found: %s", crosswalk_path) - return False - - cmd = [ - sys.executable, - str(self.base_dir / "analysis" / "clustering" / "stage2_blockgroup_regression.py"), - "--clusters", - str(clusters_path), - "--crosswalk", - str(crosswalk_path), - "--census-cache", - str(census_cache_path), - "--output-dir", - str(output_dir), - ] - - logger.info("Running Stage 2 block-group regression...") - result = subprocess.run(cmd, capture_output=True, text=True) - - # Print report (stage2 script writes a human-readable summary to stdout) - if result.stdout: - print(result.stdout) - - if result.returncode != 0: - logger.error("Stage 2 regression failed: %s", result.stderr) - return False - - logger.info("Stage 2 complete: %s", output_dir) - return True - - # ========================================================================= - # VALIDATION METHODS - # ========================================================================= - - def validate_processed_data(self) -> dict: - """Validate processed interval-level data using lazy evaluation.""" - path = self.paths["processed"] - - if not path.exists(): - return self._fail("processed", f"File not found: {path}") - - logger.info("Validating processed data: %s", path) - - errors: list[str] = [] - warnings: list[str] = [] - - try: - lf = pl.scan_parquet(path) - schema = lf.collect_schema() - except Exception as exc: - return self._fail("processed", f"Failed to read: {exc}") - - # Check required columns - required = ["zip_code", "account_identifier", "datetime", "kwh", "date", "hour"] - missing = [c for c in required if c not in schema.names()] - if missing: - errors.append(f"Missing columns: {missing}") - - # Get stats using lazy evaluation (no full load) - try: - stats_df = lf.select( - [ - pl.len().alias("rows"), - pl.col("zip_code").n_unique().alias("zip_codes"), - pl.col("account_identifier").n_unique().alias("accounts"), - pl.col("kwh").min().alias("kwh_min"), - pl.col("kwh").null_count().alias("kwh_nulls"), - ], - ).collect() - - stats_dict = stats_df.to_dicts()[0] - - # Check row count - if stats_dict["rows"] == 0: - errors.append("No data rows") - - # Check for nulls - if stats_dict["kwh_nulls"] > 0: - null_pct = stats_dict["kwh_nulls"] / stats_dict["rows"] * 100 - if null_pct > 5: - errors.append(f"kwh: {null_pct:.1f}% null") - elif null_pct > 0: - warnings.append(f"kwh: {null_pct:.1f}% null") - - # Check kWh range - if stats_dict["kwh_min"] is not None and stats_dict["kwh_min"] < 0: - warnings.append(f"Negative kWh values: min={stats_dict['kwh_min']}") - - stats = { - "rows": stats_dict["rows"], - "zip_codes": stats_dict["zip_codes"], - "accounts": stats_dict["accounts"], - "file_size_mb": path.stat().st_size / 1024 / 1024, - } - except Exception as exc: - return self._fail("processed", f"Failed to compute stats: {exc}") - - return self._result("processed", errors, warnings, stats) - - def validate_clustering_inputs(self) -> dict: - """Validate clustering input files (household-day profiles).""" - profiles_path = self.paths["clustering_dir"] / "sampled_profiles.parquet" - - if not profiles_path.exists(): - return self._fail("clustering_inputs", f"Profiles not found: {profiles_path}") - - logger.info("Validating clustering inputs: %s", profiles_path) - - try: - df = pl.read_parquet(profiles_path) - except Exception as exc: - return self._fail("clustering_inputs", f"Failed to read: {exc}") - - errors: list[str] = [] - warnings: list[str] = [] - - # Check required columns - required = ["zip_code", "date", "profile"] - missing = [c for c in required if c not in df.columns] - if missing: - errors.append(f"Missing columns: {missing}") - - # Check profile lengths - if "profile" in df.columns: - lengths = df.select(pl.col("profile").list.len()).unique()["profile"].to_list() - if len(lengths) > 1: - errors.append(f"Inconsistent profile lengths: {lengths}") - elif lengths[0] != 48: - errors.append(f"Expected 48-point profiles, got {lengths[0]}") - - stats = { - "profiles": len(df), - "zip_codes": df["zip_code"].n_unique() if "zip_code" in df.columns else 0, - "dates": df["date"].n_unique() if "date" in df.columns else 0, - } - - return self._result("clustering_inputs", errors, warnings, stats) - - def validate_clustering_outputs(self) -> dict: - """Validate clustering output files.""" - results_dir = self.paths["clustering_dir"] / "results" - assignments_path = results_dir / "cluster_assignments.parquet" - - if not assignments_path.exists(): - return self._skip("clustering_outputs", "No clustering results yet") - - logger.info("Validating clustering outputs: %s", results_dir) - - try: - assignments = pl.read_parquet(assignments_path) - except Exception as exc: - return self._fail("clustering_outputs", f"Failed to read: {exc}") - - errors: list[str] = [] - warnings: list[str] = [] - - # Check required columns - if "cluster" not in assignments.columns: - errors.append("Missing 'cluster' column") - - # Check cluster distribution - if "cluster" in assignments.columns: - cluster_counts = assignments["cluster"].value_counts() - if cluster_counts["count"].min() == 0: - warnings.append("Some clusters have no assignments") - - stats = { - "n_assigned": len(assignments), - "k": assignments["cluster"].n_unique() if "cluster" in assignments.columns else 0, - } - - # Load metrics if available - metrics_path = results_dir / "clustering_metrics.json" - if metrics_path.exists(): - import json - - with open(metrics_path) as f: - metrics = json.load(f) - stats["silhouette"] = metrics.get("silhouette_score") - stats["inertia"] = metrics.get("inertia") - - return self._result("clustering_outputs", errors, warnings, stats) - - # ========================================================================= - # ORCHESTRATION METHODS - # ========================================================================= - - def run_full_pipeline( - self, - year_month: str, - num_files: int, - sample_days: int = DEFAULT_CLUSTERING_CONFIG["sample_days"], - sample_households: int | None = DEFAULT_CLUSTERING_CONFIG["sample_households"], - day_strategy: str = DEFAULT_CLUSTERING_CONFIG["day_strategy"], - k_min: int = DEFAULT_CLUSTERING_CONFIG["k_min"], - k_max: int = DEFAULT_CLUSTERING_CONFIG["k_max"], - n_init: int = DEFAULT_CLUSTERING_CONFIG["n_init"], - chunk_size: int = DEFAULT_CLUSTERING_CONFIG["chunk_size"], - skip_clustering: bool = False, - run_stage2: bool = False, - ) -> bool: - """ - Execute the complete pipeline. - - Args: - year_month: Target month (YYYYMM format). - num_files: Number of S3 files to download. - sample_days: Days to sample for clustering. - sample_households: Households to sample (None = all). - day_strategy: Day sampling strategy ("stratified" or "random"). - k_min: Minimum clusters to test. - k_max: Maximum clusters to test. - n_init: Number of k-means initializations. - chunk_size: Households per chunk for streaming profile builder. - skip_clustering: If True, stop after preparing data. - run_stage2: If True, run demographic regression after clustering. - - Returns: - True if all steps succeed, False otherwise. - """ - self._print_header("COMED PIPELINE EXECUTION") - print(f"Year-Month: {year_month}") - print(f"Files: {num_files}") - print(f"Output: {self.run_dir}") - print(f"Clustering: {'Skipped' if skip_clustering else f'k={k_min}-{k_max}'}") - print(f"Stage 2: {'Yes' if run_stage2 else 'No'}") - print(f"Profiles: streaming (chunk_size={chunk_size})") - - self.setup_directories() - - # Step 1: Download - self._print_step("DOWNLOADING FROM S3") - if not self.download_from_s3(year_month, num_files): - return False - - # Step 2: Process - self._print_step("PROCESSING RAW DATA") - if not self.process_raw_data(year_month): - return False - - # Step 3: Prepare clustering data - self._print_step("PREPARING CLUSTERING DATA") - if not self.prepare_clustering_data( - sample_days=sample_days, - sample_households=sample_households, - day_strategy=day_strategy, - chunk_size=chunk_size, - ): - return False - - # Step 4: Cluster (optional) - if not skip_clustering: - self._print_step("RUNNING K-MEANS CLUSTERING") - if not self.run_clustering( - k_min=k_min, - k_max=k_max, - n_init=n_init, - ): - return False - - # Step 5: Stage 2 regression (optional) - if run_stage2 and not skip_clustering: - self._print_step("RUNNING STAGE 2: DEMOGRAPHIC REGRESSION") - if not self.run_stage2_regression(): - logger.warning("Stage 2 regression failed, but Stage 1 completed successfully") - - logger.info("Pipeline execution complete") - return True - - def validate_all(self) -> bool: - """ - Run all validation checks. - - Returns: - True if all critical validations pass, False otherwise. - """ - self._print_header("VALIDATION") - - self.results["processed"] = self.validate_processed_data() - self.results["clustering_inputs"] = self.validate_clustering_inputs() - self.results["clustering_outputs"] = self.validate_clustering_outputs() - - return self._print_summary() - - def validate_stage(self, stage: str) -> bool: - """Validate a specific pipeline stage.""" - if stage == "processed": - self.results["processed"] = self.validate_processed_data() - elif stage == "clustering": - self.results["clustering_inputs"] = self.validate_clustering_inputs() - self.results["clustering_outputs"] = self.validate_clustering_outputs() - else: - logger.error("Unknown stage: %s", stage) - return False - - return self._print_summary() - - # ========================================================================= - # HELPER METHODS - # ========================================================================= - - def _print_header(self, title: str) -> None: - print(f"\n{'=' * 70}") - print(title) - print(f"{'=' * 70}") - - def _print_step(self, title: str) -> None: - print(f"\n{'─' * 70}") - print(title) - print(f"{'─' * 70}") - - def _result(self, stage: str, errors: list[str], warnings: list[str], stats: dict) -> dict: - status = "PASS" if not errors else "FAIL" - icon = "✅" if status == "PASS" else "❌" - print(f"\n{icon} {stage.upper()}: {status}") - - for err in errors: - print(f" Error: {err}") - for warn in warnings: - print(f" ⚠️ {warn}") - - return {"status": status, "errors": errors, "warnings": warnings, "stats": stats} - - def _fail(self, stage: str, message: str) -> dict: - print(f"\n❌ {stage.upper()}: FAILED - {message}") - return {"status": "FAIL", "errors": [message], "warnings": [], "stats": {}} - - def _skip(self, stage: str, message: str) -> dict: - print(f"\n⏭️ {stage.upper()}: SKIPPED - {message}") - return {"status": "SKIP", "errors": [], "warnings": [], "stats": {}} - - def _print_summary(self) -> bool: - self._print_header("SUMMARY") - - all_passed = True - for stage, result in self.results.items(): - status = result.get("status", "UNKNOWN") - icon = {"PASS": "✅", "FAIL": "❌", "SKIP": "⏭️"}.get(status, "❓") - print(f"{icon} {stage}: {status}") - - if status == "FAIL": - all_passed = False - - # Show clustering summary if available - if "clustering_outputs" in self.results: - stats = self.results["clustering_outputs"].get("stats", {}) - if stats.get("k"): - print("\nClustering Results:") - print(f" • {stats.get('n_assigned', '?')} profiles → {stats.get('k')} clusters") - if stats.get("silhouette") is not None: - print(f" • Silhouette score: {stats['silhouette']:.3f}") - - print() - return all_passed - - -# ============================================================================= -# COMMAND LINE INTERFACE -# ============================================================================= +def _configure_logging(log_path: Path) -> None: + log_path.parent.mkdir(parents=True, exist_ok=True) -def main() -> None: - parser = argparse.ArgumentParser( - description="ComEd Smart Meter Analysis Pipeline", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Quick local test (100 files, fewer households/days, no Stage 2) - python scripts/run_comed_pipeline.py \\ - --from-s3 \\ - --year-month 202308 \\ - --num-files 100 \\ - --sample-households 2000 \\ - --sample-days 10 \\ - --k-range 3 4 - - # High-volume analysis (tested configuration, with Stage 2) - python scripts/run_comed_pipeline.py \\ - --from-s3 \\ - --year-month 202308 \\ - --num-files 10000 \\ - --sample-households 20000 \\ - --sample-days 31 \\ - --k-range 3 6 \\ - --run-stage2 - - # Validate existing results for a specific run - python scripts/run_comed_pipeline.py \\ - --validate-only \\ - --run-name 202308_10000 - """, - ) + handlers: list[logging.Handler] = [ + logging.StreamHandler(sys.stdout), + logging.FileHandler(str(log_path), encoding="utf-8"), + ] - # Mode selection - mode_group = parser.add_argument_group("Mode") - mode_group.add_argument( - "--from-s3", - action="store_true", - help="Run full pipeline: download from S3, process, cluster, validate", + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=handlers, + force=True, ) - mode_group.add_argument( - "--validate-only", - action="store_true", - help="Only validate existing files (no processing)", + + +def _get_git_sha(repo_root: Path) -> str | None: + try: + r = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=str(repo_root), + check=False, + capture_output=True, + text=True, + ) + if r.returncode != 0: + return None + sha = (r.stdout or "").strip() + return sha or None + except Exception: + return None + + +def create_directories(run_dir: Path) -> dict[str, Path]: + paths = { + "run_dir": run_dir, + "raw": run_dir / "raw", + "processed": run_dir / "processed", + "clustering": run_dir / "clustering", + "stage2": run_dir / "stage2", + "logs": run_dir / "logs", + } + for p in paths.values(): + p.mkdir(parents=True, exist_ok=True) + return paths + + +def _run_subprocess(cmd: list[str], *, cwd: Path | None = None, env: dict[str, str] | None = None) -> None: + logger.info("Command: %s", " ".join(cmd)) + r = subprocess.run(cmd, cwd=str(cwd) if cwd else None, env=env, check=False) + if r.returncode != 0: + raise RuntimeError(f"Command failed with return code {r.returncode}: {' '.join(cmd)}") + + +def download_from_s3(*, year_month: str, num_files: int, run_dir: Path) -> dict[str, Any]: + raw_dir = run_dir / "raw" + raw_dir.mkdir(parents=True, exist_ok=True) + + manifest_path = run_dir / "download_manifest.jsonl" + + # list_s3_files MUST return keys (not URIs). + s3_keys = list_s3_files(year_month=year_month, max_files=num_files) + + logger.info("Downloading %d files for %s into %s", len(s3_keys), year_month, raw_dir) + result = download_s3_batch( + s3_keys=s3_keys, + output_dir=raw_dir, + manifest_path=manifest_path, + max_files=num_files, + fail_fast=True, + max_errors=10, + retries=3, + backoff_factor=2.0, ) - mode_group.add_argument( - "--skip-clustering", - action="store_true", - help="Run pipeline but skip clustering step (useful for testing)", + logger.info( + "Download complete: downloaded=%s failed=%s manifest=%s", + result.get("downloaded"), + result.get("failed"), + result.get("manifest_path"), ) - mode_group.add_argument( - "--run-stage2", - action="store_true", - help="Run Stage 2 regression after clustering (requires census data)", + return result + + +def process_csvs(*, run_dir: Path, year_month: str, day_mode: str = "calendar") -> Path: + raw_dir = run_dir / "raw" + processed_dir = run_dir / "processed" + processed_dir.mkdir(parents=True, exist_ok=True) + + output_file = processed_dir / f"comed_{year_month}.parquet" + manifest_file = run_dir / "processing_manifest.jsonl" + + cmd = [ + sys.executable, + "scripts/process_csvs_batched_optimized.py", + "--input-dir", + str(raw_dir), + "--output", + str(output_file), + "--processing-manifest", + str(manifest_file), + "--day-mode", + str(day_mode), + ] + + logger.info("Ingesting CSVs into canonical interval parquet: %s", output_file) + _run_subprocess(cmd) + + if not output_file.exists(): + raise FileNotFoundError(f"Expected processed parquet missing: {output_file}") + + return output_file + + +def prepare_clustering( + *, + run_dir: Path, + processed_parquet: Path, + sample_days: int, + sample_households: int | None, + seed: int, +) -> Path: + clustering_dir = run_dir / "clustering" + clustering_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, + "analysis/clustering/prepare_clustering_data_households.py", + "--input", + str(processed_parquet), + "--output-dir", + str(clustering_dir), + "--sample-days", + str(int(sample_days)), + "--seed", + str(int(seed)), + "--streaming", + ] + + if sample_households is not None: + cmd.extend(["--sample-households", str(int(sample_households))]) + + logger.info("Preparing Stage 1 clustering data in %s", clustering_dir) + _run_subprocess(cmd) + + profiles = clustering_dir / "sampled_profiles.parquet" + if not profiles.exists(): + raise FileNotFoundError(f"Expected clustering profiles missing: {profiles}") + return profiles + + +def run_clustering( + *, + run_dir: Path, + profiles_path: Path, + k: int, + clustering_seed: int, +) -> None: + clustering_dir = run_dir / "clustering" + clustering_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, + "analysis/clustering/euclidean_clustering_minibatch.py", + "--input", + str(profiles_path), + "--output-dir", + str(clustering_dir), + "--k", + str(int(k)), + "--random-state", + str(int(clustering_seed)), + "--normalize", + "--normalize-method", + "minmax", + ] + + logger.info("Running clustering (MiniBatchKMeans) k=%d output=%s", k, clustering_dir) + _run_subprocess(cmd) + + expected = clustering_dir / "cluster_assignments.parquet" + if not expected.exists(): + raise FileNotFoundError(f"Expected clustering output missing: {expected}") + + +def _default_stage2_census_cache(*, stage2_dir: Path, state_fips: str, acs_year: int) -> Path: + # Option 1: cache lives under the run directory + return stage2_dir / f"census_cache_{state_fips}_{acs_year}.parquet" + + +def run_stage2_logratio( + *, + run_dir: Path, + crosswalk_path: Path, + state_fips: str, + acs_year: int, + min_obs_per_bg: int, + alpha: float, + standardize: bool, + fetch_census: bool, + no_ols: bool, + baseline_cluster: str | None, + predictors_from: Path | None, + census_cache_path: Path | None, +) -> None: + stage2_dir = run_dir / "stage2" + stage2_dir.mkdir(parents=True, exist_ok=True) + + clusters_path = run_dir / "clustering" / "cluster_assignments.parquet" + if not clusters_path.exists(): + raise FileNotFoundError(f"Missing cluster assignments for Stage 2: {clusters_path}") + + if not crosswalk_path.exists(): + raise FileNotFoundError(f"Missing crosswalk for Stage 2: {crosswalk_path}") + + script_path = Path("analysis/clustering/stage2_logratio_regression.py") + if not script_path.exists(): + raise FileNotFoundError(f"Missing Stage 2 script: {script_path}") + + resolved_cache = census_cache_path or _default_stage2_census_cache( + stage2_dir=stage2_dir, + state_fips=state_fips, + acs_year=acs_year, ) - # Data selection - data_group = parser.add_argument_group("Data Selection") - data_group.add_argument( - "--year-month", - default="202308", - help="Target month in YYYYMM format (default: 202308)", + cmd = [ + sys.executable, + str(script_path), + "--clusters", + str(clusters_path), + "--crosswalk", + str(crosswalk_path), + "--output-dir", + str(stage2_dir), + "--census-cache", + str(resolved_cache), + "--state-fips", + str(state_fips), + "--acs-year", + str(int(acs_year)), + "--min-obs-per-bg", + str(int(min_obs_per_bg)), + "--alpha", + str(float(alpha)), + ] + + if fetch_census: + cmd.append("--fetch-census") + + if standardize: + cmd.append("--standardize") + + if no_ols: + cmd.append("--no-ols") + + if baseline_cluster is not None: + cmd.extend(["--baseline-cluster", str(baseline_cluster)]) + + if predictors_from is not None: + cmd.extend(["--predictors-from", str(predictors_from)]) + + logger.info("Running Stage 2 log-ratio regression output=%s", stage2_dir) + _run_subprocess(cmd) + + +def write_run_manifest(*, run_dir: Path, args: argparse.Namespace) -> None: + repo_root = Path().resolve() + + manifest: dict[str, Any] = { + "run_name": args.run_name, + "timestamp_utc": _utc_now_iso(), + "args": vars(args), + "git_sha": _get_git_sha(repo_root), + "python_version": sys.version, + "polars_version": pl.__version__, + "seeds": { + "sampling": args.seed, + "clustering": args.clustering_seed, + }, + "directory_structure": "data/runs/{run_name}/", + "pipeline_version": "v1.0-restored-baseline", + "cwd": str(Path.cwd()), + "platform": { + "os_name": os.name, + "sys_platform": sys.platform, + }, + } + + out = run_dir / "run_manifest.json" + out.write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8") + logger.info("Wrote run manifest: %s", out) + + +def run_pipeline(args: argparse.Namespace) -> None: + run_dir = DEFAULT_RUNS_DIR / args.run_name + paths = create_directories(run_dir) + + _configure_logging(paths["logs"] / "pipeline.log") + + logger.info("Run directory: %s", run_dir) + logger.info("Year-month: %s", args.year_month) + + if args.from_s3: + download_from_s3(year_month=args.year_month, num_files=int(args.num_files), run_dir=run_dir) + else: + logger.info("Skipping S3 download (--from-s3 not set). Expecting CSVs in: %s", paths["raw"]) + + processed_parquet = process_csvs(run_dir=run_dir, year_month=args.year_month, day_mode=args.day_mode) + + profiles_path = prepare_clustering( + run_dir=run_dir, + processed_parquet=processed_parquet, + sample_days=int(args.sample_days), + sample_households=(int(args.sample_households) if args.sample_households is not None else None), + seed=int(args.seed), ) - data_group.add_argument( - "--num-files", - type=int, - default=1000, - help="Number of S3 files to download (default: 1000)", + + run_clustering( + run_dir=run_dir, + profiles_path=profiles_path, + k=int(args.k), + clustering_seed=int(args.clustering_seed), ) - # Clustering parameters - cluster_group = parser.add_argument_group("Clustering Parameters") - cluster_group.add_argument( + if bool(args.run_stage2): + run_stage2_logratio( + run_dir=run_dir, + crosswalk_path=Path(args.stage2_crosswalk), + state_fips=str(args.stage2_state_fips), + acs_year=int(args.stage2_acs_year), + min_obs_per_bg=int(args.stage2_min_obs_per_bg), + alpha=float(args.stage2_alpha), + standardize=bool(args.stage2_standardize), + no_ols=bool(args.stage2_no_ols), + fetch_census=bool(args.stage2_fetch_census), + baseline_cluster=(str(args.stage2_baseline_cluster) if args.stage2_baseline_cluster is not None else None), + predictors_from=(Path(args.stage2_predictors_from) if args.stage2_predictors_from is not None else None), + census_cache_path=(Path(args.stage2_census_cache) if args.stage2_census_cache is not None else None), + ) + + write_run_manifest(run_dir=run_dir, args=args) + + logger.info("Pipeline completed successfully.") + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="ComEd Smart Meter Analysis Pipeline Orchestrator (Phase 3)") + + p.add_argument("--run-name", required=True, help="Run name (directory under data/runs/)") + p.add_argument("--year-month", required=True, help="Target month in YYYYMM format (e.g., 202307)") + p.add_argument("--from-s3", action="store_true", help="Download CSVs from S3 into run_dir/raw/") + + p.add_argument("--num-files", type=int, default=10, help="Number of S3 files to download (default: 10)") + + p.add_argument("--sample-days", type=int, default=31, help="Days to sample for clustering (default: 31)") + p.add_argument( "--sample-households", type=int, - default=DEFAULT_CLUSTERING_CONFIG["sample_households"], - help=(f"Households to sample (default: {DEFAULT_CLUSTERING_CONFIG['sample_households']}, use 0 for all)"), - ) - cluster_group.add_argument( - "--sample-days", - type=int, - default=DEFAULT_CLUSTERING_CONFIG["sample_days"], - help=f"Days to sample (default: {DEFAULT_CLUSTERING_CONFIG['sample_days']})", + default=None, + help="Households to sample (default: all). Provide an integer to limit.", ) - cluster_group.add_argument( - "--k-range", - type=int, - nargs=2, - metavar=("MIN", "MAX"), - default=[DEFAULT_CLUSTERING_CONFIG["k_min"], DEFAULT_CLUSTERING_CONFIG["k_max"]], - help=( - "Cluster range to test " - f"(default: {DEFAULT_CLUSTERING_CONFIG['k_min']} {DEFAULT_CLUSTERING_CONFIG['k_max']})" - ), + + p.add_argument("--seed", type=int, default=42, help="Random seed for sampling (default: 42)") + p.add_argument("--clustering-seed", type=int, default=42, help="Random seed for clustering (default: 42)") + + p.add_argument("--k", type=int, required=True, help="Number of clusters (single k per run)") + + p.add_argument("--day-mode", choices=["calendar", "billing"], default="calendar", help="Day attribution mode") + + p.add_argument("--run-stage2", action="store_true", help="Run Stage 2 log-ratio regression (optional)") + + # ---------------------------- + # Stage 2 options (user-facing) + # ---------------------------- + p.add_argument( + "--stage2-crosswalk", + default=str(DEFAULT_CROSSWALK_PATH), + help=f"ZIP+4 → block-group crosswalk path (default: {DEFAULT_CROSSWALK_PATH})", ) - cluster_group.add_argument( - "--day-strategy", - choices=["stratified", "random"], - default=DEFAULT_CLUSTERING_CONFIG["day_strategy"], - help="Day sampling strategy (default: stratified = 70% weekday, 30% weekend)", + p.add_argument( + "--stage2-state-fips", + default=DEFAULT_STATE_FIPS, + help=f"State FIPS (default: {DEFAULT_STATE_FIPS})", ) - cluster_group.add_argument( - "--n-init", + p.add_argument( + "--stage2-acs-year", type=int, - default=DEFAULT_CLUSTERING_CONFIG["n_init"], - help=f"Number of k-means initializations (default: {DEFAULT_CLUSTERING_CONFIG['n_init']})", + default=DEFAULT_ACS_YEAR, + help=f"ACS year (default: {DEFAULT_ACS_YEAR})", ) - cluster_group.add_argument( - "--fast", + p.add_argument( + "--stage2-fetch-census", action="store_true", - help="Fast mode: k=3-4 (for testing)", + help="Stage 2: force re-fetch Census data (ignore cache)", ) - cluster_group.add_argument( - "--chunk-size", - type=int, - default=DEFAULT_CLUSTERING_CONFIG["chunk_size"], + p.add_argument( + "--stage2-census-cache", + default=None, help=( - f"Households per chunk for streaming profile builder (default: {DEFAULT_CLUSTERING_CONFIG['chunk_size']})" + "Optional census cache parquet path. If omitted, defaults to " + "data/runs/{run_name}/stage2/census_cache_{state_fips}_{acs_year}.parquet (Option 1)." ), ) - # Output options - output_group = parser.add_argument_group("Output Options") - output_group.add_argument( - "--run-name", - help="Name for this run (default: {year_month}_{num_files})", + p.add_argument( + "--stage2-min-obs-per-bg", + type=int, + default=50, + help="Stage 2 minimum household-day observations per block group (default: 50)", ) - output_group.add_argument( - "--base-dir", - type=Path, - default=Path("."), - help="Project root directory (default: current directory)", + p.add_argument("--stage2-alpha", type=float, default=0.5, help="Stage 2 Laplace smoothing alpha (default: 0.5)") + p.add_argument("--stage2-standardize", action="store_true", help="Stage 2: standardize predictors") + p.add_argument("--stage2-no-ols", action="store_true", help="Stage 2: skip OLS robustness check") + p.add_argument( + "--stage2-baseline-cluster", + default=None, + help="Stage 2: optional baseline cluster label (default: most frequent cluster)", ) - output_group.add_argument( - "--stage", - choices=["processed", "clustering", "all"], - default="all", - help="Stage to validate (default: all)", + p.add_argument( + "--stage2-predictors-from", + default=None, + help="Stage 2: optional path to predictors list (one per line) to force exact predictors", ) - args = parser.parse_args() + return p - # Handle --fast mode - if args.fast: - args.k_range = [3, 4] - logger.info("Fast mode enabled: k=3-4") - # Determine run name - run_name = args.run_name or (f"{args.year_month}_{args.num_files}" if args.from_s3 else args.run_name) - - # Create pipeline - pipeline = ComedPipeline(args.base_dir, run_name) - - # Handle sample_households = 0 as None (all households) - sample_households = args.sample_households if args.sample_households > 0 else None +def main() -> None: + args = build_parser().parse_args() - # Execute based on mode - if args.from_s3: - success = pipeline.run_full_pipeline( - year_month=args.year_month, - num_files=args.num_files, - sample_days=args.sample_days, - sample_households=sample_households, - day_strategy=args.day_strategy, - k_min=args.k_range[0], - k_max=args.k_range[1], - n_init=args.n_init, - chunk_size=args.chunk_size, - skip_clustering=args.skip_clustering, - run_stage2=args.run_stage2, - ) + if int(args.num_files) <= 0 and args.from_s3: + raise ValueError("--num-files must be > 0 when using --from-s3") - if success: - pipeline.validate_all() - elif args.validate_only: - success = pipeline.validate_all() if args.stage == "all" else pipeline.validate_stage(args.stage) - else: - parser.print_help() - print("\n⚠️ Specify --from-s3 to run pipeline or --validate-only to check existing files") - sys.exit(1) + if int(args.k) <= 1: + raise ValueError("--k must be >= 2") - sys.exit(0 if success else 1) + run_pipeline(args) if __name__ == "__main__": diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py deleted file mode 100755 index 41357b8..0000000 --- a/scripts/run_pipeline.py +++ /dev/null @@ -1,275 +0,0 @@ -#!/usr/bin/env python3 -""" -Main pipeline script for monthly smart meter analysis. - -This script parameterizes the pipeline to process any month by changing -a single --month parameter (1-12). - -Usage: - python scripts/run_pipeline.py --month 7 --input path/to/input.parquet - python scripts/run_pipeline.py --month 1 --year 2023 --input path/to/input.parquet - python scripts/run_pipeline.py --month 7 --config config/custom.yaml --input path/to/input.parquet - -The script: -1. Loads configuration from config/monthly_run.yaml (or custom config) -2. Overrides month/year if provided via CLI -3. Runs the clustering pipeline with month-specific filtering -""" - -from __future__ import annotations - -import argparse -import logging -import sys -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -from analysis.clustering.euclidean_clustering_minibatch import main as clustering_main -from analysis.clustering.prepare_clustering_data_households import prepare_clustering_data -from smart_meter_analysis.config import get_year_month_str, load_config - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) - - -@dataclass(frozen=True) -class PipelineArgs: - """Parsed CLI arguments for the monthly pipeline.""" - - month: int - year: int | None - config: Path | None - input: Path - output_dir: Path | None - skip_clustering: bool - - -def _parse_args(argv: list[str] | None = None) -> PipelineArgs: - parser = argparse.ArgumentParser( - description="Run smart meter analysis pipeline for a specific month", - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - - parser.add_argument( - "--month", - type=int, - required=True, - choices=range(1, 13), - metavar="MONTH", - help="Month to process (1-12, e.g., 7 for July)", - ) - parser.add_argument( - "--year", - type=int, - default=None, - help="Year to process (default: from config file, typically 2023)", - ) - parser.add_argument( - "--config", - type=Path, - default=None, - help="Path to config file (default: config/monthly_run.yaml)", - ) - parser.add_argument( - "--input", - type=Path, - required=True, - help="Input parquet file path (processed interval data)", - ) - parser.add_argument( - "--output-dir", - type=Path, - default=None, - help="Output directory for clustering results (default: data/clustering)", - ) - parser.add_argument( - "--skip-clustering", - action="store_true", - help="Only prepare clustering data, skip actual clustering", - ) - - ns = parser.parse_args(argv) - return PipelineArgs( - month=ns.month, - year=ns.year, - config=ns.config, - input=ns.input, - output_dir=ns.output_dir, - skip_clustering=ns.skip_clustering, - ) - - -def _load_and_override_config(args: PipelineArgs) -> dict[str, Any]: - try: - config: dict[str, Any] = load_config(args.config) - except FileNotFoundError as e: - logger.error("Config file not found: %s", e) - raise - - # Override month/year from CLI - config["month"] = args.month - if args.year is not None: - config["year"] = args.year - - return config - - -def _resolve_output_dir(config: dict[str, Any], override: Path | None) -> Path: - if override is not None: - return override - - default_dir = config.get("output", {}).get("clustering_dir", "data/clustering") - return Path(default_dir) - - -def _log_run_header(*, year: int, month: int, year_month_str: str, input_path: Path) -> None: - logger.info("=" * 70) - logger.info("MONTHLY PIPELINE EXECUTION") - logger.info("=" * 70) - logger.info("Year: %d", year) - logger.info("Month: %d", month) - logger.info("Year-Month: %s", year_month_str) - logger.info("Input: %s", input_path) - - -def _prepare_data(*, config: dict[str, Any], input_path: Path, output_dir: Path, year: int, month: int) -> None: - sampling_config = config.get("sampling", {}) - sample_households = sampling_config.get("sample_households") - sample_days = sampling_config.get("sample_days", 20) - day_strategy = sampling_config.get("day_strategy", "stratified") - streaming = sampling_config.get("streaming", True) - chunk_size = sampling_config.get("chunk_size", 5000) - seed = sampling_config.get("seed", 42) - - logger.info("") - logger.info("=" * 70) - logger.info("STEP 1: PREPARING CLUSTERING DATA") - logger.info("=" * 70) - - stats = prepare_clustering_data( - input_paths=[input_path], - output_dir=output_dir, - sample_households=sample_households, - sample_days=sample_days, - day_strategy=day_strategy, - streaming=streaming, - chunk_size=chunk_size, - seed=seed, - year=year, - month=month, - ) - - logger.info("✓ Clustering data preparation complete") - logger.info(" Profiles: %s", f"{stats['n_profiles']:,}") - logger.info(" Households: %s", f"{stats['n_households']:,}") - - -def _run_clustering(*, config: dict[str, Any], output_dir: Path) -> int: - logger.info("") - logger.info("=" * 70) - logger.info("STEP 2: RUNNING CLUSTERING") - logger.info("=" * 70) - - clustering_config = config.get("clustering", {}) - n_clusters = clustering_config.get("n_clusters", 4) - batch_size = clustering_config.get("batch_size", 10000) - n_init = clustering_config.get("n_init", 3) - random_state = clustering_config.get("random_state", 42) - normalize = clustering_config.get("normalize", True) - normalize_method = clustering_config.get("normalize_method", "minmax") - silhouette_sample_size = clustering_config.get("silhouette_sample_size", 5000) - - input_profiles = output_dir / "sampled_profiles.parquet" - clustering_output_dir = output_dir / "results" - clustering_output_dir.mkdir(parents=True, exist_ok=True) - - clustering_args: list[str] = [ - "--input", - str(input_profiles), - "--output-dir", - str(clustering_output_dir), - "--k", - str(n_clusters), - "--batch-size", - str(batch_size), - "--n-init", - str(n_init), - "--random-state", - str(random_state), - "--silhouette-sample-size", - str(silhouette_sample_size), - ] - - if normalize: - clustering_args.extend(["--normalize", "--normalize-method", normalize_method]) - else: - clustering_args.extend(["--normalize-method", "none"]) - - old_argv = sys.argv - try: - sys.argv = ["euclidean_clustering_minibatch.py", *clustering_args] - result = clustering_main() - if result != 0: - logger.error("Clustering failed") - return int(result) - logger.info("✓ Clustering complete") - return 0 - finally: - sys.argv = old_argv - - -def main(argv: list[str] | None = None) -> int: - """Main entry point for monthly pipeline.""" - args = _parse_args(argv) - - try: - config = _load_and_override_config(args) - except FileNotFoundError: - return 1 - - year = int(config.get("year", 2023)) - month = int(config.get("month", args.month)) - year_month_str = get_year_month_str(config) - - _log_run_header(year=year, month=month, year_month_str=year_month_str, input_path=args.input) - - if not args.input.exists(): - logger.error("Input file not found: %s", args.input) - return 1 - - output_dir = _resolve_output_dir(config, args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - logger.info("Output directory: %s", output_dir) - - try: - _prepare_data(config=config, input_path=args.input, output_dir=output_dir, year=year, month=month) - except Exception as e: - logger.error("Failed to prepare clustering data: %s", e, exc_info=True) - return 1 - - if args.skip_clustering: - logger.info("Skipping clustering (--skip-clustering specified)") - else: - try: - result = _run_clustering(config=config, output_dir=output_dir) - except Exception as e: - logger.error("Failed to run clustering: %s", e, exc_info=True) - return 1 - if result != 0: - return result - - logger.info("") - logger.info("=" * 70) - logger.info("PIPELINE COMPLETE") - logger.info("=" * 70) - logger.info("Output: %s", output_dir) - - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/smart_meter_analysis/aws_loader.py b/smart_meter_analysis/aws_loader.py index 0c7ea36..0e2abb8 100644 --- a/smart_meter_analysis/aws_loader.py +++ b/smart_meter_analysis/aws_loader.py @@ -1,297 +1,383 @@ # smart_meter_analysis/aws_loader.py -""" -AWS S3 utilities for batch processing ComEd smart meter data. +"""AWS S3 utilities for deterministic batch download of ComEd smart meter CSV files. + +Contract A (S3 Download only) +---------------------------- +This module is responsible for: +- Deterministically listing S3 CSV keys for a given year-month (YYYYMM) +- Downloading those keys to local disk with retry/backoff +- Writing a JSONL download manifest for provenance and QA + +Downstream processing belongs to scripts/process_csvs_batched_optimized.py and +smart_meter_analysis/transformation.py. + +Operational scaling note +------------------------ +For full-scale runs (e.g., ~700k files), downloads must be resumable. +Accordingly, this module supports: +- Appending to an existing manifest (default) +- Skipping downloads for files that already exist on disk (default) +- Optional manifest overwrite for clean-room re-runs """ from __future__ import annotations +import argparse +import json import logging +import re +import time +from datetime import datetime, timezone from pathlib import Path from typing import Any -import polars as pl - -# Local dtype alias for mypy -DType = Any +import boto3 +from botocore.exceptions import BotoCoreError, ClientError logger = logging.getLogger(__name__) -# S3 Configuration S3_BUCKET = "smart-meter-data-sb" S3_PREFIX = "sharepoint-files/Zip4/" -# Error messages +ERR_BAD_YEAR_MONTH = "year_month must be YYYYMM (e.g., 202308); got: {}" ERR_NO_FILES_FOUND = "No files found for month: {}" -ERR_NO_SUCCESSFUL_PROCESS = "No files were successfully processed" -# ID columns with explicit dtypes -COMED_SCHEMA_OVERRIDES: dict[str, DType] = { - "ZIP_CODE": pl.Utf8, - "DELIVERY_SERVICE_CLASS": pl.Utf8, - "DELIVERY_SERVICE_NAME": pl.Utf8, - "ACCOUNT_IDENTIFIER": pl.Utf8, - "INTERVAL_READING_DATE": pl.Utf8, # parsed later - "INTERVAL_LENGTH": pl.Utf8, - "TOTAL_REGISTERED_ENERGY": pl.Float64, - "PLC_VALUE": pl.Utf8, - "NSPL_VALUE": pl.Utf8, -} -COMED_INTERVAL_COLUMNS: list[str] = [ - f"INTERVAL_HR{m // 60:02d}{m % 60:02d}_ENERGY_QTY" for m in range(30, 24 * 60 + 1, 30) -] +def _utc_now_iso() -> str: + return datetime.now(timezone.utc).replace(microsecond=0).isoformat() -_INTERVAL_SCHEMA: dict[str, DType] = dict.fromkeys(COMED_INTERVAL_COLUMNS, pl.Float64) -COMED_SCHEMA: dict[str, DType] = {**COMED_SCHEMA_OVERRIDES, **_INTERVAL_SCHEMA} +def _validate_year_month(year_month: str) -> None: + if not re.fullmatch(r"\d{6}", year_month): + raise ValueError(ERR_BAD_YEAR_MONTH.format(year_month)) def list_s3_files( year_month: str, + *, bucket: str = S3_BUCKET, prefix: str = S3_PREFIX, max_files: int | None = None, ) -> list[str]: - """ - List CSV files in S3 for a given year-month. - - Args: - year_month: 'YYYYMM' (e.g., '202308') - max_files: optional limit for testing + """List CSV files in S3 for a given year-month. Returns: - S3 URIs as s3://bucket/key + Deterministically sorted S3 *keys* (e.g., "sharepoint-files/Zip4/202307/file.csv"), + NOT URIs. These keys are intended to be passed directly to download_s3_batch(). + + Determinism: + - Collect ALL keys for the prefix + - Sort the full list + - Then apply max_files slicing """ - import boto3 + _validate_year_month(year_month) s3 = boto3.client("s3") full_prefix = f"{prefix}{year_month}/" - logger.info(f"Listing files from s3://{bucket}/{full_prefix}") + logger.info("Listing files from s3://%s/%s", bucket, full_prefix) paginator = s3.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket, Prefix=full_prefix) - s3_uris: list[str] = [] + keys: list[str] = [] for page in pages: - if "Contents" not in page: - continue - for obj in page["Contents"]: - key = obj["Key"] - if key.endswith(".csv"): - s3_uri = f"s3://{bucket}/{key}" - s3_uris.append(s3_uri) - if max_files and len(s3_uris) >= max_files: - logger.info(f"Limited to {max_files} files for testing") - return s3_uris - - logger.info(f"Found {len(s3_uris)} CSV files") - return s3_uris - - -def scan_single_csv_lazy( - s3_uri: str, schema: dict[str, DType] | None = None, day_mode: str = "calendar" -) -> pl.LazyFrame: - """ - Lazily scan a single CSV from S3 and apply transformations using a fixed schema. - """ - schema = COMED_SCHEMA if schema is None else schema - lf = pl.scan_csv(s3_uri, schema_overrides=schema, ignore_errors=True) - lf_long = transform_wide_to_long_lazy(lf) - return add_time_columns_lazy(lf_long, day_mode=day_mode) + for obj in page.get("Contents", []): + key = obj.get("Key") + if key and key.endswith(".csv"): + keys.append(key) + keys = sorted(keys) -def transform_wide_to_long_lazy( - lf: pl.LazyFrame, - date_col: str = "INTERVAL_READING_DATE", -) -> pl.LazyFrame: - """ - Transform wide ComEd interval file to long format with timestamps. - """ - id_cols = [ - "ZIP_CODE", - "DELIVERY_SERVICE_CLASS", - "DELIVERY_SERVICE_NAME", - "ACCOUNT_IDENTIFIER", - "INTERVAL_READING_DATE", - "INTERVAL_LENGTH", - "TOTAL_REGISTERED_ENERGY", - "PLC_VALUE", - "NSPL_VALUE", - ] - - requested_cols = id_cols + COMED_INTERVAL_COLUMNS - - lf_long = ( - lf.select(requested_cols) - .unpivot( - index=id_cols, - on=COMED_INTERVAL_COLUMNS, - variable_name="interval_col", - value_name="kwh", - ) - .filter(pl.col("kwh").is_not_null()) - .with_columns(pl.col("interval_col").str.extract(r"HR(\d{4})", 1).alias("time_str")) - .with_columns([ - pl.col(date_col).str.strptime(pl.Date, format="%m/%d/%Y", strict=False).alias("service_date"), - pl.col("time_str").str.slice(0, 2).cast(pl.Int16).alias("hour_raw"), - pl.col("time_str").str.slice(2, 2).cast(pl.Int16).alias("minute"), - ]) - .with_columns([ - (pl.col("hour_raw") // 24).alias("days_offset"), - (pl.col("hour_raw") % 24).alias("hour"), - ]) - .with_columns([ - ( - pl.col("service_date").cast(pl.Datetime) - + pl.duration(days=pl.col("days_offset"), hours=pl.col("hour"), minutes=pl.col("minute")) - ).alias("datetime") - ]) - .select([ - pl.col("ZIP_CODE").alias("zip_code"), - pl.col("DELIVERY_SERVICE_CLASS").alias("delivery_service_class"), - pl.col("DELIVERY_SERVICE_NAME").alias("delivery_service_name"), - pl.col("ACCOUNT_IDENTIFIER").alias("account_identifier"), - pl.col("datetime"), - pl.col("kwh").cast(pl.Float64), - ]) - ) - return lf_long + if max_files is not None: + keys = keys[: int(max_files)] + if not keys: + raise ValueError(ERR_NO_FILES_FOUND.format(year_month)) -def add_time_columns_lazy(lf: pl.LazyFrame, day_mode: str = "calendar") -> pl.LazyFrame: - """ - Add derived time columns and day-attribution flags. - - Day attribution modes: - - "calendar": date == datetime.date() - - "billing": date == datetime.date(), except 00:00 rows are shifted to the previous date - (so HR2400 is attributed to the prior day). - """ - from datetime import date as _date - - if day_mode not in {"calendar", "billing"}: - raise ValueError("day_mode must be 'calendar' or 'billing'") - - DST_SPRING_2023 = _date(2023, 3, 12) - DST_FALL_2023 = _date(2023, 11, 5) + logger.info("Found %d CSV files (after limit)", len(keys)) + return keys - dt = pl.col("datetime") - if day_mode == "calendar": - date_expr = dt.dt.date() - else: - date_expr = ( - pl.when((dt.dt.hour() == 0) & (dt.dt.minute() == 0)) - .then((dt - pl.duration(days=1)).dt.date()) - .otherwise(dt.dt.date()) - ) - - weekday_expr = pl.col("date").dt.weekday() # Polars: Mon=1 ... Sun=7 - - return ( - lf.with_columns([ - date_expr.alias("date"), - dt.dt.hour().alias("hour"), - ]) - .with_columns([ - weekday_expr.alias("weekday"), - (weekday_expr >= 6).alias("is_weekend"), - ]) - .with_columns([ - (pl.col("date") == DST_SPRING_2023).alias("is_spring_forward_day"), - (pl.col("date") == DST_FALL_2023).alias("is_fall_back_day"), - ((pl.col("date") == DST_SPRING_2023) | (pl.col("date") == DST_FALL_2023)).alias("is_dst_day"), - ]) - ) +def _write_manifest_line(fp: Any, record: dict[str, Any]) -> None: + fp.write(json.dumps(record, sort_keys=True) + "\n") + fp.flush() -def process_month_batch( - year_month: str, - output_path: Path, - max_files: int | None = None, +def _validate_manifest(manifest_path: Path) -> bool: + """Best-effort validation that the manifest is valid JSONL and contains required fields. + Non-fatal: returns False on any issue. + """ + required_fields = {"s3_key", "status", "timestamp"} + try: + with manifest_path.open("r", encoding="utf-8") as f: + for i, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + rec = json.loads(line) + if not required_fields.issubset(rec.keys()): + raise ValueError(f"Line {i} missing required fields: {required_fields - set(rec.keys())}") + return True + except Exception as exc: + logger.error("Invalid manifest JSONL at %s: %s", manifest_path, exc) + return False + + +def _local_path_for_key(*, key: str, output_dir: Path, prefix: str) -> Path: + """Map an S3 key to a deterministic local path under output_dir. + + We preserve the key's structure beneath the configured prefix to avoid filename collisions. + Example: + key: sharepoint-files/Zip4/202307/foo.csv + prefix: sharepoint-files/Zip4/ + local: /202307/foo.csv + """ + key_path = Path(key) + prefix_path = Path(prefix.rstrip("/")) + + try: + rel = key_path.relative_to(prefix_path) + except ValueError: + # Fallback if key does not start with prefix as a Path. + rel = key_path + + return output_dir / rel + + +def _is_existing_nonempty(path: Path) -> tuple[bool, int | None]: + """Return (exists_and_nonempty, size_bytes_if_known).""" + try: + if path.exists(): + size = path.stat().st_size + return (size > 0, size) + return (False, None) + except OSError: + return (False, None) + + +def download_s3_batch( + *, + s3_keys: list[str], + output_dir: Path, + manifest_path: Path, bucket: str = S3_BUCKET, prefix: str = S3_PREFIX, - sort_output: bool = False, - day_mode: str = "calendar", -) -> None: - """ - Process all CSVs for a month and save as a single Parquet file. + max_files: int | None = None, + fail_fast: bool = True, + max_errors: int = 10, + retries: int = 3, + backoff_factor: float = 2.0, + log_every: int = 100, + overwrite_manifest: bool = False, + skip_existing: bool = True, +) -> dict[str, Any]: + """Download S3 keys to local directory with manifest tracking. Args: - year_month: Month in 'YYYYMM' format. - output_path: Path to the Parquet file to write. - max_files: Optional limit for testing. - bucket: S3 bucket name. - prefix: S3 prefix path. - sort_output: Whether to sort by datetime before writing. - day_mode: Day attribution mode ('calendar' or 'billing'). - """ - logger.info(f"Processing month: {year_month}") + s3_keys: S3 keys (NOT URIs). Example: "sharepoint-files/Zip4/202307/file.csv" + output_dir: Local directory to write downloaded CSVs + manifest_path: JSONL manifest path (append by default; optionally overwritten) + bucket: S3 bucket + prefix: S3 prefix used to compute deterministic local paths + max_files: Optional cap (applied after s3_keys is already deterministic/sorted upstream) + fail_fast: Stop immediately on first error + max_errors: Allowed errors before aborting when fail_fast=False + retries: Number of retry attempts per file (in addition to the initial attempt) + backoff_factor: Exponential backoff multiplier (seconds: 1, 2, 4, ...) + log_every: Progress logging interval + overwrite_manifest: If True, truncate and rewrite manifest (clean-room rerun) + skip_existing: If True, skip download when local_path exists and is non-empty - s3_uris = list_s3_files(year_month, bucket, prefix, max_files) - if not s3_uris: - raise ValueError(ERR_NO_FILES_FOUND.format(year_month)) - - logger.info(f"Scanning {len(s3_uris)} files lazily...") + Returns: + dict with keys: downloaded, failed, skipped, manifest_path + """ + if max_files is not None: + s3_keys = s3_keys[: int(max_files)] - lazy_frames: list[pl.LazyFrame] = [] - for i, s3_uri in enumerate(s3_uris, 1): - filename = s3_uri.split("/")[-1] - logger.debug(f"Scanning {i}/{len(s3_uris)}: {filename}") - try: - lf = scan_single_csv_lazy(s3_uri, day_mode=day_mode) - lazy_frames.append(lf) - except Exception: - logger.exception(f"Failed to scan {s3_uri}") - continue + output_dir.mkdir(parents=True, exist_ok=True) + manifest_path.parent.mkdir(parents=True, exist_ok=True) - if not lazy_frames: - raise ValueError(ERR_NO_SUCCESSFUL_PROCESS) + s3 = boto3.client("s3") - logger.info(f"Concatenating {len(lazy_frames)} lazy frames...") - lf_combined = pl.concat(lazy_frames, how="diagonal_relaxed") + downloaded = 0 + failed = 0 + skipped = 0 - if sort_output: - logger.info("Sorting by datetime (this will materialize data)...") - lf_combined = lf_combined.sort("datetime") + mode = "w" if overwrite_manifest else "a" + if overwrite_manifest: + logger.info("Overwriting manifest (clean run): %s", manifest_path) + else: + if manifest_path.exists(): + logger.info("Appending to existing manifest (resume mode): %s", manifest_path) + else: + logger.info("Creating new manifest: %s", manifest_path) + + with manifest_path.open(mode, encoding="utf-8") as mf: + for i, key in enumerate(s3_keys, 1): + if log_every > 0 and (i == 1 or i % log_every == 0 or i == len(s3_keys)): + logger.info("Downloading %d/%d", i, len(s3_keys)) + + local_path = _local_path_for_key(key=key, output_dir=output_dir, prefix=prefix) + local_path.parent.mkdir(parents=True, exist_ok=True) + + s3_uri = f"s3://{bucket}/{key}" + ts = _utc_now_iso() + + if skip_existing: + exists_nonempty, size_bytes = _is_existing_nonempty(local_path) + if exists_nonempty: + _write_manifest_line( + mf, + { + "s3_key": key, + "s3_uri": s3_uri, + "local_path": str(local_path), + "status": "skipped_exists", + "size_bytes": size_bytes, + "timestamp": ts, + "attempt": 0, + }, + ) + skipped += 1 + continue + + attempt = 0 + last_exc: Exception | None = None + while attempt <= retries: + try: + attempt += 1 + s3.download_file(bucket, key, str(local_path)) + + size_bytes = None + try: + size_bytes = local_path.stat().st_size + except OSError: + size_bytes = None + + _write_manifest_line( + mf, + { + "s3_key": key, + "s3_uri": s3_uri, + "local_path": str(local_path), + "status": "success", + "size_bytes": size_bytes, + "timestamp": ts, + "attempt": attempt, + }, + ) + downloaded += 1 + last_exc = None + break + + except (ClientError, BotoCoreError, OSError) as exc: + last_exc = exc + if attempt > retries: + break + + sleep_s = float(backoff_factor) ** float(attempt - 1) + logger.warning( + "Download failed (attempt %d/%d) for %s: %s; backing off %.1fs", + attempt, + retries + 1, + key, + type(exc).__name__, + sleep_s, + ) + time.sleep(sleep_s) + + if last_exc is not None: + failed += 1 + _write_manifest_line( + mf, + { + "s3_key": key, + "s3_uri": s3_uri, + "local_path": str(local_path), + "status": "error", + "error": f"{type(last_exc).__name__}: {last_exc}", + "timestamp": ts, + "attempt": attempt, + }, + ) + + msg = f"Failed to download {key} after {attempt} attempt(s): {type(last_exc).__name__}: {last_exc}" + if fail_fast: + raise RuntimeError(msg) from last_exc + + if failed > max_errors: + raise RuntimeError( + f"Exceeded max_errors={max_errors} during S3 download. " + f"Downloaded={downloaded} Skipped={skipped} Failed={failed}. Last error: {msg}", + ) from last_exc + + logger.info( + "Download complete. Success=%d Skipped=%d Failed=%d Manifest=%s", + downloaded, + skipped, + failed, + manifest_path, + ) - output_path.parent.mkdir(parents=True, exist_ok=True) + if not _validate_manifest(manifest_path): + logger.warning("Manifest validation failed (non-fatal): %s", manifest_path) - logger.info("Collecting and writing to Parquet (this is where execution happens)...") - lf_combined.sink_parquet(output_path) - logger.info(f"Successfully wrote data to {output_path}") + return {"downloaded": downloaded, "skipped": skipped, "failed": failed, "manifest_path": str(manifest_path)} def main() -> None: - """Command-line entry point for processing ComEd smart meter data from S3.""" - import sys - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + parser = argparse.ArgumentParser(description="Deterministically list and download ComEd CSVs from S3.") + parser.add_argument("year_month", help="Month in YYYYMM format (e.g., 202307)") + parser.add_argument("--bucket", default=S3_BUCKET) + parser.add_argument("--prefix", default=S3_PREFIX) + parser.add_argument("--max-files", type=int, default=None, help="Limit number of files (testing)") + parser.add_argument("--output-dir", type=Path, default=Path("data/runs/manual/raw")) + parser.add_argument("--manifest", type=Path, default=Path("data/runs/manual/download_manifest.jsonl")) + parser.add_argument( + "--overwrite-manifest", + action="store_true", + help="Overwrite (truncate) the manifest instead of appending (clean-room rerun).", + ) + parser.add_argument( + "--no-skip-existing", + action="store_true", + help="Do not skip downloads for existing non-empty local files.", + ) + parser.add_argument("--no-fail-fast", action="store_true", help="Allow errors up to --max-errors") + parser.add_argument("--max-errors", type=int, default=10) + parser.add_argument("--retries", type=int, default=3) + parser.add_argument("--backoff-factor", type=float, default=2.0) + parser.add_argument("--log-every", type=int, default=100) + + args = parser.parse_args() + _validate_year_month(args.year_month) + + keys = list_s3_files( + args.year_month, + bucket=args.bucket, + prefix=args.prefix, + max_files=args.max_files, ) - if len(sys.argv) < 2: - print("Usage: python -m smart_meter_analysis.aws_loader YYYYMM [max_files] [calendar|billing]") - sys.exit(1) - - year_month = sys.argv[1] - max_files = int(sys.argv[2]) if len(sys.argv) > 2 and sys.argv[2].isdigit() else None - day_mode = sys.argv[3] if len(sys.argv) > 3 else "calendar" - - if day_mode not in {"calendar", "billing"}: - print("Error: day_mode must be either 'calendar' or 'billing'") - sys.exit(1) - - output_path = Path(f"data/processed/comed_{year_month}.parquet") - - process_month_batch( - year_month=year_month, - output_path=output_path, - max_files=max_files, - sort_output=False, - day_mode=day_mode, + download_s3_batch( + s3_keys=keys, + output_dir=args.output_dir, + manifest_path=args.manifest, + bucket=args.bucket, + prefix=args.prefix, + max_files=args.max_files, + fail_fast=not args.no_fail_fast, + max_errors=args.max_errors, + retries=args.retries, + backoff_factor=args.backoff_factor, + log_every=args.log_every, + overwrite_manifest=bool(args.overwrite_manifest), + skip_existing=not bool(args.no_skip_existing), ) - logger.info(f"✓ Successfully processed {year_month} in '{day_mode}' mode") + +if __name__ == "__main__": + main() diff --git a/smart_meter_analysis/census.py b/smart_meter_analysis/census.py index dc9ef79..47f1ff2 100644 --- a/smart_meter_analysis/census.py +++ b/smart_meter_analysis/census.py @@ -1,6 +1,5 @@ # smart_meter_analysis/census.py -""" -Census data fetcher for demographic analysis. +"""Census data fetcher for demographic analysis. Fetches: - ACS 5-year (detailed tables) data at Block Group level via the official Census API @@ -42,8 +41,7 @@ def chunk_list(items: list[Any], size: int) -> list[list[Any]]: def build_geoid(df: pl.DataFrame) -> pl.DataFrame: - """ - Build 12-digit Block Group GEOID from components. + """Build 12-digit Block Group GEOID from components. GEOID format: State(2) + County(3) + Tract(6) + BlockGroup(1) = 12 digits """ @@ -53,7 +51,7 @@ def build_geoid(df: pl.DataFrame) -> pl.DataFrame: pl.col("county").cast(pl.Utf8).str.zfill(3), pl.col("tract").cast(pl.Utf8).str.zfill(6), pl.col("block group").cast(pl.Utf8).str.zfill(1), - ]).alias("GEOID") + ]).alias("GEOID"), ) @@ -78,8 +76,7 @@ def clean_census_values(df: pl.DataFrame) -> pl.DataFrame: # Spec-driven helpers # ----------------------------------------------------------------------------- def extract_raw_codes(expr: str | None) -> set[str]: - """ - Extract ACS raw variable codes (e.g., B25040_002E) from a spec string. + """Extract ACS raw variable codes (e.g., B25040_002E) from a spec string. Supports expressions like: "B25070_007E + B25070_008E" @@ -90,8 +87,7 @@ def extract_raw_codes(expr: str | None) -> set[str]: def gather_acs_codes(variable_specs: list[dict[str, Any]]) -> list[str]: - """ - Gather the full set of ACS variable codes required to compute VARIABLE_SPECS. + """Gather the full set of ACS variable codes required to compute VARIABLE_SPECS. We parse: - spec["numerator"] @@ -107,8 +103,7 @@ def gather_acs_codes(variable_specs: list[dict[str, Any]]) -> list[str]: def expr_to_polars(expr: str) -> pl.Expr: - """ - Convert a simple linear expression into a Polars expression. + """Convert a simple linear expression into a Polars expression. Supported grammar: - "A" @@ -117,6 +112,7 @@ def expr_to_polars(expr: str) -> pl.Expr: Notes: - Inputs are expected to be ACS codes present in the DataFrame. - Each term is cast to Float64 to avoid integer division issues. + """ parts = [p.strip() for p in expr.split("+")] out = pl.lit(0.0) @@ -128,8 +124,7 @@ def expr_to_polars(expr: str) -> pl.Expr: def is_acs_only_spec(spec: dict[str, Any]) -> bool: - """ - Check if a spec is ACS-only (excludes Decennial H2_*N variables). + """Check if a spec is ACS-only (excludes Decennial H2_*N variables). Returns True if the spec does not reference any H2_*N variables in its numerator or denominator expressions. @@ -144,9 +139,9 @@ def is_acs_only_spec(spec: dict[str, Any]) -> bool: def build_feature_columns(variable_specs: list[dict[str, Any]]) -> list[pl.Expr]: - """ - Build Polars expressions for engineered features defined by VARIABLE_SPECS. + """Build Polars expressions for engineered features defined by VARIABLE_SPECS. + - If "variable" exists: use it directly (single ACS code) - If denominator exists: compute percent = numerator / denominator * 100 (safe_percent). - If denominator is None: use numerator directly. - transformation: @@ -157,20 +152,27 @@ def build_feature_columns(variable_specs: list[dict[str, Any]]) -> list[pl.Expr] for spec in variable_specs: name = spec["name"] + variable = spec.get("variable") # Single variable (no ratio) numer = spec.get("numerator") denom = spec.get("denominator") transform = spec.get("transformation", "none") - if not numer: - raise ValueError(f"Missing numerator for spec: {name}") - - if denom: + # Determine the base value expression + if variable: + # Single variable (e.g., median_household_income) + value = expr_to_polars(variable) + elif numer: + # Ratio (e.g., unemployment_rate, pct_owner_occupied) numer_expr = expr_to_polars(numer) - denom_expr = expr_to_polars(denom) - value = safe_percent(numer_expr, denom_expr) + if denom: + denom_expr = expr_to_polars(denom) + value = safe_percent(numer_expr, denom_expr) + else: + value = numer_expr else: - value = expr_to_polars(numer) + raise ValueError(f"Spec '{name}' must have either 'variable' or 'numerator'") + # Apply transformation if transform == "log": value = pl.when(value.is_not_null() & (value >= 0)).then((value + 1.0).log()).otherwise(None) elif transform == "none": @@ -186,14 +188,13 @@ def build_feature_columns(variable_specs: list[dict[str, Any]]) -> list[pl.Expr] # ----------------------------------------------------------------------------- # Fetchers # ----------------------------------------------------------------------------- -def fetch_acs_data( # noqa: C901 +def fetch_acs_data( state_fips: str = "17", year: int = 2023, county_fips: str | None = None, keep_raw_debug_cols: list[str] | None = None, ) -> pl.DataFrame: - """ - Fetch ACS 5-Year *detailed table* raw variables needed for VARIABLE_SPECS at Block Group level, + """Fetch ACS 5-Year *detailed table* raw variables needed for VARIABLE_SPECS at Block Group level, then compute engineered features. Args: @@ -210,6 +211,7 @@ def fetch_acs_data( # noqa: C901 - GEOID, NAME - engineered ACS features defined by VARIABLE_SPECS (ACS-only specs) - optionally, raw ACS variables specified in keep_raw_debug_cols + """ suffix = f", county {county_fips}" if county_fips else "" logger.info(f"Fetching ACS {year} detailed-table data for state {state_fips}{suffix}") @@ -254,7 +256,7 @@ def fetch_acs_data( # noqa: C901 response_body = resp.text[:500] if resp.text else str(e) raise RuntimeError( f"ACS API chunk {i + 1}/{len(code_chunks)} failed with HTTP {status_code}.\n" - f"Response body (first 500 chars): {response_body}" + f"Response body (first 500 chars): {response_body}", ) from e except Exception as e: raise RuntimeError(f"ACS API chunk {i + 1}/{len(code_chunks)} failed with error: {e}") from e @@ -271,7 +273,7 @@ def fetch_acs_data( # noqa: C901 if missing_in_response: logger.warning( f" Chunk {i + 1}: Variables requested but not in API response: {missing_in_response[:5]}" - + (f" (and {len(missing_in_response) - 5} more)" if len(missing_in_response) > 5 else "") + + (f" (and {len(missing_in_response) - 5} more)" if len(missing_in_response) > 5 else ""), ) chunk_df = pl.DataFrame(rows, schema=header, orient="row") @@ -332,8 +334,7 @@ def fetch_acs_data( # noqa: C901 def fetch_decennial_data(state_fips: str = "17", year: int = 2020) -> pl.DataFrame: - """ - Fetch Decennial Census DHC H2 (Urban and Rural) at Block Group level, then compute urban_percent. + """Fetch Decennial Census DHC H2 (Urban and Rural) at Block Group level, then compute urban_percent. Args: state_fips: State FIPS code (default '17' for Illinois) @@ -345,6 +346,7 @@ def fetch_decennial_data(state_fips: str = "17", year: int = 2020) -> pl.DataFra - Urban_Housing_Units - Rural_Housing_Units - urban_percent + """ logger.info(f"Fetching Decennial {year} DHC H2 data for state {state_fips}") @@ -384,7 +386,7 @@ def fetch_decennial_data(state_fips: str = "17", year: int = 2020) -> pl.DataFra safe_percent( pl.col("Urban_Housing_Units"), pl.col("Urban_Housing_Units") + pl.col("Rural_Housing_Units"), - ).alias("urban_percent") + ).alias("urban_percent"), ]) return df.select(["GEOID", "Urban_Housing_Units", "Rural_Housing_Units", "urban_percent"]) @@ -398,8 +400,7 @@ def fetch_census_data( output_path: Path | str | None = None, keep_raw_debug_cols: list[str] | None = None, ) -> pl.DataFrame: - """ - Fetch and combine ACS (engineered features) and Decennial (urban_percent) at Block Group level. + """Fetch and combine ACS (engineered features) and Decennial (urban_percent) at Block Group level. Args: state_fips: State FIPS code (default: '17' for Illinois) @@ -415,6 +416,7 @@ def fetch_census_data( Returns: Combined DataFrame with all engineered ACS features + urban_percent by block group. Optionally includes raw ACS variables specified in keep_raw_debug_cols. + """ acs_df = fetch_acs_data( state_fips=state_fips, @@ -445,11 +447,11 @@ def fetch_census_data( def validate_census_data(df: pl.DataFrame) -> dict[str, Any]: - """ - Validate census data quality. + """Validate census data quality. Returns: Dict with validation metrics. + """ # Percent columns are conventionally named with "pct_" or explicitly "urban_percent" pct_cols = [c for c in df.columns if c.startswith("pct_")] + ( diff --git a/smart_meter_analysis/census_specs.py b/smart_meter_analysis/census_specs.py index 4f76c7a..c617df2 100644 --- a/smart_meter_analysis/census_specs.py +++ b/smart_meter_analysis/census_specs.py @@ -1,13 +1,9 @@ # smart_meter_analysis/census_specs.py -""" -Census variable registry + stable Stage 2 predictor list. +"""Census variable registry + stable Stage 2 predictor list. - VARIABLE_SPECS is the spec-driven registry used by census.py to decide which ACS/Decennial variables to request and how to engineer features. - STAGE2_PREDICTORS_47 is the *stable* final predictor list used by Stage 2 regression. - -Note: Some predictors in STAGE2_PREDICTORS_47 may be engineered composites (e.g., multifamily_10_plus) -and therefore may not correspond 1:1 to a single ACS code. """ from __future__ import annotations @@ -15,7 +11,7 @@ from typing import Any # ----------------------------------------------------------------------------- -# Stable Stage 2 predictor list (exactly what your run log shows as "Using 47 predictors") +# Stable Stage 2 predictor list (47 predictors) # ----------------------------------------------------------------------------- STAGE2_PREDICTORS_47: list[str] = [ "avg_family_size", @@ -67,65 +63,398 @@ "urban_percent", ] - # ----------------------------------------------------------------------------- -# Spec registry used by census.py for fetching/engineering. +# Complete spec registry with ACS variable codes # ----------------------------------------------------------------------------- - VARIABLE_SPECS: list[dict[str, Any]] = [ - # --- Spatial --- - {"name": "urban_percent", "category": "spatial", "source": "acs"}, - # --- Economic --- - {"name": "median_household_income", "category": "economic", "source": "acs"}, - {"name": "unemployment_rate", "category": "economic", "source": "acs"}, - {"name": "pct_in_civilian_labor_force", "category": "economic", "source": "acs"}, - {"name": "pct_not_in_labor_force", "category": "economic", "source": "acs"}, - {"name": "pct_income_under_25k", "category": "economic", "source": "acs"}, - {"name": "pct_income_25k_to_75k", "category": "economic", "source": "acs"}, - {"name": "pct_income_75k_plus", "category": "economic", "source": "acs"}, - # --- Housing --- - {"name": "pct_owner_occupied", "category": "housing", "source": "acs"}, - {"name": "pct_renter_occupied", "category": "housing", "source": "acs"}, - {"name": "pct_heat_utility_gas", "category": "housing", "source": "acs"}, - {"name": "pct_heat_electric", "category": "housing", "source": "acs"}, - {"name": "pct_housing_built_2000_plus", "category": "housing", "source": "acs"}, - {"name": "pct_housing_built_1980_1999", "category": "housing", "source": "acs"}, - {"name": "old_building_pct", "category": "housing", "source": "acs"}, - {"name": "pct_structure_single_family_detached", "category": "housing", "source": "acs"}, - {"name": "pct_structure_single_family_attached", "category": "housing", "source": "acs"}, - {"name": "pct_structure_multifamily_2_to_4", "category": "housing", "source": "acs"}, - {"name": "pct_structure_multifamily_5_to_19", "category": "housing", "source": "acs"}, - {"name": "pct_structure_multifamily_20_plus", "category": "housing", "source": "acs"}, - {"name": "pct_structure_multifamily_10_plus", "category": "housing", "source": "acs"}, - {"name": "pct_structure_mobile_home", "category": "housing", "source": "acs"}, - {"name": "pct_vacant_housing_units", "category": "housing", "source": "acs"}, - {"name": "pct_home_value_under_150k", "category": "housing", "source": "acs"}, - {"name": "pct_home_value_150k_to_299k", "category": "housing", "source": "acs"}, - {"name": "pct_home_value_300k_plus", "category": "housing", "source": "acs"}, - {"name": "pct_rent_burden_30_plus", "category": "housing", "source": "acs"}, - {"name": "pct_rent_burden_50_plus", "category": "housing", "source": "acs"}, - {"name": "pct_owner_cost_burden_30_plus_mortgage", "category": "housing", "source": "acs"}, - {"name": "pct_owner_cost_burden_50_plus_mortgage", "category": "housing", "source": "acs"}, - {"name": "pct_owner_overcrowded_2plus_per_room", "category": "housing", "source": "acs"}, - {"name": "pct_renter_overcrowded_2plus_per_room", "category": "housing", "source": "acs"}, - # --- Household --- - {"name": "avg_household_size", "category": "household", "source": "acs"}, - {"name": "avg_family_size", "category": "household", "source": "acs"}, - {"name": "pct_single_parent_households", "category": "household", "source": "acs"}, - # --- Demographic --- - {"name": "median_age", "category": "demographic", "source": "acs"}, - {"name": "pct_white_alone", "category": "demographic", "source": "acs"}, - {"name": "pct_black_alone", "category": "demographic", "source": "acs"}, - {"name": "pct_asian_alone", "category": "demographic", "source": "acs"}, - {"name": "pct_two_or_more_races", "category": "demographic", "source": "acs"}, - {"name": "pct_population_under_5", "category": "demographic", "source": "acs"}, - {"name": "pct_population_5_to_17", "category": "demographic", "source": "acs"}, - {"name": "pct_population_18_to_24", "category": "demographic", "source": "acs"}, - {"name": "pct_population_25_to_44", "category": "demographic", "source": "acs"}, - {"name": "pct_population_45_to_64", "category": "demographic", "source": "acs"}, - {"name": "pct_population_65_plus", "category": "demographic", "source": "acs"}, - {"name": "pct_female", "category": "demographic", "source": "acs"}, + # ------------------------------------------------------------------------- + # ECONOMIC (8 variables) + # ------------------------------------------------------------------------- + { + "name": "median_household_income", + "category": "economic", + "source": "acs", + "variable": "B19013_001E", + "transformation": "none", + }, + { + "name": "unemployment_rate", + "category": "economic", + "source": "acs", + "numerator": "B23025_005E", # Unemployed + "denominator": "B23025_003E", # In civilian labor force + "transformation": "none", + }, + { + "name": "pct_in_civilian_labor_force", + "category": "economic", + "source": "acs", + "numerator": "B23025_003E", # In civilian labor force + "denominator": "B23025_002E", # Population 16+ in labor force + "transformation": "none", + }, + { + "name": "pct_not_in_labor_force", + "category": "economic", + "source": "acs", + "numerator": "B23025_007E", # Not in labor force + "denominator": "B23025_001E", # Population 16+ + "transformation": "none", + }, + { + "name": "pct_income_under_25k", + "category": "economic", + "source": "acs", + "numerator": "B19001_002E + B19001_003E + B19001_004E + B19001_005E", + "denominator": "B19001_001E", + "transformation": "none", + }, + { + "name": "pct_income_25k_to_75k", + "category": "economic", + "source": "acs", + "numerator": "B19001_006E + B19001_007E + B19001_008E + B19001_009E + B19001_010E + B19001_011E + B19001_012E", + "denominator": "B19001_001E", + "transformation": "none", + }, + { + "name": "pct_income_75k_plus", + "category": "economic", + "source": "acs", + "numerator": "B19001_013E + B19001_014E + B19001_015E + B19001_016E + B19001_017E", + "denominator": "B19001_001E", + "transformation": "none", + }, + # ------------------------------------------------------------------------- + # HOUSING (25 variables) + # ------------------------------------------------------------------------- + { + "name": "pct_owner_occupied", + "category": "housing", + "source": "acs", + "numerator": "B25003_002E", # Owner occupied + "denominator": "B25003_001E", # Total occupied + "transformation": "none", + }, + { + "name": "pct_renter_occupied", + "category": "housing", + "source": "acs", + "numerator": "B25003_003E", # Renter occupied + "denominator": "B25003_001E", # Total occupied + "transformation": "none", + }, + { + "name": "pct_vacant_housing_units", + "category": "housing", + "source": "acs", + "numerator": "B25002_003E", # Vacant + "denominator": "B25002_001E", # Total housing units + "transformation": "none", + }, + { + "name": "pct_heat_utility_gas", + "category": "housing", + "source": "acs", + "numerator": "B25040_002E", # Utility gas + "denominator": "B25040_001E", # Total occupied units + "transformation": "none", + }, + { + "name": "pct_heat_electric", + "category": "housing", + "source": "acs", + "numerator": "B25040_005E", # Electricity + "denominator": "B25040_001E", # Total occupied units + "transformation": "none", + }, + { + "name": "pct_housing_built_2000_plus", + "category": "housing", + "source": "acs", + "numerator": "B25034_002E + B25034_003E + B25034_004E + B25034_005E", # 2000-2009, 2010-2013, 2014-2017, 2018-2020+ + "denominator": "B25034_001E", + "transformation": "none", + }, + { + "name": "pct_housing_built_1980_1999", + "category": "housing", + "source": "acs", + "numerator": "B25034_006E + B25034_007E", # 1990-1999, 1980-1989 + "denominator": "B25034_001E", + "transformation": "none", + }, + { + "name": "old_building_pct", + "category": "housing", + "source": "acs", + "numerator": "B25034_008E + B25034_009E + B25034_010E + B25034_011E", # Pre-1980 + "denominator": "B25034_001E", + "transformation": "none", + }, + { + "name": "pct_structure_single_family_detached", + "category": "housing", + "source": "acs", + "numerator": "B25024_002E", # 1-unit detached + "denominator": "B25024_001E", + "transformation": "none", + }, + { + "name": "pct_structure_single_family_attached", + "category": "housing", + "source": "acs", + "numerator": "B25024_003E", # 1-unit attached + "denominator": "B25024_001E", + "transformation": "none", + }, + { + "name": "pct_structure_multifamily_2_to_4", + "category": "housing", + "source": "acs", + "numerator": "B25024_004E + B25024_005E", # 2 units + 3-4 units + "denominator": "B25024_001E", + "transformation": "none", + }, + { + "name": "pct_structure_multifamily_5_to_19", + "category": "housing", + "source": "acs", + "numerator": "B25024_006E + B25024_007E", # 5-9 units + 10-19 units + "denominator": "B25024_001E", + "transformation": "none", + }, + { + "name": "pct_structure_multifamily_20_plus", + "category": "housing", + "source": "acs", + "numerator": "B25024_008E + B25024_009E", # 20-49 units + 50+ units + "denominator": "B25024_001E", + "transformation": "none", + }, + { + "name": "pct_structure_multifamily_10_plus", + "category": "housing", + "source": "acs", + "numerator": "B25024_007E + B25024_008E + B25024_009E", # 10-19, 20-49, 50+ + "denominator": "B25024_001E", + "transformation": "none", + }, + { + "name": "pct_structure_mobile_home", + "category": "housing", + "source": "acs", + "numerator": "B25024_010E", # Mobile home + "denominator": "B25024_001E", + "transformation": "none", + }, + { + "name": "pct_home_value_under_150k", + "category": "housing", + "source": "acs", + "numerator": "B25075_002E + B25075_003E + B25075_004E + B25075_005E + B25075_006E + B25075_007E + B25075_008E + B25075_009E + B25075_010E + B25075_011E + B25075_012E + B25075_013E + B25075_014E", + "denominator": "B25075_001E", + "transformation": "none", + }, + { + "name": "pct_home_value_150k_to_299k", + "category": "housing", + "source": "acs", + "numerator": "B25075_015E + B25075_016E + B25075_017E + B25075_018E + B25075_019E", + "denominator": "B25075_001E", + "transformation": "none", + }, + { + "name": "pct_home_value_300k_plus", + "category": "housing", + "source": "acs", + "numerator": "B25075_020E + B25075_021E + B25075_022E + B25075_023E + B25075_024E + B25075_025E + B25075_026E + B25075_027E", + "denominator": "B25075_001E", + "transformation": "none", + }, + { + "name": "pct_rent_burden_30_plus", + "category": "housing", + "source": "acs", + "numerator": "B25070_007E + B25070_008E + B25070_009E + B25070_010E", # 30-34.9%, 35-39.9%, 40-49.9%, 50%+ + "denominator": "B25070_001E", + "transformation": "none", + }, + { + "name": "pct_rent_burden_50_plus", + "category": "housing", + "source": "acs", + "numerator": "B25070_010E", # 50%+ + "denominator": "B25070_001E", + "transformation": "none", + }, + { + "name": "pct_owner_cost_burden_30_plus_mortgage", + "category": "housing", + "source": "acs", + "numerator": "B25091_008E + B25091_009E + B25091_010E + B25091_011E", # With mortgage: 30-34.9%, 35-39.9%, 40-49.9%, 50%+ + "denominator": "B25091_002E", # With mortgage total + "transformation": "none", + }, + { + "name": "pct_owner_cost_burden_50_plus_mortgage", + "category": "housing", + "source": "acs", + "numerator": "B25091_011E", # With mortgage: 50%+ + "denominator": "B25091_002E", + "transformation": "none", + }, + { + "name": "pct_owner_overcrowded_2plus_per_room", + "category": "housing", + "source": "acs", + "numerator": "B25014_007E", # Owner: 2+ persons per room + "denominator": "B25014_002E", # Owner total + "transformation": "none", + }, + { + "name": "pct_renter_overcrowded_2plus_per_room", + "category": "housing", + "source": "acs", + "numerator": "B25014_013E", # Renter: 2+ persons per room + "denominator": "B25014_008E", # Renter total + "transformation": "none", + }, + # ------------------------------------------------------------------------- + # HOUSEHOLD (3 variables) + # ------------------------------------------------------------------------- + { + "name": "avg_household_size", + "category": "household", + "source": "acs", + "variable": "B25010_001E", + "transformation": "none", + }, + { + "name": "avg_family_size", + "category": "household", + "source": "acs", + "variable": "B25010_002E", + "transformation": "none", + }, + { + "name": "pct_single_parent_households", + "category": "household", + "source": "acs", + "numerator": "B11001_006E + B11001_007E", # Male householder + Female householder, no spouse + "denominator": "B11001_001E", # Total households + "transformation": "none", + }, + # ------------------------------------------------------------------------- + # DEMOGRAPHIC (10 variables) + # ------------------------------------------------------------------------- + { + "name": "median_age", + "category": "demographic", + "source": "acs", + "variable": "B01002_001E", + "transformation": "none", + }, + { + "name": "pct_female", + "category": "demographic", + "source": "acs", + "numerator": "B01001_026E", # Female total + "denominator": "B01001_001E", # Total population + "transformation": "none", + }, + { + "name": "pct_white_alone", + "category": "demographic", + "source": "acs", + "numerator": "B03002_003E", # White alone, not Hispanic/Latino + "denominator": "B03002_001E", + "transformation": "none", + }, + { + "name": "pct_black_alone", + "category": "demographic", + "source": "acs", + "numerator": "B03002_004E", # Black alone, not Hispanic/Latino + "denominator": "B03002_001E", + "transformation": "none", + }, + { + "name": "pct_asian_alone", + "category": "demographic", + "source": "acs", + "numerator": "B03002_006E", # Asian alone, not Hispanic/Latino + "denominator": "B03002_001E", + "transformation": "none", + }, + { + "name": "pct_two_or_more_races", + "category": "demographic", + "source": "acs", + "numerator": "B03002_009E", # Two or more races, not Hispanic/Latino + "denominator": "B03002_001E", + "transformation": "none", + }, + { + "name": "pct_population_under_5", + "category": "demographic", + "source": "acs", + "numerator": "B01001_003E + B01001_027E", # Male + Female under 5 + "denominator": "B01001_001E", + "transformation": "none", + }, + { + "name": "pct_population_5_to_17", + "category": "demographic", + "source": "acs", + "numerator": "B01001_004E + B01001_005E + B01001_006E + B01001_028E + B01001_029E + B01001_030E", + "denominator": "B01001_001E", + "transformation": "none", + }, + { + "name": "pct_population_18_to_24", + "category": "demographic", + "source": "acs", + "numerator": "B01001_007E + B01001_008E + B01001_009E + B01001_010E + B01001_031E + B01001_032E + B01001_033E + B01001_034E", + "denominator": "B01001_001E", + "transformation": "none", + }, + { + "name": "pct_population_25_to_44", + "category": "demographic", + "source": "acs", + "numerator": "B01001_011E + B01001_012E + B01001_013E + B01001_014E + B01001_035E + B01001_036E + B01001_037E + B01001_038E", + "denominator": "B01001_001E", + "transformation": "none", + }, + { + "name": "pct_population_45_to_64", + "category": "demographic", + "source": "acs", + "numerator": "B01001_015E + B01001_016E + B01001_017E + B01001_018E + B01001_019E + B01001_039E + B01001_040E + B01001_041E + B01001_042E + B01001_043E", + "denominator": "B01001_001E", + "transformation": "none", + }, + { + "name": "pct_population_65_plus", + "category": "demographic", + "source": "acs", + "numerator": "B01001_020E + B01001_021E + B01001_022E + B01001_023E + B01001_024E + B01001_025E + B01001_044E + B01001_045E + B01001_046E + B01001_047E + B01001_048E + B01001_049E", + "denominator": "B01001_001E", + "transformation": "none", + }, + # ------------------------------------------------------------------------- + # SPATIAL (1 variable - from Decennial, not ACS) + # ------------------------------------------------------------------------- + { + "name": "urban_percent", + "category": "spatial", + "source": "decennial", + "numerator": "H2_002N", # Urban housing units + "denominator": "H2_002N + H2_003N", # Total (urban + rural) + "transformation": "none", + }, ] -# Optional convenience: expose the expected count (useful for asserts/logging) +# Optional convenience: expose the expected count EXPECTED_STAGE2_PREDICTOR_COUNT = len(STAGE2_PREDICTORS_47) diff --git a/smart_meter_analysis/run_manifest.py b/smart_meter_analysis/run_manifest.py new file mode 100644 index 0000000..018400f --- /dev/null +++ b/smart_meter_analysis/run_manifest.py @@ -0,0 +1,202 @@ +# smart_meter_analysis/run_manifest.py +from __future__ import annotations + +import json +import platform +import sys +from collections.abc import Iterable +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path + + +@dataclass(frozen=True) +class Stage2RunManifest: + """Lightweight, reproducible metadata for a Stage 2 run.""" + + created_utc: str + python: str + platform: str + git_commit: str | None + command: str + output_dir: str + + # Inputs + clusters_path: str + crosswalk_path: str + census_cache_path: str + + # Key parameters + baseline_cluster: str | int | None + min_obs_per_bg: int + alpha: float + weight_column: str | None + + # Predictor handling + predictors_total_detected: int | None + predictors_used: list[str] + predictors_excluded_all_null: list[str] + + # Dataset sizes + block_groups_total: int | None + block_groups_after_min_obs: int | None + block_groups_after_drop_null_predictors: int | None + + # Outputs + regression_data_path: str | None + regression_report_path: str | None + run_log_path: str | None + + +def _safe_git_commit(repo_root: Path) -> str | None: + """Best-effort git commit retrieval without depending on GitPython.""" + try: + import subprocess + + r = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=str(repo_root), + check=False, + capture_output=True, + text=True, + ) + if r.returncode == 0: + return r.stdout.strip() or None + return None + except Exception: + return None + + +def write_stage2_manifest( + *, + output_dir: str | Path, + command: str, + repo_root: str | Path | None = None, + clusters_path: str | Path, + crosswalk_path: str | Path, + census_cache_path: str | Path, + baseline_cluster: str | int | None, + min_obs_per_bg: int, + alpha: float, + weight_column: str | None, + predictors_detected: int | None, + predictors_used: Iterable[str], + predictors_excluded_all_null: Iterable[str], + block_groups_total: int | None, + block_groups_after_min_obs: int | None, + block_groups_after_drop_null_predictors: int | None, + regression_data_path: str | Path | None, + regression_report_path: str | Path | None, + run_log_path: str | Path | None, +) -> Path: + """Writes: + - stage2_manifest.json : run metadata + - predictors_used.txt : final predictor list (stable across runs) + - predictors_excluded_all_null.txt : excluded predictors with 100% nulls + """ + out = Path(output_dir) + out.mkdir(parents=True, exist_ok=True) + + predictors_used_list = sorted(dict.fromkeys(list(predictors_used))) + excluded_list = sorted(dict.fromkeys(list(predictors_excluded_all_null))) + + # Persist predictor lists (the key stability artifact) + (out / "predictors_used.txt").write_text("\n".join(predictors_used_list) + "\n", encoding="utf-8") + (out / "predictors_excluded_all_null.txt").write_text("\n".join(excluded_list) + "\n", encoding="utf-8") + + # Build manifest + created_utc = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + repo_root_path = Path(repo_root) if repo_root is not None else None + git_commit = _safe_git_commit(repo_root_path) if repo_root_path else None + + manifest = Stage2RunManifest( + created_utc=created_utc, + python=sys.version.replace("\n", " "), + platform=f"{platform.system()} {platform.release()} ({platform.machine()})", + git_commit=git_commit, + command=command, + output_dir=str(out), + clusters_path=str(clusters_path), + crosswalk_path=str(crosswalk_path), + census_cache_path=str(census_cache_path), + baseline_cluster=baseline_cluster, + min_obs_per_bg=min_obs_per_bg, + alpha=alpha, + weight_column=weight_column, + predictors_total_detected=predictors_detected, + predictors_used=predictors_used_list, + predictors_excluded_all_null=excluded_list, + block_groups_total=block_groups_total, + block_groups_after_min_obs=block_groups_after_min_obs, + block_groups_after_drop_null_predictors=block_groups_after_drop_null_predictors, + regression_data_path=str(regression_data_path) if regression_data_path else None, + regression_report_path=str(regression_report_path) if regression_report_path else None, + run_log_path=str(run_log_path) if run_log_path else None, + ) + + manifest_path = out / "stage2_manifest.json" + manifest_path.write_text(json.dumps(manifest.__dict__, indent=2, sort_keys=True) + "\n", encoding="utf-8") + return manifest_path + + +def write_run_manifest( + *, + output_dir: str | Path, + command: str, + repo_root: str | Path | None = None, + run_name: str, + year_month: str, + num_files: int, + sample_days: int, + sample_households: int | None, + day_strategy: str, + k_min: int, + k_max: int, + n_init: int, +) -> Path: + """Write a manifest file for a pipeline run. + + Records all parameters and metadata for reproducibility. + Similar to Stage2RunManifest but for the full pipeline orchestrator. + """ + out = Path(output_dir) + out.mkdir(parents=True, exist_ok=True) + + created_utc = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + repo_root_path = Path(repo_root) if repo_root is not None else None + git_commit = _safe_git_commit(repo_root_path) if repo_root_path else None + + manifest = { + "created_utc": created_utc, + "python": sys.version.replace("\n", " "), + "platform": f"{platform.system()} {platform.release()} ({platform.machine()})", + "git_commit": git_commit, + "command": command, + "run_name": run_name, + "output_dir": str(out), + "parameters": { + "year_month": year_month, + "num_files": num_files, + "sample_days": sample_days, + "sample_households": sample_households, + "day_strategy": day_strategy, + "k_min": k_min, + "k_max": k_max, + "n_init": n_init, + }, + } + + manifest_path = out / "run_manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8") + return manifest_path + + +def load_persisted_predictors(output_dir: str | Path) -> list[str] | None: + """If `predictors_used.txt` exists, return that list to reuse exactly on a new run. + This is the "stable across runs" option. + """ + p = Path(output_dir) / "predictors_used.txt" + if not p.exists(): + return None + lines = [ln.strip() for ln in p.read_text(encoding="utf-8").splitlines()] + return [ln for ln in lines if ln] diff --git a/smart_meter_analysis/transformation.py b/smart_meter_analysis/transformation.py index 72042f4..97a324b 100644 --- a/smart_meter_analysis/transformation.py +++ b/smart_meter_analysis/transformation.py @@ -1,10 +1,13 @@ # smart_meter_analysis/transformation.py -""" -Data transformation utilities for ComEd smart meter data. +"""Data transformation utilities for ComEd smart meter data. + +Supports both eager (DataFrame) and lazy (LazyFrame) processing paths. -This module supports both eager (local) and lazy (S3) processing paths. -It converts wide-format ComEd CSVs into a normalized long format, -adds timestamps, and applies optional day-attribution modes. +Core behaviors: +- Converts wide-format ComEd interval columns into long format. +- Builds a timestamp per interval using the service date + interval "hour/minute" encoding. +- Preserves ComEd DST anomalies: some days include 24:30 and 25:00 columns (2430, 2500). +- First interval column is expected to be 00:30 (0030), not 00:00. Day attribution modes: - "calendar": 00:00 readings belong to the new day (default) @@ -13,28 +16,35 @@ from __future__ import annotations +from collections.abc import Sequence from datetime import date as _date import polars as pl __all__ = [ + "COMED_INTERVAL_COLUMNS", "add_time_columns", "transform_wide_to_long", + "transform_wide_to_long_lf", ] -# Interval columns pattern — same as in aws_loader -COMED_INTERVAL_COLUMNS = [f"INTERVAL_HR{m // 60:02d}{m % 60:02d}_ENERGY_QTY" for m in range(30, 24 * 60 + 1, 30)] +# Canonical interval columns: +# - Standard: 0030..2400 (48 columns) +# - DST fall-back may add: 2430, 2500 +# We include the DST extensions explicitly; for non-DST days these columns are simply absent. +COMED_INTERVAL_COLUMNS: list[str] = [ + f"INTERVAL_HR{m // 60:02d}{m % 60:02d}_ENERGY_QTY" for m in range(30, 24 * 60 + 1, 30) +] + [ + "INTERVAL_HR2430_ENERGY_QTY", + "INTERVAL_HR2500_ENERGY_QTY", +] -def transform_wide_to_long( - df: pl.DataFrame, - date_col: str = "INTERVAL_READING_DATE", -) -> pl.DataFrame: - """ - Convert wide-format ComEd data into a long-format DataFrame. +def _present_interval_cols(cols: Sequence[str]) -> list[str]: + return [c for c in COMED_INTERVAL_COLUMNS if c in cols] - Each row represents a 30-minute interval with a timestamp and energy usage. - """ + +def _present_id_cols(cols: Sequence[str]) -> list[str]: id_cols = [ "ZIP_CODE", "DELIVERY_SERVICE_CLASS", @@ -46,60 +56,132 @@ def transform_wide_to_long( "PLC_VALUE", "NSPL_VALUE", ] + return [c for c in id_cols if c in cols] + + +def _validate_required_cols(cols: Sequence[str], *, date_col: str) -> None: + required = ["ZIP_CODE", "ACCOUNT_IDENTIFIER", date_col] + missing = [c for c in required if c not in cols] + if missing: + raise ValueError(f"Missing required columns for transform_wide_to_long: {missing}") + - requested_cols = [c for c in id_cols + COMED_INTERVAL_COLUMNS if c in df.columns] +def _service_date_expr(date_col: str) -> pl.Expr: + """Normalize INTERVAL_READING_DATE to pl.Date. + + Handles common cases: + - Utf8 like '07/31/2023' + - Date (already correct type) + - Datetime (cast to date) + + Strategy: Try string parsing first, fall back to direct casting. + Works in both eager and lazy contexts. + """ + col = pl.col(date_col) + + # Try parsing as string first (most common case for ComEd CSVs) + # If it fails, the column is likely already Date or Datetime + return ( + col.str.strptime(pl.Date, format="%m/%d/%Y", strict=False) + .fill_null(col.cast(pl.Date, strict=False)) + .alias("service_date") + ) + + +def transform_wide_to_long(df: pl.DataFrame, date_col: str = "INTERVAL_READING_DATE") -> pl.DataFrame: + """Eager wrapper for converting wide-format ComEd data into long format. + + Returns one row per account x interval with: + zip_code, delivery_service_class, delivery_service_name, + account_identifier, datetime, kwh + """ + return transform_wide_to_long_lf(df.lazy(), date_col=date_col).collect() - df_long = ( - df.select(requested_cols) + +def transform_wide_to_long_lf(lf: pl.LazyFrame, date_col: str = "INTERVAL_READING_DATE") -> pl.LazyFrame: + """Convert wide-format ComEd data into long format (lazy). + + Notes: + - Interval columns are detected from the input schema and may include DST extensions + (INTERVAL_HR2430_ENERGY_QTY, INTERVAL_HR2500_ENERGY_QTY). + - First interval is expected to be 0030. + + """ + schema = lf.collect_schema() + cols = schema.names() + + _validate_required_cols(cols, date_col=date_col) + + id_cols = _present_id_cols(cols) + interval_cols = _present_interval_cols(cols) + if not interval_cols: + raise ValueError("No ComEd interval columns found (expected INTERVAL_HR####_ENERGY_QTY).") + + # Step 1: Unpivot to long format + out = ( + lf.select(id_cols + interval_cols) .unpivot( index=id_cols, - on=[c for c in COMED_INTERVAL_COLUMNS if c in df.columns], + on=interval_cols, variable_name="interval_col", value_name="kwh", ) .filter(pl.col("kwh").is_not_null()) - .with_columns(pl.col("interval_col").str.extract(r"HR(\d{4})", 1).alias("time_str")) - .with_columns([ - pl.col(date_col).str.strptime(pl.Date, format="%m/%d/%Y", strict=False).alias("service_date"), - pl.col("time_str").str.slice(0, 2).cast(pl.Int16).alias("hour_raw"), - pl.col("time_str").str.slice(2, 2).cast(pl.Int16).alias("minute"), - ]) - .with_columns([ - (pl.col("hour_raw") // 24).alias("days_offset"), - (pl.col("hour_raw") % 24).alias("hour"), - ]) - .with_columns([ - ( - pl.col("service_date").cast(pl.Datetime) - + pl.duration(days=pl.col("days_offset"), hours=pl.col("hour"), minutes=pl.col("minute")) - ).alias("datetime") - ]) - .select([ - pl.col("ZIP_CODE").alias("zip_code"), - pl.col("DELIVERY_SERVICE_CLASS").alias("delivery_service_class"), - pl.col("DELIVERY_SERVICE_NAME").alias("delivery_service_name"), - pl.col("ACCOUNT_IDENTIFIER").alias("account_identifier"), - pl.col("datetime"), - pl.col("kwh").cast(pl.Float64), - ]) ) - return df_long + # Step 2: Add service_date (parsed from date column) + out = out.with_columns([_service_date_expr(date_col).alias("service_date")]) + + # Step 3: Extract time components from interval column name + out = out.with_columns([pl.col("interval_col").str.extract(r"HR(\d{4})", 1).alias("time_str")]) + + # Step 4: Parse hour and minute from time_str + out = out.with_columns([ + pl.col("time_str").str.slice(0, 2).cast(pl.Int16).alias("hour_raw"), + pl.col("time_str").str.slice(2, 2).cast(pl.Int16).alias("minute"), + ]) + + # Step 5: Handle DST extensions (hours 24, 25 -> days_offset) + out = out.with_columns([ + (pl.col("hour_raw") // 24).alias("days_offset"), + (pl.col("hour_raw") % 24).alias("hour"), + ]) + + # Step 6: Build datetime + out = out.with_columns([ + ( + pl.col("service_date").cast(pl.Datetime) + + pl.duration(days=pl.col("days_offset"), hours=pl.col("hour"), minutes=pl.col("minute")) + ).alias("datetime"), + ]) + + # Step 7: Select final columns with proper names + out = out.select([ + pl.col("ZIP_CODE").alias("zip_code"), + pl.col("DELIVERY_SERVICE_CLASS").alias("delivery_service_class"), + pl.col("DELIVERY_SERVICE_NAME").alias("delivery_service_name"), + pl.col("ACCOUNT_IDENTIFIER").alias("account_identifier"), + pl.col("datetime"), + pl.col("kwh").cast(pl.Float64), + ]) + + return out def add_time_columns(df: pl.DataFrame, day_mode: str = "calendar") -> pl.DataFrame: - """ - Add derived time columns and day-attribution flags. + """Add derived time columns and day-attribution flags. Args: df: Polars DataFrame with a 'datetime' column. day_mode: 'calendar' (default) or 'billing' - "calendar": 00:00 belongs to the new day. - - "billing": 00:00 readings assigned to the previous day. + - "billing": 00:00 readings assigned to the previous date. + """ if day_mode not in {"calendar", "billing"}: raise ValueError("day_mode must be 'calendar' or 'billing'") + # 2023-only flags retained for continuity; DST handling is otherwise implicit in interval columns. DST_SPRING_2023 = _date(2023, 3, 12) DST_FALL_2023 = _date(2023, 11, 5) @@ -108,7 +190,6 @@ def add_time_columns(df: pl.DataFrame, day_mode: str = "calendar") -> pl.DataFra if day_mode == "calendar": date_expr = dt.dt.date() else: - # Assign midnight (00:00) readings to previous date date_expr = ( pl.when((dt.dt.hour() == 0) & (dt.dt.minute() == 0)) .then((dt - pl.duration(days=1)).dt.date()) @@ -125,9 +206,9 @@ def add_time_columns(df: pl.DataFrame, day_mode: str = "calendar") -> pl.DataFra (pl.col("date").dt.weekday() >= 5).alias("is_weekend"), ]) .with_columns([ - (pl.col("date") == DST_SPRING_2023).alias("is_spring_forward_day"), - (pl.col("date") == DST_FALL_2023).alias("is_fall_back_day"), - ((pl.col("date") == DST_SPRING_2023) | (pl.col("date") == DST_FALL_2023)).alias("is_dst_day"), + (pl.col("date") == DST_SPRING_2023).alias("is_spring_forward_day_2023"), + (pl.col("date") == DST_FALL_2023).alias("is_fall_back_day_2023"), + ((pl.col("date") == DST_SPRING_2023) | (pl.col("date") == DST_FALL_2023)).alias("is_dst_day_2023"), ]) ) diff --git a/tests/validate_total_comed_pipeline.py b/tests/validate_total_comed_pipeline.py deleted file mode 100644 index b10bc1f..0000000 --- a/tests/validate_total_comed_pipeline.py +++ /dev/null @@ -1,986 +0,0 @@ -#!/usr/bin/env python3 -""" -ComEd Smart Meter Analysis Pipeline Validation - -Validates data integrity and correctness across all stages of the ComEd smart meter -clustering analysis. Can operate on existing local files or pull fresh test data -from AWS S3 to perform end-to-end validation. - -Validation Stages: - 1. Processed Data - Interval-level energy data after wide-to-long transformation - 2. Enriched Data - Energy data joined with Census demographics - 3. Clustering Inputs - Daily load profiles and ZIP+4 demographics - 4. Clustering Outputs - Cluster assignments, centroids, and evaluation metrics - -Usage: - # Validate existing local files - python validate_total_comed_pipeline.py - - # Pull fresh data from S3 and run full validation - python validate_total_comed_pipeline.py --from-s3 --num-files 1000 - - # Validate specific stage only - python validate_total_comed_pipeline.py --stage clustering -""" - -from __future__ import annotations - -import argparse -import json -import logging -import sys -from pathlib import Path - -import polars as pl - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) - - -# Default paths for production data -DEFAULT_PATHS = { - "processed": Path("data/processed/comed_202308.parquet"), - "enriched": Path("data/enriched_test/enriched.parquet"), - "clustering_dir": Path("data/clustering"), - "crosswalk": Path("data/reference/2023_comed_zip4_census_crosswalk.txt"), -} - - -class PipelineValidator: - """ - Validates data quality and integrity across the ComEd analysis pipeline. - - This validator performs comprehensive checks at each pipeline stage to ensure - data transformations preserve integrity and outputs meet quality standards - required for reliable clustering analysis. - """ - - def __init__(self, base_dir: Path, run_name: str | None = None): - """ - Initialize validator with project base directory. - - Args: - base_dir: Root directory of the smart-meter-analysis project - run_name: Optional name for this validation run (used for S3 test data) - """ - self.base_dir = base_dir - self.run_name = run_name - self.results: dict[str, dict] = {} - - # If run_name provided, use validation_runs directory structure - if run_name: - self.run_dir = base_dir / "data" / "validation_runs" / run_name - # Extract year_month from run_name (e.g., "202308_1000" -> "202308") - self.year_month = run_name.split("_")[0] if "_" in run_name else run_name - self.paths = { - "samples": self.run_dir / "samples", - "processed": self.run_dir / "processed" / f"comed_{self.year_month}.parquet", - "enriched": self.run_dir / "enriched" / "enriched.parquet", - "clustering_dir": self.run_dir / "clustering", - } - else: - self.run_dir = None - self.paths = DEFAULT_PATHS.copy() - - def setup_run_directories(self) -> None: - """Create directory structure for a validation run.""" - if not self.run_dir: - return - - for subdir in ["samples", "processed", "enriched", "clustering/results"]: - (self.run_dir / subdir).mkdir(parents=True, exist_ok=True) - - logger.info(f"Created validation run directory: {self.run_dir}") - - def download_from_s3( - self, - year_month: str, - num_files: int, - bucket: str = "smart-meter-data-sb", - prefix: str = "sharepoint-files/Zip4/", - ) -> bool: - """ - Download sample CSV files from S3 for validation testing. - - Args: - year_month: Target month in YYYYMM format (e.g., '202308') - num_files: Number of CSV files to download - bucket: S3 bucket name - prefix: S3 key prefix for ComEd data - - Returns: - True if download successful, False otherwise - """ - try: - import boto3 - except ImportError: - logger.error("boto3 not installed. Run: pip install boto3") - return False - - logger.info(f"Connecting to S3 bucket: {bucket}") - - try: - s3 = boto3.client("s3") - full_prefix = f"{prefix}{year_month}/" - - # List available files - paginator = s3.get_paginator("list_objects_v2") - pages = paginator.paginate(Bucket=bucket, Prefix=full_prefix) - - csv_keys = [] - for page in pages: - if "Contents" not in page: - continue - for obj in page["Contents"]: - if obj["Key"].endswith(".csv"): - csv_keys.append(obj["Key"]) - if len(csv_keys) >= num_files: - break - if len(csv_keys) >= num_files: - break - - if not csv_keys: - logger.error(f"No CSV files found in s3://{bucket}/{full_prefix}") - return False - - logger.info(f"Found {len(csv_keys)} files, downloading to {self.paths['samples']}") - - # Download files - for i, key in enumerate(csv_keys, 1): - filename = Path(key).name - local_path = self.paths["samples"] / filename - - if i % 100 == 0 or i == len(csv_keys): - logger.info(f" Downloaded {i}/{len(csv_keys)} files") - - s3.download_file(bucket, key, str(local_path)) - - logger.info(f"Successfully downloaded {len(csv_keys)} files") - return True - - except Exception as e: - logger.error(f"S3 download failed: {e}") - logger.error("Verify AWS credentials are configured (aws configure or environment variables)") - return False - - def run_processing_stage(self, year_month: str) -> bool: - """ - Execute the data processing stage on downloaded samples. - - Uses lazy evaluation via Polars scan_csv and sink_parquet to process - files without loading all data into memory simultaneously. - - Args: - year_month: Month identifier for output file naming - - Returns: - True if processing successful, False otherwise - """ - csv_files = sorted(self.paths["samples"].glob("*.csv")) - if not csv_files: - logger.error(f"No CSV files found in {self.paths['samples']}") - return False - - logger.info(f"Processing {len(csv_files)} CSV files using lazy evaluation") - - # Import the schema and transformation functions from aws_loader - from smart_meter_analysis.aws_loader import ( - COMED_SCHEMA, - add_time_columns_lazy, - transform_wide_to_long_lazy, - ) - - # Build lazy frames for each CSV - lazy_frames = [] - for i, csv_path in enumerate(csv_files, 1): - if i % 200 == 0 or i == len(csv_files): - logger.info(f" Scanned {i}/{len(csv_files)} files") - - try: - lf = pl.scan_csv(str(csv_path), schema_overrides=COMED_SCHEMA, ignore_errors=True) - lf_long = transform_wide_to_long_lazy(lf) - lf_time = add_time_columns_lazy(lf_long, day_mode="calendar") - lazy_frames.append(lf_time) - except Exception as e: - logger.warning(f"Failed to scan {csv_path.name}: {e}") - continue - - if not lazy_frames: - logger.error("No files were successfully scanned") - return False - - logger.info(f"Concatenating {len(lazy_frames)} lazy frames and writing to parquet") - - # Concatenate lazily and sink to parquet (memory-efficient) - lf_combined = pl.concat(lazy_frames, how="diagonal_relaxed") - - # Ensure output directory exists - self.paths["processed"].parent.mkdir(parents=True, exist_ok=True) - - # sink_parquet executes the lazy query and writes directly to disk - lf_combined.sink_parquet(self.paths["processed"]) - - # Read back row count for logging - row_count = pl.scan_parquet(self.paths["processed"]).select(pl.len()).collect()[0, 0] - logger.info(f"Wrote {row_count:,} records to {self.paths['processed']}") - - return True - - def run_enrichment_stage(self, year_month: str) -> bool: - """ - Execute the census enrichment stage using lazy evaluation. - - Joins processed energy data with Census demographics via ZIP+4 crosswalk. - Uses streaming joins to handle large datasets without excessive memory. - - Args: - year_month: Month identifier for locating input file - - Returns: - True if enrichment successful, False otherwise - """ - from smart_meter_analysis.census import fetch_census_data - - crosswalk_path = self.base_dir / DEFAULT_PATHS["crosswalk"] - if not crosswalk_path.exists(): - logger.error(f"Crosswalk file not found: {crosswalk_path}") - return False - - # Ensure output directory exists - self.paths["enriched"].parent.mkdir(parents=True, exist_ok=True) - - # Census cache location - cache_dir = self.base_dir / "data" / "reference" - cache_dir.mkdir(parents=True, exist_ok=True) - census_cache = cache_dir / "census_17_2023.parquet" - - try: - # Load or fetch census data - if census_cache.exists(): - logger.info(f"Loading cached census data from {census_cache}") - census_df = pl.read_parquet(census_cache) - else: - logger.info("Fetching census data from API") - census_df = fetch_census_data(state_fips="17", acs_year=2023) - census_df.write_parquet(census_cache) - logger.info(f"Cached census data to {census_cache}") - - logger.info(f"Census data: {len(census_df):,} block groups") - - # Load crosswalk and create enriched mapping - logger.info(f"Loading crosswalk from {crosswalk_path}") - crosswalk = pl.read_csv(crosswalk_path, separator="\t", infer_schema_length=10000) - logger.info(f" Loaded {len(crosswalk):,} ZIP+4 mappings") - - # Create standardized join keys - crosswalk = crosswalk.with_columns([ - (pl.col("Zip").cast(pl.Utf8).str.zfill(5) + "-" + pl.col("Zip4").cast(pl.Utf8).str.zfill(4)).alias( - "zip4" - ), - pl.col("CensusKey2023").cast(pl.Utf8).str.zfill(15).str.slice(0, 12).alias("block_group_geoid"), - ]).select(["zip4", "block_group_geoid"]) - - # Prepare census data for join - census_df = census_df.with_columns(pl.col("GEOID").cast(pl.Utf8).str.zfill(12).alias("block_group_geoid")) - census_cols = [ - c for c in census_df.columns if c not in ["GEOID", "NAME", "state", "county", "tract", "block group"] - ] - census_for_join = census_df.select( - ["block_group_geoid"] + [c for c in census_cols if c != "block_group_geoid"] - ) - - # Join crosswalk with census - logger.info("Creating enriched crosswalk") - enriched_crosswalk = crosswalk.join(census_for_join, on="block_group_geoid", how="left") - logger.info(f" Enriched crosswalk: {len(enriched_crosswalk):,} rows") - - # Lazy join with energy data - logger.info("Joining energy data with demographics (lazy)") - energy_lf = pl.scan_parquet(self.paths["processed"]) - crosswalk_lf = enriched_crosswalk.lazy() - - enriched_lf = energy_lf.join( - crosswalk_lf, - left_on="zip_code", - right_on="zip4", - how="left", - ) - - # Sink to parquet - logger.info(f"Writing enriched data to {self.paths['enriched']}") - enriched_lf.sink_parquet(self.paths["enriched"]) - - # Get row count for logging - row_count = pl.scan_parquet(self.paths["enriched"]).select(pl.len()).collect()[0, 0] - logger.info(f"Enrichment complete: {row_count:,} records") - - return True - - except Exception as e: - logger.error(f"Enrichment failed: {e}") - import traceback - - traceback.print_exc() - return False - - def run_clustering_prep_stage(self) -> bool: - """ - Execute the clustering data preparation stage. - - Aggregates interval data to daily ZIP+4 profiles for clustering. - Works directly from processed data - demographic enrichment is - handled separately in Stage 2. - - Returns: - True if preparation successful, False otherwise - """ - import subprocess - - clustering_dir = self.paths["clustering_dir"] - clustering_dir.mkdir(parents=True, exist_ok=True) - - # Use processed data directly (not enriched) - input_path = self.paths["processed"] - - if not input_path.exists(): - logger.error(f"Processed data not found: {input_path}") - return False - - cmd = [ - sys.executable, - str(self.base_dir / "analysis" / "clustering" / "prepare_clustering_data.py"), - "--input", - str(input_path), - "--output-dir", - str(clustering_dir), - "--day-strategy", - "stratified", - "--sample-days", - "20", - "--sample-zips", - "500", - ] - - logger.info(f"Preparing clustering data from {input_path}") - result = subprocess.run(cmd, capture_output=True, text=True) - - if result.returncode != 0: - logger.error(f"Clustering prep failed: {result.stderr}") - # Print stdout too for debugging - if result.stdout: - logger.error(f"stdout: {result.stdout}") - return False - - logger.info(f"Clustering data prepared: {clustering_dir}") - return True - - def run_clustering_stage(self) -> bool: - """ - Execute the DTW clustering stage. - - Performs k-means clustering with DTW distance on daily load profiles. - - Returns: - True if clustering successful, False otherwise - """ - import subprocess - - clustering_dir = self.paths["clustering_dir"] - results_dir = clustering_dir / "results" - results_dir.mkdir(parents=True, exist_ok=True) - - profiles_path = clustering_dir / "sampled_profiles.parquet" - if not profiles_path.exists(): - logger.error(f"Profiles not found: {profiles_path}") - return False - - cmd = [ - sys.executable, - str(self.base_dir / "analysis" / "clustering" / "dtw_clustering.py"), - "--input", - str(profiles_path), - "--output-dir", - str(results_dir), - "--k-range", - "3", - "6", - "--find-optimal-k", - "--normalize", - ] - - logger.info("Running DTW clustering") - result = subprocess.run(cmd, capture_output=True, text=True) - - if result.returncode != 0: - logger.error(f"Clustering failed: {result.stderr}") - return False - - logger.info(f"Clustering complete: {results_dir}") - return True - - def validate_processed_data(self, path: Path | None = None) -> dict: - """ - Validate processed energy data quality. - - Verifies the wide-to-long transformation produced valid interval-level - data with expected schema, reasonable value ranges, and complete coverage. - - Args: - path: Path to processed parquet file (uses default if not specified) - - Returns: - Validation result dictionary with status, errors, warnings, and statistics - """ - stage = "processed" - print(f"\n{'=' * 70}") - print("STAGE 1: PROCESSED DATA VALIDATION") - print(f"{'=' * 70}") - - path = path or self.paths.get("processed") or DEFAULT_PATHS["processed"] - - if not path.exists(): - return self._fail(stage, f"File not found: {path}") - - print(f"File: {path}") - - errors = [] - warnings = [] - stats = {} - - try: - df = pl.read_parquet(path) - stats["rows"] = len(df) - stats["columns"] = len(df.columns) - print(f"Shape: {stats['rows']:,} rows × {stats['columns']} columns") - - # Schema validation - required = ["zip_code", "account_identifier", "datetime", "kwh", "date", "hour"] - missing = [c for c in required if c not in df.columns] - if missing: - errors.append(f"Missing required columns: {missing}") - else: - print("✓ Required columns present") - - # Record counts - stats["accounts"] = df["account_identifier"].n_unique() - stats["zip_codes"] = df["zip_code"].n_unique() - print(f"✓ Unique accounts: {stats['accounts']:,}") - print(f"✓ Unique ZIP+4 codes: {stats['zip_codes']:,}") - - # Temporal coverage - stats["min_date"] = str(df["date"].min()) - stats["max_date"] = str(df["date"].max()) - stats["unique_dates"] = df["date"].n_unique() - print(f"✓ Date range: {stats['min_date']} to {stats['max_date']} ({stats['unique_dates']} days)") - - # Energy value validation - kwh_min = df["kwh"].min() - kwh_max = df["kwh"].max() - kwh_mean = df["kwh"].mean() - stats["kwh_min"] = float(kwh_min) if kwh_min is not None else None - stats["kwh_max"] = float(kwh_max) if kwh_max is not None else None - stats["kwh_mean"] = float(kwh_mean) if kwh_mean is not None else None - - if kwh_min is not None and kwh_min < 0: - warnings.append(f"Negative kWh values detected: min={kwh_min:.4f}") - print(f"✓ kWh range: {kwh_min:.4f} to {kwh_max:.2f} (mean: {kwh_mean:.4f})") - - # Null value check - null_cols = [] - for col in required: - if col in df.columns: - null_pct = df[col].null_count() / len(df) * 100 - if null_pct > 0: - null_cols.append(f"{col}: {null_pct:.2f}%") - - if null_cols: - warnings.append(f"Null values found: {', '.join(null_cols)}") - else: - print("✓ No null values in required columns") - - except Exception as e: - return self._fail(stage, f"Error reading file: {e}") - - return self._result(stage, errors, warnings, stats) - - def validate_enriched_data(self, path: Path | None = None) -> dict: - """ - Validate census-enriched energy data. - - Verifies geographic join success rate and presence of demographic variables - required for the clustering analysis. - - Args: - path: Path to enriched parquet file (uses default if not specified) - - Returns: - Validation result dictionary with status, errors, warnings, and statistics - """ - stage = "enriched" - print(f"\n{'=' * 70}") - print("STAGE 2: ENRICHED DATA VALIDATION") - print(f"{'=' * 70}") - - path = path or self.paths.get("enriched") or DEFAULT_PATHS["enriched"] - - if not path.exists(): - return self._skip(stage, f"File not found: {path}") - - print(f"File: {path}") - - errors = [] - warnings = [] - stats = {} - - try: - df = pl.read_parquet(path) - stats["rows"] = len(df) - stats["columns"] = len(df.columns) - print(f"Shape: {stats['rows']:,} rows × {stats['columns']} columns") - - # Geographic enrichment validation - if "block_group_geoid" not in df.columns: - errors.append("Missing block_group_geoid column - census join may have failed") - else: - matched = df.filter(pl.col("block_group_geoid").is_not_null()).height - match_rate = matched / len(df) * 100 - stats["geographic_match_rate"] = match_rate - stats["block_groups"] = df["block_group_geoid"].n_unique() - - if match_rate < 90: - errors.append(f"Geographic match rate below 90%: {match_rate:.1f}%") - elif match_rate < 95: - warnings.append(f"Geographic match rate below 95%: {match_rate:.1f}%") - else: - print(f"✓ Geographic match rate: {match_rate:.1f}%") - - print(f"✓ Unique block groups: {stats['block_groups']:,}") - - # Demographic variable validation - census_indicators = ["Total_Households", "Median_Household_Income", "Owner_Occupied"] - found_census = [c for c in census_indicators if c in df.columns] - - if not found_census: - errors.append("No census demographic variables found") - else: - excluded = { - "zip_code", - "account_identifier", - "datetime", - "kwh", - "date", - "hour", - "weekday", - "is_weekend", - "block_group_geoid", - "delivery_service_class", - "delivery_service_name", - "is_spring_forward_day", - "is_fall_back_day", - "is_dst_day", - } - census_cols = [c for c in df.columns if c not in excluded] - stats["census_variables"] = len(census_cols) - print(f"✓ Census variables: {stats['census_variables']}") - - except Exception as e: - return self._fail(stage, f"Error reading file: {e}") - - return self._result(stage, errors, warnings, stats) - - def validate_clustering_inputs(self) -> dict: - """ - Validate clustering input data structures. - - Ensures daily load profiles have the expected 48-point structure and - demographic data provides complete coverage of profiled ZIP+4 codes. - - Returns: - Validation result dictionary with status, errors, warnings, and statistics - """ - stage = "clustering_inputs" - print(f"\n{'=' * 70}") - print("STAGE 3: CLUSTERING INPUTS VALIDATION") - print(f"{'=' * 70}") - - clustering_dir = self.paths.get("clustering_dir") or DEFAULT_PATHS["clustering_dir"] - profiles_path = clustering_dir / "sampled_profiles.parquet" - demos_path = clustering_dir / "zip4_demographics.parquet" - - errors = [] - warnings = [] - stats = {} - - # Profile validation - if not profiles_path.exists(): - return self._skip(stage, f"Profiles not found: {profiles_path}") - - print(f"Profiles: {profiles_path}") - - try: - profiles = pl.read_parquet(profiles_path) - stats["n_profiles"] = len(profiles) - stats["n_zip_codes"] = profiles["zip_code"].n_unique() - stats["n_dates"] = profiles["date"].n_unique() - print(f"✓ Profiles: {stats['n_profiles']} ({stats['n_zip_codes']} ZIP codes × {stats['n_dates']} dates)") - - # Profile length validation (must be 48 for 30-minute intervals) - profile_lengths = profiles.select(pl.col("profile").list.len()).unique()["profile"].to_list() - stats["profile_lengths"] = profile_lengths - - if profile_lengths != [48]: - if all(length in [47, 48] for length in profile_lengths): - warnings.append(f"Some profiles have 47 intervals (likely DST days): {profile_lengths}") - else: - errors.append(f"Invalid profile lengths detected: {profile_lengths}") - else: - print("✓ All profiles have 48 timepoints") - - # Null profile check - null_profiles = profiles.filter(pl.col("profile").is_null()).height - if null_profiles > 0: - errors.append(f"{null_profiles} null profiles detected") - - except Exception as e: - return self._fail(stage, f"Error reading profiles: {e}") - - # Demographics validation - if not demos_path.exists(): - warnings.append(f"Demographics file not found: {demos_path}") - else: - print(f"Demographics: {demos_path}") - try: - demos = pl.read_parquet(demos_path) - stats["n_demo_zips"] = len(demos) - stats["n_demo_vars"] = len(demos.columns) - 2 # Exclude zip_code and block_group_geoid - print(f"✓ Demographics: {stats['n_demo_zips']} ZIP codes, {stats['n_demo_vars']} variables") - - # Coverage validation - profile_zips = set(profiles["zip_code"].unique().to_list()) - demo_zips = set(demos["zip_code"].unique().to_list()) - missing = profile_zips - demo_zips - - if missing: - warnings.append(f"{len(missing)} profile ZIP codes missing demographics") - else: - print("✓ Demographics cover all profile ZIP codes") - - except Exception as e: - warnings.append(f"Error reading demographics: {e}") - - return self._result(stage, errors, warnings, stats) - - def validate_clustering_outputs(self) -> dict: - """ - Validate clustering analysis outputs. - - Verifies cluster assignments, metadata, and visualizations were generated - correctly and that cluster distribution is reasonable. - - Returns: - Validation result dictionary with status, errors, warnings, and statistics - """ - stage = "clustering_outputs" - print(f"\n{'=' * 70}") - print("STAGE 4: CLUSTERING OUTPUTS VALIDATION") - print(f"{'=' * 70}") - - clustering_dir = self.paths.get("clustering_dir") or DEFAULT_PATHS["clustering_dir"] - results_dir = clustering_dir / "results" - - if not results_dir.exists(): - return self._skip(stage, f"Results directory not found: {results_dir}") - - errors = [] - warnings = [] - stats = {} - - # Cluster assignments validation - assignments_path = results_dir / "cluster_assignments.parquet" - if not assignments_path.exists(): - errors.append("cluster_assignments.parquet not found") - else: - print(f"Assignments: {assignments_path}") - try: - assignments = pl.read_parquet(assignments_path) - stats["n_assigned"] = len(assignments) - - if "cluster" not in assignments.columns: - errors.append("cluster column missing from assignments") - else: - clusters = assignments["cluster"].unique().sort().to_list() - stats["clusters"] = clusters - stats["k"] = len(clusters) - print(f"✓ Clusters: {stats['k']} (labels: {clusters})") - - # Distribution analysis - dist = assignments.group_by("cluster").agg(pl.len().alias("count")).sort("cluster") - print("✓ Cluster distribution:") - for row in dist.iter_rows(named=True): - pct = row["count"] / len(assignments) * 100 - print(f" Cluster {row['cluster']}: {row['count']} ({pct:.1f}%)") - - # Flag highly imbalanced clusters - min_cluster_size = dist["count"].min() - if min_cluster_size < len(assignments) * 0.05: - warnings.append(f"Smallest cluster has only {min_cluster_size} profiles (<5%)") - - except Exception as e: - errors.append(f"Error reading assignments: {e}") - - # Metadata validation - metadata_path = results_dir / "clustering_metadata.json" - if not metadata_path.exists(): - errors.append("clustering_metadata.json not found") - else: - try: - with open(metadata_path) as f: - meta = json.load(f) - stats["metadata"] = meta - print(f"✓ Metadata: k={meta.get('k')}, inertia={meta.get('inertia', 0):.2f}") - except Exception as e: - errors.append(f"Error reading metadata: {e}") - - # K evaluation validation (optional - only present if --find-optimal-k was used) - k_eval_path = results_dir / "k_evaluation.json" - if k_eval_path.exists(): - try: - with open(k_eval_path) as f: - k_eval = json.load(f) - best_k_idx = k_eval["silhouette"].index(max(k_eval["silhouette"])) - best_k = k_eval["k_values"][best_k_idx] - best_sil = k_eval["silhouette"][best_k_idx] - stats["best_k"] = best_k - stats["best_silhouette"] = best_sil - print(f"✓ K evaluation: best k={best_k} (silhouette={best_sil:.3f})") - except Exception as e: - warnings.append(f"Error reading k_evaluation: {e}") - - # Visualization validation - viz_files = ["elbow_curve.png", "cluster_centroids.png", "cluster_samples.png"] - missing_viz = [f for f in viz_files if not (results_dir / f).exists()] - - if missing_viz: - warnings.append(f"Missing visualizations: {missing_viz}") - else: - print("✓ All visualizations generated") - - return self._result(stage, errors, warnings, stats) - - def run_full_pipeline(self, year_month: str, num_files: int) -> bool: - """ - Execute and validate the complete Stage 1 pipeline from S3 through clustering. - - Stage 1 focuses on usage pattern clustering only. Demographic enrichment - is deferred to Stage 2 (multinomial regression) to reduce memory requirements - and maintain separation of concerns. - - Args: - year_month: Target month in YYYYMM format - num_files: Number of S3 files to download - - Returns: - True if all stages complete successfully, False otherwise - """ - print(f"\n{'=' * 70}") - print("STAGE 1 PIPELINE EXECUTION") - print(f"{'=' * 70}") - print(f"Year-Month: {year_month}") - print(f"Files: {num_files}") - print(f"Output: {self.run_dir}") - - self.setup_run_directories() - - # Step 1: Download from S3 - print(f"\n{'─' * 70}") - print("DOWNLOADING FROM S3") - print(f"{'─' * 70}") - if not self.download_from_s3(year_month, num_files): - return False - - # Step 2: Process raw data - print(f"\n{'─' * 70}") - print("PROCESSING RAW DATA") - print(f"{'─' * 70}") - if not self.run_processing_stage(year_month): - return False - - # Step 3: Prepare clustering data (from processed, not enriched) - print(f"\n{'─' * 70}") - print("PREPARING CLUSTERING DATA") - print(f"{'─' * 70}") - if not self.run_clustering_prep_stage(): - return False - - # Step 4: Run clustering - print(f"\n{'─' * 70}") - print("RUNNING DTW CLUSTERING") - print(f"{'─' * 70}") - if not self.run_clustering_stage(): - return False - - logger.info("Stage 1 pipeline execution complete") - return True - - def validate_all(self) -> bool: - """ - Run validation checks on all Stage 1 pipeline outputs. - - Note: Enrichment validation is skipped for Stage 1. Demographic - enrichment occurs in Stage 2 (multinomial regression). - - Returns: - True if all critical validations pass, False otherwise - """ - self.results["processed"] = self.validate_processed_data() - # Skip enriched validation for Stage 1 - demographics added in Stage 2 - self.results["clustering_inputs"] = self.validate_clustering_inputs() - self.results["clustering_outputs"] = self.validate_clustering_outputs() - - return self._print_summary() - - def validate_stage(self, stage: str) -> bool: - """ - Run validation for a specific pipeline stage. - - Args: - stage: One of 'processed', 'enriched', 'clustering' - - Returns: - True if validation passes, False otherwise - """ - if stage == "processed": - self.results["processed"] = self.validate_processed_data() - elif stage == "enriched": - self.results["enriched"] = self.validate_enriched_data() - elif stage == "clustering": - self.results["clustering_inputs"] = self.validate_clustering_inputs() - self.results["clustering_outputs"] = self.validate_clustering_outputs() - else: - print(f"Unknown stage: {stage}") - return False - - return self._print_summary() - - def _result(self, stage: str, errors: list, warnings: list, stats: dict) -> dict: - """Format validation results for a stage.""" - status = "PASS" if not errors else "FAIL" - if status == "PASS": - print(f"\n✅ {stage.upper()}: PASSED") - else: - print(f"\n❌ {stage.upper()}: FAILED") - for e in errors: - print(f" Error: {e}") - - if warnings: - for w in warnings: - print(f" ⚠️ {w}") - - return {"status": status, "errors": errors, "warnings": warnings, "stats": stats} - - def _fail(self, stage: str, message: str) -> dict: - """Create a failed validation result.""" - print(f"\n❌ {stage.upper()}: FAILED - {message}") - return {"status": "FAIL", "errors": [message], "warnings": [], "stats": {}} - - def _skip(self, stage: str, message: str) -> dict: - """Create a skipped validation result.""" - print(f"\n⏭️ {stage.upper()}: SKIPPED - {message}") - return {"status": "SKIP", "errors": [], "warnings": [], "stats": {}} - - def _print_summary(self) -> bool: - """Print validation summary and return overall success status.""" - print(f"\n{'=' * 70}") - print("VALIDATION SUMMARY") - print(f"{'=' * 70}") - - all_passed = True - for stage, result in self.results.items(): - status = result.get("status", "UNKNOWN") - icon = {"PASS": "✅", "FAIL": "❌", "SKIP": "⏭️"}.get(status, "❓") - print(f"{icon} {stage}: {status}") - - if status == "FAIL": - all_passed = False - for e in result.get("errors", []): - print(f" Error: {e}") - - print() - if all_passed: - print("✓ All validations passed") - - # Report Stage 2 readiness - if "clustering_outputs" in self.results: - out_stats = self.results["clustering_outputs"].get("stats", {}) - if out_stats.get("k"): - print("\nStage 1 Analysis Complete:") - print(f" • {out_stats.get('n_assigned', '?')} profiles clustered into {out_stats.get('k')} groups") - print(" • Ready for Stage 2: Multinomial logistic regression") - else: - print("⚠️ Some validations failed. Review errors above.") - - return all_passed - - -def main(): - parser = argparse.ArgumentParser( - description="Validate ComEd smart meter analysis pipeline", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Validate existing local files - python validate_total_comed_pipeline.py - - # Download from S3 and run full pipeline validation - python validate_total_comed_pipeline.py --from-s3 --num-files 1000 - - # Validate specific stage only - python validate_total_comed_pipeline.py --stage clustering - """, - ) - parser.add_argument( - "--stage", - choices=["processed", "enriched", "clustering", "all"], - default="all", - help="Pipeline stage to validate (default: all)", - ) - parser.add_argument("--from-s3", action="store_true", help="Download fresh test data from S3 and run full pipeline") - parser.add_argument("--num-files", type=int, default=1000, help="Number of S3 files to download (default: 1000)") - parser.add_argument("--year-month", default="202308", help="Target month in YYYYMM format (default: 202308)") - parser.add_argument( - "--base-dir", type=Path, default=Path("."), help="Project root directory (default: current directory)" - ) - parser.add_argument( - "--run-name", help="Name for this validation run (default: auto-generated from year-month and num-files)" - ) - - args = parser.parse_args() - - # Generate run name if pulling from S3 - run_name = None - if args.from_s3: - run_name = args.run_name or f"{args.year_month}_{args.num_files}" - - validator = PipelineValidator(args.base_dir, run_name) - - if args.from_s3: - # Full pipeline: download, process, enrich, cluster, validate - if not validator.run_full_pipeline(args.year_month, args.num_files): - print("\n❌ Pipeline execution failed") - sys.exit(1) - - # Validate all outputs - success = validator.validate_all() - elif args.stage == "all": - success = validator.validate_all() - else: - success = validator.validate_stage(args.stage) - - sys.exit(0 if success else 1) - - -if __name__ == "__main__": - main()