From b30869d313af740e8c1e4090e3bf1b901157139e Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Fri, 1 May 2026 00:34:47 -0500 Subject: [PATCH 01/14] ADD: Ruff is working locally --- .github/workflows/ci.yml | 6 + docs/conf.py | 2 +- pyproject.toml | 14 ++ src/adapt/api/client.py | 91 +++++++------ src/adapt/cli.py | 2 - src/adapt/configuration/schemas/cli.py | 18 +-- .../configuration/schemas/directories.py | 15 +-- .../configuration/schemas/initialization.py | 17 ++- src/adapt/configuration/schemas/internal.py | 17 +-- src/adapt/configuration/schemas/param.py | 15 ++- src/adapt/configuration/schemas/resolve.py | 11 +- src/adapt/configuration/schemas/user.py | 125 +++++++++--------- src/adapt/execution/graph/__init__.py | 2 +- src/adapt/execution/graph/builder.py | 10 +- src/adapt/execution/graph/executor.py | 5 +- src/adapt/execution/graph/node.py | 10 +- src/adapt/execution/module_registry.py | 12 +- src/adapt/execution/pipeline_builder.py | 6 +- src/adapt/gui/__init__.py | 2 +- src/adapt/gui/dashboard.py | 7 +- src/adapt/modules/acquisition/module.py | 12 +- src/adapt/modules/analysis/contracts.py | 2 +- src/adapt/modules/analysis/module.py | 17 +-- src/adapt/modules/base.py | 11 +- src/adapt/modules/detection/contracts.py | 3 +- src/adapt/modules/detection/module.py | 10 +- src/adapt/modules/ingest/contracts.py | 1 + src/adapt/modules/ingest/module.py | 22 +-- src/adapt/modules/projection/contracts.py | 1 + src/adapt/modules/projection/module.py | 9 +- src/adapt/modules/tracking/contracts.py | 1 + src/adapt/modules/tracking/module.py | 63 ++++----- src/adapt/persistence/catalog.py | 71 +++++----- src/adapt/persistence/registry.py | 42 +++--- src/adapt/persistence/repository.py | 70 +++++----- src/adapt/persistence/track_store.py | 21 ++- src/adapt/persistence/writer.py | 11 +- src/adapt/runtime/__init__.py | 2 +- src/adapt/runtime/file_tracker.py | 41 +++--- src/adapt/runtime/orchestrator.py | 12 +- src/adapt/runtime/processor.py | 18 +-- src/adapt/visualization/__init__.py | 2 +- src/adapt/visualization/plotter.py | 37 +++--- .../test_tracker_scan_local_outputs.py | 10 +- tests/persistence/test_track_store.py | 6 +- 45 files changed, 454 insertions(+), 428 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d6de6c6..3688203 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,11 +44,17 @@ jobs: - name: Install adapt run: pip install -e . --no-deps --force-reinstall + - name: Install ruff + run: pip install ruff + - name: Show environment info run: | python --version pip list | head -n 30 + - name: Lint with ruff + run: ruff check src tests + - name: Run tests run: | pytest -m "not integration" \ diff --git a/docs/conf.py b/docs/conf.py index 5a356d7..48053d0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,6 +1,6 @@ import os -import sys import re +import sys from importlib.metadata import version as _pkg_version sys.path.insert(0, os.path.abspath("../src")) diff --git a/pyproject.toml b/pyproject.toml index 1e9d4e8..56d6d0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,3 +102,17 @@ source = ["src/adapt"] [tool.coverage.report] ignore_errors = true precision = 2 + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = [ + "E", # syntax / style errors + "F", # logic bugs (undefined vars, etc.) + "I", # import hygiene + "B", # bug patterns + "UP", # modern Python upgrades + "SIM", # simplifications (often removes subtle mistakes) +] diff --git a/src/adapt/api/client.py b/src/adapt/api/client.py index 0bd642c..4690cba 100644 --- a/src/adapt/api/client.py +++ b/src/adapt/api/client.py @@ -39,19 +39,18 @@ print(f"Got {len(batch)} new rows") """ -import json import logging import time -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any import duckdb import pandas as pd import xarray as xr -from adapt.persistence.registry import RepositoryRegistry from adapt.persistence.catalog import RadarCatalog +from adapt.persistence.registry import RepositoryRegistry from adapt.persistence.track_store import TrackStore __all__ = ['DataClient'] @@ -77,7 +76,7 @@ class DataClient: >>> df = client.latest("analysis2d", radar="KHTX") """ - def __init__(self, repository_root: Union[str, Path]): + def __init__(self, repository_root: str | Path): """Initialize DataClient from repository root. Parameters @@ -94,10 +93,10 @@ def __init__(self, repository_root: Union[str, Path]): self.registry = RepositoryRegistry.get_instance(self.root_dir) # DuckDB connection for SQL queries - self._duckdb_conn: Optional[duckdb.DuckDBPyConnection] = None + self._duckdb_conn: duckdb.DuckDBPyConnection | None = None # Cache of radar catalogs - self._radar_catalogs: Dict[str, RadarCatalog] = {} + self._radar_catalogs: dict[str, RadarCatalog] = {} logger.info(f"DataClient initialized at {self.root_dir}") @@ -132,7 +131,7 @@ def is_initialized(self) -> bool: registry_path = self.root_dir / "adapt_registry.db" return registry_path.exists() - def get_repository_info(self) -> Dict[str, Any]: + def get_repository_info(self) -> dict[str, Any]: """Get repository summary information. Returns @@ -169,7 +168,7 @@ def get_repository_info(self) -> Dict[str, Any]: # Discovery Methods # ========================================================================= - def list_runs(self, radar: Optional[str] = None) -> pd.DataFrame: + def list_runs(self, radar: str | None = None) -> pd.DataFrame: """List all runs, optionally filtered by radar. Parameters @@ -184,7 +183,7 @@ def list_runs(self, radar: Optional[str] = None) -> pd.DataFrame: """ return self.registry.list_runs(radar=radar) - def list_radars(self) -> List[str]: + def list_radars(self) -> list[str]: """List all registered radars. Returns @@ -195,7 +194,7 @@ def list_radars(self) -> List[str]: radars_df = self.registry.list_radars() return radars_df['radar'].tolist() if not radars_df.empty else [] - def get_radar_info(self, radar: str) -> Dict[str, Any]: + def get_radar_info(self, radar: str) -> dict[str, Any]: """Get detailed information for a specific radar. Parameters @@ -265,7 +264,7 @@ def get_radar_info(self, radar: str) -> Dict[str, Any]: 'num_scans': num_scans, } - def get_run_info(self, run_id: str, radar: Optional[str] = None) -> Dict[str, Any]: + def get_run_info(self, run_id: str, radar: str | None = None) -> dict[str, Any]: """Get detailed information for a specific run. Parameters @@ -336,7 +335,7 @@ def get_run_info(self, run_id: str, radar: Optional[str] = None) -> Dict[str, An 'num_scans': num_scans, } - def item_types(self) -> List[str]: + def item_types(self) -> list[str]: """List registered item types. Returns @@ -346,7 +345,7 @@ def item_types(self) -> List[str]: """ return self.registry.list_item_types() - def fields(self, item_type: str, radar: Optional[str] = None) -> List[str]: + def fields(self, item_type: str, radar: str | None = None) -> list[str]: """Get column names for a Parquet table item type. Parameters @@ -390,7 +389,7 @@ def fields(self, item_type: str, radar: Optional[str] = None) -> List[str]: return [] - def status(self, run_id: Optional[str] = None, radar: Optional[str] = None) -> Dict: + def status(self, run_id: str | None = None, radar: str | None = None) -> dict: """Get processing status/progress. Parameters @@ -432,8 +431,8 @@ def status(self, run_id: Optional[str] = None, radar: Optional[str] = None) -> D def latest( self, item_type: str, - radar: Optional[str] = None - ) -> Union[pd.DataFrame, xr.Dataset]: + radar: str | None = None + ) -> pd.DataFrame | xr.Dataset: """Load the most recent item of a given type. Parameters @@ -481,7 +480,7 @@ def latest( else: raise ValueError(f"Unknown file format for {file_path}") - def query(self, sql: str, radar: Optional[str] = None) -> pd.DataFrame: + def query(self, sql: str, radar: str | None = None) -> pd.DataFrame: """Execute SQL query on Parquet tables. Only SELECT queries are allowed. Dynamically creates DuckDB views @@ -559,7 +558,7 @@ def query(self, sql: str, radar: Optional[str] = None) -> pd.DataFrame: def list_scans( self, item_type: str, - radar: Optional[str] = None, + radar: str | None = None, limit: int = 50 ) -> pd.DataFrame: """List available scans with timestamps. @@ -599,10 +598,10 @@ def list_scans( def get_scan_at( self, - scan_time: Union[str, datetime], + scan_time: str | datetime, item_type: str, - radar: Optional[str] = None - ) -> Union[pd.DataFrame, xr.Dataset]: + radar: str | None = None + ) -> pd.DataFrame | xr.Dataset: """Load a specific scan by timestamp. Parameters @@ -678,7 +677,7 @@ def get_scan_at( # Cell Tracking Methods # ========================================================================= - def _track_store(self, radar: Optional[str] = None) -> TrackStore: + def _track_store(self, radar: str | None = None) -> TrackStore: if not radar: radars = self.list_radars() if not radars: @@ -691,7 +690,7 @@ def cells_by_scan( self, run_id: str, scan_time: datetime, - radar: Optional[str] = None, + radar: str | None = None, ) -> pd.DataFrame: """All tracked cells for a single scan.""" return self._track_store(radar).get_cells_by_scan(run_id, scan_time) @@ -700,7 +699,7 @@ def track_history( self, run_id: str, cell_uid: str, - radar: Optional[str] = None, + radar: str | None = None, ) -> pd.DataFrame: """All scan rows for one track, ordered by scan_time.""" return self._track_store(radar).get_track_history(run_id, cell_uid) @@ -708,8 +707,8 @@ def track_history( def cell_events( self, run_id: str, - cell_uid: Optional[str] = None, - radar: Optional[str] = None, + cell_uid: str | None = None, + radar: str | None = None, ) -> pd.DataFrame: """Lineage events for a run, optionally filtered to one cell_uid.""" return self._track_store(radar).get_cell_events(run_id, cell_uid) @@ -717,7 +716,7 @@ def cell_events( def cell_tracks( self, run_id: str, - radar: Optional[str] = None, + radar: str | None = None, ) -> pd.DataFrame: """Lifecycle summary for all tracks in a run.""" return self._track_store(radar).get_cell_tracks(run_id) @@ -726,7 +725,7 @@ def cell_tracks( # Pipeline Status Methods # ========================================================================= - def is_pipeline_running(self, radar: Optional[str] = None) -> bool: + def is_pipeline_running(self, radar: str | None = None) -> bool: """Check if pipeline is actively processing. Checks for active run status and recent progress updates. @@ -779,7 +778,7 @@ def is_pipeline_running(self, radar: Optional[str] = None) -> bool: last_update = datetime.fromisoformat( progress['last_updated'].replace('Z', '+00:00') ) - age_seconds = (datetime.now(timezone.utc) - last_update).total_seconds() + age_seconds = (datetime.now(UTC) - last_update).total_seconds() return age_seconds < 60 except Exception as e: @@ -789,9 +788,9 @@ def is_pipeline_running(self, radar: Optional[str] = None) -> bool: def get_pipeline_progress( self, - radar: Optional[str] = None, - run_id: Optional[str] = None - ) -> Dict[str, Any]: + radar: str | None = None, + run_id: str | None = None + ) -> dict[str, Any]: """Get detailed pipeline progress. Parameters @@ -864,9 +863,9 @@ def get_pipeline_progress( def get_scan_bundle( self, - scan_time: Union[str, datetime], - radar: Optional[str] = None - ) -> Dict[str, Any]: + scan_time: str | datetime, + radar: str | None = None + ) -> dict[str, Any]: """Get all data for a specific scan in a single call. Returns all linked data products for a scan: segmentation, cells DataFrame, @@ -908,7 +907,7 @@ def get_scan_bundle( catalog = self._get_radar_catalog(radar) - bundle: Dict[str, Any] = { + bundle: dict[str, Any] = { 'scan_time': scan_time_dt.isoformat() if isinstance(scan_time_dt, datetime) else scan_time, 'radar': radar, 'segmentation2d': None, @@ -970,8 +969,8 @@ def _get_scan_bundle_fallback( self, scan_time: datetime, radar: str, - bundle: Dict[str, Any] - ) -> Dict[str, Any]: + bundle: dict[str, Any] + ) -> dict[str, Any]: """Fallback scan bundle using item queries (for legacy data).""" scan_time_str = scan_time.isoformat() catalog = self._get_radar_catalog(radar) @@ -1015,7 +1014,7 @@ def _get_scan_bundle_fallback( return bundle - def _get_item_by_id(self, radar: str, item_id: str) -> Optional[Dict]: + def _get_item_by_id(self, radar: str, item_id: str) -> dict | None: """Get item record by ID.""" catalog = self._get_radar_catalog(radar) conn = catalog._get_connection() @@ -1030,9 +1029,9 @@ def _get_item_by_id(self, radar: str, item_id: str) -> Optional[Dict]: def list_scan_times( self, - radar: Optional[str] = None, - start_time: Optional[Union[str, datetime]] = None, - end_time: Optional[Union[str, datetime]] = None, + radar: str | None = None, + start_time: str | datetime | None = None, + end_time: str | datetime | None = None, limit: int = 100 ) -> pd.DataFrame: """List available scan times from scans table or items fallback. @@ -1095,8 +1094,8 @@ def list_scan_times( def _list_scan_times_from_items( self, radar: str, - start_time: Optional[datetime], - end_time: Optional[datetime], + start_time: datetime | None, + end_time: datetime | None, limit: int ) -> pd.DataFrame: """Fallback: get scan times from items table.""" @@ -1136,7 +1135,7 @@ def stream( self, sql: str, poll_interval: int = 5, - radar: Optional[str] = None + radar: str | None = None ): """Stream new results from a SQL query (generator). diff --git a/src/adapt/cli.py b/src/adapt/cli.py index 9d0719b..1385a09 100644 --- a/src/adapt/cli.py +++ b/src/adapt/cli.py @@ -26,7 +26,6 @@ import time from pathlib import Path - # --------------------------------------------------------------------------- # Single-instance enforcement # --------------------------------------------------------------------------- @@ -281,7 +280,6 @@ def _build_config_parser(sub: argparse.ArgumentParser) -> None: def _config_cmd(args: argparse.Namespace) -> None: """Write a config.yaml template to the specified path.""" from datetime import datetime - import os try: cwd = Path.cwd() diff --git a/src/adapt/configuration/schemas/cli.py b/src/adapt/configuration/schemas/cli.py index 50bf50e..b2190a9 100644 --- a/src/adapt/configuration/schemas/cli.py +++ b/src/adapt/configuration/schemas/cli.py @@ -9,8 +9,10 @@ This schema handles command-line arguments parsed by argparse. """ -from typing import Literal, Optional +from typing import Literal + from pydantic import Field, model_validator + from adapt.configuration.schemas.base import AdaptBaseModel @@ -44,13 +46,13 @@ class CLIConfig(AdaptBaseModel): internal = resolve_config(param_cfg, user_cfg, cli_cfg) """ - mode: Optional[Literal["realtime", "historical"]] = None - radar: Optional[str] = None - base_dir: Optional[str] = None - start_time: Optional[str] = None - end_time: Optional[str] = None - log_level: Optional[Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]] = None - run_id: Optional[str] = Field( + mode: Literal["realtime", "historical"] | None = None + radar: str | None = None + base_dir: str | None = None + start_time: str | None = None + end_time: str | None = None + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | None = None + run_id: str | None = Field( default=None, description="Optional run ID for continuation (format: YYYYMONDD-HHMM-RADAR)" ) diff --git a/src/adapt/configuration/schemas/directories.py b/src/adapt/configuration/schemas/directories.py index 18f993c..5359a08 100644 --- a/src/adapt/configuration/schemas/directories.py +++ b/src/adapt/configuration/schemas/directories.py @@ -9,12 +9,11 @@ Author: Bhupendra Raut """ -from pathlib import Path -from typing import Dict from datetime import datetime +from pathlib import Path -def setup_output_directories(base_dir: str) -> Dict[str, Path]: +def setup_output_directories(base_dir: str) -> dict[str, Path]: """Setup output directory structure. Creates the standard Adapt directory layout under base_dir. @@ -66,7 +65,7 @@ def setup_output_directories(base_dir: str) -> Dict[str, Path]: def get_nexrad_path( - output_dirs: Dict[str, Path], + output_dirs: dict[str, Path], radar: str, filename: str, scan_time: datetime @@ -98,7 +97,7 @@ def get_nexrad_path( def get_netcdf_path( - output_dirs: Dict[str, Path], + output_dirs: dict[str, Path], radar: str, filename: str, scan_time: datetime @@ -130,7 +129,7 @@ def get_netcdf_path( def get_analysis_path( - output_dirs: Dict[str, Path], + output_dirs: dict[str, Path], radar: str, filename: str = None, scan_time: datetime = None, @@ -168,7 +167,7 @@ def get_analysis_path( def get_plot_path( - output_dirs: Dict[str, Path], + output_dirs: dict[str, Path], radar: str, plot_type: str = None, scan_time: datetime = None, @@ -214,7 +213,7 @@ def get_plot_path( def get_log_path( - output_dirs: Dict[str, Path], + output_dirs: dict[str, Path], radar: str = None, log_name: str = None ) -> Path: diff --git a/src/adapt/configuration/schemas/initialization.py b/src/adapt/configuration/schemas/initialization.py index b8163dc..27860fe 100644 --- a/src/adapt/configuration/schemas/initialization.py +++ b/src/adapt/configuration/schemas/initialization.py @@ -14,18 +14,17 @@ """ import importlib.util -import shutil import json import re +import shutil +from datetime import UTC, datetime from pathlib import Path -from typing import Dict -from datetime import datetime, timezone -from adapt.configuration.schemas.resolve import resolve_config -from adapt.configuration.schemas.param import ParamConfig -from adapt.configuration.schemas.user import UserConfig from adapt.configuration.schemas.cli import CLIConfig from adapt.configuration.schemas.internal import InternalConfig +from adapt.configuration.schemas.param import ParamConfig +from adapt.configuration.schemas.resolve import resolve_config +from adapt.configuration.schemas.user import UserConfig from adapt.persistence import DataRepository from adapt.persistence.registry import RepositoryRegistry @@ -69,7 +68,7 @@ def _load_user_config_dict(config_path: str) -> dict: raise ValueError(f"No CONFIG dict found in {path}") -def _setup_output_directories(base_dir: str) -> Dict[str, Path]: +def _setup_output_directories(base_dir: str) -> dict[str, Path]: """Setup output directory structure. Creates the standard Adapt directory layout under base_dir. @@ -116,7 +115,7 @@ def _handle_rerun_cleanup(base_dir: str, radar: str, rerun: bool) -> None: print("Radar output cleaned") -def _persist_runtime_config(config: InternalConfig, run_id: str, output_dirs: Dict[str, Path]) -> None: +def _persist_runtime_config(config: InternalConfig, run_id: str, output_dirs: dict[str, Path]) -> None: """Persist final runtime configuration to output directory with run ID. Saves the complete resolved configuration for reproducibility and debugging. @@ -130,7 +129,7 @@ def _persist_runtime_config(config: InternalConfig, run_id: str, output_dirs: Di # Add run_id to config dict for persistence config_dict = config.model_dump() config_dict["run_id"] = run_id - config_dict["created_at"] = datetime.now(timezone.utc).isoformat() + config_dict["created_at"] = datetime.now(UTC).isoformat() with open(config_file, 'w') as f: json.dump(config_dict, f, indent=2, default=str) diff --git a/src/adapt/configuration/schemas/internal.py b/src/adapt/configuration/schemas/internal.py index 1676d76..0e140e9 100644 --- a/src/adapt/configuration/schemas/internal.py +++ b/src/adapt/configuration/schemas/internal.py @@ -10,10 +10,11 @@ runtime code - everything is explicit here. """ -from typing import Literal, Optional -from pydantic import Field, ConfigDict -from adapt.configuration.schemas.base import AdaptBaseModel +from typing import Literal + +from pydantic import ConfigDict, Field +from adapt.configuration.schemas.base import AdaptBaseModel # ============================================================================= # Nested Configuration Models (Runtime) @@ -32,8 +33,8 @@ class InternalDownloaderConfig(AdaptBaseModel): latest_files: int latest_minutes: int poll_interval_sec: int - start_time: Optional[str] - end_time: Optional[str] + start_time: str | None + end_time: str | None min_file_size: int @@ -52,7 +53,7 @@ class InternalSegmenterConfig(AdaptBaseModel): method: Literal["threshold"] threshold: float min_cellsize_gridpoint: int - max_cellsize_gridpoint: Optional[int] + max_cellsize_gridpoint: int | None closing_kernel: tuple[int, int] filter_by_size: bool h_maxima: float @@ -192,8 +193,8 @@ def __init__(self, config: InternalConfig): mode: Literal["realtime", "historical"] base_dir: str - run_id: Optional[str] = Field(default=None, description="Unique run identifier generated during initialization") - output_dirs: Optional[dict[str, str]] = Field(default=None, description="Output directory paths from initialization") + run_id: str | None = Field(default=None, description="Unique run identifier generated during initialization") + output_dirs: dict[str, str] | None = Field(default=None, description="Output directory paths from initialization") reader: InternalReaderConfig downloader: InternalDownloaderConfig regridder: InternalRegridderConfig diff --git a/src/adapt/configuration/schemas/param.py b/src/adapt/configuration/schemas/param.py index 0495bed..bbd597e 100644 --- a/src/adapt/configuration/schemas/param.py +++ b/src/adapt/configuration/schemas/param.py @@ -10,10 +10,11 @@ Runtime code NEVER reads from ParamConfig directly - it only receives InternalConfig. """ -from typing import Literal, Optional +from typing import Literal + from pydantic import Field, field_validator -from adapt.configuration.schemas.base import AdaptBaseModel +from adapt.configuration.schemas.base import AdaptBaseModel # ============================================================================= # Nested Configuration Models @@ -26,13 +27,13 @@ class ReaderConfig(AdaptBaseModel): class DownloaderConfig(AdaptBaseModel): """NEXRAD data downloader configuration.""" - radar: Optional[str] = None - output_dir: Optional[str] = None + radar: str | None = None + output_dir: str | None = None latest_files: int = Field(5, ge=1, description="Number of latest files to keep") latest_minutes: int = Field(60, ge=1, description="Time window in minutes") poll_interval_sec: int = Field(300, ge=1, description="Polling interval in seconds") - start_time: Optional[str] = None - end_time: Optional[str] = None + start_time: str | None = None + end_time: str | None = None min_file_size: int = Field(1024, ge=1, description="Minimum file size in bytes to consider valid") @@ -55,7 +56,7 @@ class SegmenterConfig(AdaptBaseModel): method: Literal["threshold"] = "threshold" threshold: float = Field(30.0, description="Reflectivity threshold in dBZ") min_cellsize_gridpoint: int = Field(5, ge=1) - max_cellsize_gridpoint: Optional[int] = Field(None, ge=1) + max_cellsize_gridpoint: int | None = Field(None, ge=1) closing_kernel: tuple[int, int] = (1, 1) filter_by_size: bool = True h_maxima: float = Field(5.0, gt=0, description="h-maxima height for cell seeding (dBZ)") diff --git a/src/adapt/configuration/schemas/resolve.py b/src/adapt/configuration/schemas/resolve.py index c930ee9..4880ee6 100644 --- a/src/adapt/configuration/schemas/resolve.py +++ b/src/adapt/configuration/schemas/resolve.py @@ -31,11 +31,10 @@ - Result: {"radar_variables": ["C"], "threshold": 40} # List replaced """ -from typing import Union, Optional, Any -from adapt.configuration.schemas.param import ParamConfig -from adapt.configuration.schemas.user import UserConfig from adapt.configuration.schemas.cli import CLIConfig from adapt.configuration.schemas.internal import InternalConfig +from adapt.configuration.schemas.param import ParamConfig +from adapt.configuration.schemas.user import UserConfig def deep_merge(base: dict, *overrides: dict) -> dict: @@ -78,9 +77,9 @@ def deep_merge(base: dict, *overrides: dict) -> dict: def resolve_config( - param_cfg: Union[dict, ParamConfig], - user_cfg: Optional[Union[dict, UserConfig]] = None, - cli_cfg: Optional[Union[dict, CLIConfig]] = None, + param_cfg: dict | ParamConfig, + user_cfg: dict | UserConfig | None = None, + cli_cfg: dict | CLIConfig | None = None, ) -> InternalConfig: """Resolve final runtime configuration from param, user, and CLI configs. diff --git a/src/adapt/configuration/schemas/user.py b/src/adapt/configuration/schemas/user.py index dd42a1e..12f19e7 100644 --- a/src/adapt/configuration/schemas/user.py +++ b/src/adapt/configuration/schemas/user.py @@ -11,20 +11,22 @@ both uppercase and lowercase keys, integers where floats are expected, etc. """ -from typing import Literal, Optional, Any +from typing import Any, Literal + from pydantic import Field, field_validator, model_validator + from adapt.configuration.schemas.base import AdaptBaseModel class UserSegmenterConfig(AdaptBaseModel): """User-facing segmentation config with aliases.""" - method: Optional[str] = None - threshold: Optional[float] = None - min_cellsize_gridpoint: Optional[int] = None - max_cellsize_gridpoint: Optional[int] = None - closing_kernel: Optional[tuple[int, int]] = None - filter_by_size: Optional[bool] = None - h_maxima: Optional[float] = None + method: str | None = None + threshold: float | None = None + min_cellsize_gridpoint: int | None = None + max_cellsize_gridpoint: int | None = None + closing_kernel: tuple[int, int] | None = None + filter_by_size: bool | None = None + h_maxima: float | None = None @field_validator("threshold", mode="before") @classmethod @@ -45,9 +47,9 @@ def normalize_method(cls, v): class UserGlobalConfig(AdaptBaseModel): """User-facing global config.""" - z_level: Optional[float] = None - var_names: Optional[dict[str, str]] = None - coord_names: Optional[dict[str, str]] = None + z_level: float | None = None + var_names: dict[str, str] | None = None + coord_names: dict[str, str] | None = None @field_validator("z_level", mode="before") @classmethod @@ -60,13 +62,13 @@ def coerce_z_level(cls, v): class UserProjectorConfig(AdaptBaseModel): """User-facing projector config.""" - method: Optional[str] = None - max_time_interval_minutes: Optional[int] = None - max_projection_steps: Optional[int] = None - nan_fill_value: Optional[float] = None - flow_params: Optional[dict[str, Any]] = None - min_motion_threshold: Optional[float] = None - max_flow_magnitude: Optional[float] = None + method: str | None = None + max_time_interval_minutes: int | None = None + max_projection_steps: int | None = None + nan_fill_value: float | None = None + flow_params: dict[str, Any] | None = None + min_motion_threshold: float | None = None + max_flow_magnitude: float | None = None @field_validator("method", mode="before") @classmethod @@ -79,29 +81,29 @@ def normalize_method(cls, v): class UserRegridderConfig(AdaptBaseModel): """User-facing regridder config.""" - grid_shape: Optional[tuple[int, int, int]] = None - grid_limits: Optional[tuple[tuple[float, float], tuple[float, float], tuple[float, float]]] = None - roi_func: Optional[str] = None - min_radius: Optional[float] = None - weighting_function: Optional[str] = None - save_netcdf: Optional[bool] = None + grid_shape: tuple[int, int, int] | None = None + grid_limits: tuple[tuple[float, float], tuple[float, float], tuple[float, float]] | None = None + roi_func: str | None = None + min_radius: float | None = None + weighting_function: str | None = None + save_netcdf: bool | None = None class UserDownloaderConfig(AdaptBaseModel): """User-facing downloader config.""" - radar: Optional[str] = None - output_dir: Optional[str] = None - latest_files: Optional[int] = None - latest_minutes: Optional[int] = None - poll_interval_sec: Optional[int] = None - start_time: Optional[str] = None - end_time: Optional[str] = None + radar: str | None = None + output_dir: str | None = None + latest_files: int | None = None + latest_minutes: int | None = None + poll_interval_sec: int | None = None + start_time: str | None = None + end_time: str | None = None class UserAnalyzerConfig(AdaptBaseModel): """User-facing analyzer config.""" - radar_variables: Optional[list[str]] = None - exclude_fields: Optional[list[str]] = None + radar_variables: list[str] | None = None + exclude_fields: list[str] | None = None class UserConfig(AdaptBaseModel): @@ -126,46 +128,46 @@ class UserConfig(AdaptBaseModel): """ # Top-level operational settings - mode: Optional[Literal["realtime", "historical"]] = Field(None, alias="MODE") - radar: Optional[str] = Field(None, alias="RADAR_ID") - base_dir: Optional[str] = Field(None, alias="BASE_DIR") + mode: Literal["realtime", "historical"] | None = Field(None, alias="MODE") + radar: str | None = Field(None, alias="RADAR_ID") + base_dir: str | None = Field(None, alias="BASE_DIR") # Realtime settings - latest_files: Optional[int] = Field(None, alias="LATEST_FILES") - latest_minutes: Optional[int] = Field(None, alias="LATEST_MINUTES") - poll_interval_sec: Optional[int] = Field(None, alias="POLL_INTERVAL_SEC") + latest_files: int | None = Field(None, alias="LATEST_FILES") + latest_minutes: int | None = Field(None, alias="LATEST_MINUTES") + poll_interval_sec: int | None = Field(None, alias="POLL_INTERVAL_SEC") # Historical settings - start_time: Optional[str] = Field(None, alias="START_TIME") - end_time: Optional[str] = Field(None, alias="END_TIME") + start_time: str | None = Field(None, alias="START_TIME") + end_time: str | None = Field(None, alias="END_TIME") # Grid settings (flat aliases) - grid_shape: Optional[tuple[int, int, int]] = Field(None, alias="GRID_SHAPE") - grid_limits: Optional[tuple[tuple[float, float], tuple[float, float], tuple[float, float]]] = Field(None, alias="GRID_LIMITS") + grid_shape: tuple[int, int, int] | None = Field(None, alias="GRID_SHAPE") + grid_limits: tuple[tuple[float, float], tuple[float, float], tuple[float, float]] | None = Field(None, alias="GRID_LIMITS") # Segmentation settings (flat aliases) - z_level: Optional[float] = Field(None, alias="Z_LEVEL") - reflectivity_var: Optional[str] = Field(None, alias="REFLECTIVITY_VAR") - segmentation_method: Optional[str] = Field(None, alias="SEGMENTATION_METHOD") - threshold: Optional[float] = Field(None, alias="THRESHOLD_DBZ") - min_cellsize_gridpoint: Optional[int] = Field(None, alias="MIN_CELLSIZE_GRIDPOINT") - max_cellsize_gridpoint: Optional[int] = Field(None, alias="MAX_CELLSIZE_GRIDPOINT") + z_level: float | None = Field(None, alias="Z_LEVEL") + reflectivity_var: str | None = Field(None, alias="REFLECTIVITY_VAR") + segmentation_method: str | None = Field(None, alias="SEGMENTATION_METHOD") + threshold: float | None = Field(None, alias="THRESHOLD_DBZ") + min_cellsize_gridpoint: int | None = Field(None, alias="MIN_CELLSIZE_GRIDPOINT") + max_cellsize_gridpoint: int | None = Field(None, alias="MAX_CELLSIZE_GRIDPOINT") # Projection settings (flat aliases) - projection_method: Optional[str] = Field(None, alias="PROJECTION_METHOD") - max_projection_steps: Optional[int] = Field(None, alias="MAX_PROJECTION_STEPS") + projection_method: str | None = Field(None, alias="PROJECTION_METHOD") + max_projection_steps: int | None = Field(None, alias="MAX_PROJECTION_STEPS") # Analyzer settings (flat aliases) - radar_variables: Optional[list[str]] = None - exclude_fields: Optional[list[str]] = None + radar_variables: list[str] | None = None + exclude_fields: list[str] | None = None # Nested overrides (advanced users) - downloader: Optional[UserDownloaderConfig] = None - regridder: Optional[UserRegridderConfig] = None - segmenter: Optional[UserSegmenterConfig] = None - global_: Optional[UserGlobalConfig] = Field(None, alias="global") - projector: Optional[UserProjectorConfig] = None - analyzer: Optional[UserAnalyzerConfig] = None + downloader: UserDownloaderConfig | None = None + regridder: UserRegridderConfig | None = None + segmenter: UserSegmenterConfig | None = None + global_: UserGlobalConfig | None = Field(None, alias="global") + projector: UserProjectorConfig | None = None + analyzer: UserAnalyzerConfig | None = None model_config = AdaptBaseModel.model_config.copy() # Allow forgiving input dictionaries (ignore unknown legacy keys) @@ -180,10 +182,7 @@ def infer_historical_mode_from_times(self): """ if self.mode is None: # Check top-level times - if self.start_time and self.end_time: - self.mode = "historical" - # Check nested downloader times - elif self.downloader and (self.downloader.start_time and self.downloader.end_time): + if self.start_time and self.end_time or self.downloader and (self.downloader.start_time and self.downloader.end_time): self.mode = "historical" return self diff --git a/src/adapt/execution/graph/__init__.py b/src/adapt/execution/graph/__init__.py index 2663529..16eae03 100644 --- a/src/adapt/execution/graph/__init__.py +++ b/src/adapt/execution/graph/__init__.py @@ -3,8 +3,8 @@ """Execution graph: build and run DAGs from module declarations.""" -from adapt.execution.graph.node import Node from adapt.execution.graph.builder import GraphBuilder from adapt.execution.graph.executor import GraphExecutor +from adapt.execution.graph.node import Node __all__ = ['Node', 'GraphBuilder', 'GraphExecutor'] diff --git a/src/adapt/execution/graph/builder.py b/src/adapt/execution/graph/builder.py index 909e687..7206908 100644 --- a/src/adapt/execution/graph/builder.py +++ b/src/adapt/execution/graph/builder.py @@ -8,7 +8,7 @@ nodes that produce its required inputs. """ -from typing import Dict, List, TYPE_CHECKING +from typing import TYPE_CHECKING from adapt.execution.graph.node import Node @@ -36,10 +36,10 @@ class GraphBuilder: nodes = builder.build() """ - def __init__(self, modules: List["BaseModule"]) -> None: + def __init__(self, modules: list["BaseModule"]) -> None: self.modules = modules - def build(self) -> List[Node]: + def build(self) -> list[Node]: """Build and return the list of connected nodes. Returns @@ -49,10 +49,10 @@ def build(self) -> List[Node]: Nodes are returned in insertion order; execution order is determined by the GraphExecutor. """ - nodes: Dict[str, Node] = {m.name: Node(m) for m in self.modules} + nodes: dict[str, Node] = {m.name: Node(m) for m in self.modules} # Map each output key → the node that produces it - output_map: Dict[str, Node] = {} + output_map: dict[str, Node] = {} for node in nodes.values(): for output in node.outputs: if output in output_map: diff --git a/src/adapt/execution/graph/executor.py b/src/adapt/execution/graph/executor.py index 40dd610..1e48608 100644 --- a/src/adapt/execution/graph/executor.py +++ b/src/adapt/execution/graph/executor.py @@ -16,7 +16,6 @@ """ import logging -from typing import List, Set from adapt.execution.graph.node import Node @@ -38,7 +37,7 @@ class GraphExecutor: result_context = executor.run(initial_context={}) """ - def __init__(self, nodes: List[Node]) -> None: + def __init__(self, nodes: list[Node]) -> None: self.nodes = nodes def run(self, context: dict) -> dict: @@ -62,7 +61,7 @@ def run(self, context: dict) -> dict: If the graph contains a cycle (nodes that can never be ready). """ context = dict(context) # shallow copy — don't mutate caller's dict - completed: Set[str] = set() + completed: set[str] = set() max_iterations = len(self.nodes) ** 2 + len(self.nodes) + 1 iteration = 0 diff --git a/src/adapt/execution/graph/node.py b/src/adapt/execution/graph/node.py index 8f599b4..3e6a5c2 100644 --- a/src/adapt/execution/graph/node.py +++ b/src/adapt/execution/graph/node.py @@ -8,7 +8,7 @@ on it (notified when this node completes). """ -from typing import List, TYPE_CHECKING +from typing import TYPE_CHECKING if TYPE_CHECKING: from adapt.modules.base import BaseModule @@ -37,10 +37,10 @@ class Node: def __init__(self, module: "BaseModule") -> None: self.module = module - self.inputs: List[str] = list(module.inputs) - self.outputs: List[str] = list(module.outputs) - self.dependencies: List["Node"] = [] - self.dependents: List["Node"] = [] + self.inputs: list[str] = list(module.inputs) + self.outputs: list[str] = list(module.outputs) + self.dependencies: list[Node] = [] + self.dependents: list[Node] = [] @property def name(self) -> str: diff --git a/src/adapt/execution/module_registry.py b/src/adapt/execution/module_registry.py index 50c91d8..8568262 100644 --- a/src/adapt/execution/module_registry.py +++ b/src/adapt/execution/module_registry.py @@ -26,7 +26,7 @@ class DetectModule(BaseModule): GraphExecutor(nodes).run(context) """ -from typing import Dict, List, Type, TYPE_CHECKING +from typing import TYPE_CHECKING if TYPE_CHECKING: from adapt.modules.base import BaseModule @@ -49,9 +49,9 @@ class ModuleRegistry: """ def __init__(self) -> None: - self._modules: Dict[str, Type["BaseModule"]] = {} + self._modules: dict[str, type[BaseModule]] = {} - def register(self, module_class: Type["BaseModule"]) -> None: + def register(self, module_class: type["BaseModule"]) -> None: """Register a module class by its ``name`` attribute. Parameters @@ -79,7 +79,7 @@ def register(self, module_class: Type["BaseModule"]) -> None: ) self._modules[name] = module_class - def create_modules(self) -> List["BaseModule"]: + def create_modules(self) -> list["BaseModule"]: """Instantiate and return all registered modules. Returns @@ -89,7 +89,7 @@ def create_modules(self) -> List["BaseModule"]: """ return [cls() for cls in self._modules.values()] - def get(self, name: str) -> Type["BaseModule"]: + def get(self, name: str) -> type["BaseModule"]: """Return the module class registered under ``name``. Raises @@ -101,7 +101,7 @@ def get(self, name: str) -> Type["BaseModule"]: raise KeyError(f"Module '{name}' is not registered.") return self._modules[name] - def list_modules(self) -> List[str]: + def list_modules(self) -> list[str]: """Return names of all registered modules.""" return list(self._modules.keys()) diff --git a/src/adapt/execution/pipeline_builder.py b/src/adapt/execution/pipeline_builder.py index 74402d3..53b6838 100644 --- a/src/adapt/execution/pipeline_builder.py +++ b/src/adapt/execution/pipeline_builder.py @@ -20,13 +20,13 @@ import importlib import logging from pathlib import Path -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import yaml -from adapt.execution.module_registry import registry from adapt.execution.graph.builder import GraphBuilder from adapt.execution.graph.executor import GraphExecutor +from adapt.execution.module_registry import registry if TYPE_CHECKING: from adapt.configuration.schemas import InternalConfig @@ -91,7 +91,7 @@ class NexradPipeline: def __init__( self, config: "InternalConfig", - output_dirs: Optional[dict] = None, + output_dirs: dict | None = None, ) -> None: self.config = config self.output_dirs = output_dirs or {} diff --git a/src/adapt/gui/__init__.py b/src/adapt/gui/__init__.py index b7f2b52..b51955b 100644 --- a/src/adapt/gui/__init__.py +++ b/src/adapt/gui/__init__.py @@ -23,6 +23,6 @@ - Hover cell info display """ -from adapt.gui.dashboard import main, AdaptDashboard +from adapt.gui.dashboard import AdaptDashboard, main __all__ = ['main', 'AdaptDashboard'] diff --git a/src/adapt/gui/dashboard.py b/src/adapt/gui/dashboard.py index 7e21026..81ee343 100644 --- a/src/adapt/gui/dashboard.py +++ b/src/adapt/gui/dashboard.py @@ -64,7 +64,7 @@ def _suppress_osx_stderr(): # ── Tkinter ─────────────────────────────────────────────────────────────────── import tkinter as tk -from tkinter import ttk, filedialog, scrolledtext, messagebox +from tkinter import filedialog, messagebox, scrolledtext, ttk # ── Optional deps ───────────────────────────────────────────────────────────── try: @@ -77,8 +77,8 @@ def _suppress_osx_stderr(): import matplotlib matplotlib.use('TkAgg') import cmweather.cm # registers ChaseSpectral and other radar colormaps — must follow use() - import matplotlib.pyplot as plt import matplotlib.dates as mdates + import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk HAS_MPL = True except ImportError: @@ -1065,8 +1065,9 @@ def _load_cells_data(self, repo, radar): db_path = Path(repo) / radar / "catalog.db" if db_path.exists(): try: - from adapt.persistence.track_store import TrackStore import sqlite3 + + from adapt.persistence.track_store import TrackStore conn = sqlite3.connect(str(db_path)) conn.row_factory = sqlite3.Row run_row = conn.execute( diff --git a/src/adapt/modules/acquisition/module.py b/src/adapt/modules/acquisition/module.py index b4289de..b9b5fa4 100644 --- a/src/adapt/modules/acquisition/module.py +++ b/src/adapt/modules/acquisition/module.py @@ -7,12 +7,12 @@ in realtime or historical batches. Deduplicates files to avoid re-downloading. """ +import logging import threading import time -import logging -from datetime import datetime, timezone, timedelta +from datetime import UTC, datetime, timedelta from pathlib import Path -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from nexradaws import NexradAwsInterface @@ -140,7 +140,7 @@ def __init__( self.result_queue = result_queue self.conn = conn or NexradAwsInterface() # injectable time helpers for testing - self._clock = clock or (lambda: datetime.now(timezone.utc)) + self._clock = clock or (lambda: datetime.now(UTC)) self._sleep = sleeper or time.sleep self._stop_event = threading.Event() @@ -284,7 +284,7 @@ def run(self): while not self.stopped(): try: self._download_task() - except Exception as e: + except Exception: logger.exception("Download task failed") # Historical: exit after completion @@ -426,7 +426,7 @@ def _check_radar_available(self, start: datetime, end: datetime) -> None: all_checks_failed = True # Track if all checks failed (don't warn in that case) while current <= end_date: - dt = datetime(current.year, current.month, current.day, tzinfo=timezone.utc) + dt = datetime(current.year, current.month, current.day, tzinfo=UTC) y = dt.strftime("%Y") m = dt.strftime("%m") d = dt.strftime("%d") diff --git a/src/adapt/modules/analysis/contracts.py b/src/adapt/modules/analysis/contracts.py index 3afcc51..959236b 100644 --- a/src/adapt/modules/analysis/contracts.py +++ b/src/adapt/modules/analysis/contracts.py @@ -11,7 +11,7 @@ """ import pandas as pd -import numpy as np + from adapt.modules.base import require diff --git a/src/adapt/modules/analysis/module.py b/src/adapt/modules/analysis/module.py index d9bf26e..a48e694 100644 --- a/src/adapt/modules/analysis/module.py +++ b/src/adapt/modules/analysis/module.py @@ -22,12 +22,14 @@ Author: Bhupendra Raut """ -import logging import json +import logging +from datetime import UTC +from typing import TYPE_CHECKING + import numpy as np import pandas as pd import xarray as xr -from typing import TYPE_CHECKING from scipy.ndimage import center_of_mass from skimage.measure import regionprops @@ -364,7 +366,7 @@ def _normalize_time_scalar(time_val): except Exception: pass if getattr(type(tv), "__module__", "").startswith("cftime"): - from datetime import datetime, timezone + from datetime import datetime tv = datetime( int(tv.year), int(tv.month), @@ -373,7 +375,7 @@ def _normalize_time_scalar(time_val): int(tv.minute), int(tv.second), int(getattr(tv, "microsecond", 0) or 0), - tzinfo=timezone.utc, + tzinfo=UTC, ) return tv @@ -383,9 +385,7 @@ def _get_lat_lon_grids(self, ds): Returns lat/lon grids if available, otherwise returns placeholder grids of zeros (valid for in-memory analysis, invalid for geographic output). """ - if "lat" in ds.coords and "lon" in ds.coords: - return ds["lat"].values, ds["lon"].values - elif "lat" in ds.data_vars and "lon" in ds.data_vars: + if "lat" in ds.coords and "lon" in ds.coords or "lat" in ds.data_vars and "lon" in ds.data_vars: return ds["lat"].values, ds["lon"].values else: # No lat/lon available - use placeholder zeros @@ -638,8 +638,9 @@ def get_lat_lon(ix, iy, lat_grid, lon_grid): # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- -from adapt.modules.base import BaseModule from adapt.execution.module_registry import registry +from adapt.modules.base import BaseModule + from .contracts import assert_analysis_output, assert_cell_adjacency diff --git a/src/adapt/modules/base.py b/src/adapt/modules/base.py index 26a968b..5f64ea1 100644 --- a/src/adapt/modules/base.py +++ b/src/adapt/modules/base.py @@ -13,8 +13,7 @@ """ from abc import ABC, abstractmethod -from typing import ClassVar, Dict, List, Optional - +from typing import ClassVar # ──────────────────────────────────────────────────────────────────────────── # Contract Enforcement Infrastructure @@ -101,10 +100,10 @@ def run(self, context): """ name: ClassVar[str] = "" - inputs: ClassVar[List[str]] = [] - outputs: ClassVar[List[str]] = [] - input_contracts: ClassVar[Dict[str, object]] = {} - output_contracts: ClassVar[Dict[str, object]] = {} + inputs: ClassVar[list[str]] = [] + outputs: ClassVar[list[str]] = [] + input_contracts: ClassVar[dict[str, object]] = {} + output_contracts: ClassVar[dict[str, object]] = {} @abstractmethod def run(self, context: dict) -> dict: diff --git a/src/adapt/modules/detection/contracts.py b/src/adapt/modules/detection/contracts.py index a896e0d..5a87aed 100644 --- a/src/adapt/modules/detection/contracts.py +++ b/src/adapt/modules/detection/contracts.py @@ -10,8 +10,9 @@ properly typed, and in canonical form (largest cells first). """ -import xarray as xr import numpy as np +import xarray as xr + from adapt.modules.base import require diff --git a/src/adapt/modules/detection/module.py b/src/adapt/modules/detection/module.py index fdb6800..6d1131c 100644 --- a/src/adapt/modules/detection/module.py +++ b/src/adapt/modules/detection/module.py @@ -22,11 +22,12 @@ - Metadata preservation (threshold, z-level, configuration) """ -import xarray as xr -import numpy as np import logging from typing import TYPE_CHECKING -from scipy.ndimage import binary_closing, label + +import numpy as np +import xarray as xr +from scipy.ndimage import label from skimage.morphology import h_maxima from skimage.segmentation import watershed @@ -361,9 +362,10 @@ def _relabel_by_size(self, labels: np.ndarray, labels_to_keep: np.ndarray, count # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- -from adapt.modules.base import BaseModule from adapt.execution.module_registry import registry +from adapt.modules.base import BaseModule from adapt.modules.ingest.contracts import assert_gridded + from .contracts import assert_segmented diff --git a/src/adapt/modules/ingest/contracts.py b/src/adapt/modules/ingest/contracts.py index 3297118..666b9ec 100644 --- a/src/adapt/modules/ingest/contracts.py +++ b/src/adapt/modules/ingest/contracts.py @@ -11,6 +11,7 @@ """ import xarray as xr + from adapt.modules.base import require diff --git a/src/adapt/modules/ingest/module.py b/src/adapt/modules/ingest/module.py index 5e46039..d030a45 100644 --- a/src/adapt/modules/ingest/module.py +++ b/src/adapt/modules/ingest/module.py @@ -17,11 +17,11 @@ Author: Bhupendra Raut """ -from pathlib import Path -from typing import Optional, TYPE_CHECKING import logging -import pyart +from pathlib import Path +from typing import TYPE_CHECKING +import pyart import xarray as xr if TYPE_CHECKING: @@ -161,7 +161,7 @@ def read(self, filepath: Path | str) -> object: return radar def regrid(self, radar: object, grid_kwargs: dict = None, - output_dir: str = None, source_filepath: str = None) -> Optional[xr.Dataset]: + output_dir: str = None, source_filepath: str = None) -> xr.Dataset | None: """Transform a Py-ART Radar object from polar to Cartesian grid. Performs distance-weighted interpolation to convert irregular polar @@ -279,7 +279,7 @@ def _write_netcdf(self, ds, output_dir, source_filepath): def load_and_regrid(self, filepath: Path | str, grid_kwargs: dict = None, - save_netcdf: bool = True, output_dir: str = None) -> Optional[xr.Dataset]: + save_netcdf: bool = True, output_dir: str = None) -> xr.Dataset | None: """Read and regrid a NEXRAD file in one call (convenience method). Combines read() and regrid() operations for simpler usage when @@ -349,13 +349,17 @@ def load_and_regrid(self, filepath: Path | str, grid_kwargs: dict = None, # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- +from datetime import UTC +from datetime import datetime as _dt + import numpy as np import xarray as _xr -from datetime import datetime as _dt, timezone as _tz -from adapt.modules.base import BaseModule + +from adapt.configuration.schemas.directories import get_netcdf_path from adapt.execution.module_registry import registry +from adapt.modules.base import BaseModule + from .contracts import assert_gridded -from adapt.configuration.schemas.directories import get_netcdf_path def _check_grid_ds_2d(ds): @@ -405,7 +409,7 @@ def run(self, context: dict) -> dict: radar = config.downloader.radar nc_filename = Path(filepath).stem - scan_time = _dt.now(_tz.utc) + scan_time = _dt.now(UTC) try: parts = nc_filename.split("_") dt_str = parts[0][-8:] + parts[1] diff --git a/src/adapt/modules/projection/contracts.py b/src/adapt/modules/projection/contracts.py index 25d81ac..804bbb5 100644 --- a/src/adapt/modules/projection/contracts.py +++ b/src/adapt/modules/projection/contracts.py @@ -11,6 +11,7 @@ """ import xarray as xr + from adapt.modules.base import require diff --git a/src/adapt/modules/projection/module.py b/src/adapt/modules/projection/module.py index 9fb4d6e..c7a8551 100644 --- a/src/adapt/modules/projection/module.py +++ b/src/adapt/modules/projection/module.py @@ -21,12 +21,13 @@ """ import logging +from typing import TYPE_CHECKING + +import cv2 import numpy as np import xarray as xr -import cv2 -from typing import TYPE_CHECKING -from scipy.spatial import Delaunay from scipy.ndimage import binary_dilation +from scipy.spatial import Delaunay if TYPE_CHECKING: from adapt.configuration.schemas import InternalConfig @@ -561,8 +562,8 @@ def _fill_concave_hull(self, label_mask, alpha=0.1): # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- -from adapt.modules.base import BaseModule from adapt.execution.module_registry import registry +from adapt.modules.base import BaseModule from adapt.modules.detection.contracts import assert_segmented diff --git a/src/adapt/modules/tracking/contracts.py b/src/adapt/modules/tracking/contracts.py index ffad07d..5aba26e 100644 --- a/src/adapt/modules/tracking/contracts.py +++ b/src/adapt/modules/tracking/contracts.py @@ -14,6 +14,7 @@ from __future__ import annotations import pandas as pd + from adapt.modules.base import require diff --git a/src/adapt/modules/tracking/module.py b/src/adapt/modules/tracking/module.py index a259d0d..9280d10 100644 --- a/src/adapt/modules/tracking/module.py +++ b/src/adapt/modules/tracking/module.py @@ -29,14 +29,16 @@ References: Raut, B. A., Jackson, R., Picel, M., Collis, S. M., Bergemann, M., & Jakob, C. (2021). An adaptive tracking algorithm for convection in simulated and remote sensing data. Journal of Applied Meteorology and Climatology, 60(4), 513-526. """ -import logging import hashlib +import logging import string +from datetime import UTC +from typing import TYPE_CHECKING + +import networkx as nx import numpy as np import pandas as pd import xarray as xr -import networkx as nx -from typing import TYPE_CHECKING, Dict, List, Tuple, Optional from scipy.optimize import linear_sum_assignment if TYPE_CHECKING: @@ -205,7 +207,7 @@ def get_node_attr(self, node_id: int, attr: str): """ return self.graph.nodes[node_id].get(attr) - def get_nodes_at_time(self, time) -> List[int]: + def get_nodes_at_time(self, time) -> list[int]: """Get all node IDs for a given timestamp. Parameters @@ -220,14 +222,14 @@ def get_nodes_at_time(self, time) -> List[int]: """ return [n for n, d in self.graph.nodes(data=True) if d.get('time') == time] - def get_track_nodes(self, track_index: int) -> List[int]: + def get_track_nodes(self, track_index: int) -> list[int]: """Get all nodes belonging to a track, sorted by time.""" nodes = [(n, d['time']) for n, d in self.graph.nodes(data=True) if d.get('track_index') == track_index] nodes.sort(key=lambda x: x[1]) return [n for n, _ in nodes] - def get_predecessors(self, node_id: int) -> List[Tuple[int, str]]: + def get_predecessors(self, node_id: int) -> list[tuple[int, str]]: """Get predecessor nodes with their edge types. Parameters @@ -243,7 +245,7 @@ def get_predecessors(self, node_id: int) -> List[Tuple[int, str]]: return [(pred, self.graph.edges[pred, node_id]['edge_type']) for pred in self.graph.predecessors(node_id)] - def get_successors(self, node_id: int) -> List[Tuple[int, str]]: + def get_successors(self, node_id: int) -> list[tuple[int, str]]: """Get successor nodes with their edge types. Parameters @@ -272,10 +274,10 @@ def __init__(self, config: "InternalConfig"): def compute_cost_matrix( self, - prev_node_ids: List[int], + prev_node_ids: list[int], graph: "TrackingGraph", proj_labels: np.ndarray, - curr_cells: List[Dict], + curr_cells: list[dict], dummy_cost: float, ) -> np.ndarray: """Build (n_prev × n_curr) cost matrix. @@ -305,7 +307,7 @@ def _compute_cost( prev_node: int, graph: "TrackingGraph", proj_mask: np.ndarray, - curr_cell: Dict, + curr_cell: dict, ) -> float: """5-term cost: 0.4*Dpos + 0.3*(1-IoU) + 0.15*|log(A2/A1)| + 0.1*|Z2-Z1|/50""" prev_cx = graph.get_node_attr(prev_node, 'centroid_x') @@ -366,7 +368,7 @@ def __init__(self, config: "InternalConfig"): self.graph = TrackingGraph() self.matcher = MatchingEngine(config) - self._previous_scan: Optional[Tuple] = None # (time, ds, node_ids) + self._previous_scan: tuple | None = None # (time, ds, node_ids) self._cell_identity: dict[int, tuple[str, str]] = {} logger.info( @@ -382,7 +384,7 @@ def track( self, ds_projected: xr.Dataset, cell_stats_df: pd.DataFrame, - ) -> Tuple[pd.DataFrame, pd.DataFrame]: + ) -> tuple[pd.DataFrame, pd.DataFrame]: """Process one scan. Returns scan-local outputs: @@ -454,7 +456,7 @@ def _normalize_time_scalar(time_val): # Handle cftime.* objects (pandas cannot convert them directly) if getattr(type(tv), "__module__", "").startswith("cftime"): - from datetime import datetime, timezone + from datetime import datetime tv = datetime( int(tv.year), @@ -464,7 +466,7 @@ def _normalize_time_scalar(time_val): int(tv.minute), int(tv.second), int(getattr(tv, "microsecond", 0) or 0), - tzinfo=timezone.utc, + tzinfo=UTC, ) return tv @@ -477,11 +479,11 @@ def _time_key(time_val) -> str: def _extract_cells_from_analyzer( self, ds: xr.Dataset, cell_stats_df: pd.DataFrame - ) -> List[Dict]: + ) -> list[dict]: """Merge per-cell stats (from AnalysisModule) with segmentation masks.""" labels = ds[self.labels_var].values - cell_props_map: Dict[int, Dict] = {} + cell_props_map: dict[int, dict] = {} for _, row in cell_stats_df.iterrows(): lbl = int(row['cell_label']) cell_props_map[lbl] = { @@ -502,7 +504,7 @@ def _extract_cells_from_analyzer( dy = float(np.abs(ds.y[1] - ds.y[0])) pixel_area_km2 = (dx * dy) / 1e6 - cells: List[Dict] = [] + cells: list[dict] = [] for cell_id in np.unique(labels): if cell_id == 0: continue @@ -529,7 +531,7 @@ def _extract_cells_from_analyzer( }) return cells - def _new_cell_identity(self, cell: Dict) -> tuple[str, str]: + def _new_cell_identity(self, cell: dict) -> tuple[str, str]: cfg = self.config.tracker.cell_uid max_zdr = float(cell['max_zdr']) if max_zdr < 0: @@ -552,7 +554,7 @@ def _new_cell_identity(self, cell: Dict) -> tuple[str, str]: # Track initialisation helpers # ------------------------------------------------------------------ - def _initialize_tracks(self, time, cells: List[Dict]) -> List[int]: + def _initialize_tracks(self, time, cells: list[dict]) -> list[int]: node_ids = [] for cell in cells: track_index = self.graph.get_new_track_index() @@ -565,10 +567,10 @@ def _initialize_tracks(self, time, cells: List[Dict]) -> List[int]: def _add_cell_node( self, time, - cell: Dict, + cell: dict, track_index: int, - cell_uid: Optional[str] = None, - track_signature: Optional[str] = None, + cell_uid: str | None = None, + track_signature: str | None = None, ) -> int: if cell_uid is None or track_signature is None: cell_uid, track_signature = self.get_cell_identity(track_index) @@ -594,10 +596,10 @@ def _track_frame_pair( self, prev_time, ds_prev: xr.Dataset, - prev_node_ids: List[int], + prev_node_ids: list[int], curr_time, ds_curr: xr.Dataset, - curr_cells: List[Dict], + curr_cells: list[dict], ) -> list[dict]: events: list[dict] = [] if "cell_projections" not in ds_curr.data_vars: @@ -649,8 +651,8 @@ def _track_frame_pair( row_ind, col_ind = linear_sum_assignment(square) # ── Step 5: post-filter → CONTINUE / dissipated / born ─────────── - matched_prev: Dict[int, int] = {} # prev_idx → new curr node_id - matched_curr: Dict[int, int] = {} # curr_idx → new curr node_id + matched_prev: dict[int, int] = {} # prev_idx → new curr node_id + matched_curr: dict[int, int] = {} # curr_idx → new curr node_id n_continue = 0 for r, c in zip(row_ind, col_ind): @@ -761,7 +763,7 @@ def _track_frame_pair( # Scan-local builders (no per-track analytics) # ------------------------------------------------------------------ - def _build_tracked_cells_current(self, time, node_ids: List[int]) -> pd.DataFrame: + def _build_tracked_cells_current(self, time, node_ids: list[int]) -> pd.DataFrame: rows: list[dict] = [] for node_id in node_ids: node = self.graph.graph.nodes[node_id] @@ -874,7 +876,7 @@ def _event_initiation(self, time, node_id: int) -> dict: "event_group_id": f"{self._time_key(time)}:INITIATION:{target_uid}", } - def _event_termination(self, time, source_node_id: int, target_node_id: Optional[int]) -> dict: + def _event_termination(self, time, source_node_id: int, target_node_id: int | None) -> dict: source_path = int(self.graph.get_node_attr(source_node_id, "track_index")) target_path = int(self.graph.get_node_attr(target_node_id, "track_index")) if target_node_id is not None else None source_uid = self.get_cell_identity(source_path)[0] @@ -895,10 +897,11 @@ def _event_termination(self, time, source_node_id: int, target_node_id: Optional # BaseModule wrapper (Phase 6 implementation placeholder) # ============================================================================= -from adapt.modules.base import BaseModule from adapt.execution.module_registry import registry +from adapt.modules.base import BaseModule from adapt.modules.projection.contracts import assert_projected -from .contracts import assert_tracked_cells, assert_cell_events + +from .contracts import assert_cell_events, assert_tracked_cells def _check_projected_ds(ds: xr.Dataset) -> None: diff --git a/src/adapt/persistence/catalog.py b/src/adapt/persistence/catalog.py index e909435..9f483ae 100644 --- a/src/adapt/persistence/catalog.py +++ b/src/adapt/persistence/catalog.py @@ -19,9 +19,8 @@ import logging import sqlite3 import threading -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union import pandas as pd @@ -51,7 +50,7 @@ class RadarCatalog: >>> items = catalog.query_items(item_type="analysis2d", limit=10) """ - def __init__(self, radar_dir: Union[str, Path]): + def __init__(self, radar_dir: str | Path): """Initialize radar catalog. Parameters @@ -65,7 +64,7 @@ def __init__(self, radar_dir: Union[str, Path]): # Thread safety self._lock = threading.RLock() - self._conn: Optional[sqlite3.Connection] = None + self._conn: sqlite3.Connection | None = None # Initialize database self._init_database() @@ -177,10 +176,10 @@ def register_item( file_path: str, processing_stage: str = "complete", status: str = "complete", - parent_ids: Optional[List[str]] = None, - metadata: Optional[Dict] = None, - file_size_bytes: Optional[int] = None, - file_hash: Optional[str] = None + parent_ids: list[str] | None = None, + metadata: dict | None = None, + file_size_bytes: int | None = None, + file_hash: str | None = None ) -> None: """Register a data item in the catalog. @@ -209,7 +208,7 @@ def register_item( file_hash : str, optional File hash (SHA256) """ - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() parent_ids_json = json.dumps(parent_ids) if parent_ids else None metadata_json = json.dumps(metadata) if metadata else None @@ -232,7 +231,7 @@ def update_item_status( self, item_id: str, status: str, - error_message: Optional[str] = None + error_message: str | None = None ) -> None: """Update item status. @@ -245,7 +244,7 @@ def update_item_status( error_message : str, optional Error message if status=failed """ - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() conn = self._get_connection() with self._lock: @@ -258,10 +257,10 @@ def update_item_status( def query_items( self, - item_type: Optional[str] = None, - run_id: Optional[str] = None, - status: Optional[str] = None, - limit: Optional[int] = None, + item_type: str | None = None, + run_id: str | None = None, + status: str | None = None, + limit: int | None = None, order_by: str = "scan_time DESC" ) -> pd.DataFrame: """Query items with optional filters. @@ -309,8 +308,8 @@ def query_items( def get_latest_item( self, item_type: str, - run_id: Optional[str] = None - ) -> Optional[Dict]: + run_id: str | None = None + ) -> dict | None: """Get the most recent item of a type. Parameters @@ -344,7 +343,7 @@ def get_latest_item( return dict(row) if row else None - def get_item(self, item_id: str) -> Optional[Dict]: + def get_item(self, item_id: str) -> dict | None: """Get a single item record by ID. Returns None if not found.""" conn = self._get_connection() with self._lock: @@ -371,7 +370,7 @@ def update_progress( **kwargs Progress fields to update (latest_downloaded_time, etc.) """ - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() # Build update query dynamically fields = list(kwargs.keys()) @@ -405,7 +404,7 @@ def update_progress( conn.commit() - def get_progress(self, run_id: str) -> Optional[Dict]: + def get_progress(self, run_id: str) -> dict | None: """Get progress status for a run. Parameters @@ -434,7 +433,7 @@ def get_progress(self, run_id: str) -> Optional[Dict]: def register_schema( self, item_type: str, - columns: List[Dict[str, str]], + columns: list[dict[str, str]], schema_version: int = 1 ) -> None: """Register or update schema for an item type. @@ -448,7 +447,7 @@ def register_schema( schema_version : int Schema version number """ - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() columns_json = json.dumps(columns) conn = self._get_connection() @@ -462,7 +461,7 @@ def register_schema( logger.debug(f"Schema registered for {item_type} (v{schema_version})") - def get_schema(self, item_type: str) -> Optional[List[Dict]]: + def get_schema(self, item_type: str) -> list[dict] | None: """Get schema for an item type. Parameters @@ -494,7 +493,7 @@ def register_scan( self, scan_time: datetime, run_id: str, - nexrad_file_name: Optional[str] = None + nexrad_file_name: str | None = None ) -> str: """Register a new scan. Idempotent on scan_time+run_id. @@ -516,7 +515,7 @@ def register_scan( scan_time_str = scan_time.isoformat() scan_date = scan_time.strftime('%Y%m%d') - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() conn = self._get_connection() with self._lock: @@ -547,9 +546,9 @@ def link_item_to_scan( scan_time: datetime, item_type: str, item_id: str, - num_cells: Optional[int] = None, - max_reflectivity: Optional[float] = None, - has_tracks: Optional[bool] = None + num_cells: int | None = None, + max_reflectivity: float | None = None, + has_tracks: bool | None = None ) -> None: """Link an item to its parent scan. @@ -569,7 +568,7 @@ def link_item_to_scan( Whether tracks exist for this scan """ scan_time_str = scan_time.isoformat() - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() # Map item_type to column name column_map = { @@ -631,7 +630,7 @@ def link_item_to_scan( logger.debug(f"Item {item_id} linked to scan at {scan_time_str}") - def get_scan(self, scan_time: datetime) -> Optional[Dict]: + def get_scan(self, scan_time: datetime) -> dict | None: """Get scan record by time. Parameters @@ -655,7 +654,7 @@ def get_scan(self, scan_time: datetime) -> Optional[Dict]: return dict(row) if row else None - def get_scan_by_id(self, scan_id: str) -> Optional[Dict]: + def get_scan_by_id(self, scan_id: str) -> dict | None: """Get scan by ID. Parameters @@ -679,10 +678,10 @@ def get_scan_by_id(self, scan_id: str) -> Optional[Dict]: def list_scans( self, - start_time: Optional[datetime] = None, - end_time: Optional[datetime] = None, - run_id: Optional[str] = None, - status: Optional[str] = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + run_id: str | None = None, + status: str | None = None, limit: int = 100 ) -> pd.DataFrame: """List scans with optional time range filter. @@ -728,7 +727,7 @@ def list_scans( with self._lock: return pd.read_sql_query(query, conn, params=params) - def get_latest_scan(self, run_id: Optional[str] = None) -> Optional[Dict]: + def get_latest_scan(self, run_id: str | None = None) -> dict | None: """Get the most recent scan. Parameters diff --git a/src/adapt/persistence/registry.py b/src/adapt/persistence/registry.py index c1f9218..90cf1d8 100644 --- a/src/adapt/persistence/registry.py +++ b/src/adapt/persistence/registry.py @@ -15,13 +15,11 @@ Thread-safe for concurrent writer/reader access via SQLite WAL mode. """ -import json import logging import sqlite3 import threading -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union import pandas as pd @@ -30,7 +28,7 @@ logger = logging.getLogger(__name__) # Cache of registry instances per root directory -_registry_cache: Dict[str, 'RepositoryRegistry'] = {} +_registry_cache: dict[str, 'RepositoryRegistry'] = {} _cache_lock = threading.Lock() @@ -50,7 +48,7 @@ class RepositoryRegistry: >>> runs = registry.list_runs() """ - def __init__(self, root_dir: Union[str, Path]): + def __init__(self, root_dir: str | Path): """Initialize registry at root directory. Parameters @@ -63,7 +61,7 @@ def __init__(self, root_dir: Union[str, Path]): # Thread safety self._lock = threading.RLock() - self._conn: Optional[sqlite3.Connection] = None + self._conn: sqlite3.Connection | None = None # Initialize database self._init_database() @@ -71,7 +69,7 @@ def __init__(self, root_dir: Union[str, Path]): logger.debug("RepositoryRegistry initialized at %s", self.db_path) @classmethod - def get_instance(cls, root_dir: Union[str, Path]) -> 'RepositoryRegistry': + def get_instance(cls, root_dir: str | Path) -> 'RepositoryRegistry': """Get singleton instance for a root directory. Parameters @@ -175,7 +173,7 @@ def _create_schema_inline(self) -> None: """) # Prepopulate item types - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() item_types_data = [ ('gridded3d', 'Gridded reflectivity volume', 'netcdf', '3d', now), ('segmentation2d', 'Cell segmentation masks', 'netcdf', '2d', now), @@ -198,8 +196,8 @@ def _create_schema_inline(self) -> None: def register_radar( self, radar: str, - lat: Optional[float] = None, - lon: Optional[float] = None + lat: float | None = None, + lon: float | None = None ) -> None: """Register a radar in the repository. @@ -217,7 +215,7 @@ def register_radar( catalog_path = str(radar_dir / "catalog.db") data_path = str(radar_dir) - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() conn = self._get_connection() with self._lock: @@ -230,7 +228,7 @@ def register_radar( logger.debug("Radar registered: %s at %s", radar, data_path) - def get_radar_location(self, radar: str) -> tuple[Optional[float], Optional[float]]: + def get_radar_location(self, radar: str) -> tuple[float | None, float | None]: """Get stored radar location (lat, lon) from the registry.""" conn = self._get_connection() with self._lock: @@ -259,7 +257,7 @@ def ensure_radar_location(self, radar: str, lat: float, lon: float) -> None: raise ValueError(f"Invalid lat/lon types: {type(lat)} {type(lon)}") from e conn = self._get_connection() - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() with self._lock: row = conn.execute( @@ -280,7 +278,7 @@ def ensure_radar_location(self, radar: str, lat: float, lon: float) -> None: ) conn.commit() - def get_radar_catalog_path(self, radar: str) -> Optional[Path]: + def get_radar_catalog_path(self, radar: str) -> Path | None: """Get path to radar's catalog database. Parameters @@ -325,8 +323,8 @@ def register_run( self, run_id: str, radar: str, - mode: Optional[str] = None, - config_path: Optional[str] = None, + mode: str | None = None, + config_path: str | None = None, repository_version: str = "0.1.0" ) -> None: """Register a new pipeline run. @@ -344,7 +342,7 @@ def register_run( repository_version : str Adapt version """ - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() conn = self._get_connection() with self._lock: @@ -361,7 +359,7 @@ def update_run_status( self, run_id: str, status: str, - end_time: Optional[str] = None + end_time: str | None = None ) -> None: """Update run status. @@ -390,7 +388,7 @@ def update_run_status( logger.debug(f"Run {run_id} status updated to {status}") - def list_runs(self, radar: Optional[str] = None) -> pd.DataFrame: + def list_runs(self, radar: str | None = None) -> pd.DataFrame: """Get list of runs, optionally filtered by radar. Parameters @@ -412,7 +410,7 @@ def list_runs(self, radar: Optional[str] = None) -> pd.DataFrame: query = "SELECT * FROM runs ORDER BY start_time DESC" return pd.read_sql_query(query, conn) - def get_latest_run(self, radar: Optional[str] = None) -> Optional[Dict]: + def get_latest_run(self, radar: str | None = None) -> dict | None: """Get the most recent run. Parameters @@ -443,7 +441,7 @@ def get_latest_run(self, radar: Optional[str] = None) -> Optional[Dict]: # Item Types Management # ========================================================================= - def list_item_types(self) -> List[str]: + def list_item_types(self) -> list[str]: """Get list of registered item types. Returns @@ -457,7 +455,7 @@ def list_item_types(self) -> List[str]: return [row['item_type'] for row in rows] - def get_item_type_info(self, item_type: str) -> Optional[Dict]: + def get_item_type_info(self, item_type: str) -> dict | None: """Get metadata for an item type. Parameters diff --git a/src/adapt/persistence/repository.py b/src/adapt/persistence/repository.py index 8a525f9..c9a0f1a 100644 --- a/src/adapt/persistence/repository.py +++ b/src/adapt/persistence/repository.py @@ -21,15 +21,15 @@ import sqlite3 import tempfile import threading -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import pandas as pd import xarray as xr -from adapt.persistence.registry import RepositoryRegistry from adapt.persistence.catalog import RadarCatalog +from adapt.persistence.registry import RepositoryRegistry if TYPE_CHECKING: from adapt.configuration.schemas import InternalConfig @@ -100,7 +100,7 @@ class DataRepository: def __init__( self, run_id: str, - base_dir: Union[str, Path], + base_dir: str | Path, radar: str, config: Optional["InternalConfig"] = None, ): @@ -200,7 +200,7 @@ def _register_in_new_catalog(self) -> None: def generate_artifact_id( product_type: str, radar: str, - scan_time: Optional[datetime], + scan_time: datetime | None, run_id: str, content_hint: str = "" ) -> str: @@ -235,11 +235,11 @@ def generate_artifact_id( def register_artifact( self, product_type: str, - file_path: Union[str, Path], - scan_time: Optional[datetime] = None, + file_path: str | Path, + scan_time: datetime | None = None, producer: str = "unknown", - parent_ids: Optional[List[str]] = None, - metadata: Optional[Dict] = None + parent_ids: list[str] | None = None, + metadata: dict | None = None ) -> str: """Register an artifact in the RadarCatalog. @@ -341,7 +341,7 @@ def open_dataset(self, artifact_id: str) -> xr.Dataset: def open_table( self, artifact_id: str, - table_name: Optional[str] = None + table_name: str | None = None ) -> pd.DataFrame: """Open SQLite or Parquet artifact as DataFrame. @@ -384,11 +384,11 @@ def open_table( def query( self, - product_type: Optional[str] = None, - time_range: Optional[Tuple[datetime, datetime]] = None, - radar: Optional[str] = None, - limit: Optional[int] = None - ) -> List[Dict]: + product_type: str | None = None, + time_range: tuple[datetime, datetime] | None = None, + radar: str | None = None, + limit: int | None = None + ) -> list[dict]: """Query artifacts by criteria. Parameters @@ -426,8 +426,8 @@ def query( def get_latest( self, product_type: str, - radar: Optional[str] = None - ) -> Optional[Dict]: + radar: str | None = None + ) -> dict | None: """Get the most recent artifact of a given type. Parameters @@ -448,9 +448,9 @@ def get_latest( def get_all_since( self, product_type: str, - since_artifact_id: Optional[str] = None, - radar: Optional[str] = None - ) -> List[Dict]: + since_artifact_id: str | None = None, + radar: str | None = None + ) -> list[dict]: """Get all artifacts of a type created after a given artifact. Parameters @@ -480,7 +480,7 @@ def get_all_since( df = df.sort_values("scan_time", ascending=True) return [self._normalize_item(row) for _, row in df.iterrows()] - def get_artifact(self, artifact_id: str) -> Optional[Dict]: + def get_artifact(self, artifact_id: str) -> dict | None: """Get artifact record by ID. Parameters @@ -495,7 +495,7 @@ def get_artifact(self, artifact_id: str) -> Optional[Dict]: """ return self._get_artifact(artifact_id) - def _normalize_item(self, item) -> Dict: + def _normalize_item(self, item) -> dict: """Convert a RadarCatalog item row to a repository artifact dict. Resolves relative file_path to absolute, aliases item_id → artifact_id, @@ -536,9 +536,9 @@ def write_netcdf( product_type: str, scan_time: datetime, producer: str, - parent_ids: Optional[List[str]] = None, - metadata: Optional[Dict] = None, - filename_stem: Optional[str] = None + parent_ids: list[str] | None = None, + metadata: dict | None = None, + filename_stem: str | None = None ) -> str: """Write xarray Dataset to NetCDF and register artifact. @@ -592,9 +592,9 @@ def write_parquet( product_type: str, scan_time: datetime, producer: str, - parent_ids: Optional[List[str]] = None, - metadata: Optional[Dict] = None, - filename_stem: Optional[str] = None + parent_ids: list[str] | None = None, + metadata: dict | None = None, + filename_stem: str | None = None ) -> str: """Write DataFrame to Parquet and register artifact. @@ -648,8 +648,8 @@ def write_analysis2d_parquet( df: pd.DataFrame, scan_time: datetime, producer: str = "processor", - parent_ids: Optional[List[str]] = None, - metadata: Optional[Dict] = None + parent_ids: list[str] | None = None, + metadata: dict | None = None ) -> str: """Write analysis DataFrame to Parquet file (one per run_id). @@ -893,7 +893,7 @@ def get_or_create_cells_db( self, scan_time: datetime, producer: str, - parent_ids: Optional[List[str]] = None + parent_ids: list[str] | None = None ) -> str: """Get existing cells database or create new one. @@ -940,7 +940,7 @@ def _generate_netcdf_path( self, product_type: str, scan_time: datetime, - filename_stem: Optional[str] = None + filename_stem: str | None = None ) -> Path: """Generate NetCDF output path with run_id suffix. @@ -967,7 +967,7 @@ def _generate_netcdf_path( def _generate_parquet_path( self, scan_time: datetime, - filename_stem: Optional[str] = None + filename_stem: str | None = None ) -> Path: """Generate Parquet output path with run_id suffix. @@ -1078,7 +1078,7 @@ def _init_cells_db(self, db_path: Path) -> None: # Internal: Helpers # ========================================================================= - def _get_artifact(self, artifact_id: str) -> Optional[Dict]: + def _get_artifact(self, artifact_id: str) -> dict | None: """Get artifact record by ID from RadarCatalog.""" row = self.catalog.get_item(artifact_id) return self._normalize_item(row) if row else None @@ -1095,7 +1095,7 @@ def finalize_run(self, status: str = "completed") -> None: status : str Final status (completed, failed, cancelled) """ - end_time = datetime.now(timezone.utc).isoformat() + end_time = datetime.now(UTC).isoformat() if self.registry: self.registry.update_run_status( run_id=self.run_id, diff --git a/src/adapt/persistence/track_store.py b/src/adapt/persistence/track_store.py index de27352..74ab6b5 100644 --- a/src/adapt/persistence/track_store.py +++ b/src/adapt/persistence/track_store.py @@ -17,9 +17,8 @@ import logging import sqlite3 import threading -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Optional import pandas as pd @@ -69,7 +68,7 @@ class TrackStore: def __init__(self, db_path: Path): self._db_path = Path(db_path) self._lock = threading.RLock() - self._conn: Optional[sqlite3.Connection] = None + self._conn: sqlite3.Connection | None = None # ------------------------------------------------------------------ # Connection @@ -197,7 +196,7 @@ def get_track_history(self, run_id: str, cell_uid: str) -> pd.DataFrame: ).fetchall() return pd.DataFrame([dict(r) for r in rows]) - def get_cell_events(self, run_id: str, cell_uid: Optional[str] = None) -> pd.DataFrame: + def get_cell_events(self, run_id: str, cell_uid: str | None = None) -> pd.DataFrame: conn = self._connect() with self._lock: if cell_uid is None: @@ -368,7 +367,7 @@ def _build_cells_rows( # Parse current scan time once for age computation from datetime import datetime as _dt try: - scan_dt = _dt.strptime(scan_iso, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + scan_dt = _dt.strptime(scan_iso, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=UTC) except ValueError: scan_dt = None @@ -382,7 +381,7 @@ def _build_cells_rows( age_seconds = 0.0 if scan_dt is not None and cl not in initiated and first_seen_map and tid in first_seen_map: try: - first_dt = _dt.strptime(first_seen_map[tid], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + first_dt = _dt.strptime(first_seen_map[tid], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=UTC) age_seconds = max(0.0, (scan_dt - first_dt).total_seconds()) except ValueError: pass @@ -424,7 +423,7 @@ def _upsert_cells(self, conn: sqlite3.Connection, rows: list[dict]) -> None: ) conn.executemany(sql, [tuple(r[c] for c in cols) for r in rows]) - def _prev_scan_time(self, conn: sqlite3.Connection, run_id: str, scan_iso: str) -> Optional[str]: + def _prev_scan_time(self, conn: sqlite3.Connection, run_id: str, scan_iso: str) -> str | None: row = conn.execute( "SELECT MAX(scan_time) AS t FROM cells_by_scan WHERE run_id=? AND scan_time None: cols = [ @@ -480,10 +479,10 @@ def _insert_cell_events( placeholders = ", ".join("?" * len(cols)) sql = f"INSERT INTO cell_events ({', '.join(cols)}) VALUES ({placeholders})" - def _src_time(etype: str) -> Optional[str]: + def _src_time(etype: str) -> str | None: return None if etype == "INITIATION" else source_iso - def _tgt_time(etype: str) -> Optional[str]: + def _tgt_time(etype: str) -> str | None: return None if etype == "TERMINATION" else target_iso rows = [] @@ -626,7 +625,7 @@ def _upsert_cell_tracks( def _to_iso(dt: datetime) -> str: if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) + dt = dt.replace(tzinfo=UTC) return dt.strftime("%Y-%m-%dT%H:%M:%SZ") diff --git a/src/adapt/persistence/writer.py b/src/adapt/persistence/writer.py index 30d8266..4efe301 100644 --- a/src/adapt/persistence/writer.py +++ b/src/adapt/persistence/writer.py @@ -10,12 +10,13 @@ from datetime import datetime from pathlib import Path -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import pandas as pd if TYPE_CHECKING: import xarray as xr + from adapt.persistence.repository import DataRepository @@ -36,8 +37,8 @@ def write_analysis( df: pd.DataFrame, scan_time: datetime, producer: str, - parent_ids: Optional[List[str]] = None, - metadata: Optional[dict] = None, + parent_ids: list[str] | None = None, + metadata: dict | None = None, ) -> str: """Persist cell analysis DataFrame as a Parquet artifact. Returns artifact ID.""" return self.repository.write_analysis2d_parquet( @@ -54,8 +55,8 @@ def write_netcdf( path: Path, scan_time: datetime, producer: str, - parent_ids: Optional[List[str]] = None, - metadata: Optional[dict] = None, + parent_ids: list[str] | None = None, + metadata: dict | None = None, ) -> str: """Persist an xarray Dataset as a NetCDF artifact. Returns artifact ID.""" return self.repository.write_netcdf( diff --git a/src/adapt/runtime/__init__.py b/src/adapt/runtime/__init__.py index d494099..4ad17c8 100644 --- a/src/adapt/runtime/__init__.py +++ b/src/adapt/runtime/__init__.py @@ -8,9 +8,9 @@ - file_tracker: SQLite-based file tracking """ +from adapt.runtime.file_tracker import FileProcessingTracker from adapt.runtime.orchestrator import PipelineOrchestrator from adapt.runtime.processor import RadarProcessor -from adapt.runtime.file_tracker import FileProcessingTracker __all__ = [ "PipelineOrchestrator", diff --git a/src/adapt/runtime/file_tracker.py b/src/adapt/runtime/file_tracker.py index 5784c66..c0c642d 100644 --- a/src/adapt/runtime/file_tracker.py +++ b/src/adapt/runtime/file_tracker.py @@ -7,12 +7,11 @@ Enables idempotent processing with stop/restart, progress tracking, and failure recovery. """ -import sqlite3 import logging -from pathlib import Path -from datetime import datetime, timezone -from typing import Optional, Dict, List +import sqlite3 import threading +from datetime import UTC, datetime +from pathlib import Path __all__ = ['FileProcessingTracker'] @@ -160,7 +159,7 @@ def _migrate_database(self): conn.commit() def register_file(self, file_id: str, radar: str, scan_time: datetime, - nexrad_path: Optional[Path] = None) -> bool: + nexrad_path: Path | None = None) -> bool: """Register a new file for tracking. Creates an initial database record for a newly discovered NEXRAD file. @@ -213,7 +212,7 @@ def register_file(self, file_id: str, radar: str, scan_time: datetime, scan_time.isoformat(), str(nexrad_path) if nexrad_path else None, file_size_mb, - datetime.now(timezone.utc).isoformat() + datetime.now(UTC).isoformat() )) conn.commit() @@ -221,10 +220,10 @@ def register_file(self, file_id: str, radar: str, scan_time: datetime, return True def mark_stage_complete(self, file_id: str, stage: str, - path: Optional[Path] = None, - num_cells: Optional[int] = None, - error: Optional[str] = None, - timings: Optional[Dict[str, float]] = None): + path: Path | None = None, + num_cells: int | None = None, + error: str | None = None, + timings: dict[str, float] | None = None): """Mark a pipeline stage as complete or failed for a file. Called by downloader, processor, and plotter threads to record progress. @@ -283,7 +282,7 @@ def mark_stage_complete(self, file_id: str, stage: str, else: new_status = 'processing' - now = datetime.now(timezone.utc).isoformat() + now = datetime.now(UTC).isoformat() # Build SET clause dynamically to include optional timing columns set_parts = [ @@ -324,7 +323,7 @@ def mark_stage_complete(self, file_id: str, stage: str, logger.debug("Marked %s complete: %s", stage, file_id) - def get_file_status(self, file_id: str) -> Optional[Dict]: + def get_file_status(self, file_id: str) -> dict | None: """Get complete processing status for a file. Parameters @@ -358,9 +357,9 @@ def get_file_status(self, file_id: str) -> Optional[Dict]: return dict(row) return None - def get_pending_files(self, stage: Optional[str] = None, - radar: Optional[str] = None, - limit: Optional[int] = None) -> List[Dict]: + def get_pending_files(self, stage: str | None = None, + radar: str | None = None, + limit: int | None = None) -> list[dict]: """Get files awaiting processing at a specific stage. Used by downloader/processor/plotter to find files needing work. @@ -415,7 +414,7 @@ def get_pending_files(self, stage: Optional[str] = None, cursor = conn.execute(query, params) return [dict(row) for row in cursor.fetchall()] - def get_statistics(self, radar: Optional[str] = None) -> Dict: + def get_statistics(self, radar: str | None = None) -> dict: """Get summary statistics for processing progress. Parameters @@ -488,7 +487,7 @@ def should_process(self, file_id: str, stage: str) -> bool: timestamp_col = f"{stage}_at" return status.get(timestamp_col) is None - def reset_failed(self, radar: Optional[str] = None): + def reset_failed(self, radar: str | None = None): """Reset all failed files to pending for retry. Useful for recovery after fixing errors (e.g., config changes, bug fixes). @@ -513,18 +512,18 @@ def reset_failed(self, radar: Optional[str] = None): UPDATE radar_file_processing SET status = 'pending', error_message = NULL, updated_at = ? WHERE status = 'failed' AND radar = ? - """, (datetime.now(timezone.utc).isoformat(), radar)) + """, (datetime.now(UTC).isoformat(), radar)) else: conn.execute(""" UPDATE radar_file_processing SET status = 'pending', error_message = NULL, updated_at = ? WHERE status = 'failed' - """, (datetime.now(timezone.utc).isoformat(),)) + """, (datetime.now(UTC).isoformat(),)) conn.commit() - logger.info(f"Reset failed files to pending") + logger.info("Reset failed files to pending") - def cleanup_deleted_files(self, radar: Optional[str] = None): + def cleanup_deleted_files(self, radar: str | None = None): """Remove database records for files deleted from disk. Useful after clearing output directories. On next run, these files diff --git a/src/adapt/runtime/orchestrator.py b/src/adapt/runtime/orchestrator.py index 60c1730..3e901b9 100644 --- a/src/adapt/runtime/orchestrator.py +++ b/src/adapt/runtime/orchestrator.py @@ -11,16 +11,16 @@ is not blocked by visualization and validates repository API integrity. """ +import logging import queue import time -import logging from pathlib import Path -from typing import Optional, Dict, TYPE_CHECKING +from typing import TYPE_CHECKING from adapt.modules.acquisition.module import AwsNexradDownloader -from adapt.runtime.processor import RadarProcessor -from adapt.runtime.file_tracker import FileProcessingTracker from adapt.persistence import DataRepository +from adapt.runtime.file_tracker import FileProcessingTracker +from adapt.runtime.processor import RadarProcessor if TYPE_CHECKING: from adapt.configuration.schemas import InternalConfig @@ -130,7 +130,7 @@ def __init__( # DataRepository (initialized in start()) - use run_id from config or generate self.run_id = config.run_id - self.repository: Optional[DataRepository] = None + self.repository: DataRepository | None = None # Lifecycle state self._stop_event = False @@ -186,7 +186,7 @@ def _setup_logging(self): self.tracker = FileProcessingTracker(tracker_path) logger.debug("Processing tracker: %s", tracker_path) - def start(self, max_runtime: Optional[int] = None): + def start(self, max_runtime: int | None = None): """Start the pipeline and run until completion or user interrupt. This is a blocking call that starts the downloader and processor diff --git a/src/adapt/runtime/processor.py b/src/adapt/runtime/processor.py index e6ca954..10cd69d 100644 --- a/src/adapt/runtime/processor.py +++ b/src/adapt/runtime/processor.py @@ -18,16 +18,16 @@ import queue import threading import time -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import pandas as pd from adapt.modules.base import ContractViolation from adapt.persistence import DataRepository, ProductType -from adapt.persistence.writer import RepositoryWriter from adapt.persistence.track_store import TrackStore +from adapt.persistence.writer import RepositoryWriter if TYPE_CHECKING: from adapt.configuration.schemas import InternalConfig @@ -69,7 +69,7 @@ def __init__( config: "InternalConfig", output_dirs: dict, file_tracker=None, - repository: Optional[DataRepository] = None, + repository: DataRepository | None = None, name: str = "RadarProcessor", ): super().__init__(daemon=True, name=name) @@ -253,7 +253,7 @@ def process_file(self, filepath) -> bool: logger.info( "Processed pair: %d cells | ingest=%.1fs detect=%.1fs project=%.1fs%s", n_cells, ingest_s, detect_s, project_s, - f" queue=%.1fs" % queue_wait_s if queue_wait_s is not None else "", + " queue=%.1fs" % queue_wait_s if queue_wait_s is not None else "", ) # Mark both files as processed @@ -290,13 +290,13 @@ def process_file(self, filepath) -> bool: # ── NetCDF persistence ──────────────────────────────────────────────────── - def _save_analysis_netcdf(self, ds, filepath: str, scan_time) -> Optional[str]: + def _save_analysis_netcdf(self, ds, filepath: str, scan_time) -> str | None: """Write the analysis dataset to a NetCDF artifact in the repository.""" try: radar = self.config.downloader.radar filename_stem = Path(filepath).stem if scan_time is None: - scan_time = datetime.now(timezone.utc) + scan_time = datetime.now(UTC) ds.attrs.update({ "source": str(filepath), @@ -342,8 +342,8 @@ def _run_ingest_detection_only(self, context: dict): Wall time for the detection (segmentation) step """ # Import modules directly (not through pipeline graph) - from adapt.modules.ingest.module import LoadModule from adapt.modules.detection.module import DetectModule + from adapt.modules.ingest.module import LoadModule # Instantiate if not cached (persist across calls) if not hasattr(self, '_ingest_module'): @@ -449,7 +449,7 @@ def _save_results(self, result: dict, scan_time): - tracked_cells, cell_events as SQLite via TrackStore (label→uid adjacency mapping in TrackStore) """ if scan_time is not None and scan_time.tzinfo is None: - scan_time = scan_time.replace(tzinfo=timezone.utc) + scan_time = scan_time.replace(tzinfo=UTC) # NetCDF: segmentation + projections + flow vectors projected_ds = result.get("projected_ds") diff --git a/src/adapt/visualization/__init__.py b/src/adapt/visualization/__init__.py index afb3d3a..eb50eeb 100644 --- a/src/adapt/visualization/__init__.py +++ b/src/adapt/visualization/__init__.py @@ -6,6 +6,6 @@ # OBSOLETE — RadarPlotter and PlotterThread are exported but never imported externally. # Only PlotConsumer is used (imported directly in cli.py). # Consider removing these exports or the classes themselves. -from .plotter import RadarPlotter, PlotterThread +from .plotter import PlotterThread, RadarPlotter __all__ = ['RadarPlotter', 'PlotterThread'] diff --git a/src/adapt/visualization/plotter.py b/src/adapt/visualization/plotter.py index c80ad68..ffafb37 100644 --- a/src/adapt/visualization/plotter.py +++ b/src/adapt/visualization/plotter.py @@ -7,17 +7,18 @@ Supports threaded queue-based processing for pipeline integration. """ -import threading -import queue import logging +import queue +import threading +from datetime import UTC, datetime from pathlib import Path -from datetime import datetime, timezone -from typing import Optional, Dict, List, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING +import matplotlib import numpy as np import pandas as pd import xarray as xr -import matplotlib + matplotlib.use('Agg') import matplotlib.pyplot as plt @@ -164,7 +165,7 @@ def _get_coord_name(self, coord_key: str, default: str) -> str: def _extract_timestamp(self, ds: xr.Dataset) -> datetime: """Extract timestamp from dataset.""" if 'time' not in ds.coords: - return datetime.now(timezone.utc) + return datetime.now(UTC) try: time_val = ds.coords['time'].values @@ -173,9 +174,9 @@ def _extract_timestamp(self, ds: xr.Dataset) -> datetime: else: return pd.Timestamp(time_val[0]).to_pydatetime() except Exception: - return datetime.now(timezone.utc) + return datetime.now(UTC) - def _get_coordinates_km(self, ds: xr.Dataset) -> Tuple[np.ndarray, np.ndarray]: + def _get_coordinates_km(self, ds: xr.Dataset) -> tuple[np.ndarray, np.ndarray]: """Get x, y coordinates in km.""" y_name = self._get_coord_name("y", "y") x_name = self._get_coord_name("x", "x") @@ -193,7 +194,7 @@ def _mask_reflectivity(self, refl: np.ndarray) -> np.ma.MaskedArray: refl_float ) - def _setup_figure(self) -> Tuple[plt.Figure, plt.Axes, plt.Axes]: + def _setup_figure(self) -> tuple[plt.Figure, plt.Axes, plt.Axes]: """Create figure with two subplots.""" fig, (ax1, ax2) = plt.subplots( 1, 2, @@ -202,7 +203,7 @@ def _setup_figure(self) -> Tuple[plt.Figure, plt.Axes, plt.Axes]: ) return fig, ax1, ax2 - def _get_radar_location(self, ds: xr.Dataset) -> Tuple[float, float]: + def _get_radar_location(self, ds: xr.Dataset) -> tuple[float, float]: """Extract radar lat/lon from dataset.""" def extract_float(val): """Convert various types to Python float scalar.""" @@ -486,7 +487,7 @@ def plot_reflectivity_with_cells( self, ds: xr.Dataset, frame_offset: int = 0, - output_path: Optional[Path] = None, + output_path: Path | None = None, ) -> str: """Generate publication-quality two-panel radar visualization. @@ -603,7 +604,7 @@ def plot_reflectivity_with_cells( def plot_from_netcdf( self, segmentation_nc: Path, - output_path: Optional[Path] = None, + output_path: Path | None = None, ) -> str: """Load analysis NetCDF and generate visualization. @@ -716,7 +717,7 @@ class PlotterThread(threading.Thread): def __init__( self, input_queue: queue.Queue, - output_dirs: Dict, + output_dirs: dict, config: "InternalConfig" = None, file_tracker = None, show_plots: bool = False, @@ -777,12 +778,12 @@ def run(self): logger.info(f"{self.name} stopped") - def _process_item(self, item: Dict): + def _process_item(self, item: dict): """Process plot item from queue.""" try: seg_nc = item.get('segmentation_nc') radar = item.get('radar', 'RADAR') - timestamp = item.get('timestamp', datetime.now(timezone.utc)) + timestamp = item.get('timestamp', datetime.now(UTC)) if not seg_nc or not Path(seg_nc).exists(): logger.warning(f"Segmentation file not found: {seg_nc}") @@ -920,7 +921,7 @@ def __init__( self.plotter = RadarPlotter(config=config, show_plots=show_live) # Track last processed artifact to detect new ones - self._last_seen_id: Optional[str] = None + self._last_seen_id: str | None = None self._processed_count = 0 # Import ProductType here to avoid circular imports @@ -976,7 +977,7 @@ def _poll_and_process(self): except Exception as e: logger.error(f"Error polling repository: {e}", exc_info=True) - def _process_artifact(self, artifact: Dict): + def _process_artifact(self, artifact: dict): """Generate plot from artifact.""" artifact_id = artifact['artifact_id'] file_path = Path(artifact['file_path']) @@ -987,7 +988,7 @@ def _process_artifact(self, artifact: Dict): if scan_time_str: scan_time = datetime.fromisoformat(scan_time_str) else: - scan_time = datetime.now(timezone.utc) + scan_time = datetime.now(UTC) # Load dataset from repository ds = self.repository.open_dataset(artifact_id) diff --git a/tests/modules/tracking/test_tracker_scan_local_outputs.py b/tests/modules/tracking/test_tracker_scan_local_outputs.py index b2ef60c..768ba2b 100644 --- a/tests/modules/tracking/test_tracker_scan_local_outputs.py +++ b/tests/modules/tracking/test_tracker_scan_local_outputs.py @@ -1,18 +1,18 @@ # Copyright © 2026, UChicago Argonne, LLC # See LICENSE for terms and disclaimer. -import numpy as np -import pandas as pd -import xarray as xr import tempfile from pathlib import Path +import numpy as np +import pandas as pd import pytest +import xarray as xr -from adapt.modules.tracking.module import RadarCellTracker from adapt.configuration.schemas.param import ParamConfig -from adapt.configuration.schemas.user import UserConfig from adapt.configuration.schemas.resolve import resolve_config +from adapt.configuration.schemas.user import UserConfig +from adapt.modules.tracking.module import RadarCellTracker @pytest.fixture diff --git a/tests/persistence/test_track_store.py b/tests/persistence/test_track_store.py index 592b3f0..a62194e 100644 --- a/tests/persistence/test_track_store.py +++ b/tests/persistence/test_track_store.py @@ -7,9 +7,7 @@ Inputs are synthetic DataFrames; no file I/O. """ import sqlite3 -import tempfile -from datetime import datetime, timezone -from pathlib import Path +from datetime import UTC, datetime import pandas as pd import pytest @@ -107,7 +105,7 @@ def store(db_path): def _t(iso: str) -> datetime: - return datetime.fromisoformat(iso).replace(tzinfo=timezone.utc) + return datetime.fromisoformat(iso).replace(tzinfo=UTC) def _cell_stats(cell_label: int, area: float = 4.0, refl: float = 40.0) -> pd.DataFrame: From 5a71245964e9778a917edf4c4d5da0bd3b94cb1c Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Fri, 1 May 2026 01:16:49 -0500 Subject: [PATCH 02/14] ADD: risky fixes made by Ruff --- src/adapt/api/client.py | 5 +---- src/adapt/cli.py | 5 ++--- src/adapt/gui/dashboard.py | 33 +++++++++------------------- src/adapt/modules/analysis/module.py | 5 ++--- src/adapt/modules/tracking/module.py | 7 +++--- src/adapt/persistence/repository.py | 5 +---- src/adapt/runtime/orchestrator.py | 5 ++--- src/adapt/runtime/processor.py | 2 +- src/adapt/visualization/plotter.py | 10 ++++----- 9 files changed, 27 insertions(+), 50 deletions(-) diff --git a/src/adapt/api/client.py b/src/adapt/api/client.py index 4690cba..d1e6b6b 100644 --- a/src/adapt/api/client.py +++ b/src/adapt/api/client.py @@ -630,10 +630,7 @@ def get_scan_at( radar = radars[0] # Convert to string for comparison - if isinstance(scan_time, datetime): - scan_time_str = scan_time.isoformat() - else: - scan_time_str = scan_time + scan_time_str = scan_time.isoformat() if isinstance(scan_time, datetime) else scan_time catalog = self._get_radar_catalog(radar) conn = catalog._get_connection() diff --git a/src/adapt/cli.py b/src/adapt/cli.py index 1385a09..52bda57 100644 --- a/src/adapt/cli.py +++ b/src/adapt/cli.py @@ -19,6 +19,7 @@ """ import argparse +import contextlib import os import signal import sys @@ -55,10 +56,8 @@ def _write_pid() -> None: def _remove_pid() -> None: - try: + with contextlib.suppress(Exception): _PID_FILE.unlink(missing_ok=True) - except Exception: - pass # --------------------------------------------------------------------------- diff --git a/src/adapt/gui/dashboard.py b/src/adapt/gui/dashboard.py index 81ee343..430bc78 100644 --- a/src/adapt/gui/dashboard.py +++ b/src/adapt/gui/dashboard.py @@ -166,12 +166,10 @@ def __init__(self, canvas, window, *, pack_toolbar=True, lat0=0.0, lon0=0.0): self._ltrans = None if HAS_PROJ and (lat0 or lon0): - try: + with contextlib.suppress(Exception): self._ltrans = Transformer.from_crs( f'+proj=aeqd +lat_0={lat0} +lon_0={lon0} +units=m', 'EPSG:4326', always_xy=True) - except Exception: - pass super().__init__(canvas, window, pack_toolbar=pack_toolbar) def set_message(self, s): @@ -666,7 +664,7 @@ def _update(*_, lv=lo_var, hv=hi_var, ll=lo_lbl, hl=hi_lbl, f=fmt): self.tv = ttk.Treeview(tv_frame, columns=self._tv_cols, show='headings', height=24) widths = [70, 60, 75, 80, 80, 85, 75, 90, 90] - for c, w in zip(self._tv_cols, widths): + for c, w in zip(self._tv_cols, widths, strict=False): hdr = (c.replace('radar_differential_reflectivity_mean', 'ZDR mean') .replace('radar_', '').replace('cell_', '') .replace('_', ' ')) @@ -836,10 +834,8 @@ def _on_close(self): # "invalid command name" errors from orphaned scheduled calls. self._nc_loop_running = False for after_id in self._after_ids: - try: + with contextlib.suppress(Exception): self.after_cancel(after_id) - except Exception: - pass self._after_ids.clear() # Close matplotlib figures @@ -1221,10 +1217,8 @@ def _draw_scan(self, ds, fig, ax=None): # Close previous dataset if self._current_nc_ds is not None and self._current_nc_ds is not ds: - try: + with contextlib.suppress(Exception): self._current_nc_ds.close() - except Exception: - pass self._current_nc_ds = ds self._cell_contours = {} for var in self._hv.values(): @@ -1459,9 +1453,8 @@ def _on_cell_click(self, event) -> None: if history_df is None or history_df.empty: df = self._current_cell_df - if df is not None: - if cell_uid is not None and 'cell_uid' in df.columns: - history_df = df[df['cell_uid'] == cell_uid].copy() + if df is not None and cell_uid is not None and 'cell_uid' in df.columns: + history_df = df[df['cell_uid'] == cell_uid].copy() self._clear_tracking_history() self._selected_cell_uid = str(cell_uid) if cell_uid is not None else None @@ -1504,10 +1497,8 @@ def _draw_tracking_history(self, ax, history_df: pd.DataFrame | None = None) -> def _clear_tracking_history(self) -> None: if self._track_overlay: for artist in self._track_overlay: - try: + with contextlib.suppress(Exception): artist.remove() - except Exception: - pass self._track_overlay = None self._selected_cell_uid = None @@ -1607,7 +1598,7 @@ def _clear_time_series(self) -> None: if self._ts_axes is None: return for ax, (ylabel, title) in zip(self._ts_axes, - [('km²', 'Area'), ('dBZ', 'Reflectivity'), ('dB', 'ZDR')]): + [('km²', 'Area'), ('dBZ', 'Reflectivity'), ('dB', 'ZDR')], strict=False): ax.cla() self._style_ts_ax(ax, ylabel, title) ax.text(0.5, 0.5, 'click a cell', transform=ax.transAxes, @@ -1648,10 +1639,8 @@ def _clear_canvas(self): self._canvas_refs = None self._hover_canvas = None if self._current_nc_ds is not None: - try: + with contextlib.suppress(Exception): self._current_nc_ds.close() - except Exception: - pass self._current_nc_ds = None self._cell_contours = {} for var in self._hv.values(): @@ -1836,10 +1825,8 @@ def _refresh_table(self): mask = pd.Series(True, index=df.index) for col, (lo_v, hi_v) in self._flt.items(): if col in df.columns: - try: + with contextlib.suppress(Exception): mask &= df[col].between(float(lo_v.get()), float(hi_v.get())) - except Exception: - pass # Cell UID prefix filter pid_prefix = self._cell_uid_filter.get().strip().upper() if self._cell_uid_filter else '' diff --git a/src/adapt/modules/analysis/module.py b/src/adapt/modules/analysis/module.py index a48e694..95c971e 100644 --- a/src/adapt/modules/analysis/module.py +++ b/src/adapt/modules/analysis/module.py @@ -22,6 +22,7 @@ Author: Bhupendra Raut """ +import contextlib import json import logging from datetime import UTC @@ -361,10 +362,8 @@ def _normalize_time_scalar(time_val): if isinstance(tv, np.ndarray): tv = tv.reshape(-1)[0] if hasattr(tv, "item"): - try: + with contextlib.suppress(Exception): tv = tv.item() - except Exception: - pass if getattr(type(tv), "__module__", "").startswith("cftime"): from datetime import datetime tv = datetime( diff --git a/src/adapt/modules/tracking/module.py b/src/adapt/modules/tracking/module.py index 9280d10..17ef3ee 100644 --- a/src/adapt/modules/tracking/module.py +++ b/src/adapt/modules/tracking/module.py @@ -29,6 +29,7 @@ References: Raut, B. A., Jackson, R., Picel, M., Collis, S. M., Bergemann, M., & Jakob, C. (2021). An adaptive tracking algorithm for convection in simulated and remote sensing data. Journal of Applied Meteorology and Climatology, 60(4), 513-526. """ +import contextlib import hashlib import logging import string @@ -449,10 +450,8 @@ def _normalize_time_scalar(time_val): tv = tv.reshape(-1)[0] if hasattr(tv, "item"): - try: + with contextlib.suppress(Exception): tv = tv.item() - except Exception: - pass # Handle cftime.* objects (pandas cannot convert them directly) if getattr(type(tv), "__module__", "").startswith("cftime"): @@ -655,7 +654,7 @@ def _track_frame_pair( matched_curr: dict[int, int] = {} # curr_idx → new curr node_id n_continue = 0 - for r, c in zip(row_ind, col_ind): + for r, c in zip(row_ind, col_ind, strict=False): if r >= n_prev or c >= n_curr: continue # dummy slot if square[r, c] <= self.keep_cost: diff --git a/src/adapt/persistence/repository.py b/src/adapt/persistence/repository.py index c9a0f1a..cd8f01e 100644 --- a/src/adapt/persistence/repository.py +++ b/src/adapt/persistence/repository.py @@ -501,10 +501,7 @@ def _normalize_item(self, item) -> dict: Resolves relative file_path to absolute, aliases item_id → artifact_id, and hoists producer from the metadata JSON for backward compatibility. """ - if hasattr(item, 'to_dict'): - d = item.to_dict() - else: - d = dict(item) + d = item.to_dict() if hasattr(item, 'to_dict') else dict(item) # Resolve relative path stored in catalog to absolute rel = d.get("file_path", "") if rel and not Path(rel).is_absolute(): diff --git a/src/adapt/runtime/orchestrator.py b/src/adapt/runtime/orchestrator.py index 3e901b9..54796b2 100644 --- a/src/adapt/runtime/orchestrator.py +++ b/src/adapt/runtime/orchestrator.py @@ -289,9 +289,8 @@ def _main_loop(self, mode: str): break # 1. Historical completion check (must run before downloader death check) - if mode == "historical": - if self._check_historical_complete(): - break + if mode == "historical" and self._check_historical_complete(): + break # 2. Check for thread failures or self-stops (e.g. ContractViolation) if self.processor.stopped(): diff --git a/src/adapt/runtime/processor.py b/src/adapt/runtime/processor.py index 10cd69d..7db853f 100644 --- a/src/adapt/runtime/processor.py +++ b/src/adapt/runtime/processor.py @@ -253,7 +253,7 @@ def process_file(self, filepath) -> bool: logger.info( "Processed pair: %d cells | ingest=%.1fs detect=%.1fs project=%.1fs%s", n_cells, ingest_s, detect_s, project_s, - " queue=%.1fs" % queue_wait_s if queue_wait_s is not None else "", + f" queue={queue_wait_s:.1f}s" if queue_wait_s is not None else "", ) # Mark both files as processed diff --git a/src/adapt/visualization/plotter.py b/src/adapt/visualization/plotter.py index ffafb37..1ed39e0 100644 --- a/src/adapt/visualization/plotter.py +++ b/src/adapt/visualization/plotter.py @@ -20,6 +20,8 @@ import xarray as xr matplotlib.use('Agg') +import contextlib + import matplotlib.pyplot as plt try: @@ -290,7 +292,7 @@ def _plot_reflectivity_field( def _add_colorbar(self, ax: plt.Axes, im: matplotlib.image.AxesImage) -> None: """Add colorbar to axis.""" - cbar = plt.colorbar(im, ax=ax, label='Reflectivity (dBZ)', fraction=0.046, pad=0.04) + plt.colorbar(im, ax=ax, label='Reflectivity (dBZ)', fraction=0.046, pad=0.04) def _plot_heading_yectors( self, @@ -980,7 +982,7 @@ def _poll_and_process(self): def _process_artifact(self, artifact: dict): """Generate plot from artifact.""" artifact_id = artifact['artifact_id'] - file_path = Path(artifact['file_path']) + Path(artifact['file_path']) scan_time_str = artifact.get('scan_time') try: @@ -1015,10 +1017,8 @@ def _process_artifact(self, artifact: dict): # Show live if enabled if self.show_live: - try: + with contextlib.suppress(Exception): plt.pause(0.1) - except Exception: - pass # Print table statistics self._print_table_stats() From e6ebdde2812b331f7ad5848bdc78212db7450db9 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Fri, 1 May 2026 02:19:52 -0500 Subject: [PATCH 03/14] STYLE:fix all ruff lint errors; untested --- src/adapt/api/client.py | 15 ++-- src/adapt/cli.py | 5 +- src/adapt/configuration/schemas/cli.py | 6 +- .../configuration/schemas/initialization.py | 17 ++-- src/adapt/configuration/schemas/internal.py | 11 ++- src/adapt/configuration/schemas/param.py | 40 +++++++-- src/adapt/configuration/schemas/user.py | 13 +-- src/adapt/gui/dashboard.py | 21 +++-- src/adapt/modules/acquisition/module.py | 6 +- src/adapt/modules/analysis/module.py | 81 +++++++++++++------ src/adapt/modules/detection/module.py | 28 ++++--- src/adapt/modules/ingest/module.py | 16 ++-- src/adapt/modules/projection/contracts.py | 6 +- src/adapt/modules/projection/module.py | 32 +++++--- src/adapt/modules/tracking/module.py | 76 ++++++++++++----- src/adapt/persistence/catalog.py | 9 ++- src/adapt/persistence/registry.py | 31 ++++--- src/adapt/persistence/repository.py | 9 ++- src/adapt/persistence/track_store.py | 51 ++++++++---- src/adapt/runtime/file_tracker.py | 24 ++++-- src/adapt/runtime/orchestrator.py | 4 +- src/adapt/runtime/processor.py | 3 +- src/adapt/visualization/plotter.py | 43 ++++++---- .../test_tracker_scan_local_outputs.py | 16 +++- tests/persistence/test_track_store.py | 18 +++-- 25 files changed, 397 insertions(+), 184 deletions(-) diff --git a/src/adapt/api/client.py b/src/adapt/api/client.py index d1e6b6b..ce23098 100644 --- a/src/adapt/api/client.py +++ b/src/adapt/api/client.py @@ -39,6 +39,7 @@ print(f"Got {len(batch)} new rows") """ +import contextlib import logging import time from datetime import UTC, datetime @@ -905,7 +906,9 @@ def get_scan_bundle( catalog = self._get_radar_catalog(radar) bundle: dict[str, Any] = { - 'scan_time': scan_time_dt.isoformat() if isinstance(scan_time_dt, datetime) else scan_time, + 'scan_time': ( + scan_time_dt.isoformat() if isinstance(scan_time_dt, datetime) else scan_time + ), 'radar': radar, 'segmentation2d': None, 'cells': None, @@ -915,11 +918,8 @@ def get_scan_bundle( # Try to get scan from scans table, fall back to items if table doesn't exist scan = None - try: + with contextlib.suppress(Exception): scan = catalog.get_scan(scan_time_dt) - except Exception: - # scans table doesn't exist - use fallback - pass # If no scan record, fall back to item-based lookup if not scan: @@ -1006,7 +1006,10 @@ def _get_scan_bundle_fallback( # Extract cell tracking info from cells DataFrame if 'cell_uid' in bundle['cells'].columns: - for uid in sorted(bundle['cells']['cell_uid'].dropna().astype(str).unique().tolist()): + uids = sorted( + bundle['cells']['cell_uid'].dropna().astype(str).unique().tolist() + ) + for uid in uids: bundle['tracks'].append({'cell_uid': uid}) return bundle diff --git a/src/adapt/cli.py b/src/adapt/cli.py index 52bda57..ce9bf15 100644 --- a/src/adapt/cli.py +++ b/src/adapt/cli.py @@ -351,7 +351,10 @@ def main() -> None: """Top-level CLI dispatcher.""" parser = argparse.ArgumentParser( prog='adapt', - description='Adapt - Real-Time data processing for informed adaptive scanning of ARM weather radars.', + description=( + 'Adapt - Real-Time data processing for informed adaptive scanning ' + 'of ARM weather radars.' + ), ) subparsers = parser.add_subparsers(dest='command', metavar='COMMAND') subparsers.required = True diff --git a/src/adapt/configuration/schemas/cli.py b/src/adapt/configuration/schemas/cli.py index b2190a9..cba98bf 100644 --- a/src/adapt/configuration/schemas/cli.py +++ b/src/adapt/configuration/schemas/cli.py @@ -65,10 +65,8 @@ def infer_historical_mode_from_times(self): the mode should automatically be historical. Runtime code should not make this decision. """ - if self.mode is None: - # Check if either start_time or end_time are provided - if self.start_time or self.end_time: - self.mode = "historical" + if self.mode is None and (self.start_time or self.end_time): + self.mode = "historical" return self diff --git a/src/adapt/configuration/schemas/initialization.py b/src/adapt/configuration/schemas/initialization.py index 27860fe..f1006c3 100644 --- a/src/adapt/configuration/schemas/initialization.py +++ b/src/adapt/configuration/schemas/initialization.py @@ -42,10 +42,10 @@ def _load_user_config_dict(config_path: str) -> dict: if path.suffix in ('.yaml', '.yml'): try: import yaml - except ImportError: + except ImportError as err: raise ImportError( "PyYAML is required for YAML config files: pip install pyyaml" - ) + ) from err with open(path) as f: data = yaml.safe_load(f) return data or {} @@ -115,7 +115,9 @@ def _handle_rerun_cleanup(base_dir: str, radar: str, rerun: bool) -> None: print("Radar output cleaned") -def _persist_runtime_config(config: InternalConfig, run_id: str, output_dirs: dict[str, Path]) -> None: +def _persist_runtime_config( + config: InternalConfig, run_id: str, output_dirs: dict[str, Path] +) -> None: """Persist final runtime configuration to output directory with run ID. Saves the complete resolved configuration for reproducibility and debugging. @@ -242,7 +244,10 @@ def init_runtime_config(args) -> InternalConfig: if _run_id_exists(base_dir_arg, normalized_run_id): print(f"Continuing existing run ID: {normalized_run_id}") - print("Ignoring user config file and CLI config overrides; reusing saved runtime config for this run.") + print( + "Ignoring user config file and CLI config overrides; " + "reusing saved runtime config for this run." + ) return _load_saved_runtime_config(base_dir_arg, normalized_run_id) # 1. Load and resolve configuration from all sources @@ -277,7 +282,9 @@ def init_runtime_config(args) -> InternalConfig: # 2. Handle --rerun cleanup BEFORE directory setup rerun = getattr(args, 'rerun', False) - _handle_rerun_cleanup(internal_config_dict["base_dir"], internal_config_dict["downloader"]["radar"], rerun) + _handle_rerun_cleanup( + internal_config_dict["base_dir"], internal_config_dict["downloader"]["radar"], rerun + ) # 3. Setup output directories output_dirs = _setup_output_directories(internal_config_dict["base_dir"]) diff --git a/src/adapt/configuration/schemas/internal.py b/src/adapt/configuration/schemas/internal.py index 0e140e9..1f82f84 100644 --- a/src/adapt/configuration/schemas/internal.py +++ b/src/adapt/configuration/schemas/internal.py @@ -159,7 +159,8 @@ class InternalProcessorConfig(AdaptBaseModel): """Runtime processor configuration.""" max_history: int = Field(default=2, ge=2, le=10) # Frame history for optical flow min_file_size: int = Field(default=5000, ge=1000) # Minimum file size in bytes - db_filename_pattern: str = Field(default="{radar}_cells_statistics.db") # Database filename pattern + # Database filename pattern + db_filename_pattern: str = Field(default="{radar}_cells_statistics.db") # ============================================================================= @@ -193,8 +194,12 @@ def __init__(self, config: InternalConfig): mode: Literal["realtime", "historical"] base_dir: str - run_id: str | None = Field(default=None, description="Unique run identifier generated during initialization") - output_dirs: dict[str, str] | None = Field(default=None, description="Output directory paths from initialization") + run_id: str | None = Field( + default=None, description="Unique run identifier generated during initialization" + ) + output_dirs: dict[str, str] | None = Field( + default=None, description="Output directory paths from initialization" + ) reader: InternalReaderConfig downloader: InternalDownloaderConfig regridder: InternalRegridderConfig diff --git a/src/adapt/configuration/schemas/param.py b/src/adapt/configuration/schemas/param.py index bbd597e..0094406 100644 --- a/src/adapt/configuration/schemas/param.py +++ b/src/adapt/configuration/schemas/param.py @@ -34,7 +34,9 @@ class DownloaderConfig(AdaptBaseModel): poll_interval_sec: int = Field(300, ge=1, description="Polling interval in seconds") start_time: str | None = None end_time: str | None = None - min_file_size: int = Field(1024, ge=1, description="Minimum file size in bytes to consider valid") + min_file_size: int = Field( + 1024, ge=1, description="Minimum file size in bytes to consider valid" + ) class RegridderConfig(AdaptBaseModel): @@ -114,7 +116,9 @@ class ProjectorConfig(AdaptBaseModel): nan_fill_value: float = 0.0 flow_params: FlowParamsConfig = Field(default_factory=FlowParamsConfig) min_motion_threshold: float = Field(0.5, ge=0) - max_flow_magnitude: float = Field(20.0, gt=0, description="Clip flow vectors exceeding this magnitude (pixels/frame)") + max_flow_magnitude: float = Field( + 20.0, gt=0, description="Clip flow vectors exceeding this magnitude (pixels/frame)" + ) @field_validator("method", mode="before") @classmethod @@ -149,7 +153,10 @@ class AnalyzerConfig(AdaptBaseModel): adjacency_min_touching_boundary_pixels: int = Field( 1, ge=1, - description="Min number of touching boundary pixels to count two labels as adjacent in the same scan", + description=( + "Min number of touching boundary pixels to count two labels " + "as adjacent in the same scan" + ), ) @@ -163,11 +170,28 @@ class CellUidConfig(AdaptBaseModel): width: int = Field(10, ge=1) alphabet: Literal["base36_upper"] = "base36_upper" - match_cost_threshold: float = Field(0.15, ge=0.0, description="Cost below this is forced to 0 before Hungarian (guaranteed match)") - keep_cost_threshold: float = Field(1.0, ge=0.0, description="Post-Hungarian: cost <= this confirms CONTINUE, else pair is rejected") - unmatch_cost_threshold: float = Field(2.0, ge=0.0, description="Cost above this is forced to dummy_cost before Hungarian (unlikely match)") - split_overlap_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Min fraction of projected hull area overlapping born/surviving cell to confirm SPLIT or MERGE") - core_reflectivity_threshold: float = Field(40.0, ge=0.0, description="Reflectivity threshold for core area (dBZ)") + match_cost_threshold: float = Field( + 0.15, ge=0.0, + description="Cost below this is forced to 0 before Hungarian (guaranteed match)", + ) + keep_cost_threshold: float = Field( + 1.0, ge=0.0, + description="Post-Hungarian: cost <= this confirms CONTINUE, else pair is rejected", + ) + unmatch_cost_threshold: float = Field( + 2.0, ge=0.0, + description="Cost above this is forced to dummy_cost before Hungarian (unlikely match)", + ) + split_overlap_threshold: float = Field( + 0.8, ge=0.0, le=1.0, + description=( + "Min fraction of projected hull area overlapping born/surviving cell " + "to confirm SPLIT or MERGE" + ), + ) + core_reflectivity_threshold: float = Field( + 40.0, ge=0.0, description="Reflectivity threshold for core area (dBZ)" + ) cell_uid: CellUidConfig = Field(default_factory=CellUidConfig) diff --git a/src/adapt/configuration/schemas/user.py b/src/adapt/configuration/schemas/user.py index 12f19e7..c4796ee 100644 --- a/src/adapt/configuration/schemas/user.py +++ b/src/adapt/configuration/schemas/user.py @@ -143,7 +143,9 @@ class UserConfig(AdaptBaseModel): # Grid settings (flat aliases) grid_shape: tuple[int, int, int] | None = Field(None, alias="GRID_SHAPE") - grid_limits: tuple[tuple[float, float], tuple[float, float], tuple[float, float]] | None = Field(None, alias="GRID_LIMITS") + grid_limits: tuple[ + tuple[float, float], tuple[float, float], tuple[float, float] + ] | None = Field(None, alias="GRID_LIMITS") # Segmentation settings (flat aliases) z_level: float | None = Field(None, alias="Z_LEVEL") @@ -180,10 +182,11 @@ def infer_historical_mode_from_times(self): This is a schema responsibility: if user config indicates a time range, the mode should automatically be historical. """ - if self.mode is None: - # Check top-level times - if self.start_time and self.end_time or self.downloader and (self.downloader.start_time and self.downloader.end_time): - self.mode = "historical" + if self.mode is None and ( + (self.start_time and self.end_time) + or (self.downloader and (self.downloader.start_time and self.downloader.end_time)) + ): + self.mode = "historical" return self diff --git a/src/adapt/gui/dashboard.py b/src/adapt/gui/dashboard.py index 430bc78..65731fe 100644 --- a/src/adapt/gui/dashboard.py +++ b/src/adapt/gui/dashboard.py @@ -63,12 +63,12 @@ def _suppress_osx_stderr(): pass # ── Tkinter ─────────────────────────────────────────────────────────────────── -import tkinter as tk -from tkinter import filedialog, messagebox, scrolledtext, ttk +import tkinter as tk # noqa: E402 +from tkinter import filedialog, messagebox, scrolledtext, ttk # noqa: E402 # ── Optional deps ───────────────────────────────────────────────────────────── try: - from PIL import Image, ImageTk + import PIL # noqa: F401 HAS_PIL = True except ImportError: HAS_PIL = False @@ -76,7 +76,7 @@ def _suppress_osx_stderr(): try: import matplotlib matplotlib.use('TkAgg') - import cmweather.cm # registers ChaseSpectral and other radar colormaps — must follow use() + import cmweather # noqa: F401 — registers ChaseSpectral and other radar colormaps — must follow use() import matplotlib.dates as mdates import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk @@ -1538,7 +1538,8 @@ def _update_time_series(self, history_df: pd.DataFrame | None = None) -> None: cell_uid = str(track_df['cell_uid'].dropna().iloc[0]) else: cell_uid = self._selected_cell_uid - if not cell_uid or self._current_cell_df is None or 'cell_uid' not in self._current_cell_df.columns: + if (not cell_uid or self._current_cell_df is None + or 'cell_uid' not in self._current_cell_df.columns): return track_df = ( self._current_cell_df[self._current_cell_df['cell_uid'] == str(cell_uid)] @@ -1597,8 +1598,11 @@ def _update_time_series(self, history_df: pd.DataFrame | None = None) -> None: def _clear_time_series(self) -> None: if self._ts_axes is None: return - for ax, (ylabel, title) in zip(self._ts_axes, - [('km²', 'Area'), ('dBZ', 'Reflectivity'), ('dB', 'ZDR')], strict=False): + for ax, (ylabel, title) in zip( + self._ts_axes, + [('km²', 'Area'), ('dBZ', 'Reflectivity'), ('dB', 'ZDR')], + strict=False, + ): ax.cla() self._style_ts_ax(ax, ylabel, title) ax.text(0.5, 0.5, 'click a cell', transform=ax.transAxes, @@ -1863,7 +1867,8 @@ def _avg(col, fmt='.1f'): 'time_label': 65, 'cell_uid': 160, 'cell_label': 55, 'cell_area_sqkm': 70, 'area_40dbz_km2': 70, 'radar_reflectivity_max': 75, 'radar_reflectivity_mean': 75, - 'radar_differential_reflectivity_max': 75, 'radar_differential_reflectivity_mean': 75, + 'radar_differential_reflectivity_max': 75, + 'radar_differential_reflectivity_mean': 75, 'cell_centroid_mass_lat': 80, 'cell_centroid_mass_lon': 80, 'n_adjacent_cells': 65, } diff --git a/src/adapt/modules/acquisition/module.py b/src/adapt/modules/acquisition/module.py index b9b5fa4..6c8617e 100644 --- a/src/adapt/modules/acquisition/module.py +++ b/src/adapt/modules/acquisition/module.py @@ -530,7 +530,7 @@ def _file_exists(self, path: Path) -> bool: """Check if valid file exists.""" try: return path.exists() and path.stat().st_size >= self._min_file_size - except: + except Exception: return False def _download_scan(self, scan, local_path: Path) -> bool: @@ -573,7 +573,9 @@ def _notify_queue(self, path: Path, scan_time: datetime, is_new: bool, file_id = path.stem if tracker: tracker.register_file(file_id, self.radar, scan_time, path) - timings = {"download_seconds": download_seconds} if download_seconds is not None else None + timings = ( + {"download_seconds": download_seconds} if download_seconds is not None else None + ) tracker.mark_stage_complete(file_id, "downloaded", path=path, timings=timings) self.result_queue.put( diff --git a/src/adapt/modules/analysis/module.py b/src/adapt/modules/analysis/module.py index 95c971e..af525f5 100644 --- a/src/adapt/modules/analysis/module.py +++ b/src/adapt/modules/analysis/module.py @@ -271,13 +271,17 @@ def extract_adjacency(self, ds: xr.Dataset) -> pd.DataFrame: """ labels_name = self.config.global_.var_names.cell_labels if labels_name not in ds.data_vars: - raise ValueError(f"Missing required labels variable '{labels_name}' for adjacency extraction") + raise ValueError( + f"Missing required labels variable '{labels_name}' for adjacency extraction" + ) if "time" not in ds.coords: raise ValueError("Missing required coordinate 'time' for adjacency extraction") labels = ds[labels_name].values if labels.ndim != 2: - raise ValueError(f"Expected 2D labels array for adjacency extraction, got shape={labels.shape}") + raise ValueError( + f"Expected 2D labels array for adjacency extraction, got shape={labels.shape}" + ) scan_time = str(ds.time.values) adjacency = self._compute_boundary_adjacency( @@ -342,7 +346,9 @@ def _compute_boundary_adjacency(labels: np.ndarray, min_touching_pixels: int) -> if v >= min_touching_pixels ] if not rows: - return pd.DataFrame(columns=["cell_label_a", "cell_label_b", "touching_boundary_pixels"]) + return pd.DataFrame( + columns=["cell_label_a", "cell_label_b", "touching_boundary_pixels"] + ) df = pd.DataFrame(rows) df = df.sort_values(["cell_label_a", "cell_label_b"]).reset_index(drop=True) @@ -384,7 +390,8 @@ def _get_lat_lon_grids(self, ds): Returns lat/lon grids if available, otherwise returns placeholder grids of zeros (valid for in-memory analysis, invalid for geographic output). """ - if "lat" in ds.coords and "lon" in ds.coords or "lat" in ds.data_vars and "lon" in ds.data_vars: + if (("lat" in ds.coords and "lon" in ds.coords) + or ("lat" in ds.data_vars and "lon" in ds.data_vars)): return ds["lat"].values, ds["lon"].values else: # No lat/lon available - use placeholder zeros @@ -401,10 +408,9 @@ def _get_valid_data_vars(self, ds): """ available_vars = [] for var in self.radar_variables: - if var in ds.data_vars and var not in self.exclude_fields: - # Check if it's 2D (y, x) or 3D (z, y, x) - if ds[var].dims[-2:] == ("y", "x") or ds[var].dims[-3:] == ("z", "y", "x"): - available_vars.append(var) + if (var in ds.data_vars and var not in self.exclude_fields + and (ds[var].dims[-2:] == ("y", "x") or ds[var].dims[-3:] == ("z", "y", "x"))): + available_vars.append(var) return available_vars def _compute_geometric_centroid(self, mask, lat_grid=None, lon_grid=None): @@ -491,12 +497,18 @@ def _extract_region_props(self, region, label_array, refl, lat_grid, lon_grid, y_indices, x_indices = np.where(mask) valid_mask = np.isfinite(refl_cell) if np.any(valid_mask): - centroid_mass_y = int(np.round(np.average(y_indices[valid_mask], weights=refl_cell[valid_mask]))) - centroid_mass_x = int(np.round(np.average(x_indices[valid_mask], weights=refl_cell[valid_mask]))) + centroid_mass_y = int( + np.round(np.average(y_indices[valid_mask], weights=refl_cell[valid_mask])) + ) + centroid_mass_x = int( + np.round(np.average(x_indices[valid_mask], weights=refl_cell[valid_mask])) + ) else: - centroid_mass_y, centroid_mass_x = int(np.round(geom_props["centroid_y"])), int(np.round(geom_props["centroid_x"])) + centroid_mass_y = int(np.round(geom_props["centroid_y"])) + centroid_mass_x = int(np.round(geom_props["centroid_x"])) else: - centroid_mass_y, centroid_mass_x = int(np.round(geom_props["centroid_y"])), int(np.round(geom_props["centroid_x"])) + centroid_mass_y = int(np.round(geom_props["centroid_y"])) + centroid_mass_x = int(np.round(geom_props["centroid_x"])) lat_mass, lon_mass = self.get_lat_lon(centroid_mass_x, centroid_mass_y, lat_grid, lon_grid) @@ -545,18 +557,22 @@ def _extract_region_props(self, region, label_array, refl, lat_grid, lon_grid, projection_centroids = [] # Extract centroids for each projection step - for step_idx in range(min(projections.shape[0], self.max_projection_steps + 1)): + for step_idx in range( + min(projections.shape[0], self.max_projection_steps + 1) + ): proj_mask = projections[step_idx] == region.label if np.any(proj_mask): # Use reusable centroid function (already has lat/lon) - proj_centroid = self._compute_geometric_centroid(proj_mask, lat_grid, lon_grid) + proj_centroid = self._compute_geometric_centroid( + proj_mask, lat_grid, lon_grid + ) projection_centroids.append(proj_centroid) else: projection_centroids.append(None) - + # Store each centroid in both XY and lat/lon if projection_centroids: - # Index 0 = Registration centroid (projection from previous to current frame) + # Index 0 = Registration centroid (projection from previous to current) if projection_centroids[0] is not None: reg_cent = projection_centroids[0] props["cell_centroid_registration_x"] = reg_cent["centroid_x"] @@ -564,19 +580,32 @@ def _extract_region_props(self, region, label_array, refl, lat_grid, lon_grid, if "centroid_lat" in reg_cent: props["cell_centroid_registration_lat"] = reg_cent["centroid_lat"] props["cell_centroid_registration_lon"] = reg_cent["centroid_lon"] - + # Indices 1+ = Forward projection centroids for proj_idx, proj_cent in enumerate(projection_centroids[1:], start=1): if proj_cent is not None: - props[f"cell_centroid_projection{proj_idx}_x"] = proj_cent["centroid_x"] - props[f"cell_centroid_projection{proj_idx}_y"] = proj_cent["centroid_y"] + props[f"cell_centroid_projection{proj_idx}_x"] = ( + proj_cent["centroid_x"] + ) + props[f"cell_centroid_projection{proj_idx}_y"] = ( + proj_cent["centroid_y"] + ) if "centroid_lat" in proj_cent: - props[f"cell_centroid_projection{proj_idx}_lat"] = proj_cent["centroid_lat"] - props[f"cell_centroid_projection{proj_idx}_lon"] = proj_cent["centroid_lon"] - + props[f"cell_centroid_projection{proj_idx}_lat"] = ( + proj_cent["centroid_lat"] + ) + props[f"cell_centroid_projection{proj_idx}_lon"] = ( + proj_cent["centroid_lon"] + ) + # Also store full projection centroids as JSON for compact storage props["cell_projection_centroids_json"] = json.dumps([ - {k: v for k, v in c.items() if c and not (isinstance(v, float) and np.isnan(v))} if c else None + ( + { + k: v for k, v in c.items() + if c and not (isinstance(v, float) and np.isnan(v)) + } if c else None + ) for c in projection_centroids ]) except Exception as e: @@ -637,10 +666,10 @@ def get_lat_lon(ix, iy, lat_grid, lon_grid): # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- -from adapt.execution.module_registry import registry -from adapt.modules.base import BaseModule +from adapt.execution.module_registry import registry # noqa: E402 +from adapt.modules.base import BaseModule # noqa: E402 -from .contracts import assert_analysis_output, assert_cell_adjacency +from .contracts import assert_analysis_output, assert_cell_adjacency # noqa: E402 def _check_cell_stats(df): diff --git a/src/adapt/modules/detection/module.py b/src/adapt/modules/detection/module.py index 6d1131c..ea75bff 100644 --- a/src/adapt/modules/detection/module.py +++ b/src/adapt/modules/detection/module.py @@ -295,13 +295,17 @@ def _segment2D_threshold(self, ds: xr.Dataset) -> xr.Dataset: # we attach labels to original dataset ds_out = ds.copy() ds_out[self.labels_name] = labels_da - logger.debug(f"Labels attached: var={self.labels_name}, shape={labels.shape}, max={labels.max()}") + logger.debug( + f"Labels attached: var={self.labels_name}, shape={labels.shape}, max={labels.max()}" + ) return ds_out - def _binary_to_labels(self, binary_mask: np.ndarray, field: np.ndarray, - kernel_size: tuple, filter_by_size: bool, - min_gridpoints: int, max_gridpoints: int) -> np.ndarray: + def _binary_to_labels( + self, binary_mask: np.ndarray, field: np.ndarray, + kernel_size: tuple, filter_by_size: bool, + min_gridpoints: int, max_gridpoints: int, + ) -> np.ndarray: """Morphology, detect cells, filter.""" from skimage.morphology import closing, footprint_rectangle @@ -311,7 +315,9 @@ def _binary_to_labels(self, binary_mask: np.ndarray, field: np.ndarray, # if there are any cells, filter and/or renumber if labels.max() > 0: - labels = self._filter_and_relabel(labels, filter_by_size, min_gridpoints, max_gridpoints) + labels = self._filter_and_relabel( + labels, filter_by_size, min_gridpoints, max_gridpoints + ) return labels.astype(np.int32) @@ -344,7 +350,9 @@ def _filter_and_relabel(self, labels: np.ndarray, filter_by_size: bool, return labels_renumbered - def _relabel_by_size(self, labels: np.ndarray, labels_to_keep: np.ndarray, counts: np.ndarray) -> np.ndarray: + def _relabel_by_size( + self, labels: np.ndarray, labels_to_keep: np.ndarray, counts: np.ndarray + ) -> np.ndarray: """Renumber: largest=1.""" keep_indices = np.isin(np.arange(len(counts)), labels_to_keep) keep_counts = counts[keep_indices] @@ -362,11 +370,11 @@ def _relabel_by_size(self, labels: np.ndarray, labels_to_keep: np.ndarray, count # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- -from adapt.execution.module_registry import registry -from adapt.modules.base import BaseModule -from adapt.modules.ingest.contracts import assert_gridded +from adapt.execution.module_registry import registry # noqa: E402 +from adapt.modules.base import BaseModule # noqa: E402 +from adapt.modules.ingest.contracts import assert_gridded # noqa: E402 -from .contracts import assert_segmented +from .contracts import assert_segmented # noqa: E402 def _check_grid_ds_2d(ds): diff --git a/src/adapt/modules/ingest/module.py b/src/adapt/modules/ingest/module.py index d030a45..51b4c3e 100644 --- a/src/adapt/modules/ingest/module.py +++ b/src/adapt/modules/ingest/module.py @@ -349,17 +349,17 @@ def load_and_regrid(self, filepath: Path | str, grid_kwargs: dict = None, # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- -from datetime import UTC -from datetime import datetime as _dt +from datetime import UTC # noqa: E402 +from datetime import datetime as _dt # noqa: E402 -import numpy as np -import xarray as _xr +import numpy as np # noqa: E402 +import xarray as _xr # noqa: E402 -from adapt.configuration.schemas.directories import get_netcdf_path -from adapt.execution.module_registry import registry -from adapt.modules.base import BaseModule +from adapt.configuration.schemas.directories import get_netcdf_path # noqa: E402 +from adapt.execution.module_registry import registry # noqa: E402 +from adapt.modules.base import BaseModule # noqa: E402 -from .contracts import assert_gridded +from .contracts import assert_gridded # noqa: E402 def _check_grid_ds_2d(ds): diff --git a/src/adapt/modules/projection/contracts.py b/src/adapt/modules/projection/contracts.py index 804bbb5..0c5f57f 100644 --- a/src/adapt/modules/projection/contracts.py +++ b/src/adapt/modules/projection/contracts.py @@ -51,7 +51,8 @@ def assert_projected(ds: xr.Dataset, max_steps: int = 5) -> None: projections = ds["cell_projections"] require( projections.ndim == 3, - f"Projection contract violated: 'cell_projections' has {projections.ndim} dims, expected 3 (step, y, x)" + f"Projection contract violated: 'cell_projections' has {projections.ndim} dims, " + "expected 3 (step, y, x)", ) # Use stored config value if available (self-describing data pattern) @@ -62,5 +63,6 @@ def assert_projected(ds: xr.Dataset, max_steps: int = 5) -> None: expected_steps = max_steps_actual + 1 # 1 registration + N future require( num_steps == expected_steps, - f"Projection contract violated: found {num_steps} steps, expected {expected_steps} (1 registration + {max_steps_actual} projections from config)" + f"Projection contract violated: found {num_steps} steps, expected {expected_steps} " + f"(1 registration + {max_steps_actual} projections from config)" ) diff --git a/src/adapt/modules/projection/module.py b/src/adapt/modules/projection/module.py index c7a8551..1242a28 100644 --- a/src/adapt/modules/projection/module.py +++ b/src/adapt/modules/projection/module.py @@ -124,7 +124,9 @@ class RadarCellProjector: >>> projector = RadarCellProjector(config) >>> ds_with_motion = projector.project([ds_t1, ds_t0]) >>> num_projections = ds_with_motion["cell_projections"].shape[0] - >>> print(f"Generated {num_projections} projections (1 registration + {num_projections-1} future)") + >>> print( + ... f"Generated {num_projections} projections (1 registration + {num_projections-1} future)" + ... ) """ def __init__(self, config: "InternalConfig"): @@ -253,8 +255,12 @@ def _project_opticalflow(self, ds_list): # Get reflectivity from ds (already at correct z-level from processor) # Reflectivity is always 2D at the configured z-level - refl1 = np.nan_to_num(ds_list[0][self.refl_var].values, nan=self.nan_fill).astype(np.float32) - refl2 = np.nan_to_num(ds_list[1][self.refl_var].values, nan=self.nan_fill).astype(np.float32) + refl1 = ( + np.nan_to_num(ds_list[0][self.refl_var].values, nan=self.nan_fill).astype(np.float32) + ) + refl2 = ( + np.nan_to_num(ds_list[1][self.refl_var].values, nan=self.nan_fill).astype(np.float32) + ) refl1_norm, refl2_norm = self._normalize(refl1, refl2) flow = cv2.calcOpticalFlowFarneback(refl1_norm, refl2_norm, None, **self.flow_params) @@ -268,7 +274,8 @@ def _project_opticalflow(self, ds_list): # Generate projections: # - First projection (offset=0) is registration: t-1 → t0 (uses labels from t-1) - # - Subsequent projections (offset=1,2,...) are future: t0→t1, t1→t2, etc. (uses labels from t0) + # - Subsequent projections (offset=1,2,...) are future: t0→t1, t1→t2, etc. + # (uses labels from t0) # So total projections = max_proj_steps + 1 (1 for registration + N for future) labels_proj_list = [] @@ -279,8 +286,8 @@ def _project_opticalflow(self, ds_list): # Future projections - project current labels (t0) forward (n steps) # Each pixel carries its original flow value and uses accumulated displacement. - # @TODO I have removed more complecated logic of using flow at new positions for each step, - # because some cells did not move in noisy radar data during the test. I will test it again later. + # @TODO I have removed more complecated logic of using flow at new positions for each step, + # because some cells did not move in noisy radar data during the test. future_projections = self._project_frames(labels_curr, flow, n_steps=self.max_proj_steps) for i in range(self.max_proj_steps): labels_proj_list.append(future_projections[i]) @@ -432,8 +439,8 @@ def _validate_datasets(self, ds_list, max_interval_minutes): # Note: Processor already validated time gap, so we just warn if large if abs(time_diff_minutes) > max_interval_minutes: logger.warning( - f"Time interval {time_diff_minutes:.1f} min exceeds max {max_interval_minutes} min. " - "Processor should have filtered this pair." + f"Time interval {time_diff_minutes:.1f} min exceeds max " + f"{max_interval_minutes} min. Processor should have filtered this pair." ) return time_diff_minutes @@ -562,9 +569,9 @@ def _fill_concave_hull(self, label_mask, alpha=0.1): # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- -from adapt.execution.module_registry import registry -from adapt.modules.base import BaseModule -from adapt.modules.detection.contracts import assert_segmented +from adapt.execution.module_registry import registry # noqa: E402 +from adapt.modules.base import BaseModule # noqa: E402 +from adapt.modules.detection.contracts import assert_segmented # noqa: E402 def _check_segmented_ds(ds): @@ -629,7 +636,8 @@ def run(self, context: dict) -> dict: # Must have 2 frames (guaranteed by processor orchestration, but double-check) if len(self._dataset_history) < 2: raise ValueError( - f"ProjectionModule requires 2 frames, but only {len(self._dataset_history)} available. " + f"ProjectionModule requires 2 frames, but only " + f"{len(self._dataset_history)} available. " "Processor should orchestrate frame pairing before calling projection." ) diff --git a/src/adapt/modules/tracking/module.py b/src/adapt/modules/tracking/module.py index 17ef3ee..060c30b 100644 --- a/src/adapt/modules/tracking/module.py +++ b/src/adapt/modules/tracking/module.py @@ -3,7 +3,8 @@ """Track convective cells across consecutive radar scans using mask overlap and motion prediction. -This module implements a cell tracking algorithm inspired from TINT (Raut et al., 2021) with following improvements. +This module implements a cell tracking algorithm inspired from TINT (Raut et al., 2021) with +following improvements. - Motion-aware matching via projected label masks (no centroid-only matching) - Projected mask overlap with current frame allows split and merge (registration frame) - Split candidate: one cell → multiple cells (1 to N) in the projected area of a continuing parent @@ -18,7 +19,8 @@ 1. **tracked_cells**: Per-observation rows for the current scan 2. **cell_events**: Explicit lineage/event rows for the current scan -Tracking state is stored in a directed graph structure with nodes representing cell observations and edges representing temporal relationships. +Tracking state is stored in a directed graph structure with nodes representing cell observations +and edges representing temporal relationships. What is different from TINT: - No centroid-only matching (uses full mask overlap + motion prediction) @@ -26,7 +28,9 @@ Author: Bhupendra Raut, ANL. -References: Raut, B. A., Jackson, R., Picel, M., Collis, S. M., Bergemann, M., & Jakob, C. (2021). An adaptive tracking algorithm for convection in simulated and remote sensing data. Journal of Applied Meteorology and Climatology, 60(4), 513-526. +References: Raut, B. A., Jackson, R., Picel, M., Collis, S. M., Bergemann, M., & Jakob, C. +(2021). An adaptive tracking algorithm for convection in simulated and remote sensing data. +Journal of Applied Meteorology and Climatology, 60(4), 513-526. """ import contextlib @@ -661,11 +665,15 @@ def _track_frame_pair( prev_node = prev_node_ids[r] track_index = self.graph.get_node_attr(prev_node, 'track_index') curr_node = self._add_cell_node(curr_time, curr_cells[c], int(track_index or 0)) - self.graph.add_edge(prev_node, curr_node, edge_type="CONTINUE", cost=float(square[r, c])) + self.graph.add_edge( + prev_node, curr_node, edge_type="CONTINUE", cost=float(square[r, c]) + ) matched_prev[r] = curr_node matched_curr[c] = curr_node n_continue += 1 - events.append(self._event_continue(curr_time, prev_node, curr_node, float(square[r, c]))) + events.append( + self._event_continue(curr_time, prev_node, curr_node, float(square[r, c])) + ) dissipated = [prev_node_ids[i] for i in range(n_prev) if i not in matched_prev] born_indices = [i for i in range(n_curr) if i not in matched_curr] @@ -693,7 +701,9 @@ def _track_frame_pair( new_index = self.graph.get_new_track_index() cell_uid, track_signature = self._new_cell_identity(curr_cells[b_idx]) self._cell_identity[new_index] = (cell_uid, track_signature) - child_node = self._add_cell_node(curr_time, curr_cells[b_idx], new_index, cell_uid, track_signature) + child_node = self._add_cell_node( + curr_time, curr_cells[b_idx], new_index, cell_uid, track_signature + ) self.graph.add_edge(best_parent, child_node, edge_type="SPLIT", cost=0.0) split_born.add(b_idx) events.append(self._event_split(curr_time, best_parent, child_node)) @@ -738,13 +748,17 @@ def _track_frame_pair( new_index = self.graph.get_new_track_index() cell_uid, track_signature = self._new_cell_identity(curr_cells[b_idx]) self._cell_identity[new_index] = (cell_uid, track_signature) - node_id = self._add_cell_node(curr_time, curr_cells[b_idx], new_index, cell_uid, track_signature) + node_id = self._add_cell_node( + curr_time, curr_cells[b_idx], new_index, cell_uid, track_signature + ) n_births += 1 events.append(self._event_initiation(curr_time, node_id)) for d_node in dissipated: if d_node in merged_nodes: - events.append(self._event_termination(curr_time, d_node, target_node_id=merged_nodes[d_node])) + events.append( + self._event_termination(curr_time, d_node, target_node_id=merged_nodes[d_node]) + ) else: events.append(self._event_termination(curr_time, d_node, target_node_id=None)) @@ -808,7 +822,9 @@ def _build_cell_events_dataframe(events: list[dict]) -> pd.DataFrame: if col not in df.columns: df[col] = None df = df[cols] - df["time"] = df["time"].apply(lambda t: pd.Timestamp(RadarCellTracker._normalize_time_scalar(t))) + df["time"] = df["time"].apply( + lambda t: pd.Timestamp(RadarCellTracker._normalize_time_scalar(t)) + ) return df # ------------------------------------------------------------------ @@ -816,8 +832,12 @@ def _build_cell_events_dataframe(events: list[dict]) -> pd.DataFrame: # ------------------------------------------------------------------ def _event_continue(self, time, prev_node_id: int, curr_node_id: int, cost: float) -> dict: - source_cell_uid = self.get_cell_identity(int(self.graph.get_node_attr(prev_node_id, "track_index")))[0] - target_cell_uid = self.get_cell_identity(int(self.graph.get_node_attr(curr_node_id, "track_index")))[0] + source_cell_uid = self.get_cell_identity( + int(self.graph.get_node_attr(prev_node_id, "track_index")) + )[0] + target_cell_uid = self.get_cell_identity( + int(self.graph.get_node_attr(curr_node_id, "track_index")) + )[0] return { "time": time, "event_type": "CONTINUE", @@ -831,8 +851,12 @@ def _event_continue(self, time, prev_node_id: int, curr_node_id: int, cost: floa } def _event_split(self, time, parent_node_id: int, child_node_id: int) -> dict: - parent_uid = self.get_cell_identity(int(self.graph.get_node_attr(parent_node_id, "track_index")))[0] - child_uid = self.get_cell_identity(int(self.graph.get_node_attr(child_node_id, "track_index")))[0] + parent_uid = self.get_cell_identity( + int(self.graph.get_node_attr(parent_node_id, "track_index")) + )[0] + child_uid = self.get_cell_identity( + int(self.graph.get_node_attr(child_node_id, "track_index")) + )[0] return { "time": time, "event_type": "SPLIT", @@ -862,7 +886,9 @@ def _event_merge(self, time, source_node_id: int, target_node_id: int) -> dict: } def _event_initiation(self, time, node_id: int) -> dict: - target_uid = self.get_cell_identity(int(self.graph.get_node_attr(node_id, "track_index")))[0] + target_uid = self.get_cell_identity( + int(self.graph.get_node_attr(node_id, "track_index")) + )[0] return { "time": time, "event_type": "INITIATION", @@ -877,15 +903,23 @@ def _event_initiation(self, time, node_id: int) -> dict: def _event_termination(self, time, source_node_id: int, target_node_id: int | None) -> dict: source_path = int(self.graph.get_node_attr(source_node_id, "track_index")) - target_path = int(self.graph.get_node_attr(target_node_id, "track_index")) if target_node_id is not None else None + target_path = ( + int(self.graph.get_node_attr(target_node_id, "track_index")) + if target_node_id is not None else None + ) source_uid = self.get_cell_identity(source_path)[0] return { "time": time, "event_type": "TERMINATION", "source_cell_uid": source_uid, - "target_cell_uid": self.get_cell_identity(target_path)[0] if target_path is not None else None, + "target_cell_uid": ( + self.get_cell_identity(target_path)[0] if target_path is not None else None + ), "source_cell_label": int(self.graph.get_node_attr(source_node_id, "cell_id")), - "target_cell_label": int(self.graph.get_node_attr(target_node_id, "cell_id")) if target_node_id is not None else None, + "target_cell_label": ( + int(self.graph.get_node_attr(target_node_id, "cell_id")) + if target_node_id is not None else None + ), "cost": None, "is_dominant": False, "event_group_id": f"{self._time_key(time)}:TERMINATION:{source_uid}", @@ -896,11 +930,11 @@ def _event_termination(self, time, source_node_id: int, target_node_id: int | No # BaseModule wrapper (Phase 6 implementation placeholder) # ============================================================================= -from adapt.execution.module_registry import registry -from adapt.modules.base import BaseModule -from adapt.modules.projection.contracts import assert_projected +from adapt.execution.module_registry import registry # noqa: E402 +from adapt.modules.base import BaseModule # noqa: E402 +from adapt.modules.projection.contracts import assert_projected # noqa: E402 -from .contracts import assert_cell_events, assert_tracked_cells +from .contracts import assert_cell_events, assert_tracked_cells # noqa: E402 def _check_projected_ds(ds: xr.Dataset) -> None: diff --git a/src/adapt/persistence/catalog.py b/src/adapt/persistence/catalog.py index 9f483ae..79c6513 100644 --- a/src/adapt/persistence/catalog.py +++ b/src/adapt/persistence/catalog.py @@ -87,7 +87,10 @@ def _get_connection(self) -> sqlite3.Connection: def _init_database(self) -> None: """Initialize database schema from SQL file.""" - schema_path = Path(__file__).resolve().parents[1] / "configuration" / "schemas" / "radar_catalog_schema.sql" + schema_path = ( + Path(__file__).resolve().parents[1] / "configuration" / "schemas" + / "radar_catalog_schema.sql" + ) if not schema_path.exists(): # Fallback to embedded schema @@ -134,7 +137,9 @@ def _create_schema_inline(self) -> None: conn.execute("CREATE INDEX IF NOT EXISTS idx_items_run ON items(run_id)") conn.execute("CREATE INDEX IF NOT EXISTS idx_items_type ON items(item_type)") conn.execute("CREATE INDEX IF NOT EXISTS idx_items_scan_time ON items(scan_time DESC)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_items_type_time ON items(item_type, scan_time DESC)") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_items_type_time ON items(item_type, scan_time DESC)" + ) # Progress table conn.execute(""" diff --git a/src/adapt/persistence/registry.py b/src/adapt/persistence/registry.py index 90cf1d8..e4c98a4 100644 --- a/src/adapt/persistence/registry.py +++ b/src/adapt/persistence/registry.py @@ -159,7 +159,9 @@ def _create_schema_inline(self) -> None: last_updated TEXT NOT NULL ) """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_radars_updated ON radars(last_updated DESC)") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_radars_updated ON radars(last_updated DESC)" + ) # Item types table conn.execute(""" @@ -219,11 +221,15 @@ def register_radar( conn = self._get_connection() with self._lock: - conn.execute(""" - INSERT OR REPLACE INTO radars - (radar, catalog_path, data_path, location_lat, location_lon, created_at, last_updated) + conn.execute( + """ + INSERT OR REPLACE INTO radars + (radar, catalog_path, data_path, + location_lat, location_lon, created_at, last_updated) VALUES (?, ?, ?, ?, ?, ?, ?) - """, (radar, catalog_path, data_path, lat, lon, now, now)) + """, + (radar, catalog_path, data_path, lat, lon, now, now), + ) conn.commit() logger.debug("Radar registered: %s at %s", radar, data_path) @@ -273,7 +279,8 @@ def ensure_radar_location(self, radar: str, lat: float, lon: float) -> None: return conn.execute( - "UPDATE radars SET location_lat = ?, location_lon = ?, last_updated = ? WHERE radar = ?", + "UPDATE radars SET location_lat = ?, location_lon = ?, " + "last_updated = ? WHERE radar = ?", (lat_f, lon_f, now, radar), ) conn.commit() @@ -346,11 +353,15 @@ def register_run( conn = self._get_connection() with self._lock: - conn.execute(""" - INSERT OR IGNORE INTO runs - (run_id, radar, start_time, status, mode, config_path, repository_version, created_at) + conn.execute( + """ + INSERT OR IGNORE INTO runs + (run_id, radar, start_time, status, mode, + config_path, repository_version, created_at) VALUES (?, ?, ?, 'running', ?, ?, ?, ?) - """, (run_id, radar, now, mode, config_path, repository_version, now)) + """, + (run_id, radar, now, mode, config_path, repository_version, now), + ) conn.commit() logger.debug("Run registered: %s for radar %s", run_id, radar) diff --git a/src/adapt/persistence/repository.py b/src/adapt/persistence/repository.py index cd8f01e..dc2dccf 100644 --- a/src/adapt/persistence/repository.py +++ b/src/adapt/persistence/repository.py @@ -699,7 +699,10 @@ def write_analysis2d_parquet( combined_df = pd.concat([existing_df, df], ignore_index=True) # Ensure datetime columns are properly typed (fix concat type coercion) - datetime_cols = ['time', 'scan_time', 'start_time', 'end_time', 'time_volume_start', 'time_volume_end'] + datetime_cols = [ + 'time', 'scan_time', 'start_time', 'end_time', + 'time_volume_start', 'time_volume_end', + ] for col in datetime_cols: if col in combined_df.columns: combined_df[col] = pd.to_datetime( @@ -709,7 +712,9 @@ def write_analysis2d_parquet( ).dt.tz_convert(None) table = pa.Table.from_pandas(combined_df) - logger.debug(f"Appended {len(df)} rows to {parquet_path} (schema evolution handled)") + logger.debug( + f"Appended {len(df)} rows to {parquet_path} (schema evolution handled)" + ) # Write or overwrite parquet file pq.write_table(table, parquet_path) diff --git a/src/adapt/persistence/track_store.py b/src/adapt/persistence/track_store.py index 74ab6b5..3bdb4c1 100644 --- a/src/adapt/persistence/track_store.py +++ b/src/adapt/persistence/track_store.py @@ -14,6 +14,7 @@ from __future__ import annotations +import contextlib import logging import sqlite3 import threading @@ -135,10 +136,13 @@ def write_scan( cell_uids = tracked_cells_df[uid_col].astype(str).unique().tolist() placeholders = ",".join("?" * len(cell_uids)) first_seen_rows = conn.execute( - f"SELECT cell_uid, first_seen_time FROM cell_tracks WHERE run_id=? AND cell_uid IN ({placeholders})", + "SELECT cell_uid, first_seen_time FROM cell_tracks " + f"WHERE run_id=? AND cell_uid IN ({placeholders})", [run_id] + cell_uids, ).fetchall() - first_seen_map: dict[str, str] = {r["cell_uid"]: r["first_seen_time"] for r in first_seen_rows} + first_seen_map: dict[str, str] = { + r["cell_uid"]: r["first_seen_time"] for r in first_seen_rows + } adjacency = self._build_uid_adjacency_summary( tracked_cells_df=tracked_cells_df, @@ -206,7 +210,8 @@ def get_cell_events(self, run_id: str, cell_uid: str | None = None) -> pd.DataFr ).fetchall() else: rows = conn.execute( - "SELECT * FROM cell_events WHERE run_id=? AND (source_cell_uid=? OR target_cell_uid=?) ORDER BY event_id", + "SELECT * FROM cell_events WHERE run_id=? " + "AND (source_cell_uid=? OR target_cell_uid=?) ORDER BY event_id", (run_id, cell_uid, cell_uid), ).fetchall() return pd.DataFrame([dict(r) for r in rows]) @@ -302,7 +307,9 @@ def _build_uid_adjacency_summary( lbl = int(r["cell_label"]) uid = str(r["cell_uid"]) if lbl in label_to_uid and label_to_uid[lbl] != uid: - raise ValueError(f"Non-unique mapping for cell_label={lbl}: {label_to_uid[lbl]} vs {uid}") + raise ValueError( + f"Non-unique mapping for cell_label={lbl}: {label_to_uid[lbl]} vs {uid}" + ) label_to_uid[lbl] = uid neighbors: dict[str, set[str]] = {uid: set() for uid in label_to_uid.values()} @@ -330,11 +337,9 @@ def _ensure_columns(self, conn: sqlite3.Connection, cell_stats_df: pd.DataFrame) if col in _SKIP_FROM_CELL_STATS or col in _FIXED_CBS_COLS or col in existing: continue sql_type = _infer_sql_type(col) - try: + with contextlib.suppress(sqlite3.OperationalError): conn.execute(f"ALTER TABLE cells_by_scan ADD COLUMN {col} {sql_type}") - #logger.info("cells_by_scan: added column %s %s", col, sql_type) - except sqlite3.OperationalError: - pass # race — column added concurrently + # logger.info("cells_by_scan: added column %s %s", col, sql_type) def _build_cells_rows( self, @@ -379,9 +384,12 @@ def _build_cells_rows( # Compute age_seconds from first_seen_time (0 for new initiations) age_seconds = 0.0 - if scan_dt is not None and cl not in initiated and first_seen_map and tid in first_seen_map: + if (scan_dt is not None and cl not in initiated + and first_seen_map and tid in first_seen_map): try: - first_dt = _dt.strptime(first_seen_map[tid], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=UTC) + first_dt = _dt.strptime( + first_seen_map[tid], "%Y-%m-%dT%H:%M:%SZ" + ).replace(tzinfo=UTC) age_seconds = max(0.0, (scan_dt - first_dt).total_seconds()) except ValueError: pass @@ -416,7 +424,9 @@ def _upsert_cells(self, conn: sqlite3.Connection, rows: list[dict]) -> None: cols = list(rows[0].keys()) placeholders = ", ".join("?" * len(cols)) col_list = ", ".join(cols) - update_set = ", ".join(f"{c}=excluded.{c}" for c in cols if c not in ("run_id", "scan_time", "cell_uid")) + update_set = ", ".join( + f"{c}=excluded.{c}" for c in cols if c not in ("run_id", "scan_time", "cell_uid") + ) sql = ( f"INSERT INTO cells_by_scan ({col_list}) VALUES ({placeholders}) " f"ON CONFLICT(run_id, scan_time, cell_uid) DO UPDATE SET {update_set}" @@ -454,7 +464,8 @@ def _update_retroactive_flags( def _update(flag: str, cell_uids: set) -> None: for tid in cell_uids: conn.execute( - f"UPDATE cells_by_scan SET {flag}=1 WHERE run_id=? AND scan_time=? AND cell_uid=?", + f"UPDATE cells_by_scan SET {flag}=1 " + "WHERE run_id=? AND scan_time=? AND cell_uid=?", (run_id, prev_iso, tid), ) @@ -548,7 +559,8 @@ def _upsert_cell_tracks( existing = { r["cell_uid"]: dict(r) for r in conn.execute( - "SELECT cell_uid, n_scans, max_area_sqkm, max_reflectivity FROM cell_tracks WHERE run_id=?", + "SELECT cell_uid, n_scans, max_area_sqkm, max_reflectivity " + "FROM cell_tracks WHERE run_id=?", (run_id,), ).fetchall() } @@ -592,8 +604,12 @@ def _upsert_cell_tracks( ON CONFLICT(run_id, cell_uid) DO UPDATE SET last_seen_time=excluded.last_seen_time, n_scans=cell_tracks.n_scans+1, - max_area_sqkm=MAX(COALESCE(cell_tracks.max_area_sqkm,0), excluded.max_area_sqkm), - max_reflectivity=MAX(COALESCE(cell_tracks.max_reflectivity,0), excluded.max_reflectivity)""", + max_area_sqkm=MAX( + COALESCE(cell_tracks.max_area_sqkm,0), excluded.max_area_sqkm + ), + max_reflectivity=MAX( + COALESCE(cell_tracks.max_reflectivity,0), excluded.max_reflectivity + )""", (run_id, tid, scan_iso, scan_iso, origin_type, origin_grp, origin_n, origin_parent, info["area"], info["refl"]), @@ -631,7 +647,10 @@ def _to_iso(dt: datetime) -> str: def _infer_sql_type(col: str) -> str: col_l = col.lower() - if any(col_l.endswith(s) for s in ("_lat", "_lon", "_mean", "_max", "_min", "_sqkm", "_km2", "_std", "_p25", "_p75")): + _real_suffixes = ( + "_lat", "_lon", "_mean", "_max", "_min", "_sqkm", "_km2", "_std", "_p25", "_p75" + ) + if any(col_l.endswith(s) for s in _real_suffixes): return "REAL" if any(col_l.endswith(s) for s in ("_x", "_y", "_count", "_pixels", "_index")): return "INTEGER" diff --git a/src/adapt/runtime/file_tracker.py b/src/adapt/runtime/file_tracker.py index c0c642d..70e7f84 100644 --- a/src/adapt/runtime/file_tracker.py +++ b/src/adapt/runtime/file_tracker.py @@ -7,6 +7,7 @@ Enables idempotent processing with stop/restart, progress tracking, and failure recovery. """ +import contextlib import logging import sqlite3 import threading @@ -134,9 +135,18 @@ def _init_database(self): ) """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_radar_file_processing_radar_id ON radar_file_processing(radar)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_radar_file_processing_status ON radar_file_processing(status)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_radar_file_processing_scan_time ON radar_file_processing(scan_time)") + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_radar_file_processing_radar_id " + "ON radar_file_processing(radar)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_radar_file_processing_status " + "ON radar_file_processing(status)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_radar_file_processing_scan_time " + "ON radar_file_processing(scan_time)" + ) conn.commit() @@ -152,10 +162,8 @@ def _migrate_database(self): conn = self._get_connection() with self._lock: for col_def in timing_cols: - try: + with contextlib.suppress(sqlite3.OperationalError): conn.execute(f"ALTER TABLE radar_file_processing ADD COLUMN {col_def}") - except sqlite3.OperationalError: - pass # column already exists conn.commit() def register_file(self, file_id: str, radar: str, scan_time: datetime, @@ -193,7 +201,9 @@ def register_file(self, file_id: str, radar: str, scan_time: datetime, with self._lock: # Check if already exists - cursor = conn.execute("SELECT file_id FROM radar_file_processing WHERE file_id = ?", (file_id,)) + cursor = conn.execute( + "SELECT file_id FROM radar_file_processing WHERE file_id = ?", (file_id,) + ) if cursor.fetchone(): return False diff --git a/src/adapt/runtime/orchestrator.py b/src/adapt/runtime/orchestrator.py index 54796b2..b48da92 100644 --- a/src/adapt/runtime/orchestrator.py +++ b/src/adapt/runtime/orchestrator.py @@ -294,7 +294,9 @@ def _main_loop(self, mode: str): # 2. Check for thread failures or self-stops (e.g. ContractViolation) if self.processor.stopped(): - logger.critical("Processor has stopped (likely due to contract violation). Exiting.") + logger.critical( + "Processor has stopped (likely due to contract violation). Exiting." + ) break if not self.processor.is_alive(): diff --git a/src/adapt/runtime/processor.py b/src/adapt/runtime/processor.py index 7db853f..709ba25 100644 --- a/src/adapt/runtime/processor.py +++ b/src/adapt/runtime/processor.py @@ -446,7 +446,8 @@ def _save_results(self, result: dict, scan_time): Saves: - projected_ds as NetCDF artifact - cell_stats, cell_adjacency as Parquet artifacts - - tracked_cells, cell_events as SQLite via TrackStore (label→uid adjacency mapping in TrackStore) + - tracked_cells, cell_events as SQLite via TrackStore + (label→uid adjacency mapping in TrackStore) """ if scan_time is not None and scan_time.tzinfo is None: scan_time = scan_time.replace(tzinfo=UTC) diff --git a/src/adapt/visualization/plotter.py b/src/adapt/visualization/plotter.py index 1ed39e0..70e626e 100644 --- a/src/adapt/visualization/plotter.py +++ b/src/adapt/visualization/plotter.py @@ -26,7 +26,6 @@ try: import contextily as ctx - from pyproj import Transformer CONTEXTILY_AVAILABLE = True except ImportError: CONTEXTILY_AVAILABLE = False @@ -154,14 +153,13 @@ def _get_var_name(self, var_key: str, default: str) -> str: def _get_coord_name(self, coord_key: str, default: str) -> str: """Get coordinate name from config.""" if self.config: - if coord_key == "x": - return self.config.global_.coord_names.x - elif coord_key == "y": - return self.config.global_.coord_names.y - elif coord_key == "z": - return self.config.global_.coord_names.z - elif coord_key == "time": - return self.config.global_.coord_names.time + coord_map = { + "x": self.config.global_.coord_names.x, + "y": self.config.global_.coord_names.y, + "z": self.config.global_.coord_names.z, + "time": self.config.global_.coord_names.time, + } + return coord_map.get(coord_key, default) return default def _extract_timestamp(self, ds: xr.Dataset) -> datetime: @@ -246,7 +244,9 @@ def extract_float(val): return lat, lon - def _add_basemap(self, ax: plt.Axes, ds: xr.Dataset, x_coords: np.ndarray, y_coords: np.ndarray) -> None: + def _add_basemap( + self, ax: plt.Axes, ds: xr.Dataset, x_coords: np.ndarray, y_coords: np.ndarray + ) -> None: """Add OpenStreetMap basemap to axis.""" if not self.use_basemap or not CONTEXTILY_AVAILABLE: return @@ -255,7 +255,10 @@ def _add_basemap(self, ax: plt.Axes, ds: xr.Dataset, x_coords: np.ndarray, y_coo radar_lat, radar_lon = self._get_radar_location(ds) # Set CRS for azimuthal equidistant (km units) - crs_str = f"+proj=aeqd +lat_0={radar_lat} +lon_0={radar_lon} +x_0=0 +y_0=0 +datum=WGS84 +units=km" + crs_str = ( + f"+proj=aeqd +lat_0={radar_lat} +lon_0={radar_lon} " + "+x_0=0 +y_0=0 +datum=WGS84 +units=km" + ) ax.set_xlim(x_coords.min(), x_coords.max()) ax.set_ylim(y_coords.min(), y_coords.max()) @@ -339,7 +342,10 @@ def _plot_heading_yectors( zorder=45 ) - logger.info(f"Plotted optical flow field ({len(y_indices)}x{len(x_indices)} vectors, scale={self.flow_scale})") + logger.info( + f"Plotted optical flow field ({len(y_indices)}x{len(x_indices)} vectors, " + f"scale={self.flow_scale})" + ) return True def _plot_segmentation_contours( @@ -599,7 +605,9 @@ def plot_reflectivity_with_cells( plt.tight_layout() if output_path is None: - output_path = Path(f"/tmp/radar_plot_{timestamp.strftime('%Y%m%d_%H%M%S')}.{self.output_format}") + output_path = Path( + f"/tmp/radar_plot_{timestamp.strftime('%Y%m%d_%H%M%S')}.{self.output_format}" + ) return self._save_figure(fig, Path(output_path)) @@ -821,7 +829,10 @@ def _process_item(self, item: dict): tracker = self.file_tracker if tracker: - file_id = Path(item.get('segmentation_nc', '')).stem.replace('_analysis', '').replace('_segmentation', '') + file_id = ( + Path(item.get('segmentation_nc', '')).stem + .replace('_analysis', '').replace('_segmentation', '') + ) if file_id: tracker.mark_stage_complete(file_id, "plotted", error=str(e)) @@ -1093,7 +1104,9 @@ def _print_table_stats(self): if cols_available: display_df = recent[cols_available].copy() - display_df.columns = ['Label', 'Area (km2)', 'Mean dBZ', 'Max dBZ'][:len(cols_available)] + display_df.columns = ( + ['Label', 'Area (km2)', 'Mean dBZ', 'Max dBZ'][:len(cols_available)] + ) print(display_df.to_string(index=False)) print("=" * 60 + "\n") diff --git a/tests/modules/tracking/test_tracker_scan_local_outputs.py b/tests/modules/tracking/test_tracker_scan_local_outputs.py index 768ba2b..3501b12 100644 --- a/tests/modules/tracking/test_tracker_scan_local_outputs.py +++ b/tests/modules/tracking/test_tracker_scan_local_outputs.py @@ -88,7 +88,9 @@ def test_birth_and_continue_events(tracker): labels1 = np.zeros((6, 6), dtype=np.int32) labels1[2:4, 2:4] = 1 ds1 = _synthetic_ds(t1, labels1) - stats1 = _cell_stats(t1, [{"id": 1, "area": 4.0, "cx": 2.5, "cy": 2.5, "mean_refl": 40.0, "max_refl": 45.0}]) + stats1 = _cell_stats( + t1, [{"id": 1, "area": 4.0, "cx": 2.5, "cy": 2.5, "mean_refl": 40.0, "max_refl": 45.0}] + ) tracked1, events1 = tracker.track(ds1, stats1) assert len(tracked1) == 1 @@ -99,7 +101,9 @@ def test_birth_and_continue_events(tracker): labels2 = labels1.copy() ds2 = _synthetic_ds(t2, labels2, proj_labels=labels1) - stats2 = _cell_stats(t2, [{"id": 1, "area": 4.0, "cx": 2.5, "cy": 2.5, "mean_refl": 40.0, "max_refl": 45.0}]) + stats2 = _cell_stats( + t2, [{"id": 1, "area": 4.0, "cx": 2.5, "cy": 2.5, "mean_refl": 40.0, "max_refl": 45.0}] + ) tracked2, events2 = tracker.track(ds2, stats2) assert len(tracked2) == 1 @@ -117,7 +121,9 @@ def test_split_event(tracker): labels1 = np.zeros((8, 8), dtype=np.int32) labels1[3:5, 2:6] = 1 ds1 = _synthetic_ds(t1, labels1) - stats1 = _cell_stats(t1, [{"id": 1, "area": 8.0, "cx": 3.5, "cy": 3.5, "mean_refl": 40.0, "max_refl": 45.0}]) + stats1 = _cell_stats( + t1, [{"id": 1, "area": 8.0, "cx": 3.5, "cy": 3.5, "mean_refl": 40.0, "max_refl": 45.0}] + ) tracker.track(ds1, stats1) labels2 = np.zeros((8, 8), dtype=np.int32) @@ -162,7 +168,9 @@ def test_merge_event_emits_death(tracker): proj[4:6, 2:4] = 1 proj[4:6, 6:8] = 2 ds2 = _synthetic_ds(t2, labels2, proj_labels=proj) - stats2 = _cell_stats(t2, [{"id": 1, "area": 8.0, "cx": 4.5, "cy": 4.5, "mean_refl": 45.0, "max_refl": 50.0}]) + stats2 = _cell_stats( + t2, [{"id": 1, "area": 8.0, "cx": 4.5, "cy": 4.5, "mean_refl": 45.0, "max_refl": 50.0}] + ) tracked2, events2 = tracker.track(ds2, stats2) assert len(tracked2) == 1 diff --git a/tests/persistence/test_track_store.py b/tests/persistence/test_track_store.py index a62194e..43b7fcb 100644 --- a/tests/persistence/test_track_store.py +++ b/tests/persistence/test_track_store.py @@ -131,7 +131,9 @@ def _tracked_cells(cell_label: int, cell_uid: str) -> pd.DataFrame: }]) def _empty_cell_adjacency() -> pd.DataFrame: - return pd.DataFrame(columns=["time", "cell_label_a", "cell_label_b", "touching_boundary_pixels"]) + return pd.DataFrame( + columns=["time", "cell_label_a", "cell_label_b", "touching_boundary_pixels"] + ) def _initiation_event(scan_time: datetime, cell_uid: str, cell_label: int) -> pd.DataFrame: ts = scan_time.strftime("%Y-%m-%dT%H:%M:%SZ") @@ -148,7 +150,9 @@ def _initiation_event(scan_time: datetime, cell_uid: str, cell_label: int) -> pd }]) -def _continue_event(scan_time: datetime, cell_uid: str, src_label: int, tgt_label: int) -> pd.DataFrame: +def _continue_event( + scan_time: datetime, cell_uid: str, src_label: int, tgt_label: int +) -> pd.DataFrame: ts = scan_time.strftime("%Y-%m-%dT%H:%M:%SZ") return pd.DataFrame([{ "time": scan_time, @@ -178,7 +182,9 @@ def _termination_event(scan_time: datetime, cell_uid: str, cell_label: int) -> p }]) -def _split_events(scan_time: datetime, parent_id: str, child_id: str, parent_label: int, child_label: int) -> pd.DataFrame: +def _split_events( + scan_time: datetime, parent_id: str, child_id: str, parent_label: int, child_label: int +) -> pd.DataFrame: ts = scan_time.strftime("%Y-%m-%dT%H:%M:%SZ") return pd.DataFrame([{ "time": scan_time, @@ -193,7 +199,9 @@ def _split_events(scan_time: datetime, parent_id: str, child_id: str, parent_lab }]) -def _merge_events(scan_time: datetime, src_id: str, tgt_id: str, src_label: int, tgt_label: int) -> pd.DataFrame: +def _merge_events( + scan_time: datetime, src_id: str, tgt_id: str, src_label: int, tgt_label: int +) -> pd.DataFrame: ts = scan_time.strftime("%Y-%m-%dT%H:%M:%SZ") return pd.DataFrame([{ "time": scan_time, @@ -365,7 +373,7 @@ def test_unique_constraint_rejects_duplicate_cell_label_per_scan(store): _initiation_event(t, "NN", 1), _empty_cell_adjacency()) # Writing same cell_label with different cell_uid should raise on unique(cell_label) - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017 — sqlite3.IntegrityError on unique constraint store.write_scan("r1", t, _cell_stats(1), _tracked_cells(1, "OO"), _initiation_event(t, "OO", 1), _empty_cell_adjacency()) From e1467c0e827d2ad3b22827ab4e90e354090b61a0 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Fri, 1 May 2026 15:42:46 -0500 Subject: [PATCH 04/14] MOD: --version shows correct version and location --- src/adapt/cli.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/adapt/cli.py b/src/adapt/cli.py index ce9bf15..6216bd2 100644 --- a/src/adapt/cli.py +++ b/src/adapt/cli.py @@ -27,6 +27,8 @@ import time from pathlib import Path +from adapt import __version__ + # --------------------------------------------------------------------------- # Single-instance enforcement # --------------------------------------------------------------------------- @@ -356,6 +358,15 @@ def main() -> None: 'of ARM weather radars.' ), ) + + # Add version argument + adapt_module_path = Path(__file__).parent + parser.add_argument( + '--version', + action='version', + version=f'%(prog)s {__version__}\nInstalled at: {adapt_module_path}', + ) + subparsers = parser.add_subparsers(dest='command', metavar='COMMAND') subparsers.required = True From 25a96edc1204f6691396eb76b6e08a8cc301d544 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Fri, 1 May 2026 15:44:10 -0500 Subject: [PATCH 05/14] REF: module contracts moved to adapt.contracts --- src/adapt/modules/analysis/contracts.py | 116 -------- src/adapt/modules/analysis/module.py | 3 +- src/adapt/modules/base.py | 49 +--- src/adapt/modules/detection/contracts.py | 62 ----- src/adapt/modules/detection/module.py | 4 +- src/adapt/modules/ingest/contracts.py | 55 ---- src/adapt/modules/ingest/module.py | 3 +- src/adapt/modules/projection/contracts.py | 68 ----- src/adapt/modules/projection/module.py | 48 +--- src/adapt/modules/tracking/contracts.py | 89 ------ src/adapt/modules/tracking/module.py | 8 +- src/adapt/runtime/processor.py | 325 ++++++++-------------- 12 files changed, 134 insertions(+), 696 deletions(-) delete mode 100644 src/adapt/modules/analysis/contracts.py delete mode 100644 src/adapt/modules/detection/contracts.py delete mode 100644 src/adapt/modules/ingest/contracts.py delete mode 100644 src/adapt/modules/projection/contracts.py delete mode 100644 src/adapt/modules/tracking/contracts.py diff --git a/src/adapt/modules/analysis/contracts.py b/src/adapt/modules/analysis/contracts.py deleted file mode 100644 index 959236b..0000000 --- a/src/adapt/modules/analysis/contracts.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright © 2026, UChicago Argonne, LLC -# See LICENSE for terms and disclaimer. - -"""Output contracts for the analysis module. - -The analysis module produces cell statistics dataframes. This module defines -the contract that validates the analysis output structure. - -Enforces the guarantee that after cell analysis, output contains -required columns and fields are well-formed (no spurious NaNs in required fields). -""" - -import pandas as pd - -from adapt.modules.base import require - - -def assert_analysis_output(df: pd.DataFrame, min_expected_rows: int = 0) -> None: - """Enforce analysis stage contract. - - Called after analyzer.extract(). Verifies that the output DataFrame - has required columns and data is well-formed. - - We do NOT validate the scientific correctness of statistics — that's - the analyzer's responsibility. We only check structural requirements. - - Parameters - ---------- - df : pd.DataFrame - Output from analyzer.extract() - - min_expected_rows : int, optional - Minimum number of rows expected (default 0, allows no-cell frames) - - Raises - ------ - ContractViolation - If structural requirements are violated - """ - require( - isinstance(df, pd.DataFrame), - f"Analysis contract violated: output is {type(df)}, expected DataFrame" - ) - - # Required columns - required_cols = [ - "cell_label", - "cell_area_sqkm", - "time", - "time_volume_start", - "cell_centroid_mass_lat", - "cell_centroid_mass_lon", - "radar_reflectivity_max", - "radar_differential_reflectivity_max", - "area_40dbz_km2", - ] - - for col in required_cols: - require( - col in df.columns, - f"Analysis contract violated: missing required column '{col}'" - ) - - # If there are cells, verify they have valid labels - if len(df) > 0: - require( - (df["cell_label"] > 0).all(), - "Analysis contract violated: cell_label must be > 0 for all rows" - ) - - # Verify minimum rows if specified - require( - len(df) >= min_expected_rows, - f"Analysis contract violated: got {len(df)} cells, expected >= {min_expected_rows}" - ) - - -def assert_cell_adjacency(df: pd.DataFrame) -> None: - """Enforce analysis adjacency contract. - - The analysis module may produce a scan-local adjacency table describing - direct boundary touching between labeled cells. - """ - require( - isinstance(df, pd.DataFrame), - f"Cell adjacency contract violated: output is {type(df)}, expected DataFrame" - ) - - required_cols = [ - "time", - "cell_label_a", - "cell_label_b", - "touching_boundary_pixels", - ] - - for col in required_cols: - require( - col in df.columns, - f"Cell adjacency contract violated: missing required column '{col}'" - ) - - if len(df) == 0: - return - - require( - (df["cell_label_a"] > 0).all() and (df["cell_label_b"] > 0).all(), - "Cell adjacency contract violated: cell labels must be > 0" - ) - require( - (df["cell_label_a"] < df["cell_label_b"]).all(), - "Cell adjacency contract violated: expected canonical ordering cell_label_a < cell_label_b" - ) - require( - (df["touching_boundary_pixels"] >= 1).all(), - "Cell adjacency contract violated: touching_boundary_pixels must be >= 1" - ) diff --git a/src/adapt/modules/analysis/module.py b/src/adapt/modules/analysis/module.py index af525f5..a432476 100644 --- a/src/adapt/modules/analysis/module.py +++ b/src/adapt/modules/analysis/module.py @@ -666,11 +666,10 @@ def get_lat_lon(ix, iy, lat_grid, lon_grid): # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- +from adapt.contracts import assert_analysis_output, assert_cell_adjacency # noqa: E402 from adapt.execution.module_registry import registry # noqa: E402 from adapt.modules.base import BaseModule # noqa: E402 -from .contracts import assert_analysis_output, assert_cell_adjacency # noqa: E402 - def _check_cell_stats(df): assert_analysis_output(df) diff --git a/src/adapt/modules/base.py b/src/adapt/modules/base.py index 5f64ea1..39752bb 100644 --- a/src/adapt/modules/base.py +++ b/src/adapt/modules/base.py @@ -15,54 +15,7 @@ from abc import ABC, abstractmethod from typing import ClassVar -# ──────────────────────────────────────────────────────────────────────────── -# Contract Enforcement Infrastructure -# ──────────────────────────────────────────────────────────────────────────── - - -class ContractViolation(RuntimeError): - """Raised when a pipeline contract is violated. - - This indicates a bug in pipeline logic, not bad user input or recoverable - science edge cases. It means a pipeline stage did not produce the invariants - it promised. - - Key distinction: - - ValueError: User/config error (handled by Pydantic) - - ContractViolation: Pipeline bug (programmer error) - - Exception: Recoverable science issues (try/except in algorithms) - """ - pass - - -def require(condition: bool, message: str) -> None: - """Enforce a pipeline contract. - - This is called at stage boundaries to verify the preceding stage - produced the guaranteed invariants. It is fail-fast: no recovery, - no fallback, no silence. - - Parameters - ---------- - condition : bool - The invariant that must be true. If False, ContractViolation is raised. - - message : str - Error message explaining the contract violation (for debugging). - - Raises - ------ - ContractViolation - If condition is False. This indicates a bug in pipeline logic. - - Examples - -------- - >>> require("x" in ds.coords, "Grid contract: missing 'x' coordinate") - >>> require(df.shape[0] > 0, "Analysis contract: at least one cell expected") - """ - if not condition: - raise ContractViolation(message) - +from adapt.contracts import ContractViolation, require # noqa: F401 — re-exported for callers # ──────────────────────────────────────────────────────────────────────────── # BaseModule Interface diff --git a/src/adapt/modules/detection/contracts.py b/src/adapt/modules/detection/contracts.py deleted file mode 100644 index 5a87aed..0000000 --- a/src/adapt/modules/detection/contracts.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright © 2026, UChicago Argonne, LLC -# See LICENSE for terms and disclaimer. - -"""Output contracts for the detection module. - -The detection module produces segmented cell labels. This module defines -the contract that validates the segmentation output structure. - -Enforces the guarantee that after segmentation, cell labels are present, -properly typed, and in canonical form (largest cells first). -""" - -import numpy as np -import xarray as xr - -from adapt.modules.base import require - - -def assert_segmented(ds: xr.Dataset, labels_name: str) -> None: - """Enforce segmentation stage contract. - - Called immediately after segmentation. Verifies that the segmenter - produced valid, typed, labeled output. - - Parameters - ---------- - ds : xr.Dataset - Dataset from segmenter.segment() - - labels_name : str - Name of cell labels variable (from config) - - Raises - ------ - ContractViolation - If any invariant is violated - """ - require( - labels_name in ds.data_vars, - f"Segmentation contract violated: '{labels_name}' not found" - ) - - labels = ds[labels_name] - - # Verify type - require( - labels.dtype.kind in {"i", "u"}, - f"Segmentation contract violated: '{labels_name}' dtype is {labels.dtype}, expected integer" - ) - - # Verify range: background=0, cells=1..N - label_vals = labels.values - require( - np.min(label_vals) >= 0, - f"Segmentation contract violated: labels contain negative values (min={np.min(label_vals)})" - ) - - # Verify 2D shape (consistent with grid) - require( - labels.ndim == 2, - f"Segmentation contract violated: '{labels_name}' has {labels.ndim} dims, expected 2" - ) diff --git a/src/adapt/modules/detection/module.py b/src/adapt/modules/detection/module.py index ea75bff..c3d1dcb 100644 --- a/src/adapt/modules/detection/module.py +++ b/src/adapt/modules/detection/module.py @@ -370,11 +370,9 @@ def _relabel_by_size( # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- +from adapt.contracts import assert_gridded, assert_segmented # noqa: E402 from adapt.execution.module_registry import registry # noqa: E402 from adapt.modules.base import BaseModule # noqa: E402 -from adapt.modules.ingest.contracts import assert_gridded # noqa: E402 - -from .contracts import assert_segmented # noqa: E402 def _check_grid_ds_2d(ds): diff --git a/src/adapt/modules/ingest/contracts.py b/src/adapt/modules/ingest/contracts.py deleted file mode 100644 index 666b9ec..0000000 --- a/src/adapt/modules/ingest/contracts.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright © 2026, UChicago Argonne, LLC -# See LICENSE for terms and disclaimer. - -"""Output contracts for the ingest module. - -The ingest module produces gridded radar data. This module defines the -contract that validates the gridded dataset structure. - -Enforces the guarantee that after regridding, the dataset is valid -for downstream scientific processing. -""" - -import xarray as xr - -from adapt.modules.base import require - - -def assert_gridded(ds: xr.Dataset, reflectivity_var: str) -> None: - """Enforce grid stage contract. - - Called immediately after regridding. Verifies that the loader/regridder - produced a valid Cartesian grid. - - Parameters - ---------- - ds : xr.Dataset - Dataset from loader.load_and_regrid() - - reflectivity_var : str - Name of reflectivity variable (from config) - - Raises - ------ - ContractViolation - If any invariant is violated - """ - require( - "x" in ds.coords, - "Grid contract violated: missing 'x' coordinate" - ) - require( - "y" in ds.coords, - "Grid contract violated: missing 'y' coordinate" - ) - require( - reflectivity_var in ds.data_vars, - f"Grid contract violated: missing '{reflectivity_var}' variable" - ) - - # Verify 2D structure (should be sliced at z-level already) - refl = ds[reflectivity_var] - require( - refl.ndim == 2, - f"Grid contract violated: '{reflectivity_var}' has {refl.ndim} dims, expected 2" - ) diff --git a/src/adapt/modules/ingest/module.py b/src/adapt/modules/ingest/module.py index 51b4c3e..7d81160 100644 --- a/src/adapt/modules/ingest/module.py +++ b/src/adapt/modules/ingest/module.py @@ -356,11 +356,10 @@ def load_and_regrid(self, filepath: Path | str, grid_kwargs: dict = None, import xarray as _xr # noqa: E402 from adapt.configuration.schemas.directories import get_netcdf_path # noqa: E402 +from adapt.contracts import assert_gridded # noqa: E402 from adapt.execution.module_registry import registry # noqa: E402 from adapt.modules.base import BaseModule # noqa: E402 -from .contracts import assert_gridded # noqa: E402 - def _check_grid_ds_2d(ds): assert_gridded(ds, "reflectivity") diff --git a/src/adapt/modules/projection/contracts.py b/src/adapt/modules/projection/contracts.py deleted file mode 100644 index 0c5f57f..0000000 --- a/src/adapt/modules/projection/contracts.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright © 2026, UChicago Argonne, LLC -# See LICENSE for terms and disclaimer. - -"""Output contracts for the projection module. - -The projection module produces cell projections with motion vectors. This -module defines the contract that validates the projection output structure. - -Enforces the guarantee that when projections are computed (2+ frames), -the flow fields and projection arrays are present and well-formed. -""" - -import xarray as xr - -from adapt.modules.base import require - - -def assert_projected(ds: xr.Dataset, max_steps: int = 5) -> None: - """Enforce projection stage contract. - - Called after projection computation (when 2+ frames available). - Verifies that optical flow and projected labels are present and that - projection count matches runtime config (read from dataset attributes). - - Parameters - ---------- - ds : xr.Dataset - Dataset from projector.project() - - max_steps : int, optional - Maximum number of projection steps (default 5). If dataset has - 'max_projection_steps' in attrs, that value is used instead. - This enables config-aware validation without breaking validator isolation. - - Raises - ------ - ContractViolation - If any invariant is violated - """ - require( - "heading_x" in ds.data_vars, - "Projection contract violated: missing 'heading_x' " - ) - require( - "heading_y" in ds.data_vars, - "Projection contract violated: missing 'heading_y' " - ) - - # If projections are included, verify their structure - if "cell_projections" in ds.data_vars: - projections = ds["cell_projections"] - require( - projections.ndim == 3, - f"Projection contract violated: 'cell_projections' has {projections.ndim} dims, " - "expected 3 (step, y, x)", - ) - - # Use stored config value if available (self-describing data pattern) - # This allows validators to access runtime config without context coupling - max_steps_actual = ds.attrs.get("max_projection_steps", max_steps) - - num_steps = projections.shape[0] - expected_steps = max_steps_actual + 1 # 1 registration + N future - require( - num_steps == expected_steps, - f"Projection contract violated: found {num_steps} steps, expected {expected_steps} " - f"(1 registration + {max_steps_actual} projections from config)" - ) diff --git a/src/adapt/modules/projection/module.py b/src/adapt/modules/projection/module.py index 1242a28..b0d1f87 100644 --- a/src/adapt/modules/projection/module.py +++ b/src/adapt/modules/projection/module.py @@ -569,9 +569,9 @@ def _fill_concave_hull(self, label_mask, alpha=0.1): # BaseModule wrapper — Step 6 # --------------------------------------------------------------------------- +from adapt.contracts import assert_segmented # noqa: E402 from adapt.execution.module_registry import registry # noqa: E402 from adapt.modules.base import BaseModule # noqa: E402 -from adapt.modules.detection.contracts import assert_segmented # noqa: E402 def _check_segmented_ds(ds): @@ -582,67 +582,47 @@ class ProjectionModule(BaseModule): """BaseModule wrapper for RadarCellProjector. Computes optical flow between consecutive radar frames and projects - cell positions forward in time. Maintains a rolling frame history - as instance state so it persists across files. + cell positions forward in time. Stateless: receives the frame pair + via the context key ``dataset_history`` (injected by the processor). Context inputs -------------- segmented_ds : xr.Dataset - 2D segmented dataset (output of DetectModule). - nexrad_file : str - Current file path (used as history key). + 2D segmented dataset for the current frame (output of DetectModule). + dataset_history : list of (str, xr.Dataset) + Rolling history of (filepath, segmented_ds) tuples supplied by the + processor. Must contain exactly 2 entries before this module is called. config : InternalConfig Runtime configuration. Context outputs --------------- projected_ds : xr.Dataset - 2D dataset with projection fields added (or original if <2 frames). + 2D dataset with heading_x, heading_y, and cell_projections added. """ name = "projection" - inputs = ["segmented_ds", "nexrad_file", "config"] + inputs = ["segmented_ds", "dataset_history", "config"] outputs = ["projected_ds"] input_contracts = {"segmented_ds": _check_segmented_ds} - # Output: projected_ds only present when 2+ frames available and projection succeeds - # Raises exception if time gap too large or computation fails - # Returns context unchanged (no projected_ds) if insufficient history (< 2 frames) def __init__(self) -> None: self._projector = None - self._dataset_history = [] # list of (filepath, ds_2d) tuples def run(self, context: dict) -> dict: config = context["config"] - ds_2d = context["segmented_ds"] - filepath = context["nexrad_file"] + dataset_history = context["dataset_history"] # list of (filepath, ds_2d) if self._projector is None: self._projector = RadarCellProjector(config) - # Note: Processor orchestration injects history directly into - # self._dataset_history before calling pipeline. For standalone - # testing, we still support building history internally. - max_history = config.processor.max_history - - # If history is empty or doesn't contain current file, build internally - if not self._dataset_history or self._dataset_history[-1][0] != filepath: - # Build history internally (standalone mode) - self._dataset_history.append((filepath, ds_2d)) - if len(self._dataset_history) > max_history: - self._dataset_history.pop(0) - logger.debug("Building history internally (%d frames)", len(self._dataset_history)) - - # Must have 2 frames (guaranteed by processor orchestration, but double-check) - if len(self._dataset_history) < 2: + if len(dataset_history) < 2: raise ValueError( - f"ProjectionModule requires 2 frames, but only " - f"{len(self._dataset_history)} available. " - "Processor should orchestrate frame pairing before calling projection." + f"ProjectionModule requires 2 frames in dataset_history, " + f"got {len(dataset_history)}. Processor must pair frames before calling." ) - # Compute optical flow - ds_list = [ds for _, ds in self._dataset_history] + ds_list = [ds for _, ds in dataset_history] projected = self._projector.project(ds_list) return {"projected_ds": projected} diff --git a/src/adapt/modules/tracking/contracts.py b/src/adapt/modules/tracking/contracts.py deleted file mode 100644 index 5aba26e..0000000 --- a/src/adapt/modules/tracking/contracts.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright © 2026, UChicago Argonne, LLC -# See LICENSE for terms and disclaimer. - -"""Contracts for the tracking module outputs. - -Tracking emits scan-local outputs for persistence and downstream consumers: -- tracked_cells: one row per cell observation assigned to a track in the current scan -- cell_events: one row per lineage/event edge (continue/split/merge/initiation/termination) - -A "track" is a single connected chain of cell observations across scans identified by a -stable cell_uid. -""" - -from __future__ import annotations - -import pandas as pd - -from adapt.modules.base import require - - -def assert_tracked_cells(df: pd.DataFrame) -> None: - require( - isinstance(df, pd.DataFrame), - f"Tracked cells contract violated: output is {type(df)}, expected DataFrame", - ) - - required_cols = [ - "time", - "cell_label", - "cell_uid", - "area", - "centroid_x", - "centroid_y", - "mean_reflectivity", - "max_reflectivity", - "core_area", - ] - - for col in required_cols: - require( - col in df.columns, - f"Tracked cells contract violated: missing required column '{col}'", - ) - - if len(df) == 0: - return - - require( - (df["cell_label"] > 0).all(), - "Tracked cells contract violated: cell_label must be > 0 for all rows", - ) - require( - "cell_uid" in df.columns and df["cell_uid"].notna().all(), - "Tracked cells contract violated: cell_uid must be non-null for all rows", - ) - - -def assert_cell_events(df: pd.DataFrame) -> None: - require( - isinstance(df, pd.DataFrame), - f"Cell events contract violated: output is {type(df)}, expected DataFrame", - ) - - required_cols = [ - "time", - "event_type", - "source_cell_uid", - "target_cell_uid", - "source_cell_label", - "target_cell_label", - "cost", - "is_dominant", - "event_group_id", - ] - - for col in required_cols: - require( - col in df.columns, - f"Cell events contract violated: missing required column '{col}'", - ) - - if len(df) == 0: - return - - valid = {"CONTINUE", "SPLIT", "MERGE", "INITIATION", "TERMINATION"} - require( - df["event_type"].isin(valid).all(), - f"Cell events contract violated: invalid event_type present (valid={sorted(valid)})", - ) diff --git a/src/adapt/modules/tracking/module.py b/src/adapt/modules/tracking/module.py index 060c30b..a1dcf40 100644 --- a/src/adapt/modules/tracking/module.py +++ b/src/adapt/modules/tracking/module.py @@ -930,11 +930,13 @@ def _event_termination(self, time, source_node_id: int, target_node_id: int | No # BaseModule wrapper (Phase 6 implementation placeholder) # ============================================================================= +from adapt.contracts import ( # noqa: E402 + assert_cell_events, + assert_projected, + assert_tracked_cells, +) from adapt.execution.module_registry import registry # noqa: E402 from adapt.modules.base import BaseModule # noqa: E402 -from adapt.modules.projection.contracts import assert_projected # noqa: E402 - -from .contracts import assert_cell_events, assert_tracked_cells # noqa: E402 def _check_projected_ds(ds: xr.Dataset) -> None: diff --git a/src/adapt/runtime/processor.py b/src/adapt/runtime/processor.py index 709ba25..9ed6394 100644 --- a/src/adapt/runtime/processor.py +++ b/src/adapt/runtime/processor.py @@ -4,13 +4,18 @@ """Radar data processor thread. Reads NEXRAD file paths from the downloader queue and delegates all -scientific processing to NexradPipeline (the graph-based execution engine). -After each file the segmentation NetCDF is saved to the repository. +scientific processing to two GraphExecutors built at startup: + +- ``_single_executor``: ingest + detection (runs every file) +- ``_multi_executor``: projection + analysis + tracking (runs when 2-frame + pair is ready) Responsibilities of this class (orchestration only): - Queue management: pop filepath, mark task done - File deduplication via FileProcessingTracker -- NetCDF persistence after graph run +- Frame pairing: accumulate segmented history, validate time gap +- Context assembly: inject dataset_history before calling multi-executor +- NetCDF + Parquet persistence after graph run - Stop/start lifecycle """ @@ -24,7 +29,11 @@ import pandas as pd -from adapt.modules.base import ContractViolation +from adapt.contracts import ContractViolation +from adapt.execution.graph.builder import GraphBuilder +from adapt.execution.graph.executor import GraphExecutor +from adapt.execution.module_registry import registry +from adapt.execution.pipeline_builder import _ensure_modules_registered from adapt.persistence import DataRepository, ProductType from adapt.persistence.track_store import TrackStore from adapt.persistence.writer import RepositoryWriter @@ -38,16 +47,18 @@ class RadarProcessor(threading.Thread): - """Worker thread that processes NEXRAD files through the execution graph. + """Worker thread that processes NEXRAD files through two execution graphs. + + Receives file paths from the downloader queue. For each file: - Receives file paths from the downloader queue and runs them through - ``NexradPipeline``, which executes the module DAG (ingest → detection → - projection → analysis → tracking). Scientific module instances inside the pipeline - persist across files so stateful modules (e.g. ProjectionModule frame - history) work correctly. + 1. Runs the single-frame graph (ingest → detection) via ``_single_executor``. + 2. Accumulates segmented datasets in a rolling history. + 3. When a valid 2-frame pair is ready, runs the multi-frame graph + (projection → analysis → tracking) via ``_multi_executor``, + passing the frame history in context. - After the graph runs, this class saves the projected/segmented dataset - to a NetCDF artifact in the repository for downstream consumers. + Both executors enforce input/output contracts at every DAG edge via + ``GraphExecutor``. The processor itself performs no validation. Example usage (called by PipelineOrchestrator):: @@ -88,18 +99,28 @@ def __init__( "Initialize it in the orchestrator before creating the processor." ) - # Build the graph-based pipeline once — module instances (and their - # frame history) persist across process_file() calls. - from adapt.execution.pipeline_builder import NexradPipeline - self._pipeline = NexradPipeline(config, dict(output_dirs)) + # Build two execution graphs; module instances are shared (stateful + # projector/tracker state persists across files via the module objects). + _ensure_modules_registered() + modules = registry.create_modules() + + single_modules = [m for m in modules if m.name in {"ingest", "detection"}] + multi_modules = [m for m in modules if m.name in {"projection", "analysis", "tracking"}] + + self._single_executor = GraphExecutor(GraphBuilder(single_modules).build()) + self._multi_executor = GraphExecutor(GraphBuilder(multi_modules).build()) + + logger.info( + "RadarProcessor graphs: single=[%s] multi=[%s]", + ", ".join(m.name for m in single_modules), + ", ".join(m.name for m in multi_modules), + ) # Frame pairing orchestration state - # We maintain a rolling list of segmented datasets and only call - # projection/analysis/tracking when we have 2 valid frames - self._segmented_history = [] # List of (filepath, ds_2d, scan_time) tuples - self._max_history = config.processor.max_history # Should be 2 + self._segmented_history = [] # list of (filepath, ds_2d, scan_time) + self._max_history = config.processor.max_history self._max_time_gap_minutes = config.projector.max_time_interval_minutes - self._last_skipped = False # Set True when process_file skips an analyzed file + self._last_skipped = False # ── Lifecycle ───────────────────────────────────────────────────────────── @@ -145,26 +166,22 @@ def run(self): # ── Per-file processing ─────────────────────────────────────────────────── def process_file(self, filepath) -> bool: - """Process NEXRAD file with frame pairing orchestration. + """Process a NEXRAD file with frame-pairing orchestration. - New architecture: - - File 1: Load → Detect → Wait (build history, no projection yet) - - File 2: Load → Detect → Check pair → Projection → Analysis → Tracking + Phase 1 — ingest + detection (every file): + Runs single-frame executor. Contract-validated by GraphExecutor. - Only calls projection/analysis/tracking when we have 2 segmented - datasets with an acceptable time gap. This prevents crashes when - modules expect projected_ds but only 1 file has been processed. + Phase 2 — frame pairing: + Accumulates segmented datasets. Waits until 2 frames are ready. - Parameters - ---------- - filepath : str or dict - Path to the NEXRAD Level-II file. Dict format ``{"path": ...}`` - is accepted for backwards compatibility with the downloader queue. + Phase 3 — projection + analysis + tracking (when pair is ready): + Injects dataset_history into context. Runs multi-frame executor. + Contract-validated by GraphExecutor. Returns ------- bool - True if the file was processed (or ready to pair), False on error. + True if processed or deferred (waiting for pair), False on error. """ queued_at = None if isinstance(filepath, dict): @@ -183,86 +200,88 @@ def process_file(self, filepath) -> bool: logger.info("Processing: %s", Path(filepath).name) try: - # ────────────────────────────────────────────────────────────────── - # PHASE 1: Load + Detect (always run, even for first file) - # ────────────────────────────────────────────────────────────────── - context_initial = { + # ── Phase 1: ingest + detection ──────────────────────────────── + t0 = time.perf_counter() + base_ctx = { "nexrad_file": filepath, "config": self.config, "output_dirs": self.output_dirs, } if self.repository: - context_initial["repository"] = self.repository + base_ctx["repository"] = self.repository - ds_2d, scan_time, ingest_s, detect_s = self._run_ingest_detection_only(context_initial) + frame_ctx = self._single_executor.run(base_ctx) + single_s = time.perf_counter() - t0 - # ────────────────────────────────────────────────────────────────── - # PHASE 2: Add to rolling history - # ────────────────────────────────────────────────────────────────── - self._segmented_history.append((filepath, ds_2d, scan_time)) + # Register radar location from first scan (idempotent after that) + if self.repository: + grid_ds = frame_ctx.get("grid_ds") or frame_ctx.get("grid_ds_2d") + if grid_ds is not None: + lat = grid_ds.attrs.get("radar_latitude") + lon = grid_ds.attrs.get("radar_longitude") + if lat is not None and lon is not None: + self.repository.registry.ensure_radar_location( + self.config.downloader.radar, lat=float(lat), lon=float(lon) + ) + + scan_time = frame_ctx.get("scan_time") + + # ── Phase 2: accumulate frame history ────────────────────────── + self._segmented_history.append((filepath, frame_ctx["segmented_ds"], scan_time)) if len(self._segmented_history) > self._max_history: self._segmented_history.pop(0) - # ────────────────────────────────────────────────────────────────── - # PHASE 3: Check if ready for full processing - # ────────────────────────────────────────────────────────────────── if len(self._segmented_history) < 2: logger.info( - "Segmented %s, waiting for pair | ingest=%.1fs detect=%.1fs", - Path(filepath).name, ingest_s, detect_s, + "Segmented %s, waiting for pair | %.1fs", + Path(filepath).name, single_s, ) - return True # Success, but waiting for second file + return True - # ────────────────────────────────────────────────────────────────── - # PHASE 4: Validate time gap between frames - # ────────────────────────────────────────────────────────────────── + # ── Phase 3: validate time gap ───────────────────────────────── time_gap_valid, time_gap_minutes = self._validate_time_gap() if not time_gap_valid: logger.warning( - "Time gap %.1f min > %.1f min, discarding oldest frame. " - "Waiting for next file with smaller gap.", - time_gap_minutes, - self._max_time_gap_minutes + "Time gap %.1f min > %.1f min, discarding oldest frame.", + time_gap_minutes, self._max_time_gap_minutes, ) - # Keep newest frame in history, wait for next file - return True # Not an error, just waiting for better pair + return True - # ────────────────────────────────────────────────────────────────── - # PHASE 5: Run full pipeline (projection → analysis → tracking) - # ────────────────────────────────────────────────────────────────── logger.info( "Processing pair: %s + %s (gap: %.1f min)", Path(self._segmented_history[0][0]).name, Path(self._segmented_history[1][0]).name, - time_gap_minutes + time_gap_minutes, ) + # ── Phase 4: projection + analysis + tracking ────────────────── t_proj = time.perf_counter() - result = self._run_full_pipeline(context_initial) + pair_ctx = { + **frame_ctx, + "config": self.config, + "output_dirs": self.output_dirs, + "dataset_history": [(fp, ds) for fp, ds, _ in self._segmented_history], + } + if self.repository: + pair_ctx["repository"] = self.repository + + result = self._multi_executor.run(pair_ctx) project_s = time.perf_counter() - t_proj - # ────────────────────────────────────────────────────────────────── - # PHASE 6: Save outputs to repository - # ────────────────────────────────────────────────────────────────── + # ── Phase 5: persist results ─────────────────────────────────── if self.repository and result: self._save_results(result, scan_time) - # Log cell count + timing cell_stats = result.get("cell_stats") n_cells = len(cell_stats) if cell_stats is not None else 0 logger.info( - "Processed pair: %d cells | ingest=%.1fs detect=%.1fs project=%.1fs%s", - n_cells, ingest_s, detect_s, project_s, + "Processed pair: %d cells | %.1fs proj%s", + n_cells, project_s, f" queue={queue_wait_s:.1f}s" if queue_wait_s is not None else "", ) - # Mark both files as processed if tracker: - timings = { - "ingest_seconds": ingest_s, - "detect_seconds": detect_s, - "project_seconds": project_s, - } + timings = {"project_seconds": project_s} if queue_wait_s is not None: timings["queue_wait_seconds"] = queue_wait_s for fp, _, _ in self._segmented_history: @@ -272,14 +291,10 @@ def process_file(self, filepath) -> bool: return True except ContractViolation as e: - logger.critical( - "CRITICAL: Pipeline contract violated: %s. Stopping pipeline.", e - ) + logger.critical("CRITICAL: Pipeline contract violated: %s. Stopping pipeline.", e) self.stop() if tracker: - tracker.mark_stage_complete( - file_id, "analyzed", error=f"ContractViolation: {e}" - ) + tracker.mark_stage_complete(file_id, "analyzed", error=f"ContractViolation: {e}") return False except Exception as e: @@ -288,7 +303,18 @@ def process_file(self, filepath) -> bool: tracker.mark_stage_complete(file_id, "analyzed", error=str(e)) return False - # ── NetCDF persistence ──────────────────────────────────────────────────── + # ── Frame pairing helpers ───────────────────────────────────────────────── + + def _validate_time_gap(self): + """Return (valid, gap_minutes) for the two frames in history.""" + if len(self._segmented_history) < 2: + return False, 0.0 + time1 = self._segmented_history[0][2] + time2 = self._segmented_history[1][2] + gap_minutes = (time2 - time1).total_seconds() / 60.0 + return abs(gap_minutes) <= self._max_time_gap_minutes, gap_minutes + + # ── Persistence helpers ─────────────────────────────────────────────────── def _save_analysis_netcdf(self, ds, filepath: str, scan_time) -> str | None: """Write the analysis dataset to a NetCDF artifact in the repository.""" @@ -321,144 +347,16 @@ def _save_analysis_netcdf(self, ds, filepath: str, scan_time) -> str | None: logger.warning("Could not save analysis NetCDF: %s", e) return None - # ── Frame Pairing Orchestration Helpers ─────────────────────────────────── - - def _run_ingest_detection_only(self, context: dict): - """Run ONLY ingest + detection modules (segmentation). - - This runs the first part of the pipeline (ingest and detection) without - calling projection/analysis/tracking. Used to build up the rolling - history of segmented datasets. - - Returns - ------- - ds_2d : xr.Dataset - Segmented 2D dataset with cell_labels - scan_time : datetime - Scan timestamp - ingest_seconds : float - Wall time for the ingest (regridding) step - detect_seconds : float - Wall time for the detection (segmentation) step - """ - # Import modules directly (not through pipeline graph) - from adapt.modules.detection.module import DetectModule - from adapt.modules.ingest.module import LoadModule - - # Instantiate if not cached (persist across calls) - if not hasattr(self, '_ingest_module'): - self._ingest_module = LoadModule() - self._detection_module = DetectModule() - - # Run ingest module - t0 = time.perf_counter() - result = self._ingest_module.run(context) - context.update(result) - ingest_s = time.perf_counter() - t0 - - # Persist radar location from actual data on first file (idempotent after that). - if self.repository: - grid_ds = context.get("grid_ds") or context.get("grid_ds_2d") - if grid_ds is not None: - lat = grid_ds.attrs.get("radar_latitude") - lon = grid_ds.attrs.get("radar_longitude") - if lat is not None and lon is not None: - self.repository.registry.ensure_radar_location( - self.config.downloader.radar, lat=float(lat), lon=float(lon) - ) - - # Run detection module - t1 = time.perf_counter() - result = self._detection_module.run(context) - context.update(result) - detect_s = time.perf_counter() - t1 - - # Extract outputs - ds_2d = context.get("segmented_ds") - scan_time = context.get("scan_time") - - return ds_2d, scan_time, ingest_s, detect_s - - def _validate_time_gap(self): - """Check if time gap between frames is acceptable for optical flow. - - Returns - ------- - valid : bool - True if time gap is within max_time_interval_minutes - gap_minutes : float - Actual time gap in minutes - """ - if len(self._segmented_history) < 2: - return False, 0.0 - - # Extract scan times from history tuples - time1 = self._segmented_history[0][2] # (filepath, ds_2d, scan_time) - time2 = self._segmented_history[1][2] - - # Compute gap in minutes - gap_minutes = (time2 - time1).total_seconds() / 60.0 - valid = abs(gap_minutes) <= self._max_time_gap_minutes - - return valid, gap_minutes - - def _run_full_pipeline(self, context: dict): - """Run projection → analysis → tracking on validated frame pair. - - Reuses the context already populated by _run_ingest_detection_only - (which contains grid_ds, segmented_ds, scan_time, etc.) to avoid - re-running the expensive ingest step. - - Returns - ------- - dict - Updated context with projected_ds, cell_stats, pathed_cells, etc. - """ - # Inject segmented history into ProjectionModule - projection_module = self._get_projection_module() - if projection_module: - projection_module._dataset_history = [ - (fp, ds) for fp, ds, _ in self._segmented_history - ] - - # Run only the stages after ingest+detection — they are already in context - _skip = {"ingest", "detection"} - for node in self._pipeline._nodes: - if node.name not in _skip: - result = node.module.run(context) - context.update(result) - - return context - - def _get_projection_module(self): - """Get ProjectionModule instance from pipeline graph.""" - try: - for node in self._pipeline._nodes: - if node.name == "projection": - return node.module - except Exception: - logger.warning("Could not find ProjectionModule in pipeline") - return None - def _save_results(self, result: dict, scan_time): - """Save all pipeline outputs to the repository. - - Saves: - - projected_ds as NetCDF artifact - - cell_stats, cell_adjacency as Parquet artifacts - - tracked_cells, cell_events as SQLite via TrackStore - (label→uid adjacency mapping in TrackStore) - """ + """Save all pipeline outputs to the repository.""" if scan_time is not None and scan_time.tzinfo is None: scan_time = scan_time.replace(tzinfo=UTC) - # NetCDF: segmentation + projections + flow vectors projected_ds = result.get("projected_ds") if projected_ds is not None: - filepath = self._segmented_history[-1][0] # Most recent file + filepath = self._segmented_history[-1][0] self._save_analysis_netcdf(projected_ds, filepath, scan_time) - # Parquet: analysis outputs writer = RepositoryWriter(self.repository) cell_stats = result.get("cell_stats") @@ -471,7 +369,6 @@ def _save_results(self, result: dict, scan_time): if cell_adjacency is not None and not cell_adjacency.empty: writer.write_analysis(df=cell_adjacency, scan_time=scan_time, producer="cell_adjacency") - # SQLite: track identity outputs if tracked_cells is not None and not tracked_cells.empty: if cell_stats is None: raise ValueError("Missing required cell_stats for TrackStore persistence") From da47ded86e2ff8c1da3b1f5458599571e8760d2d Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Fri, 1 May 2026 17:38:21 -0500 Subject: [PATCH 06/14] REF:(import-linting) config calls Replaced a large config with config for each module. Injected these configs into implementations and removed redundant storage. Updated BaseModule wrappers and test fixtures to use the new schemas. --- src/adapt/configuration/schemas/__init__.py | 4 +- src/adapt/execution/graph/executor.py | 10 +++-- src/adapt/gui/dashboard.py | 12 +++--- src/adapt/modules/acquisition/module.py | 8 +--- src/adapt/modules/analysis/module.py | 29 ++++++------- src/adapt/modules/detection/module.py | 36 +++++++--------- src/adapt/modules/ingest/module.py | 38 ++++++++--------- src/adapt/modules/projection/module.py | 39 ++++++++---------- src/adapt/modules/tracking/module.py | 41 +++++++++---------- src/adapt/runtime/processor.py | 10 ++++- .../test_tracker_scan_local_outputs.py | 6 ++- 11 files changed, 108 insertions(+), 125 deletions(-) diff --git a/src/adapt/configuration/schemas/__init__.py b/src/adapt/configuration/schemas/__init__.py index d08733e..74dabae 100644 --- a/src/adapt/configuration/schemas/__init__.py +++ b/src/adapt/configuration/schemas/__init__.py @@ -14,6 +14,6 @@ """ from adapt.configuration.schemas.initialization import init_runtime_config +from adapt.configuration.schemas.materialization import materialize_module_configs -# Single public function - everything else is internal implementation -__all__ = ['init_runtime_config'] +__all__ = ['init_runtime_config', 'materialize_module_configs'] diff --git a/src/adapt/execution/graph/executor.py b/src/adapt/execution/graph/executor.py index 1e48608..6bde5a1 100644 --- a/src/adapt/execution/graph/executor.py +++ b/src/adapt/execution/graph/executor.py @@ -17,6 +17,7 @@ import logging +from adapt.contracts.pipeline import require from adapt.execution.graph.node import Node logger = logging.getLogger(__name__) @@ -85,10 +86,13 @@ def run(self, context: dict) -> dict: if not ready: continue - # Validate inputs declared by the module + # Validate inputs declared by the module — fail immediately if absent for key, validator in (node.module.input_contracts or {}).items(): - if key in context: - validator(context[key]) + require( + key in context, + f"Required input '{key}' missing for module '{node.name}'", + ) + validator(context[key]) outputs = node.module.run(context) diff --git a/src/adapt/gui/dashboard.py b/src/adapt/gui/dashboard.py index 65731fe..b40c015 100644 --- a/src/adapt/gui/dashboard.py +++ b/src/adapt/gui/dashboard.py @@ -1284,11 +1284,11 @@ def _draw_scan(self, ds, fig, ax=None): fraction=0.046, pad=0.04) # ── Cell contours ───────────────────────────────────────────────────── - # for cell_id in np.unique(labels_data[labels_data > 0]): - # cs = ax.contour(x_grid, y_grid, - # (labels_data == cell_id).astype(float), - # levels=[0.5], colors='#2C3539', linewidths=0.5, zorder=50) - # self._cell_contours[int(cell_id)] = cs + for cell_id in np.unique(labels_data[labels_data > 0]): + cs = ax.contour(x_grid, y_grid, + (labels_data == cell_id).astype(float), + levels=[0.8], colors='#2C3539', linewidths=0.5, zorder=50) + self._cell_contours[int(cell_id)] = cs # ── Projection contours ─────────────────────────────────────────────── if 'cell_projections' in ds.data_vars: @@ -1301,7 +1301,7 @@ def _draw_scan(self, ds, fig, ax=None): _ls_cycle = ['dashed', 'dashdot', 'dotted'] for i in range(1, end_frame): alpha = max(0.5, 1.0 - i / n_frames) - lw = max(0.5, 1.5 - i * 0.2) + lw = max(0.7, 1.6 - i * 0.2) ls = _ls_cycle[(i - 1) % len(_ls_cycle)] lp = proj_da.isel({fo: i}).values for cid in np.unique(lp[~np.isnan(lp) & (lp > 0)]): diff --git a/src/adapt/modules/acquisition/module.py b/src/adapt/modules/acquisition/module.py index 6c8617e..e4f8cf1 100644 --- a/src/adapt/modules/acquisition/module.py +++ b/src/adapt/modules/acquisition/module.py @@ -12,14 +12,10 @@ import time from datetime import UTC, datetime, timedelta from pathlib import Path -from typing import TYPE_CHECKING from nexradaws import NexradAwsInterface -from adapt.configuration.schemas.directories import get_nexrad_path - -if TYPE_CHECKING: - from adapt.configuration.schemas import InternalConfig +from adapt.utils.paths import get_nexrad_path __all__ = ['AwsNexradDownloader'] @@ -73,7 +69,7 @@ class AwsNexradDownloader(threading.Thread): def __init__( self, - config: "InternalConfig", + config, output_dir: Path = None, output_dirs: dict = None, result_queue=None, diff --git a/src/adapt/modules/analysis/module.py b/src/adapt/modules/analysis/module.py index a432476..55a51f6 100644 --- a/src/adapt/modules/analysis/module.py +++ b/src/adapt/modules/analysis/module.py @@ -26,7 +26,6 @@ import json import logging from datetime import UTC -from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -34,9 +33,6 @@ from scipy.ndimage import center_of_mass from skimage.measure import regionprops -if TYPE_CHECKING: - from adapt.configuration.schemas import InternalConfig - __all__ = ['RadarCellAnalyzer'] logger = logging.getLogger(__name__) @@ -120,7 +116,7 @@ class RadarCellAnalyzer: >>> print(len(df)) # number of cells in this frame """ - def __init__(self, config: "InternalConfig"): + def __init__(self, config): """Initialize analyzer with validated configuration. Parameters @@ -140,12 +136,12 @@ def __init__(self, config: "InternalConfig"): >>> config = resolve_config(ParamConfig()) >>> analyzer = RadarCellAnalyzer(config) """ - self.config = config - self.reflectivity_field = config.global_.var_names.reflectivity - self.radar_variables = config.analyzer.radar_variables - self.exclude_fields = config.analyzer.exclude_fields - self.max_projection_steps = config.projector.max_projection_steps - self._adjacency_min_touching = config.analyzer.adjacency_min_touching_boundary_pixels + self.reflectivity_field = config.reflectivity_var + self.labels_field = config.labels_var + self.radar_variables = config.radar_variables + self.exclude_fields = config.exclude_fields + self.max_projection_steps = config.max_projection_steps + self._adjacency_min_touching = config.adjacency_min_touching def extract(self, ds: xr.Dataset, z_level: int = None) -> pd.DataFrame: """Extract geometric and statistical properties from all labeled cells. @@ -221,7 +217,7 @@ def extract(self, ds: xr.Dataset, z_level: int = None) -> pd.DataFrame: >>> df.to_sql('cells', conn, if_exists='append') # Database storage """ # Get labels variable name from config - labels_name = self.config.global_.var_names.cell_labels + labels_name = self.labels_field # Extract reflectivity (already 2D) refl = ds[self.reflectivity_field].values @@ -269,7 +265,7 @@ def extract_adjacency(self, ds: xr.Dataset) -> pd.DataFrame: along a shared boundary with at least N touching boundary pixel-edges, where N is config-driven (`analyzer.adjacency_min_touching_boundary_pixels`). """ - labels_name = self.config.global_.var_names.cell_labels + labels_name = self.labels_field if labels_name not in ds.data_vars: raise ValueError( f"Missing required labels variable '{labels_name}' for adjacency extraction" @@ -703,7 +699,7 @@ class AnalysisModule(BaseModule): """ name = "analysis" - inputs = ["projected_ds", "config", "scan_time"] + inputs = ["projected_ds", "analysis_config", "scan_time"] outputs = ["cell_stats", "cell_adjacency"] output_contracts = {"cell_stats": _check_cell_stats, "cell_adjacency": _check_cell_adjacency} @@ -711,14 +707,13 @@ def __init__(self) -> None: self._analyzer = None def run(self, context: dict) -> dict: - config = context["config"] + config = context["analysis_config"] ds_2d = context["projected_ds"] if self._analyzer is None: self._analyzer = RadarCellAnalyzer(config) - z_level = config.global_.z_level - df_cells = self._analyzer.extract(ds_2d, z_level=z_level) + df_cells = self._analyzer.extract(ds_2d, z_level=config.z_level) df_adjacency = self._analyzer.extract_adjacency(ds_2d) return {"cell_stats": df_cells, "cell_adjacency": df_adjacency} diff --git a/src/adapt/modules/detection/module.py b/src/adapt/modules/detection/module.py index c3d1dcb..c89fc14 100644 --- a/src/adapt/modules/detection/module.py +++ b/src/adapt/modules/detection/module.py @@ -23,7 +23,6 @@ """ import logging -from typing import TYPE_CHECKING import numpy as np import xarray as xr @@ -31,9 +30,6 @@ from skimage.morphology import h_maxima from skimage.segmentation import watershed -if TYPE_CHECKING: - from adapt.configuration.schemas import InternalConfig - __all__ = ['RadarCellSegmenter'] logger = logging.getLogger(__name__) @@ -140,7 +136,7 @@ class RadarCellSegmenter: >>> print(f"Found {ds_labeled['cell_labels'].max()} cells") """ - def __init__(self, config: "InternalConfig"): + def __init__(self, config): """Initialize segmenter with validated configuration. Parameters @@ -160,19 +156,16 @@ def __init__(self, config: "InternalConfig"): >>> config = resolve_config(ParamConfig()) >>> segmenter = RadarCellSegmenter(config) """ - self.config = config - self.method = config.segmenter.method - self.threshold = config.segmenter.threshold - self.kernel_size = config.segmenter.closing_kernel - self.filter_by_size = config.segmenter.filter_by_size - self.min_gridpoints = config.segmenter.min_cellsize_gridpoint - self.max_gridpoints = config.segmenter.max_cellsize_gridpoint - self.h_maxima = config.segmenter.h_maxima - - # Variable names from global config - self.refl_name = config.global_.var_names.reflectivity - self.labels_name = config.global_.var_names.cell_labels - self.z_level = config.global_.z_level + self.method = config.method + self.threshold = config.threshold + self.kernel_size = config.closing_kernel + self.filter_by_size = config.filter_by_size + self.min_gridpoints = config.min_cellsize_gridpoint + self.max_gridpoints = config.max_cellsize_gridpoint + self.h_maxima = config.h_maxima + self.refl_name = config.reflectivity_var + self.labels_name = config.labels_var + self.z_level = config.z_level logger.info("RadarCellSegmenter initialized: method=%s, threshold=%s", self.method, self.threshold) @@ -405,7 +398,7 @@ class DetectModule(BaseModule): """ name = "detection" - inputs = ["grid_ds_2d", "config"] + inputs = ["grid_ds_2d", "detection_config"] outputs = ["segmented_ds", "num_cells"] input_contracts = {"grid_ds_2d": _check_grid_ds_2d} output_contracts = {"segmented_ds": _check_segmented_ds} @@ -414,15 +407,14 @@ def __init__(self) -> None: self._segmenter = None def run(self, context: dict) -> dict: - config = context["config"] + config = context["detection_config"] ds_2d = context["grid_ds_2d"] if self._segmenter is None: self._segmenter = RadarCellSegmenter(config) segmented = self._segmenter.segment(ds_2d) - labels_name = config.global_.var_names.cell_labels - num_cells = int(segmented[labels_name].max().item()) + num_cells = int(segmented[config.labels_var].max().item()) return {"segmented_ds": segmented, "num_cells": num_cells} diff --git a/src/adapt/modules/ingest/module.py b/src/adapt/modules/ingest/module.py index 7d81160..3c36414 100644 --- a/src/adapt/modules/ingest/module.py +++ b/src/adapt/modules/ingest/module.py @@ -19,14 +19,10 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING import pyart import xarray as xr -if TYPE_CHECKING: - from adapt.configuration.schemas import InternalConfig - __all__ = ['RadarDataLoader'] logger = logging.getLogger(__name__) @@ -82,7 +78,7 @@ class RadarDataLoader: >>> print(ds.data_vars) # reflectivity, velocity, etc. """ - def __init__(self, config: "InternalConfig"): + def __init__(self, config): """Initialize loader with validated configuration. Parameters @@ -102,14 +98,13 @@ def __init__(self, config: "InternalConfig"): >>> config = resolve_config(ParamConfig()) >>> loader = RadarDataLoader(config) """ - self.config = config - self.file_format = config.reader.file_format - self.grid_shape = config.regridder.grid_shape - self.grid_limits = config.regridder.grid_limits - self.roi_func = config.regridder.roi_func - self.min_radius = config.regridder.min_radius - self.weighting_function = config.regridder.weighting_function - self.save_netcdf = config.regridder.save_netcdf + self.file_format = config.file_format + self.grid_shape = config.grid_shape + self.grid_limits = config.grid_limits + self.roi_func = config.roi_func + self.min_radius = config.min_radius + self.weighting_function = config.weighting_function + self.save_netcdf = config.save_netcdf def read(self, filepath: Path | str) -> object: """Read a NEXRAD archive file into a Py-ART Radar object. @@ -355,10 +350,10 @@ def load_and_regrid(self, filepath: Path | str, grid_kwargs: dict = None, import numpy as np # noqa: E402 import xarray as _xr # noqa: E402 -from adapt.configuration.schemas.directories import get_netcdf_path # noqa: E402 from adapt.contracts import assert_gridded # noqa: E402 from adapt.execution.module_registry import registry # noqa: E402 from adapt.modules.base import BaseModule # noqa: E402 +from adapt.utils.paths import get_netcdf_path # noqa: E402 def _check_grid_ds_2d(ds): @@ -391,7 +386,7 @@ class LoadModule(BaseModule): """ name = "ingest" - inputs = ["nexrad_file", "config"] + inputs = ["nexrad_file", "ingest_config"] outputs = ["grid_ds", "grid_ds_2d", "scan_time"] output_contracts = {"grid_ds_2d": _check_grid_ds_2d} @@ -399,14 +394,14 @@ def __init__(self) -> None: self._loader = None def run(self, context: dict) -> dict: - config = context["config"] + config = context["ingest_config"] filepath = context["nexrad_file"] output_dirs = context.get("output_dirs", {}) if self._loader is None: self._loader = RadarDataLoader(config) - radar = config.downloader.radar + radar = config.radar nc_filename = Path(filepath).stem scan_time = _dt.now(UTC) try: @@ -421,17 +416,16 @@ def run(self, context: dict) -> dict: ds = self._loader.load_and_regrid( filepath, - save_netcdf=config.regridder.save_netcdf, + save_netcdf=config.save_netcdf, output_dir=output_dir, ) if ds is None: raise RuntimeError(f"Ingest failed: load_and_regrid returned None for {filepath}") - # Extract 2D slice at configured z-level - z_level = config.global_.z_level - z_name = config.global_.coord_names.z - time_name = config.global_.coord_names.time + z_level = config.z_level + z_name = config.z_coord + time_name = config.time_coord z_idx = int(np.argmin(np.abs(ds[z_name].values - z_level))) ds_2d = _xr.Dataset() diff --git a/src/adapt/modules/projection/module.py b/src/adapt/modules/projection/module.py index b0d1f87..4ddfadf 100644 --- a/src/adapt/modules/projection/module.py +++ b/src/adapt/modules/projection/module.py @@ -21,7 +21,6 @@ """ import logging -from typing import TYPE_CHECKING import cv2 import numpy as np @@ -29,9 +28,6 @@ from scipy.ndimage import binary_dilation from scipy.spatial import Delaunay -if TYPE_CHECKING: - from adapt.configuration.schemas import InternalConfig - __all__ = ['RadarCellProjector'] logger = logging.getLogger(__name__) @@ -129,7 +125,7 @@ class RadarCellProjector: ... ) """ - def __init__(self, config: "InternalConfig"): + def __init__(self, config): """Initialize projector with validated configuration. Parameters @@ -149,23 +145,22 @@ def __init__(self, config: "InternalConfig"): >>> config = resolve_config(ParamConfig()) >>> projector = RadarCellProjector(config) """ - self.config = config - self.method = config.projector.method - self.nan_fill = config.projector.nan_fill_value - self.max_interval_minutes = config.projector.max_time_interval_minutes - self.max_proj_steps = config.projector.max_projection_steps + self.method = config.method + self.nan_fill = config.nan_fill_value + self.max_interval_minutes = config.max_time_interval_minutes + self.max_proj_steps = config.max_projection_steps self.flow_params = { - "pyr_scale": config.projector.flow_params.pyr_scale, - "levels": config.projector.flow_params.levels, - "winsize": config.projector.flow_params.winsize, - "iterations": config.projector.flow_params.iterations, - "poly_n": config.projector.flow_params.poly_n, - "poly_sigma": config.projector.flow_params.poly_sigma, - "flags": config.projector.flow_params.flags, + "pyr_scale": config.pyr_scale, + "levels": config.levels, + "winsize": config.winsize, + "iterations": config.iterations, + "poly_n": config.poly_n, + "poly_sigma": config.poly_sigma, + "flags": config.flags, } - self.min_motion_threshold = config.projector.min_motion_threshold - self.max_flow_magnitude = config.projector.max_flow_magnitude - self.refl_var = config.global_.var_names.reflectivity + self.min_motion_threshold = config.min_motion_threshold + self.max_flow_magnitude = config.max_flow_magnitude + self.refl_var = config.reflectivity_var def project(self, ds_list): """Project cells forward using optical flow motion vectors. @@ -602,7 +597,7 @@ class ProjectionModule(BaseModule): """ name = "projection" - inputs = ["segmented_ds", "dataset_history", "config"] + inputs = ["segmented_ds", "dataset_history", "projection_config"] outputs = ["projected_ds"] input_contracts = {"segmented_ds": _check_segmented_ds} @@ -610,7 +605,7 @@ def __init__(self) -> None: self._projector = None def run(self, context: dict) -> dict: - config = context["config"] + config = context["projection_config"] dataset_history = context["dataset_history"] # list of (filepath, ds_2d) if self._projector is None: diff --git a/src/adapt/modules/tracking/module.py b/src/adapt/modules/tracking/module.py index a1dcf40..c2753cc 100644 --- a/src/adapt/modules/tracking/module.py +++ b/src/adapt/modules/tracking/module.py @@ -38,7 +38,6 @@ import logging import string from datetime import UTC -from typing import TYPE_CHECKING import networkx as nx import numpy as np @@ -46,8 +45,6 @@ import xarray as xr from scipy.optimize import linear_sum_assignment -if TYPE_CHECKING: - from adapt.configuration.schemas import InternalConfig __all__ = ['RadarCellTracker', 'TrackingModule'] @@ -274,8 +271,8 @@ def get_successors(self, node_id: int) -> list[tuple[int, str]]: class MatchingEngine: """Cost matrix builder using projected masks (cell_projections[0] is already the hull).""" - def __init__(self, config: "InternalConfig"): - self.core_threshold = config.tracker.core_reflectivity_threshold + def __init__(self, config): + self.core_threshold = config.core_reflectivity_threshold def compute_cost_matrix( self, @@ -361,15 +358,18 @@ class RadarCellTracker: 7. True births: remaining unmatched born cells """ - def __init__(self, config: "InternalConfig"): - self.config = config - self.match_cost = config.tracker.match_cost_threshold - self.keep_cost = config.tracker.keep_cost_threshold - self.unmatch_cost = config.tracker.unmatch_cost_threshold - self.split_overlap = config.tracker.split_overlap_threshold - self.core_threshold = config.tracker.core_reflectivity_threshold - self.refl_var = config.global_.var_names.reflectivity - self.labels_var = config.global_.var_names.cell_labels + def __init__(self, config): + self.match_cost = config.match_cost + self.keep_cost = config.keep_cost + self.unmatch_cost = config.unmatch_cost + self.split_overlap = config.split_overlap + self.core_threshold = config.core_reflectivity_threshold + self.refl_var = config.reflectivity_var + self.labels_var = config.labels_var + self.uid_time_step_s = config.uid_time_step_s + self.uid_latlon_step_deg = config.uid_latlon_step_deg + self.uid_area_step_km2 = config.uid_area_step_km2 + self.uid_width = config.uid_width self.graph = TrackingGraph() self.matcher = MatchingEngine(config) @@ -535,7 +535,6 @@ def _extract_cells_from_analyzer( return cells def _new_cell_identity(self, cell: dict) -> tuple[str, str]: - cfg = self.config.tracker.cell_uid max_zdr = float(cell['max_zdr']) if max_zdr < 0: max_zdr = 0.0 @@ -546,11 +545,11 @@ def _new_cell_identity(self, cell: dict) -> tuple[str, str]: max_dbz=float(cell['max_reflectivity']), max_zdr=max_zdr, area40_km2=float(cell['area_40dbz_km2']), - time_step_s=int(cfg.time_step_s), - latlon_step_deg=float(cfg.latlon_step_deg), - area_step_km2=float(cfg.area_step_km2), + time_step_s=self.uid_time_step_s, + latlon_step_deg=self.uid_latlon_step_deg, + area_step_km2=self.uid_area_step_km2, ) - cell_uid = _cell_uid_from_signature(signature, width=int(cfg.width)) + cell_uid = _cell_uid_from_signature(signature, width=self.uid_width) return cell_uid, signature # ------------------------------------------------------------------ @@ -968,7 +967,7 @@ class TrackingModule(BaseModule): """ name = "tracking" - inputs = ["projected_ds", "cell_stats", "config", "scan_time"] + inputs = ["projected_ds", "cell_stats", "tracking_config", "scan_time"] outputs = ["tracked_cells", "cell_events"] input_contracts = {"projected_ds": _check_projected_ds} output_contracts = { @@ -980,7 +979,7 @@ def __init__(self) -> None: self._tracker = None def run(self, context: dict) -> dict: - config = context["config"] + config = context["tracking_config"] ds_2d = context["projected_ds"] cell_stats = context["cell_stats"] diff --git a/src/adapt/runtime/processor.py b/src/adapt/runtime/processor.py index 9ed6394..fee6fb8 100644 --- a/src/adapt/runtime/processor.py +++ b/src/adapt/runtime/processor.py @@ -29,6 +29,7 @@ import pandas as pd +from adapt.configuration.schemas.materialization import materialize_module_configs from adapt.contracts import ContractViolation from adapt.execution.graph.builder import GraphBuilder from adapt.execution.graph.executor import GraphExecutor @@ -110,6 +111,8 @@ def __init__( self._single_executor = GraphExecutor(GraphBuilder(single_modules).build()) self._multi_executor = GraphExecutor(GraphBuilder(multi_modules).build()) + self._module_configs = materialize_module_configs(config) + logger.info( "RadarProcessor graphs: single=[%s] multi=[%s]", ", ".join(m.name for m in single_modules), @@ -204,7 +207,8 @@ def process_file(self, filepath) -> bool: t0 = time.perf_counter() base_ctx = { "nexrad_file": filepath, - "config": self.config, + "ingest_config": self._module_configs["ingest_config"], + "detection_config": self._module_configs["detection_config"], "output_dirs": self.output_dirs, } if self.repository: @@ -258,7 +262,9 @@ def process_file(self, filepath) -> bool: t_proj = time.perf_counter() pair_ctx = { **frame_ctx, - "config": self.config, + "projection_config": self._module_configs["projection_config"], + "analysis_config": self._module_configs["analysis_config"], + "tracking_config": self._module_configs["tracking_config"], "output_dirs": self.output_dirs, "dataset_history": [(fp, ds) for fp, ds, _ in self._segmented_history], } diff --git a/tests/modules/tracking/test_tracker_scan_local_outputs.py b/tests/modules/tracking/test_tracker_scan_local_outputs.py index 3501b12..559ead6 100644 --- a/tests/modules/tracking/test_tracker_scan_local_outputs.py +++ b/tests/modules/tracking/test_tracker_scan_local_outputs.py @@ -9,6 +9,7 @@ import pytest import xarray as xr +from adapt.configuration.schemas.materialization import materialize_module_configs from adapt.configuration.schemas.param import ParamConfig from adapt.configuration.schemas.resolve import resolve_config from adapt.configuration.schemas.user import UserConfig @@ -19,12 +20,13 @@ def config(): d = tempfile.mkdtemp() try: + import shutil param = ParamConfig() param.tracker.split_overlap_threshold = 0.4 user = UserConfig(base_dir=str(Path(d)), radar="TEST_RADAR") - return resolve_config(param, user, None) + internal = resolve_config(param, user, None) + return materialize_module_configs(internal)["tracking_config"] finally: - import shutil shutil.rmtree(d, ignore_errors=True) From aa9ca658c9d97e808104502ad5e4d0412691404a Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Fri, 1 May 2026 17:55:04 -0500 Subject: [PATCH 07/14] ADD:import-linter, FIX: Ruff error --- .github/workflows/ci.yml | 6 ++++++ src/adapt/modules/tracking/module.py | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3688203..448b87c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,6 +47,9 @@ jobs: - name: Install ruff run: pip install ruff + - name: Install import-linter + run: pip install import-linter + - name: Show environment info run: | python --version @@ -55,6 +58,9 @@ jobs: - name: Lint with ruff run: ruff check src tests + - name: Check import architecture + run: lint-imports --no-cache + - name: Run tests run: | pytest -m "not integration" \ diff --git a/src/adapt/modules/tracking/module.py b/src/adapt/modules/tracking/module.py index c2753cc..4338ed8 100644 --- a/src/adapt/modules/tracking/module.py +++ b/src/adapt/modules/tracking/module.py @@ -45,7 +45,6 @@ import xarray as xr from scipy.optimize import linear_sum_assignment - __all__ = ['RadarCellTracker', 'TrackingModule'] logger = logging.getLogger(__name__) From 28087552e8838e4da8d789af23200dfe890e7a05 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Wed, 6 May 2026 16:42:53 -0500 Subject: [PATCH 08/14] REF: Clean up --- .github/workflows/ci.yml | 3 - .../configuration/schemas/directories.py | 135 ------------------ src/adapt/modules/acquisition/module.py | 16 +-- src/adapt/modules/ingest/module.py | 5 +- 4 files changed, 7 insertions(+), 152 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 448b87c..00179c6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,9 +47,6 @@ jobs: - name: Install ruff run: pip install ruff - - name: Install import-linter - run: pip install import-linter - - name: Show environment info run: | python --version diff --git a/src/adapt/configuration/schemas/directories.py b/src/adapt/configuration/schemas/directories.py index 5359a08..8ff5d7b 100644 --- a/src/adapt/configuration/schemas/directories.py +++ b/src/adapt/configuration/schemas/directories.py @@ -64,107 +64,6 @@ def setup_output_directories(base_dir: str) -> dict[str, Path]: return dirs -def get_nexrad_path( - output_dirs: dict[str, Path], - radar: str, - filename: str, - scan_time: datetime -) -> Path: - """Get path for NEXRAD Level-II file. - - Pattern: base_dir/RADAR_ID/nexrad/YYYYMMDD/filename - - Parameters - ---------- - output_dirs : Dict[str, Path] - Output directories from setup_output_directories() - radar : str - Radar identifier (e.g., "KDIX") - filename : str - NEXRAD filename (e.g., "KDIX20240115_123045_V06") - scan_time : datetime - Scan timestamp for date-based organization - - Returns - ------- - Path - Full path to NEXRAD file - """ - date_str = scan_time.strftime("%Y%m%d") - base_dir = output_dirs["base"] / radar / "nexrad" / date_str - base_dir.mkdir(parents=True, exist_ok=True) - return base_dir / filename - - -def get_netcdf_path( - output_dirs: dict[str, Path], - radar: str, - filename: str, - scan_time: datetime -) -> Path: - """Get path for gridded NetCDF file. - - Pattern: base_dir/RADAR_ID/gridnc/YYYYMMDD/filename - - Parameters - ---------- - output_dirs : Dict[str, Path] - Output directories from setup_output_directories() - radar : str - Radar identifier (e.g., "KDIX") - filename : str - NetCDF filename (e.g., "grid_KDIX_20240115_123045.nc") - scan_time : datetime - Scan timestamp for date-based organization - - Returns - ------- - Path - Full path to NetCDF file - """ - date_str = scan_time.strftime("%Y%m%d") - base_dir = output_dirs["base"] / radar / "gridnc" / date_str - base_dir.mkdir(parents=True, exist_ok=True) - return base_dir / filename - - -def get_analysis_path( - output_dirs: dict[str, Path], - radar: str, - filename: str = None, - scan_time: datetime = None, - analysis_type: str = None -) -> Path: - """Get path for analysis output file (Parquet, DB, or NetCDF). - - Pattern: base_dir/RADAR_ID/analysis/YYYYMMDD/filename - - Parameters - ---------- - output_dirs : Dict[str, Path] - Output directories from setup_output_directories() - radar : str - Radar identifier (e.g., "KDIX") - filename : str, optional - Analysis filename (e.g., "cells_KDIX_123045.parquet") - scan_time : datetime, optional - Scan timestamp for date-based organization - analysis_type : str, optional - Type of analysis (for backward compatibility) - - Returns - ------- - Path - Full path to analysis file - """ - if scan_time: - date_str = scan_time.strftime("%Y%m%d") - base_dir = output_dirs["base"] / radar / "analysis" / date_str - else: - base_dir = output_dirs["base"] / radar / "analysis" - base_dir.mkdir(parents=True, exist_ok=True) - return base_dir / filename if filename else base_dir - def get_plot_path( output_dirs: dict[str, Path], @@ -212,42 +111,8 @@ def get_plot_path( return base_dir -def get_log_path( - output_dirs: dict[str, Path], - radar: str = None, - log_name: str = None -) -> Path: - """Get path for log file. - - Pattern: base_dir/logs/log_name - - Parameters - ---------- - output_dirs : Dict[str, Path] - Output directories from setup_output_directories() - radar : str, optional - Radar identifier (e.g., "KDIX") - included in log name if provided - log_name : str, optional - Log filename (e.g., "adapt_20240115.log") - - Returns - ------- - Path - Full path to log file - """ - if log_name: - return output_dirs["logs"] / log_name - elif radar: - return output_dirs["logs"] / f"pipeline_{radar}.log" - else: - return output_dirs["logs"] / "pipeline.log" - __all__ = [ 'setup_output_directories', - 'get_nexrad_path', - 'get_netcdf_path', - 'get_analysis_path', 'get_plot_path', - 'get_log_path', ] diff --git a/src/adapt/modules/acquisition/module.py b/src/adapt/modules/acquisition/module.py index e4f8cf1..815de27 100644 --- a/src/adapt/modules/acquisition/module.py +++ b/src/adapt/modules/acquisition/module.py @@ -15,8 +15,6 @@ from nexradaws import NexradAwsInterface -from adapt.utils.paths import get_nexrad_path - __all__ = ['AwsNexradDownloader'] logger = logging.getLogger(__name__) @@ -506,20 +504,14 @@ def _process_scans(self, scans: list) -> list: return new_downloads def _get_local_path(self, scan) -> Path: - """Get local path for scan using new structure: base/RADAR_ID/nexrad/YYYYMMDD/filename.""" + """Get local path for scan: base/RADAR_ID/nexrad/YYYYMMDD/filename.""" filename = Path(scan.key).name + date_str = scan.scan_time.strftime("%Y%m%d") - # Use new path function if output_dirs available if self.output_dirs: - return get_nexrad_path( - self.output_dirs, - self.radar, - filename, - scan_time=scan.scan_time - ) + return self.output_dirs["base"] / self.radar / "nexrad" / date_str / filename - # Legacy fallback: use output_dir directly with old structure - date_str = scan.scan_time.strftime("%Y%m%d") + # Legacy fallback return (self.output_dir / date_str / self.radar / filename).resolve() def _file_exists(self, path: Path) -> bool: diff --git a/src/adapt/modules/ingest/module.py b/src/adapt/modules/ingest/module.py index 3c36414..832f228 100644 --- a/src/adapt/modules/ingest/module.py +++ b/src/adapt/modules/ingest/module.py @@ -353,7 +353,6 @@ def load_and_regrid(self, filepath: Path | str, grid_kwargs: dict = None, from adapt.contracts import assert_gridded # noqa: E402 from adapt.execution.module_registry import registry # noqa: E402 from adapt.modules.base import BaseModule # noqa: E402 -from adapt.utils.paths import get_netcdf_path # noqa: E402 def _check_grid_ds_2d(ds): @@ -411,7 +410,9 @@ def run(self, context: dict) -> dict: except Exception: pass - nc_path = get_netcdf_path(output_dirs, radar, nc_filename, scan_time=scan_time) + date_str = scan_time.strftime("%Y%m%d") + base = output_dirs.get("base") + nc_path = base / radar / "gridnc" / date_str / nc_filename if base else None output_dir = str(nc_path.parent) if nc_path else None ds = self._loader.load_and_regrid( From afb5a5b0cfbfaf4fe360af6b9abcd7a4195eaad4 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Wed, 13 May 2026 00:10:35 -0500 Subject: [PATCH 09/14] ENH: README and CI for release --- .github/workflows/docs.yml | 2 +- .github/workflows/pypi-release.yml | 2 +- .github/workflows/virus.yml | 1 + README.md | 19 +++++++++++++++++-- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index d20010a..07c8689 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,4 +1,4 @@ -name: Build and Deploy Docs +name: Deploy Docs on: push: diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 7bc094e..d72cd1e 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -1,4 +1,4 @@ -name: Build and Upload Adapt Release to PyPI +name: Release to PyPI on: release: types: diff --git a/.github/workflows/virus.yml b/.github/workflows/virus.yml index 27c8de0..31563ce 100644 --- a/.github/workflows/virus.yml +++ b/.github/workflows/virus.yml @@ -1,3 +1,4 @@ +name: Virus Scan on: pull_request: types: [assigned, opened, synchronize, reopened, closed] diff --git a/README.md b/README.md index cc6e3b1..76603be 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,26 @@ # Adapt + [![CI](https://github.com/ARM-DOE/Adapt/actions/workflows/ci.yml/badge.svg)](https://github.com/ARM-DOE/Adapt/actions?query=workflow%3ACI) [![Codecov](https://img.shields.io/codecov/c/github/ARM-DOE/Adapt.svg?logo=codecov)](https://codecov.io/gh/ARM-DOE/Adapt) -[![Docs](https://img.shields.io/badge/docs-users-4088b8.svg)](https://arm-doe.github.io/Adapt/) -[![PyPI Downloads](https://img.shields.io/pypi/dm/arm-adapt.svg)](https://pypi.org/project/arm-adapt/) +[![CodeFactor](https://www.codefactor.io/repository/github/arm-doe/adapt/badge)](https://www.codefactor.io/repository/github/arm-doe/adapt) + + +[![Docs](https://github.com/ARM-DOE/Adapt/actions/workflows/docs.yml/badge.svg)](https://arm-doe.github.io/Adapt/) +[![PyPi release](https://github.com/ARM-DOE/Adapt/actions/workflows/pypi-release.yml/badge.svg)](https://arm-doe.github.io/Adapt/) +[![PyPI - Version](https://img.shields.io/pypi/v/arm-adapt)](https://pypi.org/project/arm-adapt/) +[![PyPI Downloads](https://static.pepy.tech/personalized-badge/arm-adapt?period=total&units=INTERNATIONAL_SYSTEM&left_color=BLACK&right_color=GREEN&left_text=downloads)](https://pypi.org/project/arm-adapt/) + +[![Security](https://github.com/ARM-DOE/Adapt/actions/workflows/security-analysis.yml/badge.svg)](https://arm-doe.github.io/Adapt/) +[![Virus](https://github.com/ARM-DOE/Adapt/actions/workflows/virus.yml/badge.svg)](https://arm-doe.github.io/Adapt/) + + +[![PyPI - License](https://img.shields.io/pypi/l/arm-adapt)](https://github.com/ARM-DOE/Adapt?tab=License-1-ov-file#) [![ARM](https://img.shields.io/badge/Sponsor-ARM-blue.svg?colorA=00c1de&colorB=00539c)](https://www.arm.gov/) + + + **Real-time processing for informed adaptive scanning of ARM weather radar operations and field campaigns.** `Adapt` is a configuration-driven modular framework for near real-time analysis of convective systems designed to support the adaptive sampling and study of convective storms and their life cycles. The system implements a modular pipeline that ingests radar observations, performs gridding and segmentation to identify convective cells, and maintains their identity through time using tracking. It further derives cell-level properties and motion to characterize storm evolution and generate candidate targets for adaptive radar scanning. From 9c1755fa56dffd6b85a8ca33fee86e8f6f486742 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Wed, 13 May 2026 17:48:39 -0500 Subject: [PATCH 10/14] FIX:(ruff) pin adapt as first-party for isort --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 365a675..de09d81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ precision = 2 [tool.ruff] line-length = 100 target-version = "py311" +src = ["src"] [tool.ruff.lint] select = [ @@ -115,3 +116,6 @@ select = [ "UP", # modern Python upgrades "SIM", # simplifications (often removes subtle mistakes) ] + +[tool.ruff.lint.isort] +known-first-party = ["adapt"] From 0d44a0b59e40b379c85168b08be2c601ba2dae62 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Wed, 13 May 2026 17:56:45 -0500 Subject: [PATCH 11/14] FIX:(CI.yml) install import-linter --- .github/workflows/ci.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 437ac54..147e805 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,6 +52,9 @@ jobs: - name: Install ruff run: pip install ruff + - name: Install import-linter + run: pip install import-linter + - name: Show environment info run: | python --version From 1c8093636bd4b85b0debbee059a4355973bdeb3f Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Wed, 13 May 2026 18:01:41 -0500 Subject: [PATCH 12/14] ADD:wq(CI) .importlinter --- .gitignore | 1 + .importlinter | 156 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 .importlinter diff --git a/.gitignore b/.gitignore index e878f22..df64284 100644 --- a/.gitignore +++ b/.gitignore @@ -99,4 +99,5 @@ data/ # ignore hidden files but not gitignore .* !.gitignore +!.importlinter __pycache__/ diff --git a/.importlinter b/.importlinter new file mode 100644 index 0000000..19fbba7 --- /dev/null +++ b/.importlinter @@ -0,0 +1,156 @@ +# Import Linter configuration for ADAPT +# +# Purpose: +# Enforce architectural boundaries so the codebase stays modular, +# testable, and free of accidental coupling. +# +# Run: +# lint-imports +# +# Philosophy: +# - Modules are scientific units and must remain independent. +# - Shared behaviour belongs in contracts, runtime, or persistence. +# - Prevent cross-layer imports early, before they calcify. + +[importlinter] +root_packages = + adapt + +include_external_packages = False + +# ========================================================== +# 1. Scientific modules must remain independent +# ========================================================== +# +# No module may import from any other module — directly or +# transitively. Shared types belong in adapt.contracts. + +[importlinter:contract:independent_modules] +name = Adapt modules remain independent +type = independence +modules = + adapt.modules.acquisition + adapt.modules.analysis + adapt.modules.detection + adapt.modules.ingest + adapt.modules.projection + adapt.modules.tracking + +# ========================================================== +# 2. Modules do not depend on runtime orchestration +# ========================================================== +# +# Runtime coordinates modules; modules must not call back into it. +# Science is stateless. Orchestration is runtime's job. + +[importlinter:contract:modules_do_not_import_runtime] +name = Modules do not depend on runtime +type = forbidden +source_modules = + adapt.modules +forbidden_modules = + adapt.runtime + +# ========================================================== +# 3. Modules do not reach into infrastructure layers +# ========================================================== +# +# Modules produce data. Infrastructure stores and serves it. +# A module that imports persistence or GUI has too many responsibilities. +# +# Includes configuration: modules must receive their settings +# via dependency injection from the execution layer — not read +# configuration themselves. + +[importlinter:contract:modules_do_not_import_infrastructure] +name = Modules do not import infrastructure layers +type = forbidden +source_modules = + adapt.modules +forbidden_modules = + adapt.persistence + adapt.api + adapt.gui + adapt.visualization + adapt.configuration + +# ========================================================== +# 4. Persistence is infrastructure — no science, no orchestration +# ========================================================== +# +# Persistence reads and writes. It must not depend on modules +# (would create a circular science ↔ storage loop) or runtime +# (would couple storage to one execution strategy). + +[importlinter:contract:persistence_is_infra] +name = Persistence does not depend on modules or runtime +type = forbidden +source_modules = + adapt.persistence +forbidden_modules = + adapt.modules + adapt.runtime + +# ========================================================== +# 5. Nothing inside core depends on the CLI +# ========================================================== +# +# CLI is the outermost shell. It may import anything. +# Core packages must not import it — that would make them +# impossible to use as a library. + +[importlinter:contract:no_internal_cli_dependency] +name = Internal packages do not depend on CLI +type = forbidden +source_modules = + adapt.modules + adapt.runtime + adapt.persistence + adapt.execution +forbidden_modules = + adapt.cli + +# ========================================================== +# 6. Core must not depend on optional extensions +# ========================================================== +# +# Extensions add capability without modifying core. +# If core imports extensions, extensions can no longer be +# optional and the plug-in model breaks. + +[importlinter:contract:core_not_depend_on_extensions] +name = Core packages do not depend on extensions +type = forbidden +source_modules = + adapt.modules + adapt.runtime + adapt.persistence + adapt.execution + adapt.configuration +forbidden_modules = + adapt.extensions + +# ========================================================== +# 7. Contracts package imports no adapt internals +# ========================================================== +# +# adapt.contracts is the only layer both modules and execution +# may share. It must depend on nothing inside adapt so it +# never creates import cycles. + +[importlinter:contract:contracts_are_pure] +name = Contracts package imports no adapt internals +type = forbidden +source_modules = + adapt.contracts +forbidden_modules = + adapt.modules + adapt.runtime + adapt.persistence + adapt.configuration + adapt.execution + adapt.api + adapt.gui + adapt.visualization + adapt.cli + adapt.extensions From ce6d12a872c0f702846fb2f8dac86a609f2f387f Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Wed, 13 May 2026 18:08:01 -0500 Subject: [PATCH 13/14] REF: centralized contracts --- src/adapt/contracts/__init__.py | 27 ++++++++ src/adapt/contracts/analysis.py | 96 +++++++++++++++++++++++++++++ src/adapt/contracts/grid.py | 40 ++++++++++++ src/adapt/contracts/pipeline.py | 38 ++++++++++++ src/adapt/contracts/projection.py | 48 +++++++++++++++ src/adapt/contracts/segmentation.py | 50 +++++++++++++++ src/adapt/contracts/tracking.py | 94 ++++++++++++++++++++++++++++ 7 files changed, 393 insertions(+) create mode 100644 src/adapt/contracts/__init__.py create mode 100644 src/adapt/contracts/analysis.py create mode 100644 src/adapt/contracts/grid.py create mode 100644 src/adapt/contracts/pipeline.py create mode 100644 src/adapt/contracts/projection.py create mode 100644 src/adapt/contracts/segmentation.py create mode 100644 src/adapt/contracts/tracking.py diff --git a/src/adapt/contracts/__init__.py b/src/adapt/contracts/__init__.py new file mode 100644 index 0000000..04aa1ba --- /dev/null +++ b/src/adapt/contracts/__init__.py @@ -0,0 +1,27 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Central contract definitions for the ADAPT pipeline. + +All pipeline stage validators and the ContractViolation exception live here. +Import from this package — never from individual contract submodules. +""" + +from adapt.contracts.analysis import assert_analysis_output, assert_cell_adjacency +from adapt.contracts.grid import assert_gridded +from adapt.contracts.pipeline import ContractViolation, require +from adapt.contracts.projection import assert_projected +from adapt.contracts.segmentation import assert_segmented +from adapt.contracts.tracking import assert_cell_events, assert_tracked_cells + +__all__ = [ + "ContractViolation", + "require", + "assert_gridded", + "assert_segmented", + "assert_projected", + "assert_analysis_output", + "assert_cell_adjacency", + "assert_tracked_cells", + "assert_cell_events", +] diff --git a/src/adapt/contracts/analysis.py b/src/adapt/contracts/analysis.py new file mode 100644 index 0000000..e50c1f2 --- /dev/null +++ b/src/adapt/contracts/analysis.py @@ -0,0 +1,96 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Analysis stage contracts. + +Enforces structural requirements on cell statistics and adjacency DataFrames. +Scientific correctness of the values is the module's responsibility, not checked here. +""" + +import pandas as pd + +from adapt.contracts.pipeline import require + +_REQUIRED_STATS_COLS = [ + "cell_label", + "cell_area_sqkm", + "time", + "time_volume_start", + "cell_centroid_mass_lat", + "cell_centroid_mass_lon", + "radar_reflectivity_max", + "radar_differential_reflectivity_max", + "area_40dbz_km2", +] + +_REQUIRED_ADJACENCY_COLS = [ + "time", + "cell_label_a", + "cell_label_b", + "touching_boundary_pixels", +] + + +def assert_analysis_output(df: pd.DataFrame, min_expected_rows: int = 0) -> None: + """Enforce analysis stage contract. + + Parameters + ---------- + df : pd.DataFrame + Output from analyzer.extract() + min_expected_rows : int, optional + Minimum number of rows expected (default 0, allows no-cell frames) + + Raises + ------ + ContractViolation + If structural requirements are violated + """ + require( + isinstance(df, pd.DataFrame), + f"Analysis contract violated: output is {type(df)}, expected DataFrame", + ) + for col in _REQUIRED_STATS_COLS: + require(col in df.columns, f"Analysis contract violated: missing required column '{col}'") + if len(df) > 0: + require( + (df["cell_label"] > 0).all(), + "Analysis contract violated: cell_label must be > 0 for all rows", + ) + require( + len(df) >= min_expected_rows, + f"Analysis contract violated: got {len(df)} cells, expected >= {min_expected_rows}", + ) + + +def assert_cell_adjacency(df: pd.DataFrame) -> None: + """Enforce cell adjacency contract. + + Raises + ------ + ContractViolation + If structural requirements are violated + """ + require( + isinstance(df, pd.DataFrame), + f"Cell adjacency contract violated: output is {type(df)}, expected DataFrame", + ) + for col in _REQUIRED_ADJACENCY_COLS: + require( + col in df.columns, + f"Cell adjacency contract violated: missing required column '{col}'", + ) + if len(df) == 0: + return + require( + (df["cell_label_a"] > 0).all() and (df["cell_label_b"] > 0).all(), + "Cell adjacency contract violated: cell labels must be > 0", + ) + require( + (df["cell_label_a"] < df["cell_label_b"]).all(), + "Cell adjacency contract violated: expected canonical ordering cell_label_a < cell_label_b", + ) + require( + (df["touching_boundary_pixels"] >= 1).all(), + "Cell adjacency contract violated: touching_boundary_pixels must be >= 1", + ) diff --git a/src/adapt/contracts/grid.py b/src/adapt/contracts/grid.py new file mode 100644 index 0000000..39ab496 --- /dev/null +++ b/src/adapt/contracts/grid.py @@ -0,0 +1,40 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Grid stage contract. + +Enforces that after regridding, the dataset is a valid 2D Cartesian grid +suitable for downstream segmentation and projection. +""" + +import xarray as xr + +from adapt.contracts.pipeline import require + + +def assert_gridded(ds: xr.Dataset, reflectivity_var: str) -> None: + """Enforce grid stage contract. + + Parameters + ---------- + ds : xr.Dataset + Dataset from loader.load_and_regrid() + reflectivity_var : str + Name of reflectivity variable (from config) + + Raises + ------ + ContractViolation + If any invariant is violated + """ + require("x" in ds.coords, "Grid contract violated: missing 'x' coordinate") + require("y" in ds.coords, "Grid contract violated: missing 'y' coordinate") + require( + reflectivity_var in ds.data_vars, + f"Grid contract violated: missing '{reflectivity_var}' variable", + ) + refl = ds[reflectivity_var] + require( + refl.ndim == 2, + f"Grid contract violated: '{reflectivity_var}' has {refl.ndim} dims, expected 2", + ) diff --git a/src/adapt/contracts/pipeline.py b/src/adapt/contracts/pipeline.py new file mode 100644 index 0000000..ab6a86b --- /dev/null +++ b/src/adapt/contracts/pipeline.py @@ -0,0 +1,38 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Core contract enforcement primitives. + +ContractViolation and require are the only two names every contract +function depends on. This module has zero adapt imports so contracts +can be imported from anywhere without creating import cycles. +""" + + +class ContractViolation(RuntimeError): + """Raised when a pipeline contract is violated. + + This indicates a bug in pipeline logic, not bad user input or recoverable + science edge cases. It means a pipeline stage did not produce the invariants + it promised. + + Key distinction: + - ValueError: User/config error (handled by Pydantic) + - ContractViolation: Pipeline bug (programmer error) + - Exception: Recoverable science issues (try/except in algorithms) + """ + pass + + +def require(condition: bool, message: str) -> None: + """Enforce a pipeline contract. + + Fail-fast: no recovery, no fallback, no silence. + + Raises + ------ + ContractViolation + If condition is False. + """ + if not condition: + raise ContractViolation(message) diff --git a/src/adapt/contracts/projection.py b/src/adapt/contracts/projection.py new file mode 100644 index 0000000..625d591 --- /dev/null +++ b/src/adapt/contracts/projection.py @@ -0,0 +1,48 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Projection stage contract. + +Enforces that after optical flow computation, motion vectors and optional +projection arrays are present and structurally valid. +""" + +import xarray as xr + +from adapt.contracts.pipeline import require + + +def assert_projected(ds: xr.Dataset, max_steps: int = 5) -> None: + """Enforce projection stage contract. + + Parameters + ---------- + ds : xr.Dataset + Dataset from projector.project() + max_steps : int, optional + Maximum number of projection steps (default 5). If dataset has + 'max_projection_steps' in attrs, that value is used instead. + + Raises + ------ + ContractViolation + If any invariant is violated + """ + require("heading_x" in ds.data_vars, "Projection contract violated: missing 'heading_x' ") + require("heading_y" in ds.data_vars, "Projection contract violated: missing 'heading_y' ") + + if "cell_projections" in ds.data_vars: + projections = ds["cell_projections"] + require( + projections.ndim == 3, + f"Projection contract violated: 'cell_projections' has {projections.ndim} dims, " + "expected 3 (step, y, x)", + ) + max_steps_actual = ds.attrs.get("max_projection_steps", max_steps) + num_steps = projections.shape[0] + expected_steps = max_steps_actual + 1 + require( + num_steps == expected_steps, + f"Projection contract violated: found {num_steps} steps, expected {expected_steps} " + f"(1 registration + {max_steps_actual} projections from config)", + ) diff --git a/src/adapt/contracts/segmentation.py b/src/adapt/contracts/segmentation.py new file mode 100644 index 0000000..f09d560 --- /dev/null +++ b/src/adapt/contracts/segmentation.py @@ -0,0 +1,50 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Segmentation stage contract. + +Enforces that after cell detection, labels are present, integer-typed, +non-negative, and 2D. +""" + +import numpy as np +import xarray as xr + +from adapt.contracts.pipeline import require + + +def assert_segmented(ds: xr.Dataset, labels_name: str) -> None: + """Enforce segmentation stage contract. + + Parameters + ---------- + ds : xr.Dataset + Dataset from segmenter.segment() + labels_name : str + Name of cell labels variable (from config) + + Raises + ------ + ContractViolation + If any invariant is violated + """ + require( + labels_name in ds.data_vars, + f"Segmentation contract violated: '{labels_name}' not found", + ) + labels = ds[labels_name] + require( + labels.dtype.kind in {"i", "u"}, + f"Segmentation contract violated: '{labels_name}' dtype is {labels.dtype}, " + "expected integer", + ) + label_vals = labels.values + require( + np.min(label_vals) >= 0, + "Segmentation contract violated: labels contain negative values " + f"(min={np.min(label_vals)})", + ) + require( + labels.ndim == 2, + f"Segmentation contract violated: '{labels_name}' has {labels.ndim} dims, expected 2", + ) diff --git a/src/adapt/contracts/tracking.py b/src/adapt/contracts/tracking.py new file mode 100644 index 0000000..c81b51f --- /dev/null +++ b/src/adapt/contracts/tracking.py @@ -0,0 +1,94 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Tracking stage contracts. + +Enforces structural requirements on tracked_cells and cell_events DataFrames. +""" + +from __future__ import annotations + +import pandas as pd + +from adapt.contracts.pipeline import require + +_REQUIRED_TRACKED_COLS = [ + "time", + "cell_label", + "cell_uid", + "area", + "centroid_x", + "centroid_y", + "mean_reflectivity", + "max_reflectivity", + "core_area", +] + +_REQUIRED_EVENTS_COLS = [ + "time", + "event_type", + "source_cell_uid", + "target_cell_uid", + "source_cell_label", + "target_cell_label", + "cost", + "is_dominant", + "event_group_id", +] + +_VALID_EVENT_TYPES = {"CONTINUE", "SPLIT", "MERGE", "INITIATION", "TERMINATION"} + + +def assert_tracked_cells(df: pd.DataFrame) -> None: + """Enforce tracked cells contract. + + Raises + ------ + ContractViolation + If structural requirements are violated + """ + require( + isinstance(df, pd.DataFrame), + f"Tracked cells contract violated: output is {type(df)}, expected DataFrame", + ) + for col in _REQUIRED_TRACKED_COLS: + require( + col in df.columns, + f"Tracked cells contract violated: missing required column '{col}'", + ) + if len(df) == 0: + return + require( + (df["cell_label"] > 0).all(), + "Tracked cells contract violated: cell_label must be > 0 for all rows", + ) + require( + "cell_uid" in df.columns and df["cell_uid"].notna().all(), + "Tracked cells contract violated: cell_uid must be non-null for all rows", + ) + + +def assert_cell_events(df: pd.DataFrame) -> None: + """Enforce cell events contract. + + Raises + ------ + ContractViolation + If structural requirements are violated + """ + require( + isinstance(df, pd.DataFrame), + f"Cell events contract violated: output is {type(df)}, expected DataFrame", + ) + for col in _REQUIRED_EVENTS_COLS: + require( + col in df.columns, + f"Cell events contract violated: missing required column '{col}'", + ) + if len(df) == 0: + return + require( + df["event_type"].isin(_VALID_EVENT_TYPES).all(), + f"Cell events contract violated: invalid event_type present " + f"(valid={sorted(_VALID_EVENT_TYPES)})", + ) From ba3342f6c3ff4ad784ae675b7f6c9c3e8e94370b Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Wed, 13 May 2026 18:12:13 -0500 Subject: [PATCH 14/14] ADD:every modules ahs it's own cofig object --- .../configuration/schemas/materialization.py | 165 ++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 src/adapt/configuration/schemas/materialization.py diff --git a/src/adapt/configuration/schemas/materialization.py b/src/adapt/configuration/schemas/materialization.py new file mode 100644 index 0000000..b23f95f --- /dev/null +++ b/src/adapt/configuration/schemas/materialization.py @@ -0,0 +1,165 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Per-module config materialization. + +Slices the frozen InternalConfig into one lightweight frozen dataclass per +pipeline module. Called once at processor startup; the resulting objects are +injected into executor contexts under module-specific keys. + +Shared fields (global_, cross-module references) are copied by value so each +module config is self-contained and independent of all others. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from adapt.configuration.schemas.internal import InternalConfig + + +@dataclass(frozen=True) +class IngestModuleConfig: + file_format: str + grid_shape: tuple + grid_limits: tuple + roi_func: str + min_radius: float + weighting_function: str + save_netcdf: bool + radar: str + z_level: float + z_coord: str + time_coord: str + + +@dataclass(frozen=True) +class DetectionModuleConfig: + method: str + threshold: float + closing_kernel: tuple + filter_by_size: bool + min_cellsize_gridpoint: int + max_cellsize_gridpoint: int | None + h_maxima: float + reflectivity_var: str + labels_var: str + z_level: float + + +@dataclass(frozen=True) +class ProjectionModuleConfig: + method: str + nan_fill_value: float + max_time_interval_minutes: int + max_projection_steps: int + pyr_scale: float + levels: int + winsize: int + iterations: int + poly_n: int + poly_sigma: float + flags: int + min_motion_threshold: float + max_flow_magnitude: float + reflectivity_var: str + + +@dataclass(frozen=True) +class AnalysisModuleConfig: + radar_variables: tuple + exclude_fields: tuple + adjacency_min_touching: int + max_projection_steps: int + reflectivity_var: str + labels_var: str + z_level: float + + +@dataclass(frozen=True) +class TrackingModuleConfig: + match_cost: float + keep_cost: float + unmatch_cost: float + split_overlap: float + core_reflectivity_threshold: float + uid_time_step_s: int + uid_latlon_step_deg: float + uid_area_step_km2: float + uid_width: int + reflectivity_var: str + labels_var: str + + +def materialize_module_configs(cfg: InternalConfig) -> dict: + """Slice InternalConfig into one frozen config per module. + + Returns a dict keyed by the context key each module declares in + its ``inputs`` list. Shared values (global_, cross-module) are copied + by value — no module config holds a reference to another. + """ + return { + "ingest_config": IngestModuleConfig( + file_format=cfg.reader.file_format, + grid_shape=cfg.regridder.grid_shape, + grid_limits=cfg.regridder.grid_limits, + roi_func=cfg.regridder.roi_func, + min_radius=cfg.regridder.min_radius, + weighting_function=cfg.regridder.weighting_function, + save_netcdf=cfg.regridder.save_netcdf, + radar=cfg.downloader.radar, + z_level=cfg.global_.z_level, + z_coord=cfg.global_.coord_names.z, + time_coord=cfg.global_.coord_names.time, + ), + "detection_config": DetectionModuleConfig( + method=cfg.segmenter.method, + threshold=cfg.segmenter.threshold, + closing_kernel=cfg.segmenter.closing_kernel, + filter_by_size=cfg.segmenter.filter_by_size, + min_cellsize_gridpoint=cfg.segmenter.min_cellsize_gridpoint, + max_cellsize_gridpoint=cfg.segmenter.max_cellsize_gridpoint, + h_maxima=cfg.segmenter.h_maxima, + reflectivity_var=cfg.global_.var_names.reflectivity, + labels_var=cfg.global_.var_names.cell_labels, + z_level=cfg.global_.z_level, + ), + "projection_config": ProjectionModuleConfig( + method=cfg.projector.method, + nan_fill_value=cfg.projector.nan_fill_value, + max_time_interval_minutes=cfg.projector.max_time_interval_minutes, + max_projection_steps=cfg.projector.max_projection_steps, + pyr_scale=cfg.projector.flow_params.pyr_scale, + levels=cfg.projector.flow_params.levels, + winsize=cfg.projector.flow_params.winsize, + iterations=cfg.projector.flow_params.iterations, + poly_n=cfg.projector.flow_params.poly_n, + poly_sigma=cfg.projector.flow_params.poly_sigma, + flags=cfg.projector.flow_params.flags, + min_motion_threshold=cfg.projector.min_motion_threshold, + max_flow_magnitude=cfg.projector.max_flow_magnitude, + reflectivity_var=cfg.global_.var_names.reflectivity, + ), + "analysis_config": AnalysisModuleConfig( + radar_variables=tuple(cfg.analyzer.radar_variables), + exclude_fields=tuple(cfg.analyzer.exclude_fields), + adjacency_min_touching=cfg.analyzer.adjacency_min_touching_boundary_pixels, + max_projection_steps=cfg.projector.max_projection_steps, + reflectivity_var=cfg.global_.var_names.reflectivity, + labels_var=cfg.global_.var_names.cell_labels, + z_level=cfg.global_.z_level, + ), + "tracking_config": TrackingModuleConfig( + match_cost=cfg.tracker.match_cost_threshold, + keep_cost=cfg.tracker.keep_cost_threshold, + unmatch_cost=cfg.tracker.unmatch_cost_threshold, + split_overlap=cfg.tracker.split_overlap_threshold, + core_reflectivity_threshold=cfg.tracker.core_reflectivity_threshold, + uid_time_step_s=cfg.tracker.cell_uid.time_step_s, + uid_latlon_step_deg=cfg.tracker.cell_uid.latlon_step_deg, + uid_area_step_km2=cfg.tracker.cell_uid.area_step_km2, + uid_width=cfg.tracker.cell_uid.width, + reflectivity_var=cfg.global_.var_names.reflectivity, + labels_var=cfg.global_.var_names.cell_labels, + ), + }