From cd3bd6b4b192a41bcdc2e3e2bd03979e0851bdff Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 14:26:30 -0500 Subject: [PATCH 01/14] ADD:(tests) all infrastructure and module tests all test infrastructure (conftest, module tests, graph tests, runtime tests, etc.) --- tests/__init__.py | 0 tests/cli/test_config_command.py | 49 ++ tests/conftest.py | 175 +++++++ tests/execution/__init__.py | 0 tests/execution/test_module_registry.py | 239 +++++++++ tests/graph/__init__.py | 0 tests/graph/test_graph.py | 476 ++++++++++++++++++ tests/helpers/__init__.py | 0 tests/helpers/fake_grid.py | 59 +++ tests/helpers/fake_netcdf.py | 40 ++ tests/modules/__init__.py | 0 tests/modules/conftest.py | 294 +++++++++++ tests/runtime/__init__.py | 0 tests/runtime/conftest.py | 80 +++ tests/runtime/test_file_tracker.py | 158 ++++++ tests/runtime/test_orchestrator.py | 109 ++++ .../test_orchestrator_historical_shutdown.py | 94 ++++ tests/runtime/test_processor_core.py | 111 ++++ tests/runtime/test_processor_failures.py | 150 ++++++ .../runtime/test_processor_with_fake_grid.py | 67 +++ tests/test_architecture.py | 120 +++++ 21 files changed, 2221 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/cli/test_config_command.py create mode 100644 tests/conftest.py create mode 100644 tests/execution/__init__.py create mode 100644 tests/execution/test_module_registry.py create mode 100644 tests/graph/__init__.py create mode 100644 tests/graph/test_graph.py create mode 100644 tests/helpers/__init__.py create mode 100644 tests/helpers/fake_grid.py create mode 100644 tests/helpers/fake_netcdf.py create mode 100644 tests/modules/__init__.py create mode 100644 tests/modules/conftest.py create mode 100644 tests/runtime/__init__.py create mode 100644 tests/runtime/conftest.py create mode 100644 tests/runtime/test_file_tracker.py create mode 100644 tests/runtime/test_orchestrator.py create mode 100644 tests/runtime/test_orchestrator_historical_shutdown.py create mode 100644 tests/runtime/test_processor_core.py create mode 100644 tests/runtime/test_processor_failures.py create mode 100644 tests/runtime/test_processor_with_fake_grid.py create mode 100644 tests/test_architecture.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/cli/test_config_command.py b/tests/cli/test_config_command.py new file mode 100644 index 0000000..db6678b --- /dev/null +++ b/tests/cli/test_config_command.py @@ -0,0 +1,49 @@ +import os +import shutil +from argparse import Namespace +from pathlib import Path + +from adapt.cli import _config_cmd + + +def test_adapt_config_handles_deleted_cwd(tmp_path, monkeypatch): + # Create and chdir into a temp directory, then delete it to simulate stale cwd. + cwd = tmp_path / "gone" + cwd.mkdir() + os.chdir(cwd) + shutil.rmtree(cwd) + + home = tmp_path / "home" + home.mkdir() + monkeypatch.setenv("HOME", str(home)) + + # No output arg: must fail loudly (cannot resolve ./config.yaml). + args = Namespace(output=None) + try: + _config_cmd(args) + except FileNotFoundError as e: + assert "Current working directory no longer exists" in str(e) + else: + raise AssertionError( + "Expected FileNotFoundError when cwd is missing and no output is provided" + ) + + # Absolute output path should still work even when cwd is missing. + os.chdir(home) + out = Path(home) / "config.yaml" + args2 = Namespace(output=str(out)) + _config_cmd(args2) + assert out.exists() + text = out.read_text() + assert f'base_dir: "{str(home)}"' in text + + +def test_adapt_config_sets_base_dir_to_output_parent(tmp_path): + out_dir = tmp_path / "nested" + out_path = out_dir / "my_config.yaml" + args = Namespace(output=str(out_path)) + _config_cmd(args) + + assert out_path.exists() + text = out_path.read_text() + assert f'base_dir: "{str(out_dir)}"' in text diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ba3e114 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,175 @@ +"""Root-level pytest fixtures for Adapt test suite. + +Provides shared configuration fixtures following Pydantic-based architecture. +All tests must use these fixtures instead of creating raw dict configs. +""" + +import shutil +import tempfile +from pathlib import Path + +import pytest + +from adapt.configuration.schemas.materialization import materialize_module_configs +from adapt.configuration.schemas.resolve import resolve_config +from adapt.configuration.schemas.user import UserConfig + +# ============================================================================= +# Configuration Fixtures (Pydantic-based) +# ============================================================================= + +@pytest.fixture +def param_config(): + """Expert configuration with all defaults. + + Use this as the base for all test configs. Override specific values + using user_config or by creating custom UserConfig instances. + """ + # For tests, provide a default radar_id since it's required at runtime + from adapt.configuration.schemas.param import ParamConfig as PC + + config = PC() + # Override radar with a test default (field name is 'radar', not 'radar_id') + config.downloader.radar = "TEST_RADAR" + return config + + +@pytest.fixture +def internal_config(param_config, temp_dir): + """Fully validated runtime configuration (no overrides). + + Use this when tests don't care about specific config values and just + need a valid InternalConfig to pass to constructors. + + Examples + -------- + >>> def test_segmenter_init(internal_config): + ... seg = RadarCellSegmenter(internal_config) + ... assert seg.method == "threshold" + """ + user = UserConfig(base_dir=str(temp_dir)) + return resolve_config(param_config, user, None) + + +@pytest.fixture +def make_config(param_config, temp_dir): + """Factory fixture for creating custom test configs. + + Use this when you need to override specific values for a test. + Returns a callable that accepts UserConfig-compatible kwargs. + + Examples + -------- + >>> def test_custom_threshold(make_config): + ... config = make_config(threshold=35) + ... seg = RadarCellSegmenter(config) + ... assert seg.threshold == 35.0 + """ + def _make(**user_overrides): + """Create InternalConfig with user overrides.""" + # Ensure base_dir is always present in tests + if "base_dir" not in user_overrides: + user_overrides["base_dir"] = str(temp_dir) + + user = UserConfig(**user_overrides) + return resolve_config(param_config, user, None) + + return _make + + +# ============================================================================= +# Directory Fixtures +# ============================================================================= + +@pytest.fixture +def temp_dir(): + """Temporary directory that is cleaned up after test.""" + d = tempfile.mkdtemp() + yield Path(d) + shutil.rmtree(d, ignore_errors=True) + + +# ============================================================================= +# Per-Module Config Fixtures +# ============================================================================= + +@pytest.fixture +def detection_module_config(internal_config): + return materialize_module_configs(internal_config)["detection_config"] + + +@pytest.fixture +def analysis_module_config(internal_config): + return materialize_module_configs(internal_config)["analysis_config"] + + +@pytest.fixture +def projection_module_config(internal_config): + return materialize_module_configs(internal_config)["projection_config"] + + +@pytest.fixture +def tracking_module_config(internal_config): + return materialize_module_configs(internal_config)["tracking_config"] + + +@pytest.fixture +def ingest_module_config(internal_config): + return materialize_module_configs(internal_config)["ingest_config"] + + +@pytest.fixture +def make_detection_config(make_config): + def _make(**kw): + return materialize_module_configs(make_config(**kw))["detection_config"] + return _make + + +@pytest.fixture +def make_analysis_config(make_config): + def _make(**kw): + return materialize_module_configs(make_config(**kw))["analysis_config"] + return _make + + +@pytest.fixture +def make_projection_config(make_config): + def _make(**kw): + return materialize_module_configs(make_config(**kw))["projection_config"] + return _make + + +@pytest.fixture +def make_tracking_config(make_config): + def _make(**kw): + return materialize_module_configs(make_config(**kw))["tracking_config"] + return _make + + +@pytest.fixture +def make_ingest_config(make_config): + def _make(**kw): + return materialize_module_configs(make_config(**kw))["ingest_config"] + return _make + + +@pytest.fixture +def output_dirs(temp_dir): + """Standard Adapt output directory structure. + + Returns dict with keys: nexrad, gridnc, analysis, plots, logs + All directories are created and cleaned up automatically. + """ + dirs = { + "nexrad": temp_dir / "nexrad", + "gridded": temp_dir / "gridded", + "gridnc": temp_dir / "gridnc", # Alias for gridded + "analysis": temp_dir / "analysis", + "plots": temp_dir / "plots", + "logs": temp_dir / "logs", + } + + for d in dirs.values(): + d.mkdir(parents=True, exist_ok=True) + + return dirs diff --git a/tests/execution/__init__.py b/tests/execution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/execution/test_module_registry.py b/tests/execution/test_module_registry.py new file mode 100644 index 0000000..c6576a8 --- /dev/null +++ b/tests/execution/test_module_registry.py @@ -0,0 +1,239 @@ +"""Tests for ModuleRegistry. + +Unit tests only — no IO, no radar data. +All modules are lightweight stubs inheriting BaseModule. +""" + +import pytest + +from adapt.execution.module_registry import ModuleRegistry +from adapt.modules.base import BaseModule + +# --------------------------------------------------------------------------- +# Stub modules for testing +# --------------------------------------------------------------------------- + +class StubA(BaseModule): + name = "stub_a" + inputs = [] + outputs = ["a_out"] + + def run(self, context): + return {"a_out": "a_result"} + + +class StubB(BaseModule): + name = "stub_b" + inputs = ["a_out"] + outputs = ["b_out"] + + def run(self, context): + return {"b_out": "b_result"} + + +class StubC(BaseModule): + name = "stub_c" + inputs = ["b_out"] + outputs = ["c_out"] + + def run(self, context): + return {"c_out": "c_result"} + + +class EmptyNameModule(BaseModule): + name = "" + inputs = [] + outputs = [] + + def run(self, context): + return {} + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def reg(): + """Fresh registry per test — isolated from the global singleton.""" + return ModuleRegistry() + + +# --------------------------------------------------------------------------- +# Registration tests +# --------------------------------------------------------------------------- + +class TestRegistration: + + @pytest.mark.unit + def test_register_single_module(self, reg): + reg.register(StubA) + assert "stub_a" in reg + + @pytest.mark.unit + def test_register_multiple_modules(self, reg): + reg.register(StubA) + reg.register(StubB) + assert "stub_a" in reg + assert "stub_b" in reg + + @pytest.mark.unit + def test_list_modules_returns_names(self, reg): + reg.register(StubA) + reg.register(StubB) + names = reg.list_modules() + assert "stub_a" in names + assert "stub_b" in names + assert len(names) == 2 + + @pytest.mark.unit + def test_len_reflects_registered_count(self, reg): + assert len(reg) == 0 + reg.register(StubA) + assert len(reg) == 1 + reg.register(StubB) + assert len(reg) == 2 + + @pytest.mark.unit + def test_register_empty_name_raises(self, reg): + with pytest.raises(ValueError, match="empty name"): + reg.register(EmptyNameModule) + + @pytest.mark.unit + def test_duplicate_registration_raises(self, reg): + reg.register(StubA) + + class StubADuplicate(BaseModule): + name = "stub_a" # same name as StubA + inputs = [] + outputs = ["different_out"] + + def run(self, context): + return {} + + with pytest.raises(RuntimeError, match="stub_a"): + reg.register(StubADuplicate) + + @pytest.mark.unit + def test_unregister_removes_module(self, reg): + reg.register(StubA) + reg.unregister("stub_a") + assert "stub_a" not in reg + + @pytest.mark.unit + def test_unregister_nonexistent_is_noop(self, reg): + reg.unregister("does_not_exist") # should not raise + + @pytest.mark.unit + def test_clear_removes_all(self, reg): + reg.register(StubA) + reg.register(StubB) + reg.clear() + assert len(reg) == 0 + + +# --------------------------------------------------------------------------- +# Retrieval tests +# --------------------------------------------------------------------------- + +class TestRetrieval: + + @pytest.mark.unit + def test_get_returns_class(self, reg): + reg.register(StubA) + cls = reg.get("stub_a") + assert cls is StubA + + @pytest.mark.unit + def test_get_unknown_raises(self, reg): + with pytest.raises(KeyError, match="not registered"): + reg.get("unknown") + + @pytest.mark.unit + def test_contains_true_after_register(self, reg): + reg.register(StubA) + assert "stub_a" in reg + + @pytest.mark.unit + def test_contains_false_before_register(self, reg): + assert "stub_a" not in reg + + +# --------------------------------------------------------------------------- +# create_modules tests +# --------------------------------------------------------------------------- + +class TestCreateModules: + + @pytest.mark.unit + def test_create_modules_returns_instances(self, reg): + reg.register(StubA) + modules = reg.create_modules() + assert len(modules) == 1 + assert isinstance(modules[0], StubA) + + @pytest.mark.unit + def test_create_modules_returns_fresh_instances(self, reg): + reg.register(StubA) + m1 = reg.create_modules()[0] + m2 = reg.create_modules()[0] + assert m1 is not m2 # different instances + + @pytest.mark.unit + def test_create_modules_preserves_order(self, reg): + reg.register(StubA) + reg.register(StubB) + reg.register(StubC) + modules = reg.create_modules() + names = [m.name for m in modules] + assert names == ["stub_a", "stub_b", "stub_c"] + + @pytest.mark.unit + def test_create_modules_empty_registry(self, reg): + assert reg.create_modules() == [] + + @pytest.mark.unit + def test_created_instances_are_runnable(self, reg): + reg.register(StubA) + module = reg.create_modules()[0] + result = module.run({}) + assert result == {"a_out": "a_result"} + + +# --------------------------------------------------------------------------- +# Integration: registry → graph builder +# --------------------------------------------------------------------------- + +class TestRegistryGraphIntegration: + + @pytest.mark.unit + def test_create_modules_feeds_graph_builder(self, reg): + """Modules from registry can be used directly with GraphBuilder.""" + from adapt.execution.graph.builder import GraphBuilder + from adapt.execution.graph.executor import GraphExecutor + + reg.register(StubA) + reg.register(StubB) + + modules = reg.create_modules() + nodes = GraphBuilder(modules).build() + ctx = GraphExecutor(nodes).run({}) + + assert ctx["a_out"] == "a_result" + assert ctx["b_out"] == "b_result" + + @pytest.mark.unit + def test_full_linear_pipeline_via_registry(self, reg): + """Three-module chain registers, builds, and executes correctly.""" + from adapt.execution.graph.builder import GraphBuilder + from adapt.execution.graph.executor import GraphExecutor + + reg.register(StubA) + reg.register(StubB) + reg.register(StubC) + + modules = reg.create_modules() + nodes = GraphBuilder(modules).build() + ctx = GraphExecutor(nodes).run({}) + + assert ctx["c_out"] == "c_result" diff --git a/tests/graph/__init__.py b/tests/graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py new file mode 100644 index 0000000..356c13c --- /dev/null +++ b/tests/graph/test_graph.py @@ -0,0 +1,476 @@ +"""Tests for the execution graph: Node, GraphBuilder, GraphExecutor. + +These are pure unit tests — no IO, no radar data, no dependencies on +scientific modules. All modules are lightweight stubs. +""" + +import pytest + +from adapt.contracts import ContractViolation +from adapt.execution.graph.builder import GraphBuilder +from adapt.execution.graph.executor import GraphExecutor +from adapt.execution.graph.node import Node +from adapt.modules.base import BaseModule + +# --------------------------------------------------------------------------- +# Stub modules for testing +# --------------------------------------------------------------------------- + +class StubModule(BaseModule): + """A configurable stub module that records execution order.""" + + def __init__(self, name, inputs, outputs, side_effect=None): + self._name = name + self._inputs = inputs + self._outputs = outputs + self._side_effect = side_effect # callable(context) -> dict + + @property + def name(self): + return self._name + + @property + def inputs(self): + return self._inputs + + @property + def outputs(self): + return self._outputs + + def run(self, context): + if self._side_effect: + return self._side_effect(context) + return {k: f"{self._name}_result" for k in self._outputs} + + +# Execution order tracker shared across test stubs +_execution_order = [] + + +def _make_tracking_module(name, inputs, outputs): + """Create a stub that appends its name to _execution_order when run.""" + def run_fn(ctx): + _execution_order.append(name) + return {k: f"{name}_result" for k in outputs} + return StubModule(name, inputs, outputs, side_effect=run_fn) + + +# --------------------------------------------------------------------------- +# Node tests +# --------------------------------------------------------------------------- + +class TestNode: + + @pytest.mark.unit + def test_node_wraps_module(self): + mod = StubModule("grid", ["radar_volume"], ["grid_volume"]) + node = Node(mod) + assert node.name == "grid" + assert node.inputs == ["radar_volume"] + assert node.outputs == ["grid_volume"] + + @pytest.mark.unit + def test_node_starts_with_empty_deps(self): + mod = StubModule("grid", ["radar_volume"], ["grid_volume"]) + node = Node(mod) + assert node.dependencies == [] + assert node.dependents == [] + + @pytest.mark.unit + def test_node_repr(self): + mod = StubModule("detect", ["grid_volume"], ["storm_cells"]) + node = Node(mod) + r = repr(node) + assert "detect" in r + assert "grid_volume" in r + assert "storm_cells" in r + + +# --------------------------------------------------------------------------- +# GraphBuilder tests +# --------------------------------------------------------------------------- + +class TestGraphBuilder: + + @pytest.mark.unit + def test_build_linear_chain(self): + """A → B → C should produce correct dependencies.""" + a = StubModule("a", [], ["x"]) + b = StubModule("b", ["x"], ["y"]) + c = StubModule("c", ["y"], ["z"]) + + nodes = {n.name: n for n in GraphBuilder([a, b, c]).build()} + + assert nodes["b"].dependencies == [nodes["a"]] + assert nodes["c"].dependencies == [nodes["b"]] + assert nodes["a"].dependencies == [] + + @pytest.mark.unit + def test_build_dependents_wired(self): + """Dependents should mirror dependencies.""" + a = StubModule("a", [], ["x"]) + b = StubModule("b", ["x"], ["y"]) + + nodes = {n.name: n for n in GraphBuilder([a, b]).build()} + + assert nodes["b"] in nodes["a"].dependents + + @pytest.mark.unit + def test_build_fan_out(self): + """One output feeding two modules.""" + source = StubModule("source", [], ["data"]) + consumer1 = StubModule("consumer1", ["data"], ["out1"]) + consumer2 = StubModule("consumer2", ["data"], ["out2"]) + + nodes = {n.name: n for n in GraphBuilder([source, consumer1, consumer2]).build()} + + assert nodes["source"] in nodes["consumer1"].dependencies + assert nodes["source"] in nodes["consumer2"].dependencies + assert len(nodes["source"].dependents) == 2 + + @pytest.mark.unit + def test_build_root_node_no_deps(self): + """A module with no declared inputs has no dependencies.""" + root = StubModule("root", [], ["data"]) + nodes = GraphBuilder([root]).build() + assert nodes[0].dependencies == [] + + @pytest.mark.unit + def test_build_unconnected_inputs_ignored(self): + """Inputs that no module produces are simply ignored (external data).""" + mod = StubModule("mod", ["external_input"], ["result"]) + nodes = GraphBuilder([mod]).build() + assert nodes[0].dependencies == [] + + @pytest.mark.unit + def test_build_duplicate_output_raises(self): + """Two modules declaring the same output key should raise ValueError.""" + a = StubModule("a", [], ["shared_key"]) + b = StubModule("b", [], ["shared_key"]) + with pytest.raises(ValueError, match="shared_key"): + GraphBuilder([a, b]).build() + + @pytest.mark.unit + def test_build_returns_all_nodes(self): + mods = [ + StubModule("a", [], ["x"]), + StubModule("b", ["x"], ["y"]), + StubModule("c", ["y"], ["z"]), + ] + nodes = GraphBuilder(mods).build() + assert len(nodes) == 3 + + +# --------------------------------------------------------------------------- +# GraphExecutor tests +# --------------------------------------------------------------------------- + +class TestGraphExecutor: + + def setup_method(self): + """Clear shared tracker before each test.""" + _execution_order.clear() + + @pytest.mark.unit + def test_executor_linear_order(self): + """Nodes must execute in topological order.""" + a = _make_tracking_module("a", [], ["x"]) + b = _make_tracking_module("b", ["x"], ["y"]) + c = _make_tracking_module("c", ["y"], ["z"]) + + nodes = GraphBuilder([a, b, c]).build() + GraphExecutor(nodes).run({}) + + assert _execution_order == ["a", "b", "c"] + + @pytest.mark.unit + def test_executor_root_before_dependents(self): + """Root nodes (no deps) must execute before their dependents.""" + root = _make_tracking_module("root", [], ["data"]) + child = _make_tracking_module("child", ["data"], ["out"]) + + nodes = GraphBuilder([root, child]).build() + GraphExecutor(nodes).run({}) + + assert _execution_order.index("root") < _execution_order.index("child") + + @pytest.mark.unit + def test_executor_fan_out_both_consumers_run(self): + """Both consumers should run when source produces their shared input.""" + source = StubModule("source", [], ["data"]) + c1 = StubModule("c1", ["data"], ["out1"]) + c2 = StubModule("c2", ["data"], ["out2"]) + + nodes = GraphBuilder([source, c1, c2]).build() + result = GraphExecutor(nodes).run({}) + + assert "out1" in result + assert "out2" in result + + @pytest.mark.unit + def test_executor_outputs_merged_into_context(self): + """Module outputs should appear in the returned context.""" + mod = StubModule("mod", [], ["result"], side_effect=lambda ctx: {"result": 42}) + nodes = GraphBuilder([mod]).build() + ctx = GraphExecutor(nodes).run({}) + assert ctx["result"] == 42 + + @pytest.mark.unit + def test_executor_initial_context_available(self): + """Modules should see initial context values.""" + received = {} + + def capture(ctx): + received.update(ctx) + return {"out": True} + + mod = StubModule("mod", [], ["out"], side_effect=capture) + nodes = GraphBuilder([mod]).build() + GraphExecutor(nodes).run({"initial_key": "hello"}) + + assert received.get("initial_key") == "hello" + + @pytest.mark.unit + def test_executor_cycle_raises(self): + """A cyclic graph must raise RuntimeError, not hang.""" + # Create two nodes that depend on each other by manually wiring + a = StubModule("a", ["b_out"], ["a_out"]) + b = StubModule("b", ["a_out"], ["b_out"]) + + # Manually build nodes with circular deps (bypasses GraphBuilder) + node_a = Node(a) + node_b = Node(b) + node_a.dependencies.append(node_b) + node_b.dependencies.append(node_a) + + with pytest.raises(RuntimeError, match="cycle"): + GraphExecutor([node_a, node_b]).run({}) + + @pytest.mark.unit + def test_executor_single_node(self): + """Single node with no deps runs and returns output.""" + mod = StubModule("solo", [], ["result"], side_effect=lambda ctx: {"result": "done"}) + nodes = GraphBuilder([mod]).build() + ctx = GraphExecutor(nodes).run({}) + assert ctx["result"] == "done" + + +# --------------------------------------------------------------------------- +# GraphExecutor contract enforcement tests +# --------------------------------------------------------------------------- + +class _ContractStub(BaseModule): + """Minimal BaseModule subclass with configurable contracts for testing.""" + + def __init__(self, name, inputs, outputs, run_fn=None, + input_contracts=None, output_contracts=None): + self._name = name + self._inputs = inputs + self._outputs = outputs + self._run_fn = run_fn or (lambda ctx: {k: f"{name}_out" for k in outputs}) + self.input_contracts = input_contracts or {} + self.output_contracts = output_contracts or {} + + @property + def name(self): return self._name + @property + def inputs(self): return self._inputs + @property + def outputs(self): return self._outputs + + def run(self, context): + return self._run_fn(context) + + +class TestGraphExecutorContracts: + + @pytest.mark.unit + def test_input_contract_is_called_before_run(self): + """Input validator must be called before module.run().""" + call_order = [] + + def _validate(val): + call_order.append("validate") + + def _run(ctx): + call_order.append("run") + return {"out": 1} + + mod = _ContractStub("m", ["inp"], ["out"], run_fn=_run, + input_contracts={"inp": _validate}) + nodes = GraphBuilder([mod]).build() + GraphExecutor(nodes).run({"inp": "value"}) + + assert call_order == ["validate", "run"] + + @pytest.mark.unit + def test_output_contract_is_called_after_run(self): + """Output validator must be called after module.run().""" + call_order = [] + + def _run(ctx): + call_order.append("run") + return {"out": 42} + + def _validate(val): + call_order.append("validate") + + mod = _ContractStub("m", [], ["out"], run_fn=_run, + output_contracts={"out": _validate}) + nodes = GraphBuilder([mod]).build() + GraphExecutor(nodes).run({}) + + assert call_order == ["run", "validate"] + + @pytest.mark.unit + def test_input_contract_violation_propagates(self): + """ContractViolation raised by input validator propagates out of executor.""" + def _bad_validator(val): + raise ContractViolation("input contract broken") + + mod = _ContractStub("m", ["inp"], ["out"], + input_contracts={"inp": _bad_validator}) + nodes = GraphBuilder([mod]).build() + + with pytest.raises(ContractViolation, match="input contract broken"): + GraphExecutor(nodes).run({"inp": "anything"}) + + @pytest.mark.unit + def test_output_contract_violation_propagates(self): + """ContractViolation raised by output validator propagates out of executor.""" + def _bad_validator(val): + raise ContractViolation("output contract broken") + + mod = _ContractStub("m", [], ["out"], + output_contracts={"out": _bad_validator}) + nodes = GraphBuilder([mod]).build() + + with pytest.raises(ContractViolation, match="output contract broken"): + GraphExecutor(nodes).run({}) + + @pytest.mark.unit + def test_missing_input_key_raises_contract_violation(self): + """Executor must raise ContractViolation when a required input is absent. + + Previously the executor silently skipped validation if the key was missing + from context (guarded by 'if key in context:'). After the fix, it raises + immediately with a clear message. + """ + def _validate(val): + pass # should never be called — key is absent + + mod = _ContractStub("m", ["required_key"], ["out"], + input_contracts={"required_key": _validate}) + nodes = GraphBuilder([mod]).build() + + with pytest.raises(ContractViolation, match="required_key"): + GraphExecutor(nodes).run({}) # required_key intentionally absent + + @pytest.mark.unit + def test_input_contract_receives_correct_value(self): + """Input validator receives the actual value from context.""" + received = [] + + def _capture(val): + received.append(val) + + mod = _ContractStub("m", ["x"], ["out"], + input_contracts={"x": _capture}) + nodes = GraphBuilder([mod]).build() + GraphExecutor(nodes).run({"x": 99}) + + assert received == [99] + + @pytest.mark.unit + def test_output_contract_receives_correct_value(self): + """Output validator receives the value the module returned.""" + received = [] + + def _capture(val): + received.append(val) + + mod = _ContractStub("m", [], ["result"], + run_fn=lambda ctx: {"result": "hello"}, + output_contracts={"result": _capture}) + nodes = GraphBuilder([mod]).build() + GraphExecutor(nodes).run({}) + + assert received == ["hello"] + + +# --------------------------------------------------------------------------- +# Branch-coverage gap tests (executor lines 82, 100→106, 102→101 and +# builder branches 71→73, 73→68) +# --------------------------------------------------------------------------- + +class TestExecutorBranchGaps: + """Targeted tests for executor branches not reachable by the main suite.""" + + @pytest.mark.unit + def test_completed_node_skipped_on_second_iteration(self): + """Executor loop skips already-completed nodes (line 82 branch). + + Modules are given in REVERSE dependency order so B comes before A + in self.nodes but cannot run until A completes. On the second while + iteration A is already in `completed` — the `continue` on line 82 fires. + """ + order = [] + A = StubModule("a", [], ["x"], side_effect=lambda ctx: (order.append("a") or {"x": 1})) + B = StubModule("b", ["x"], ["y"], side_effect=lambda ctx: (order.append("b") or {"y": 2})) + # Reverse order: B first, then A + nodes = GraphBuilder([B, A]).build() + result = GraphExecutor(nodes).run({}) + assert result["x"] == 1 + assert result["y"] == 2 + assert order == ["a", "b"] + + @pytest.mark.unit + def test_module_returning_none_is_treated_as_no_output(self): + """If run() returns None the executor does not update context (branch 100→106).""" + mod = StubModule("sink", ["x"], [], side_effect=lambda ctx: None) + nodes = GraphBuilder([mod]).build() + result = GraphExecutor(nodes).run({"x": 42}) + # x stays in context (unchanged), no crash + assert result["x"] == 42 + + @pytest.mark.unit + def test_output_contract_key_absent_from_outputs_is_skipped(self): + """If an output_contract key is not in the returned dict, validator is not called + (branch 102→101). Executor must not raise — the missing key is silently skipped. + """ + validated = [] + + def _validator(val): + validated.append(val) + + # Module declares contract for "missing_key" but only returns "real_key" + mod = _ContractStub( + "m", [], ["real_key"], + run_fn=lambda ctx: {"real_key": "ok"}, + output_contracts={"missing_key": _validator}, + ) + nodes = GraphBuilder([mod]).build() + result = GraphExecutor(nodes).run({}) + # Validator never called because key absent + assert validated == [] + assert result["real_key"] == "ok" + + +class TestBuilderBranchGaps: + """Targeted tests for GraphBuilder branches 71→73 and 73→68.""" + + @pytest.mark.unit + def test_single_parent_shared_by_multiple_inputs_wired_once(self): + """When module B consumes two keys both produced by module A, the + dependency is added only once (branches 71→73 and 73→68 fire on + the second shared input). + """ + A = StubModule("a", [], ["x", "y"]) # produces two keys + B = StubModule("b", ["x", "y"], ["z"]) # consumes both + nodes = GraphBuilder([A, B]).build() + node_map = {n.name: n for n in nodes} + + # A is a dependency of B exactly once, not twice + assert node_map["b"].dependencies.count(node_map["a"]) == 1 + # B is a dependent of A exactly once + assert node_map["a"].dependents.count(node_map["b"]) == 1 diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers/fake_grid.py b/tests/helpers/fake_grid.py new file mode 100644 index 0000000..bdbf500 --- /dev/null +++ b/tests/helpers/fake_grid.py @@ -0,0 +1,59 @@ +import numpy as np +import xarray as xr + + +def make_fake_grid_ds( + time_len=1, + z_levels=(0, 1000, 2000), + shape=(5, 5), + variables=("reflectivity",), + with_labels=True, +): + """ + Create a minimal Py-ART-like grid dataset suitable + for processor integration tests. + """ + + time = np.array(["2025-01-01"], dtype="datetime64[ns]") + z = np.array(z_levels) + y = np.arange(shape[0]) + x = np.arange(shape[1]) + + data_vars = {} + for var in variables: + data = np.random.rand(time_len, len(z), shape[0], shape[1]) * 50 + data_vars[var] = (("time", "z", "y", "x"), data) + + ds = xr.Dataset( + data_vars=data_vars, + coords={"time": time, "z": z, "y": y, "x": x}, + attrs={ + "radar_latitude": 40.0, + "radar_longitude": -100.0, + "radar_altitude": 100.0, + }, + ) + + return ds + +def make_fake_grid_ds_with_labels(): + data = { + "reflectivity": (("y", "x"), np.ones((4, 4))) + } + + data["cell_labels"] = (("y", "x"), np.array([ + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 2, 2], + [0, 0, 2, 2], + ])) + + ds = xr.Dataset( + data, + coords={ + "x": np.arange(4), + "y": np.arange(4), + }, + attrs={"z_level_m": 2000} + ) + return ds diff --git a/tests/helpers/fake_netcdf.py b/tests/helpers/fake_netcdf.py new file mode 100644 index 0000000..dbbeb7b --- /dev/null +++ b/tests/helpers/fake_netcdf.py @@ -0,0 +1,40 @@ +from pathlib import Path + +import numpy as np +import xarray as xr + + +def write_fake_segmentation_netcdf( + path: Path, + with_labels: bool = True, +): + data_vars = { + "reflectivity": (("y", "x"), np.ones((4, 4), dtype="float32")), + } + + if with_labels: + labels = np.array([ + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 2, 2], + [0, 0, 2, 2], + ], dtype="int32") + data_vars["cell_labels"] = (("y", "x"), labels) + + ds = xr.Dataset( + data_vars=data_vars, + coords={ + "y": np.arange(4), + "x": np.arange(4), + }, + attrs={ + "z_level_m": 2000, + "radar_id": "TEST", + } + ) + + path.parent.mkdir(parents=True, exist_ok=True) + ds.to_netcdf(path) + ds.close() + + return path diff --git a/tests/modules/__init__.py b/tests/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/conftest.py b/tests/modules/conftest.py new file mode 100644 index 0000000..62312bf --- /dev/null +++ b/tests/modules/conftest.py @@ -0,0 +1,294 @@ +# tests/conftest.py +import shutil +import tempfile +from datetime import UTC, datetime +from pathlib import Path + +import numpy as np +import pytest +import xarray as xr + +from adapt.configuration.schemas.directories import setup_output_directories +from adapt.configuration.schemas.internal import InternalConfig +from adapt.configuration.schemas.materialization import materialize_module_configs +from adapt.configuration.schemas.param import ParamConfig +from adapt.configuration.schemas.resolve import resolve_config +from adapt.configuration.schemas.user import UserConfig + + +# ---- AwsNexradDownloader fixtures ---- +class FakeScan: + def __init__(self, key, scan_time=None): + self.key = key + self.scan_time = scan_time or datetime.now(UTC) + + +class FakeAwsConn: + def __init__(self, scans): + self.scans = scans + + def get_avail_scans_in_range(self, start, end, radar_id): + return self.scans + + def download(self, scans, target_dir, keep_aws_folders=False): + class Result: + def __init__(self, path): + self.filepath = path + + results = [] + for scan in scans: + path = target_dir / scan.key + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(b"x" * 2048) + results.append(Result(path)) + + class DownloadResults: + def iter_success(self): + return results + + return DownloadResults() + + +@pytest.fixture +def fake_scan(): + return FakeScan + + +@pytest.fixture +def fake_aws_conn(): + return FakeAwsConn + + +# ---- RadarCellSegmenter fixtures ---- +# these are for non-closing tests, so default kernel size of (1,1) is used. +@pytest.fixture +def simple_2d_ds(): + """ + 2D reflectivity field with one clear cell. + """ + data = np.array([ + [10, 10, 10, 10], + [10, 40, 40, 10], + [10, 40, 40, 10], + [10, 10, 10, 10], + ], dtype=np.float32) + + ds = xr.Dataset( + { + "reflectivity": (("y", "x"), data) + }, + coords={ + "y": np.arange(data.shape[0]), + "x": np.arange(data.shape[1]), + }, + attrs={"z_level_m": 2000}, + ) + return ds + + +@pytest.fixture +def empty_2d_ds(): + """ + All values below threshold, so no cells. + """ + data = np.zeros((4, 4), dtype=np.float32) + + return xr.Dataset( + {"reflectivity": (("y", "x"), data)}, + coords={"y": range(4), "x": range(4)}, + attrs={"z_level_m": 1000}, + ) + + +@pytest.fixture +def two_cell_ds(): + """ + Two separate cells of different sizes. + """ + data = np.array([ + [50, 50, 0, 0, 0], + [50, 50, 0, 30, 30], + [ 0, 0, 0, 30, 30], + [ 0, 0, 0, 0, 0], + ], dtype=np.float32) + + return xr.Dataset( + {"reflectivity": (("y", "x"), data)}, + coords={"y": range(4), "x": range(5)}, + ) + + +# This is fo testing segmentation with multiple cells and closing operations +@pytest.fixture +def large_multi_cell_ds(): + """ + Larger domain with multiple well-separated cells. + No closing should keep all separate. + """ + data = np.zeros((10, 10), dtype=np.float32) + + # Cell 1 (top-left) + data[1:3, 1:3] = 45 + + # Cell 2 (top-right) + data[1:3, 7:9] = 50 + + # Cell 3 (bottom-left) + data[7:9, 1:3] = 55 + + # Cell 4 (bottom-right) + data[7:9, 7:9] = 60 + + return xr.Dataset( + {"reflectivity": (("y", "x"), data)}, + coords={"y": range(10), "x": range(10)}, + attrs={"z_level_m": 2000}, + ) + + +@pytest.fixture +def close_cells_ds(): + """ + Two nearby cells separated by a 1-pixel gap. + Closing (2,2) should merge them. + """ + data = np.zeros((6, 6), dtype=np.float32) + + # Cell A + data[2:4, 1:3] = 40 + + # 1-pixel gap + + # Cell B + data[2:4, 4:6] = 40 + + return xr.Dataset( + {"reflectivity": (("y", "x"), data)}, + coords={"y": range(6), "x": range(6)}, + ) + + +# For testing motion projection + +@pytest.fixture +def simple_labeled_ds_pair(): + """ + Two small 2D datasets with: + - reflectivity + - cell_labels + - valid time coordinate + Zero motion between frames. + """ + data = np.array([ + [0, 40, 40, 0], + [0, 40, 40, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + ], dtype=np.float32) + + labels = np.array([ + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + ], dtype=np.int32) + + t0 = np.datetime64("2024-01-01T00:00") + t1 = np.datetime64("2024-01-01T00:05") + + ds1 = xr.Dataset( + { + "reflectivity": (("y", "x"), data), + "cell_labels": (("y", "x"), labels), + }, + coords={"y": range(4), "x": range(4)}, + ) + ds1 = ds1.assign_coords(time=t0) + + ds2 = xr.Dataset( + { + "reflectivity": (("y", "x"), data), + "cell_labels": (("y", "x"), labels), + }, + coords={"y": range(4), "x": range(4)}, + ) + ds2 = ds2.assign_coords(time=t1) + + return [ds1, ds2] + + +# Analyzer fixtures +@pytest.fixture +def labeled_ds_with_extras(simple_2d_ds): + """ + 2D dataset with: + - cell_labels + - reflectivity + - heading vectors + - projections + """ + ds = simple_2d_ds.copy() + + labels = np.array([ + [0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0], + ], dtype=np.int32) + + ds["cell_labels"] = (("y", "x"), labels) + + ds["heading_x"] = (("y", "x"), np.ones_like(labels, dtype=np.float32)) + ds["heading_y"] = (("y", "x"), np.zeros_like(labels, dtype=np.float32)) + ds["differential_reflectivity"] = (("y", "x"), np.full_like(labels, 1.0, dtype=np.float32)) + + projections = np.stack([labels, labels], axis=0) + ds["cell_projections"] = ( + ("frame_offset", "y", "x"), + projections + ) + + ds = ds.assign_coords(frame_offset=[0, 1]) + ds = ds.assign_coords(time=np.datetime64("2024-01-01T00:00")) + + return ds + + +@pytest.fixture +def temp_dir(): + d = tempfile.mkdtemp() + yield Path(d) + shutil.rmtree(d) + + +@pytest.fixture +def radar_config(temp_dir) -> InternalConfig: + """InternalConfig for radar module tests.""" + param = ParamConfig() + user = UserConfig(base_dir=str(temp_dir), radar="TEST_RADAR") + return resolve_config(param, user, None) + + +@pytest.fixture +def ingest_module_config_from_radar(radar_config): + """IngestModuleConfig derived from radar_config.""" + return materialize_module_configs(radar_config)["ingest_config"] + + +@pytest.fixture +def radar_output_dirs(temp_dir): + """Output directories for radar tests. + + Returns dict with 'base' and 'logs' from setup_output_directories, + plus backward-compatible keys that point to base for legacy tests. + """ + dirs = setup_output_directories(temp_dir) + # Add legacy keys for backward compatibility in tests + # These point to base since the actual paths are now under RADAR_ID/ + dirs["nexrad"] = dirs["base"] + dirs["gridnc"] = dirs["base"] + dirs["analysis"] = dirs["base"] + dirs["plots"] = dirs["base"] + return dirs + + diff --git a/tests/runtime/__init__.py b/tests/runtime/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/runtime/conftest.py b/tests/runtime/conftest.py new file mode 100644 index 0000000..aea4552 --- /dev/null +++ b/tests/runtime/conftest.py @@ -0,0 +1,80 @@ +import queue +import shutil +import tempfile +from pathlib import Path + +import pytest + +from adapt.configuration.schemas.directories import setup_output_directories +from adapt.configuration.schemas.internal import InternalConfig +from adapt.configuration.schemas.param import ParamConfig +from adapt.configuration.schemas.resolve import resolve_config +from adapt.configuration.schemas.user import UserConfig +from adapt.persistence import DataRepository +from adapt.runtime.file_tracker import FileProcessingTracker + + +@pytest.fixture +def temp_dir(): + d = tempfile.mkdtemp() + yield Path(d) + shutil.rmtree(d) + + +@pytest.fixture +def tracker(temp_dir): + db_path = temp_dir / "tracker.db" + return FileProcessingTracker(db_path) + + +@pytest.fixture +def pipeline_config(temp_dir) -> InternalConfig: + """InternalConfig for pipeline tests.""" + param = ParamConfig() + # For tests, provide defaults since radar_id and base_dir are required at runtime + user = UserConfig( + radar="TEST_RADAR", + base_dir=str(temp_dir) + ) + config_dict = resolve_config(param, user, None).model_dump() + + # Add required fields for new architecture + output_dirs = setup_output_directories(str(temp_dir)) + config_dict["output_dirs"] = {k: str(v) for k, v in output_dirs.items()} + config_dict["run_id"] = DataRepository.generate_run_id("TEST") + + return InternalConfig.model_validate(config_dict) + + +@pytest.fixture +def pipeline_output_dirs(temp_dir): + """Output directories for pipeline tests. + + Returns dict with 'base' and 'logs' from setup_output_directories, + plus backward-compatible keys that point to base for legacy tests. + """ + dirs = setup_output_directories(temp_dir) + # Add legacy keys for backward compatibility in tests + dirs["nexrad"] = dirs["base"] + dirs["gridnc"] = dirs["base"] + dirs["analysis"] = dirs["base"] + dirs["plots"] = dirs["base"] + return dirs + + +# made for processor tests +@pytest.fixture +def processor_queues(): + return queue.Queue(), queue.Queue() + + +@pytest.fixture +def test_repository(temp_dir): + """DataRepository for processor tests.""" + run_id = DataRepository.generate_run_id("TEST") + return DataRepository( + run_id=run_id, + base_dir=temp_dir, + radar="TEST_RADAR" + ) + diff --git a/tests/runtime/test_file_tracker.py b/tests/runtime/test_file_tracker.py new file mode 100644 index 0000000..3a69088 --- /dev/null +++ b/tests/runtime/test_file_tracker.py @@ -0,0 +1,158 @@ +from datetime import UTC, datetime + +import pytest + +pytestmark = [pytest.mark.unit, pytest.mark.pipeline] + + +def test_register_and_fetch_file(tracker): + file_id = "KTEST_0001" + radar_id = "KTEST" + scan_time = datetime.now(UTC) + + created = tracker.register_file(file_id, radar_id, scan_time) + assert created is True + + status = tracker.get_file_status(file_id) + assert status["file_id"] == file_id + assert status["radar"] == radar_id + assert status["status"] == "pending" + + +def test_register_duplicate_is_noop(tracker): + file_id = "KTEST_0002" + tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + created = tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + assert created is False + + +def test_stage_progression(tracker, tmp_path): + file_id = "KTEST_0003" + tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + + nc = tmp_path / "grid.nc" + tracker.mark_stage_complete(file_id, "regridded", path=nc) + + status = tracker.get_file_status(file_id) + assert status["regridded_at"] is not None + assert status["gridnc_path"] == str(nc) + assert status["status"] == "processing" + + +def test_mark_analyzed_sets_cells(tracker): + file_id = "KTEST_0004" + tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + + tracker.mark_stage_complete(file_id, "analyzed", num_cells=7) + status = tracker.get_file_status(file_id) + + assert status["num_cells"] == 7 + assert status["status"] == "processing" + + +def test_mark_plotted_completes(tracker): + file_id = "KTEST_0005" + tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + + tracker.mark_stage_complete(file_id, "plotted") + status = tracker.get_file_status(file_id) + + assert status["status"] == "completed" + + +def test_failure_sets_failed(tracker): + file_id = "KTEST_0006" + tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + + tracker.mark_stage_complete(file_id, "analyzed", error="boom") + status = tracker.get_file_status(file_id) + + assert status["status"] == "failed" + assert "boom" in status["error_message"] + + +def test_should_process_logic(tracker): + file_id = "KTEST_0007" + tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + + assert tracker.should_process(file_id, "analyzed") is True + + tracker.mark_stage_complete(file_id, "analyzed") + + +def test_get_pending_files(tracker): + """Test getting pending files.""" + file_id1 = "KTEST_0008" + file_id2 = "KTEST_0009" + + tracker.register_file(file_id1, "KTEST", datetime.now(UTC)) + tracker.register_file(file_id2, "KTEST", datetime.now(UTC)) + tracker.mark_stage_complete(file_id1, "downloaded") + + pending = tracker.get_pending_files() + assert any(f["file_id"] == file_id2 for f in pending) + + +def test_get_statistics(tracker): + """Test tracker statistics retrieval.""" + tracker.register_file("KTEST_0010", "KTEST", datetime.now(UTC)) + tracker.register_file("KTEST_0011", "KTEST", datetime.now(UTC)) + tracker.mark_stage_complete("KTEST_0010", "downloaded") + + stats = tracker.get_statistics() + assert stats["total"] >= 2 + assert stats["pending"] >= 1 + assert stats["processing"] >= 1 + + +def test_reset_failed_files(tracker): + """Test resetting failed files.""" + file_id = "KTEST_0012" + tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + tracker.mark_stage_complete(file_id, "analyzed", error="Download error") + + tracker.reset_failed("KTEST") + status = tracker.get_file_status(file_id) + assert status["status"] == "pending" + + +def test_mark_multiple_stages(tracker, tmp_path): + """Test marking file through multiple stages.""" + file_id = "KTEST_0016" + tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + + tracker.mark_stage_complete(file_id, "downloaded") + assert tracker.get_file_status(file_id)["status"] == "processing" + + nc = tmp_path / "grid.nc" + tracker.mark_stage_complete(file_id, "regridded", path=nc) + assert tracker.get_file_status(file_id)["gridnc_path"] == str(nc) + + tracker.mark_stage_complete(file_id, "analyzed", num_cells=5) + assert tracker.get_file_status(file_id)["num_cells"] == 5 + + tracker.mark_stage_complete(file_id, "plotted") + assert tracker.get_file_status(file_id)["status"] == "completed" + + +def test_cleanup_deleted_files(tracker): + """Test cleanup of deleted files.""" + file_id = "KTEST_0017" + tracker.register_file(file_id, "KTEST", datetime.now(UTC)) + + # This method should not raise + tracker.cleanup_deleted_files("KTEST") + + +def test_get_statistics_by_radar(tracker): + """Test getting statistics for specific radar.""" + tracker.register_file("KTEST_0018", "KTEST", datetime.now(UTC)) + tracker.register_file("KMOB_0001", "KMOB", datetime.now(UTC)) + + stats = tracker.get_statistics(radar="KTEST") + assert stats["total"] >= 1 + + pending = tracker.get_pending_files(radar="KTEST") + assert any(f["file_id"] == "KTEST_0018" for f in pending) + # Newly registered file should need processing for 'analyzed' (no timestamps set) + assert tracker.should_process("KTEST_0018", "analyzed") is True diff --git a/tests/runtime/test_orchestrator.py b/tests/runtime/test_orchestrator.py new file mode 100644 index 0000000..7fdf53c --- /dev/null +++ b/tests/runtime/test_orchestrator.py @@ -0,0 +1,109 @@ +import pytest + +from adapt.runtime.orchestrator import PipelineOrchestrator + +pytestmark = [pytest.mark.unit, pytest.mark.pipeline] + + +def test_orchestrator_initialization(pipeline_config): + """Orchestrator initializes with config.""" + orch = PipelineOrchestrator(pipeline_config) + assert orch.downloader_queue is not None + + +def test_orchestrator_logging_and_tracker(pipeline_config): + """Orchestrator sets up logging and file tracker.""" + orch = PipelineOrchestrator(pipeline_config) + orch._setup_logging() + + assert orch.tracker is not None + + +def test_orchestrator_queue_wiring(pipeline_config): + """Orchestrator creates queues with correct size limits.""" + orch = PipelineOrchestrator(pipeline_config) + + assert orch.downloader_queue.maxsize == 20 + + +def test_orchestrator_stop_is_idempotent(pipeline_config): + """Calling stop() multiple times is safe.""" + orch = PipelineOrchestrator(pipeline_config) + + orch.stop() + orch.stop() # should not raise + + +def test_orchestrator_has_stop_event(pipeline_config): + """Test orchestrator has internal stop flag and stop() sets it.""" + orch = PipelineOrchestrator(pipeline_config) + + assert hasattr(orch, '_stop_event') + assert orch._stop_event is False + + orch.stop() + assert orch._stop_event is True + + +def test_orchestrator_config_storage(pipeline_config): + """Test orchestrator stores config correctly.""" + orch = PipelineOrchestrator(pipeline_config) + + assert orch.config == pipeline_config + assert orch.output_dirs is not None # Should be extracted from config + + +def test_orchestrator_queue_types(pipeline_config): + """Test orchestrator creates correct queue types.""" + import queue + orch = PipelineOrchestrator(pipeline_config) + + assert isinstance(orch.downloader_queue, queue.Queue) + + +def test_orchestrator_tracker_database_path(pipeline_config): + """Test orchestrator creates tracker with correct database path (after setup).""" + orch = PipelineOrchestrator(pipeline_config) + orch._setup_logging() + + assert orch.tracker is not None + # Database should be in RADAR_ID/analysis/ directory + radar_id = pipeline_config.downloader.radar + expected_db = ( + orch.output_dirs["base"] / radar_id / "analysis" / f"{radar_id}_processing_tracker.db" + ) + assert expected_db.exists() + + +def test_orchestrator_mode_from_config(pipeline_config): + """Test orchestrator respects mode from config.""" + orch = PipelineOrchestrator(pipeline_config) + + # Should use mode from config + assert orch.config.mode in ["realtime", "historical"] + + +def test_orchestrator_stop_clears_queues(pipeline_config): + """Test that stop() sets stop flag and is idempotent.""" + orch = PipelineOrchestrator(pipeline_config) + + # Add some items to queues + orch.downloader_queue.put("test1") + + orch.stop() + + # Internal stop flag should be set + assert orch._stop_event is True + + # Calling stop again should be safe + orch.stop() + assert orch._stop_event is True + + +def test_orchestrator_processor_config_accessible(pipeline_config): + """Test orchestrator can access processor config.""" + orch = PipelineOrchestrator(pipeline_config) + + assert hasattr(orch.config, 'processor') + assert orch.config.processor.max_history >= 0 + assert orch.config.processor.min_file_size > 0 diff --git a/tests/runtime/test_orchestrator_historical_shutdown.py b/tests/runtime/test_orchestrator_historical_shutdown.py new file mode 100644 index 0000000..ccffb93 --- /dev/null +++ b/tests/runtime/test_orchestrator_historical_shutdown.py @@ -0,0 +1,94 @@ +import pytest + +from adapt.runtime.orchestrator import PipelineOrchestrator + +pytestmark = [pytest.mark.unit, pytest.mark.pipeline] + + +class _FakeDownloader: + def __init__(self, complete: bool, alive: bool, processed: int = 0, expected: int = 0): + self._complete = complete + self._alive = alive + self._processed = processed + self._expected = expected + self.stop_called = False + + def is_historical_complete(self) -> bool: + return self._complete + + def is_alive(self) -> bool: + return self._alive + + def get_historical_progress(self): + return self._processed, self._expected + + def stop(self): + self.stop_called = True + + def join(self, timeout=None): + self._alive = False + + +class _FakeProcessor: + def __init__(self): + self._alive = True + self.stop_called = False + + def is_alive(self) -> bool: + return self._alive + + def stop(self): + self.stop_called = True + self._alive = False + + def join(self, timeout=None): + self._alive = False + + +class _FakeRepository: + def __init__(self): + self.finalized = False + self.closed = False + + def finalize_run(self, status: str): + self.finalized = True + + def close(self): + self.closed = True + + +def test_historical_complete_returns_true_and_stops_processor(pipeline_config): + pipeline_config = pipeline_config.model_copy(update={"mode": "historical"}) + orch = PipelineOrchestrator(pipeline_config) + orch.downloader = _FakeDownloader(complete=True, alive=False, processed=5, expected=5) + orch.processor = _FakeProcessor() + + done = orch._check_historical_complete() + + assert done is True + assert orch.downloader.stop_called is True + assert orch.processor.stop_called is True + + +def test_historical_not_complete_returns_false_when_downloader_dead(pipeline_config): + pipeline_config = pipeline_config.model_copy(update={"mode": "historical"}) + orch = PipelineOrchestrator(pipeline_config) + orch.downloader = _FakeDownloader(complete=False, alive=False) + + done = orch._check_historical_complete() + + assert done is False + + +def test_stop_skips_repository_close_when_owned_externally(pipeline_config): + orch = PipelineOrchestrator(pipeline_config, close_repository_on_stop=False) + repo = _FakeRepository() + orch.repository = repo + + orch.stop() + + assert repo.finalized is True + assert repo.closed is False + + orch.close_repository() + assert repo.closed is True diff --git a/tests/runtime/test_processor_core.py b/tests/runtime/test_processor_core.py new file mode 100644 index 0000000..87d3bd5 --- /dev/null +++ b/tests/runtime/test_processor_core.py @@ -0,0 +1,111 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Tests for RadarProcessor graph-based processing. + +The processor delegates scientific work to two GraphExecutors built at startup: +_single_executor (ingest + detection) and _multi_executor (projection + analysis + tracking). +These tests verify the orchestration layer: initialization, stop/start lifecycle. +""" + +import queue + +import pandas as pd +import pytest + +from adapt.execution.graph.executor import GraphExecutor +from adapt.runtime.processor import RadarProcessor + +pytestmark = [pytest.mark.unit, pytest.mark.pipeline] + + +def _make_proc(pipeline_config, pipeline_output_dirs, test_repository): + return RadarProcessor( + queue.Queue(), pipeline_config, pipeline_output_dirs, + repository=test_repository, + ) + + +def test_processor_initializes_with_two_executors( + pipeline_config, pipeline_output_dirs, test_repository +): + """Processor creates two GraphExecutors (single-frame and multi-frame) on init.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + assert isinstance(proc._single_executor, GraphExecutor) + assert isinstance(proc._multi_executor, GraphExecutor) + + +def test_single_executor_contains_ingest_and_detection( + pipeline_config, pipeline_output_dirs, test_repository +): + """_single_executor graph covers exactly the ingest and detection nodes.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + single_names = {n.name for n in proc._single_executor.nodes} + assert "ingest" in single_names + assert "detection" in single_names + + +def test_multi_executor_contains_projection_analysis_tracking( + pipeline_config, pipeline_output_dirs, test_repository +): + """_multi_executor graph covers projection, analysis, and tracking nodes.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + multi_names = {n.name for n in proc._multi_executor.nodes} + assert "projection" in multi_names + assert "analysis" in multi_names + assert "tracking" in multi_names + + +def test_processor_stop_sets_flag( + pipeline_config, pipeline_output_dirs, test_repository +): + """stop() signals the run loop to exit.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + assert not proc.stopped() + proc.stop() + assert proc.stopped() + + +def test_processor_stop_is_idempotent( + pipeline_config, pipeline_output_dirs, test_repository +): + """Calling stop() twice is safe.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + proc.stop() + proc.stop() + assert proc.stopped() + + +def test_processor_get_results_returns_empty_dataframe( + pipeline_config, pipeline_output_dirs, test_repository +): + """get_results() returns an empty DataFrame — results live in the repository.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + result = proc.get_results() + assert isinstance(result, pd.DataFrame) + assert result.empty + + +def test_processor_save_results_returns_none( + pipeline_config, pipeline_output_dirs, test_repository +): + """save_results() is a no-op; persistence is handled by RepositoryWriter.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + result = proc.save_results() + assert result is None + + +def test_processor_close_database_returns_none( + pipeline_config, pipeline_output_dirs, test_repository +): + """close_database() is a no-op; the repository owns its own lifecycle.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + result = proc.close_database() + assert result is None + + +def test_processor_requires_repository(pipeline_config, pipeline_output_dirs): + """RadarProcessor raises ValueError when repository is None.""" + with pytest.raises(ValueError, match="DataRepository is required"): + RadarProcessor(queue.Queue(), pipeline_config, pipeline_output_dirs, + repository=None) diff --git a/tests/runtime/test_processor_failures.py b/tests/runtime/test_processor_failures.py new file mode 100644 index 0000000..17b1afb --- /dev/null +++ b/tests/runtime/test_processor_failures.py @@ -0,0 +1,150 @@ +"""Tests for RadarProcessor error handling and success paths. + +The processor orchestrates ingest+detection via _single_executor and +projection+analysis+tracking via _multi_executor. These tests patch the +executors to keep the focus on orchestration rather than scientific behavior. +""" + +import queue +from datetime import UTC, datetime + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from adapt.contracts import ContractViolation +from adapt.runtime.processor import RadarProcessor + +pytestmark = [pytest.mark.unit, pytest.mark.pipeline] + + +def _make_proc(pipeline_config, pipeline_output_dirs, test_repository): + q = queue.Queue() + return RadarProcessor(q, pipeline_config, pipeline_output_dirs, + repository=test_repository) + + +def _fake_ds(): + return xr.Dataset( + { + "reflectivity": (("y", "x"), np.ones((4, 4))), + "cell_labels": (("y", "x"), np.zeros((4, 4), dtype=int)), + }, + coords={"x": np.arange(4), "y": np.arange(4)}, + attrs={"z_level_m": 2000}, + ) + + +def _fake_single_result(scan_time): + """Return what _single_executor.run() would produce.""" + return { + "grid_ds": _fake_ds(), + "grid_ds_2d": _fake_ds(), + "segmented_ds": _fake_ds(), + "scan_time": scan_time, + "num_cells": 0, + } + + +# ── Error paths ─────────────────────────────────────────────────────────────── + +def test_process_file_pipeline_exception_returns_false( + monkeypatch, pipeline_config, pipeline_output_dirs, test_repository +): + """process_file returns False when single-frame executor raises.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + + def _boom(context): + raise OSError("disk failure") + + monkeypatch.setattr(proc._single_executor, "run", _boom) + + ok = proc.process_file("/fake/path/file") + assert ok is False + + +def test_process_file_contract_violation_stops_processor( + monkeypatch, pipeline_config, pipeline_output_dirs, test_repository +): + """ContractViolation during multi-frame executor causes processor to stop.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + + scan_times = [ + datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC), + datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC), + ] + + def _fake_single(context): + return _fake_single_result(scan_times.pop(0)) + + def _boom_multi(context): + raise ContractViolation("bad grid") + + monkeypatch.setattr(proc._single_executor, "run", _fake_single) + monkeypatch.setattr(proc._multi_executor, "run", _boom_multi) + + ok1 = proc.process_file("/fake/path/file_1") + ok2 = proc.process_file("/fake/path/file_2") + assert ok1 is True + assert ok2 is False + assert proc.stopped() + + +# ── Success path ────────────────────────────────────────────────────────────── + +def test_process_file_success_saves_netcdf_and_returns_true( + monkeypatch, pipeline_config, pipeline_output_dirs, test_repository +): + """process_file returns True and attempts NetCDF save on success.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + + scan_times = [ + datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC), + datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC), + ] + + def _fake_single(context): + return _fake_single_result(scan_times.pop(0)) + + fake_result = { + "projected_ds": _fake_ds(), + "scan_time": datetime(2024, 1, 1, 12, 5, 0, tzinfo=UTC), + "cell_stats": pd.DataFrame({"cell_label": [1]}), + "cell_adjacency": pd.DataFrame(), + } + + monkeypatch.setattr(proc._single_executor, "run", _fake_single) + monkeypatch.setattr(proc._multi_executor, "run", lambda ctx: fake_result) + + saved = [] + monkeypatch.setattr(proc, "_save_analysis_netcdf", + lambda ds, fp, st: saved.append(fp) or "/tmp/out.nc") + monkeypatch.setattr(proc, "_save_results", lambda result, st: None) + + ok1 = proc.process_file("/fake/path/file_1") + ok2 = proc.process_file("/fake/path/file_2") + assert ok1 is True + assert ok2 is True + + +def test_process_file_skips_already_analyzed( + monkeypatch, pipeline_config, pipeline_output_dirs, test_repository +): + """process_file skips a file that the tracker marks as done.""" + proc = _make_proc(pipeline_config, pipeline_output_dirs, test_repository) + + class _FakeTracker: + def should_process(self, file_id, stage): + return False + + proc.file_tracker = _FakeTracker() + called = [] + monkeypatch.setattr( + proc._single_executor, "run", + lambda ctx: called.append(1) or _fake_single_result(datetime.now(UTC)), + ) + + ok = proc.process_file("/fake/path/file") + assert ok is True + assert called == [] # single executor was NOT called diff --git a/tests/runtime/test_processor_with_fake_grid.py b/tests/runtime/test_processor_with_fake_grid.py new file mode 100644 index 0000000..19e7b18 --- /dev/null +++ b/tests/runtime/test_processor_with_fake_grid.py @@ -0,0 +1,67 @@ +"""Processor orchestration test using fake module outputs. + +The processor runs a single-frame executor (ingest+detection) to build a +2-frame history, then runs the multi-frame executor (projection+analysis+tracking). +These tests patch the executors to avoid touching real scientific modules. +""" + +import queue +from datetime import UTC, datetime + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from adapt.runtime.processor import RadarProcessor + +pytestmark = [pytest.mark.unit, pytest.mark.pipeline] + + +def _fake_ds(): + return xr.Dataset( + { + "reflectivity": (("y", "x"), np.ones((4, 4))), + "cell_labels": (("y", "x"), np.zeros((4, 4), dtype=int)), + }, + coords={"x": np.arange(4), "y": np.arange(4)}, + attrs={"z_level_m": 2000}, + ) + + +def test_processor_accepts_fake_grid( + tmp_path, monkeypatch, pipeline_config, pipeline_output_dirs, test_repository +): + """Processor handles a successful 2-frame pipeline result correctly.""" + in_q = queue.Queue() + proc = RadarProcessor(in_q, pipeline_config, pipeline_output_dirs, + repository=test_repository) + + scan_times = [ + datetime(2024, 5, 18, 12, 0, 0, tzinfo=UTC), + datetime(2024, 5, 18, 12, 5, 0, tzinfo=UTC), + ] + + def _fake_single(context): + return { + "grid_ds": _fake_ds(), + "grid_ds_2d": _fake_ds(), + "segmented_ds": _fake_ds(), + "scan_time": scan_times.pop(0), + "num_cells": 0, + } + + fake_multi_result = { + "projected_ds": _fake_ds(), + "cell_stats": pd.DataFrame(), + "cell_adjacency": pd.DataFrame(), + } + + monkeypatch.setattr(proc._single_executor, "run", _fake_single) + monkeypatch.setattr(proc._multi_executor, "run", lambda ctx: fake_multi_result) + monkeypatch.setattr(proc, "_save_results", lambda result, st: None) + + ok1 = proc.process_file("/fake/file_1") + ok2 = proc.process_file("/fake/file_2") + assert ok1 is True + assert ok2 is True diff --git a/tests/test_architecture.py b/tests/test_architecture.py new file mode 100644 index 0000000..369474f --- /dev/null +++ b/tests/test_architecture.py @@ -0,0 +1,120 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Architecture tests: enforce module-independence without hardcoding module names. + +These tests discover adapt.modules subpackages at runtime and verify that +no scientific module imports from any other scientific module. New modules +are picked up automatically — no test edits required. + +Run: pytest tests/test_architecture.py +""" + +import ast +import importlib +import pkgutil +from pathlib import Path + +import pytest + +# Skip the entire file gracefully if adapt is not installed in this environment. +# This prevents VSCode pytest discovery errors when the wrong interpreter is active. +adapt_modules = pytest.importorskip( + "adapt.modules", + reason="adapt not installed in this Python environment — activate adapt_env", +) + + +def _discover_module_packages() -> list[str]: + """Return all immediate subpackage names under adapt.modules.""" + return [ + f"adapt.modules.{info.name}" + for info in pkgutil.iter_modules(adapt_modules.__path__) + if info.ispkg + ] + + +def _source_files(package_name: str) -> list[Path]: + """Return all .py files belonging to a package.""" + mod = importlib.import_module(package_name) + pkg_dir = Path(mod.__file__).parent + return list(pkg_dir.rglob("*.py")) + + +def _imported_adapt_modules(py_file: Path) -> set[str]: + """Parse a .py file and return the set of adapt.modules.* names it imports.""" + try: + tree = ast.parse(py_file.read_text()) + except SyntaxError: + return set() + + imports: set[str] = set() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name.startswith("adapt.modules."): + imports.add(alias.name) + elif ( + isinstance(node, ast.ImportFrom) + and node.module + and node.module.startswith("adapt.modules.") + ): + imports.add(node.module) + return imports + + +# Build the test matrix at collection time — works for any future module. +_PACKAGES = _discover_module_packages() +_SKIP = {"adapt.modules.base"} # base.py is shared infrastructure, not a science module + + +@pytest.mark.parametrize("pkg", [p for p in _PACKAGES if p not in _SKIP]) +def test_module_does_not_import_other_modules(pkg: str) -> None: + """Scientific module must not import from any other adapt.modules subpackage. + + This test is parameterised over every subpackage discovered under adapt.modules. + Adding a new module directory makes it appear here automatically. + """ + files = _source_files(pkg) + violations: list[str] = [] + + for py_file in files: + for imported in _imported_adapt_modules(py_file): + # Allow self-imports (within the same subpackage) + if not imported.startswith(pkg): + violations.append(f" {py_file.name}: imports {imported!r}") + + assert not violations, ( + f"\n{pkg} imports from other scientific modules — " + "shared types belong in adapt.contracts:\n" + + "\n".join(violations) + ) + + +@pytest.mark.parametrize("pkg", [p for p in _PACKAGES if p not in _SKIP]) +def test_module_does_not_import_execution_or_runtime(pkg: str) -> None: + """Scientific module must not import from adapt.execution or adapt.runtime.""" + forbidden_prefixes = ("adapt.execution", "adapt.runtime", "adapt.persistence") + files = _source_files(pkg) + violations: list[str] = [] + + for py_file in files: + try: + tree = ast.parse(py_file.read_text()) + except SyntaxError: + continue + for node in ast.walk(tree): + names: list[str] = [] + if isinstance(node, ast.Import): + names = [a.name for a in node.names] + elif isinstance(node, ast.ImportFrom) and node.module: + names = [node.module] + for name in names: + if any(name.startswith(p) for p in forbidden_prefixes): + violations.append(f" {py_file.name}: imports {name!r}") + + assert not violations, ( + f"\n{pkg} imports from layers above it — " + "modules must only depend on contracts/ and utils/:\n" + + "\n".join(violations) + ) From fdf0c8ecadc2fab4a22fd5b6b00711c7f3853610 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 14:51:09 -0500 Subject: [PATCH 02/14] ADD:(tests) configuration CLI config, resolution, setup directories, userconfig normalization --- tests/configuration/__init__.py | 0 tests/configuration/test_cli_config.py | 115 ++++ tests/configuration/test_cli_precedence.py | 78 +++ tests/configuration/test_config_resolution.py | 491 ++++++++++++++++++ .../test_initialization_run_id.py | 134 +++++ tests/configuration/test_rerun_cleanup.py | 55 ++ tests/configuration/test_setup_directories.py | 37 ++ .../test_setup_directories_extended.py | 62 +++ .../test_userconfig_normalization.py | 27 + 9 files changed, 999 insertions(+) create mode 100644 tests/configuration/__init__.py create mode 100644 tests/configuration/test_cli_config.py create mode 100644 tests/configuration/test_cli_precedence.py create mode 100644 tests/configuration/test_config_resolution.py create mode 100644 tests/configuration/test_initialization_run_id.py create mode 100644 tests/configuration/test_rerun_cleanup.py create mode 100644 tests/configuration/test_setup_directories.py create mode 100644 tests/configuration/test_setup_directories_extended.py create mode 100644 tests/configuration/test_userconfig_normalization.py diff --git a/tests/configuration/__init__.py b/tests/configuration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/configuration/test_cli_config.py b/tests/configuration/test_cli_config.py new file mode 100644 index 0000000..d78fa75 --- /dev/null +++ b/tests/configuration/test_cli_config.py @@ -0,0 +1,115 @@ +"""Tests for CLIConfig schema and conversion to internal overrides.""" + +from adapt.configuration.schemas.cli import CLIConfig + + +def test_cli_to_internal_overrides_with_mode(): + """Test CLI config conversion with mode override.""" + cli = CLIConfig(mode="historical") + overrides = cli.to_internal_overrides() + assert overrides["mode"] == "historical" + + +def test_cli_to_internal_overrides_with_realtime_mode(): + """Test CLI config conversion with realtime mode.""" + cli = CLIConfig(mode="realtime") + overrides = cli.to_internal_overrides() + assert overrides["mode"] == "realtime" + + +def test_cli_to_internal_overrides_with_radar_id(): + """Test CLI config conversion with radar_id override.""" + cli = CLIConfig(radar="KMOB") + overrides = cli.to_internal_overrides() + assert overrides["downloader"]["radar"] == "KMOB" + + +def test_cli_to_internal_overrides_with_log_level(): + """Test CLI config conversion with log_level override.""" + cli = CLIConfig(log_level="DEBUG") + overrides = cli.to_internal_overrides() + assert overrides["logging"]["level"] == "DEBUG" + + +def test_cli_to_internal_overrides_with_multiple_fields(): + """Test CLI config conversion with multiple overrides.""" + cli = CLIConfig(mode="historical", radar="KHTX", log_level="INFO") + overrides = cli.to_internal_overrides() + assert overrides["mode"] == "historical" + assert overrides["downloader"]["radar"] == "KHTX" + assert overrides["logging"]["level"] == "INFO" + + +def test_cli_to_internal_overrides_empty(): + """Test CLI config conversion with no overrides.""" + cli = CLIConfig() + overrides = cli.to_internal_overrides() + assert overrides == {} + + +def test_cli_config_accepts_base_dir(): + """Test that base_dir is accepted and in overrides.""" + cli = CLIConfig(base_dir="/path/to/output") + assert cli.base_dir == "/path/to/output" + overrides = cli.to_internal_overrides() + assert overrides["base_dir"] == "/path/to/output" + +def test_cli_config_all_log_levels(): + """Test all valid log levels.""" + for level in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: + cli = CLIConfig(log_level=level) + overrides = cli.to_internal_overrides() + assert overrides["logging"]["level"] == level + + +def test_cli_infers_historical_mode_from_start_time(): + """CLI automatically sets mode=historical if start_time provided without explicit mode.""" + cli = CLIConfig(start_time="2024-01-01T00:00:00Z") + assert cli.mode == "historical" + overrides = cli.to_internal_overrides() + assert overrides["mode"] == "historical" + + +def test_cli_infers_historical_mode_from_end_time(): + """CLI automatically sets mode=historical if end_time provided without explicit mode.""" + cli = CLIConfig(end_time="2024-01-01T23:59:59Z") + assert cli.mode == "historical" + overrides = cli.to_internal_overrides() + assert overrides["mode"] == "historical" + + +def test_cli_infers_historical_mode_from_both_times(): + """CLI automatically sets mode=historical if both times provided without explicit mode.""" + cli = CLIConfig( + start_time="2024-01-01T00:00:00Z", + end_time="2024-01-01T23:59:59Z" + ) + assert cli.mode == "historical" + overrides = cli.to_internal_overrides() + assert overrides["mode"] == "historical" + assert overrides["downloader"]["start_time"] == "2024-01-01T00:00:00Z" + assert overrides["downloader"]["end_time"] == "2024-01-01T23:59:59Z" + + +def test_cli_explicit_mode_overrides_time_inference(): + """Explicit mode in CLI is not overridden by time inference.""" + cli = CLIConfig( + mode="realtime", + start_time="2024-01-01T00:00:00Z" + ) + # Explicit mode should be respected + assert cli.mode == "realtime" + + +def test_cli_time_fields_in_overrides(): + """Test that start_time and end_time are included in overrides.""" + cli = CLIConfig( + start_time="2024-01-01T00:00:00Z", + end_time="2024-01-01T23:59:59Z", + radar="KMOB" + ) + overrides = cli.to_internal_overrides() + assert overrides["downloader"]["start_time"] == "2024-01-01T00:00:00Z" + assert overrides["downloader"]["end_time"] == "2024-01-01T23:59:59Z" + assert overrides["downloader"]["radar"] == "KMOB" + diff --git a/tests/configuration/test_cli_precedence.py b/tests/configuration/test_cli_precedence.py new file mode 100644 index 0000000..113c005 --- /dev/null +++ b/tests/configuration/test_cli_precedence.py @@ -0,0 +1,78 @@ +from adapt.configuration.schemas.cli import CLIConfig +from adapt.configuration.schemas.param import ParamConfig +from adapt.configuration.schemas.resolve import resolve_config +from adapt.configuration.schemas.user import UserConfig + + +def test_cli_overrides_do_not_mutate_user(): + user = UserConfig.model_validate({"RADAR_ID": "KABC", "MODE": "realtime", "BASE_DIR": "/tmp"}) + + cli = CLIConfig.model_validate({"radar": "KHTX"}) + + internal = resolve_config(ParamConfig(), user, cli) + + # CLI should take precedence + assert internal.downloader.radar == "KHTX" + + # But the original user model should remain unchanged + assert user.radar == "KABC" + + +def test_cli_minimal_overrides_radar_id(): + """CLI radar_id override should work correctly.""" + user = UserConfig(base_dir="/tmp", radar="KABC") + cli = CLIConfig(radar="KHTX") + + config = resolve_config(ParamConfig(), user, cli) + + assert config.downloader.radar == "KHTX" # CLI wins + assert config.base_dir == "/tmp" # User value preserved + + +def test_cli_minimal_overrides_mode(): + """CLI mode override should work correctly.""" + user = UserConfig( + base_dir="/tmp", + radar="KABC", + mode="realtime", + start_time="2024-01-01T00:00:00Z", + end_time="2024-01-01T12:00:00Z" + ) + cli = CLIConfig(mode="historical") + + config = resolve_config(ParamConfig(), user, cli) + + assert config.mode == "historical" # CLI wins + assert config.downloader.radar == "KABC" # User value preserved + # Historical mode validation should pass since start/end times provided + + +def test_cli_precedence_no_user_config(): + """CLI should work even without UserConfig.""" + cli = CLIConfig(radar="KHTX", mode="realtime") + + # This will need minimal UserConfig for required fields + user = UserConfig(base_dir="/tmp") + config = resolve_config(ParamConfig(), user, cli) + + assert config.downloader.radar == "KHTX" + assert config.mode == "realtime" + + +def test_cli_only_overrides_specified_fields(): + """CLI should only override fields that are explicitly set.""" + user = UserConfig( + base_dir="/tmp", + radar="KABC", + mode="realtime", + threshold=35 + ) + + # CLI only sets radar_id + cli = CLIConfig(radar="KHTX") + + config = resolve_config(ParamConfig(), user, cli) + + assert config.downloader.radar == "KHTX" # CLI override + assert config.mode == "realtime" # User value preserved + assert config.segmenter.threshold == 35.0 # User value preserved diff --git a/tests/configuration/test_config_resolution.py b/tests/configuration/test_config_resolution.py new file mode 100644 index 0000000..1317e59 --- /dev/null +++ b/tests/configuration/test_config_resolution.py @@ -0,0 +1,491 @@ +"""Test config resolution and validation with Pydantic.""" + +import pytest + +from adapt.configuration.schemas.cli import CLIConfig +from adapt.configuration.schemas.param import ParamConfig +from adapt.configuration.schemas.resolve import deep_merge, resolve_config +from adapt.configuration.schemas.user import ( + UserAnalyzerConfig, + UserConfig, + UserProjectorConfig, + UserSegmenterConfig, +) + + +class TestConfigResolution: + """Test resolve_config() precedence and merging.""" + + def test_resolve_config_all_defaults(self): + """Resolving with no user/CLI overrides fails due to missing required fields.""" + import pytest + from pydantic import ValidationError + with pytest.raises(ValidationError): + resolve_config(ParamConfig(), None, None) + + def test_user_config_overrides_param_config(self): + """UserConfig values override ParamConfig defaults.""" + user = UserConfig(threshold=40, base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.segmenter.threshold == 40.0 + assert config.base_dir == "/tmp" + + def test_cli_config_with_valid_structure(self): + """CLIConfig structure validation (if implemented).""" + # CLIConfig currently has limited fields - test what exists + # Just verify it can be instantiated + CLIConfig() + + def test_precedence_param_user(self): + """Full precedence: User > Param.""" + param = ParamConfig() + user = UserConfig( + threshold=40, + radar="KDLH", + base_dir="/tmp" + ) + config = resolve_config(param, user, None) + + # User won on threshold + assert config.segmenter.threshold == 40.0 + # User won on radar_id + assert config.downloader.radar == "KDLH" + + def test_empty_user_config_uses_all_param_defaults(self): + """Empty UserConfig() doesn't override anything, fails if required fields missing.""" + import pytest + from pydantic import ValidationError + with pytest.raises(ValidationError): + resolve_config(ParamConfig(), UserConfig(), None) + + def test_none_user_config_uses_all_param_defaults(self): + """None UserConfig doesn't override anything, but still fails if required fields missing.""" + import pytest + from pydantic import ValidationError + with pytest.raises(ValidationError): + resolve_config(ParamConfig(), None, None) + + +class TestUserConfigAliases: + """Test UserConfig flat aliases map correctly.""" + + def test_threshold_alias(self): + """threshold flat alias maps to segmenter.threshold.""" + user = UserConfig(threshold=35, base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.segmenter.threshold == 35.0 + + def test_radar_id_alias(self): + """radar_id flat alias maps to downloader.radar_id.""" + user = UserConfig(radar="KDIX", base_dir="/tmp") + config = resolve_config(ParamConfig(), user, None) + + assert config.downloader.radar == "KDIX" + + def test_reflectivity_var_alias(self): + """reflectivity_var alias maps to global var_names.""" + user = UserConfig(reflectivity_var="dbz", base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.global_.var_names.reflectivity == "dbz" + + def test_max_projection_steps_alias(self): + """max_projection_steps alias maps to projector.max_projection_steps.""" + user = UserConfig(max_projection_steps=5, base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.projector.max_projection_steps == 5 + + def test_min_cellsize_gridpoint_alias(self): + """min_cellsize_gridpoint alias maps to segmenter.min_cellsize_gridpoint.""" + user = UserConfig(min_cellsize_gridpoint=10, base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.segmenter.min_cellsize_gridpoint == 10 + + def test_nested_segmenter_override(self): + """Nested segmenter config overrides flat alias.""" + user = UserConfig( + base_dir="/tmp", + radar="KHTX", + threshold=30, + segmenter=UserSegmenterConfig(threshold=40) + ) + config = resolve_config(ParamConfig(), user, None) + + # Nested should win + assert config.segmenter.threshold == 40.0 + + +class TestTypeCoercion: + """Test UserConfig type coercion.""" + + def test_int_coerced_to_float_for_threshold(self): + """Integer threshold is coerced to float.""" + user = UserConfig(base_dir="/tmp", radar="KHTX", threshold=35) # int + config = resolve_config(ParamConfig(), user, None) + + assert isinstance(config.segmenter.threshold, float) + assert config.segmenter.threshold == 35.0 + + def test_int_coerced_to_float_for_z_level(self): + """Integer z_level is coerced to float.""" + user = UserConfig(base_dir="/tmp", radar="KHTX", z_level=1500) # int + config = resolve_config(ParamConfig(), user, None) + + assert isinstance(config.global_.z_level, float) + assert config.global_.z_level == 1500.0 + + def test_method_normalized_to_lowercase(self): + """Method names are normalized to lowercase.""" + user = UserConfig(base_dir="/tmp", radar="KHTX", segmentation_method="THRESHOLD") + config = resolve_config(ParamConfig(), user, None) + + assert config.segmenter.method == "threshold" + + def test_uppercase_radar_id_preserved(self): + """Radar IDs are preserved in uppercase.""" + user = UserConfig(base_dir="/tmp", radar="KDIX") + config = resolve_config(ParamConfig(), user, None) + + assert config.downloader.radar == "KDIX" + + +class TestEdgeCases: + """Test config edge cases and error conditions.""" + + def test_none_values_dont_override(self): + """None values in UserConfig don't override ParamConfig.""" + user = UserConfig(threshold=None, radar="KDIX", base_dir="/tmp") + config = resolve_config(ParamConfig(), user, None) + + assert config.segmenter.threshold == 30.0 # default, not overridden + assert config.downloader.radar == "KDIX" + + def test_dict_user_config_accepted(self): + """Dict can be passed as UserConfig (converted by Pydantic).""" + user_dict = {"threshold": 35, "radar": "KDLH", "base_dir": "/tmp"} + config = resolve_config(ParamConfig(), user_dict, None) + + assert config.segmenter.threshold == 35.0 + assert config.downloader.radar == "KDLH" + + def test_empty_cli_config_dict_accepted(self): + """Empty dict can be passed as CLIConfig (converted by Pydantic).""" + cli_dict = {} + user = UserConfig(threshold=40, base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, cli_dict) + + # Empty CLI dict doesn't override anything + assert config.segmenter.threshold == 40.0 + + def test_incomplete_param_config_dict_rejected(self): + """Incomplete dict raises validation error.""" + with pytest.raises(Exception): # noqa: B017 — Pydantic ValidationError or TypeError + resolve_config({"incomplete": "dict"}, None, None) + + def test_internal_config_is_complete(self): + """Returned InternalConfig is complete with all fields.""" + user = UserConfig(base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.segmenter is not None + assert config.projector is not None + assert config.downloader is not None + + +class TestDefaultValues: + """Test ParamConfig default values match old behavior.""" + + def test_segmenter_defaults(self): + """Segmenter defaults match old hardcoded values.""" + user = UserConfig(base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.segmenter.threshold == 30.0 + assert config.segmenter.closing_kernel == (1, 1) + assert config.segmenter.min_cellsize_gridpoint == 5 + assert config.segmenter.filter_by_size is True + + def test_projector_defaults(self): + """Projector defaults match old hardcoded values.""" + user = UserConfig(base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.projector.method == "adapt_default" + assert config.projector.max_projection_steps == 3 # Updated default + assert config.projector.flow_params.winsize == 10 + assert config.projector.flow_params.iterations == 3 + + def test_downloader_defaults(self): + """Downloader defaults are initialized.""" + user = UserConfig(base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.downloader.latest_files > 0 + assert config.downloader.poll_interval_sec > 0 + + def test_regridder_defaults(self): + """Regridder defaults are complete.""" + user = UserConfig(base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + + assert config.regridder.grid_shape is not None + assert len(config.regridder.grid_shape) == 3 + assert config.regridder.save_netcdf is True + + +class TestConfigValidation: + """Test Pydantic validation of configs.""" + + def test_invalid_method_rejected(self): + """Invalid segmentation method raises validation error.""" + with pytest.raises(Exception): # noqa: B017 — Pydantic ValidationError + resolve_config( + ParamConfig(), + UserConfig(segmentation_method="invalid_method_xyz", base_dir="/tmp", radar="KHTX"), + None + ) + + def test_negative_threshold_rejected(self): + """Negative threshold is coerced to float but should work.""" + # Note: Pydantic may allow negative threshold if no constraint + # This test documents current behavior + user = UserConfig(threshold=-10, base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + assert config.segmenter.threshold == -10.0 + + def test_zero_min_cellsize_allowed(self): + """Zero min_cellsize is valid (means no filtering).""" + user = UserConfig(min_cellsize_gridpoint=0, base_dir="/tmp", radar="KHTX") + config = resolve_config(ParamConfig(), user, None) + assert config.segmenter.min_cellsize_gridpoint == 0 + + def test_valid_field_accepted(self): + """Valid fields in user config are accepted.""" + user_dict = { + "threshold": 35, + "radar": "KDIX", + "base_dir": "/tmp" + } + config = resolve_config(ParamConfig(), user_dict, None) + assert config.segmenter.threshold == 35.0 + assert config.downloader.radar == "KDIX" + + +class TestIntegration: + """Integration tests combining multiple features.""" + + def test_full_workflow_with_all_overrides(self): + """Full workflow: param + user.""" + user = UserConfig( + mode="historical", + threshold=35, + radar="KDLH", + base_dir="/tmp", + start_time="2024-01-01T00:00:00Z", + end_time="2024-01-01T12:00:00Z", + max_projection_steps=3, + segmenter=UserSegmenterConfig( + filter_by_size=False + ) + ) + config = resolve_config(ParamConfig(), user, None) + + # Verify all overrides took effect + assert config.mode == "historical" + assert config.segmenter.threshold == 35.0 + assert config.downloader.radar == "KDLH" + assert config.downloader.start_time == "2024-01-01T00:00:00Z" + assert config.projector.max_projection_steps == 3 + assert config.segmenter.filter_by_size is False + + def test_real_use_case_custom_radar(self): + """Real use case: custom radar with strict threshold.""" + user = UserConfig( + radar="KLTX", + base_dir="/tmp", + threshold=40, + reflectivity_var="reflectivity_dbz", + min_cellsize_gridpoint=20 + ) + config = resolve_config(ParamConfig(), user, None) + + assert config.downloader.radar == "KLTX" + assert config.segmenter.threshold == 40.0 + assert config.global_.var_names.reflectivity == "reflectivity_dbz" + assert config.segmenter.min_cellsize_gridpoint == 20 + + def test_nested_config_complex_flow_params(self): + """Complex nested config with custom flow parameters.""" + user = UserConfig( + base_dir="/tmp", + radar="KHTX", + projector=UserProjectorConfig( + max_projection_steps=5, + flow_params={ + "winsize": 15, + "iterations": 5, + "poly_n": 7, + } + ) + ) + config = resolve_config(ParamConfig(), user, None) + + assert config.projector.max_projection_steps == 5 + assert config.projector.flow_params.winsize == 15 + assert config.projector.flow_params.iterations == 5 + assert config.projector.flow_params.poly_n == 7 + + def test_analyzer_exclude_fields_union(self): + """analyzer.exclude_fields should union defaults with user-provided fields.""" + # Get default excludes from ParamConfig + param = ParamConfig() + default_excludes = set(param.analyzer.exclude_fields) + + # User adds additional excludes + user = UserConfig( + base_dir="/tmp", + radar="KHTX", + analyzer=UserAnalyzerConfig( + exclude_fields=["new_field1", "new_field2"] + ) + ) + + config = resolve_config(param, user, None) + + # Result should include both defaults AND user additions + actual_excludes = set(config.analyzer.exclude_fields) + expected_excludes = default_excludes | {"new_field1", "new_field2"} + + assert actual_excludes == expected_excludes + assert "new_field1" in config.analyzer.exclude_fields + assert "new_field2" in config.analyzer.exclude_fields + # Original defaults should still be there + for default_field in default_excludes: + assert default_field in config.analyzer.exclude_fields + + def test_analyzer_exclude_fields_via_top_level_alias(self): + """analyzer.exclude_fields union also works via top-level UserConfig alias.""" + param = ParamConfig() + default_excludes = set(param.analyzer.exclude_fields) + + # User sets exclude_fields at top level (alias) + user = UserConfig( + base_dir="/tmp", + radar="KHTX", + exclude_fields=["top_level_exclude"] + ) + + config = resolve_config(param, user, None) + + actual_excludes = set(config.analyzer.exclude_fields) + expected_excludes = default_excludes | {"top_level_exclude"} + + assert actual_excludes == expected_excludes + + +class TestDeepMergeSemantics: + """Test deep_merge behavior for lists, dicts, and values.""" + + def test_deep_merge_list_behavior_replacement(self): + """Lists should be replaced entirely, not concatenated.""" + base = { + "list_field": ["a", "b", "c"], + "other_field": "base_value" + } + override = { + "list_field": ["x", "y"], + "new_field": "override_value" + } + + result = deep_merge(base, override) + + # List should be completely replaced, not merged/concatenated + assert result["list_field"] == ["x", "y"] + assert result["other_field"] == "base_value" + assert result["new_field"] == "override_value" + + def test_deep_merge_nested_dict_behavior(self): + """Nested dicts should merge recursively.""" + base = { + "nested": { + "keep_this": "base_value", + "override_this": "old_value" + }, + "top_level": "base" + } + override = { + "nested": { + "override_this": "new_value", + "add_this": "added" + } + } + + result = deep_merge(base, override) + + # Nested dict should merge, not replace + assert result["nested"]["keep_this"] == "base_value" + assert result["nested"]["override_this"] == "new_value" + assert result["nested"]["add_this"] == "added" + assert result["top_level"] == "base" + + def test_deep_merge_multiple_overrides(self): + """Multiple overrides should apply in order.""" + base = {"field": "base"} + override1 = {"field": "middle"} + override2 = {"field": "final"} + + result = deep_merge(base, override1, override2) + + assert result["field"] == "final" + + +class TestParamConfigCompleteness: + """Test ParamConfig provides all runtime-critical defaults.""" + + def test_paramconfig_completeness(self): + """ParamConfig should provide all required fields for runtime.""" + param = ParamConfig() + param_dict = param.model_dump() + + # Critical runtime fields that must be present + required_top_level = ["mode", "global_", "downloader", "regridder", + "segmenter", "analyzer", "projector"] + + for field in required_top_level: + assert field in param_dict, f"Missing required top-level field: {field}" + assert param_dict[field] is not None, f"Field {field} is None" + + # Critical downloader fields + downloader = param_dict["downloader"] + downloader_required = ["output_dir", "latest_files", "latest_minutes", "poll_interval_sec"] + for field in downloader_required: + assert field in downloader, f"Missing downloader field: {field}" + + # Critical segmenter fields + segmenter = param_dict["segmenter"] + segmenter_required = ["method", "threshold", "min_cellsize_gridpoint"] + for field in segmenter_required: + assert field in segmenter, f"Missing segmenter field: {field}" + assert segmenter[field] is not None, f"Segmenter field {field} is None" + + # Critical regridder fields + regridder = param_dict["regridder"] + regridder_required = ["grid_shape", "grid_limits", "weighting_function"] + for field in regridder_required: + assert field in regridder, f"Missing regridder field: {field}" + assert regridder[field] is not None, f"Regridder field {field} is None" + + def test_paramconfig_can_instantiate_without_errors(self): + """ParamConfig should instantiate successfully with no validation errors.""" + # This should not raise any ValidationError + param = ParamConfig() + assert param is not None + + # Should be able to convert to dict + param_dict = param.model_dump() + assert isinstance(param_dict, dict) + assert len(param_dict) > 0 diff --git a/tests/configuration/test_initialization_run_id.py b/tests/configuration/test_initialization_run_id.py new file mode 100644 index 0000000..3861b03 --- /dev/null +++ b/tests/configuration/test_initialization_run_id.py @@ -0,0 +1,134 @@ +"""Tests for init_runtime_config run-id behavior.""" + +import json +import re +from argparse import Namespace + +import pytest + +from adapt.configuration.schemas.initialization import init_runtime_config +from adapt.persistence.registry import RepositoryRegistry + +_USE_TMP_BASE = object() + + +def _args(tmp_path, run_id=None, config=None, base_dir=_USE_TMP_BASE, radar="KBOX"): + resolved_base_dir = str(tmp_path) if base_dir is _USE_TMP_BASE else base_dir + return Namespace( + config=config, + radar=radar, + mode="historical", + start_time="2026-03-23T02:00:00Z", + end_time="2026-03-23T03:00:00Z", + base_dir=resolved_base_dir, + verbose=False, + run_id=run_id, + rerun=False, + ) + + +def test_init_runtime_config_accepts_valid_user_run_id(tmp_path, capsys): + """Valid user-provided run_id is accepted as a new run when not found.""" + run_id = "2026MAR23-0206-KBOX" + config = init_runtime_config(_args(tmp_path, run_id=run_id)) + out = capsys.readouterr().out + + assert config.run_id == run_id + assert f"Using user-provided run ID (new run): {run_id}" in out + + +def test_init_runtime_config_rejects_invalid_user_run_id(tmp_path): + """Invalid run_id format raises a ValueError.""" + with pytest.raises(ValueError, match="Invalid run_id: must match YYYYMONDD-HHMM-RADAR"): + init_runtime_config(_args(tmp_path, run_id="bad-run-id")) + + +def test_init_runtime_config_continues_existing_run_id(tmp_path, capsys): + """Existing user-provided run_id is treated as continuation.""" + run_id = "2026MAR23-0206-KBOX" + baseline = init_runtime_config(_args(tmp_path, run_id=None, radar="KBOX")) + saved_cfg = baseline.model_dump() + saved_cfg["run_id"] = run_id + saved_cfg["created_at"] = "2026-03-23T02:06:00+00:00" + with open(tmp_path / f"runtime_config_{run_id}.json", "w") as f: + json.dump(saved_cfg, f) + + registry = RepositoryRegistry.get_instance(tmp_path) + registry.register_radar("KBOX") + registry.register_run(run_id=run_id, radar="KBOX", mode="historical") + + config = init_runtime_config(_args(tmp_path, run_id=run_id)) + out = capsys.readouterr().out + + assert config.run_id == run_id + assert f"Continuing existing run ID: {run_id}" in out + + +def test_init_runtime_config_requires_base_dir_with_run_id(tmp_path): + """--base-dir is required when --run-id is provided.""" + with pytest.raises(ValueError, match="--base-dir is required when --run-id is provided"): + init_runtime_config(_args(tmp_path, run_id="2026MAR23-0206-KBOX", base_dir=None)) + + +def test_init_runtime_config_existing_run_ignores_config_and_cli(tmp_path, capsys): + """Existing run_id must load saved runtime config and ignore incoming config/CLI overrides.""" + run_id = "2026MAR23-0206-KBOX" + + # Create a saved runtime config for the run. + baseline = init_runtime_config(_args(tmp_path, run_id=None, radar="KBOX")) + saved_cfg = baseline.model_dump() + saved_cfg["run_id"] = run_id + saved_cfg["downloader"]["radar"] = "KBOX" + saved_cfg["segmenter"]["threshold"] = 37.0 + saved_cfg["created_at"] = "2026-03-23T02:06:00+00:00" + with open(tmp_path / f"runtime_config_{run_id}.json", "w") as f: + json.dump(saved_cfg, f) + + # Register existing run in repository registry. + registry = RepositoryRegistry.get_instance(tmp_path) + registry.register_radar("KBOX") + registry.register_run(run_id=run_id, radar="KBOX", mode="historical") + + # Provide conflicting config file and CLI radar override - should be ignored. + conflict_cfg = tmp_path / "conflict_config.py" + conflict_cfg.write_text( + "CONFIG = {\n" + " 'radar': 'KTLX',\n" + " 'base_dir': '/tmp/should_not_be_used',\n" + " 'threshold': 99,\n" + "}\n" + ) + config = init_runtime_config( + _args( + tmp_path, + run_id=run_id, + config=str(conflict_cfg), + radar="KTLX", + ) + ) + out = capsys.readouterr().out + + assert config.run_id == run_id + assert config.downloader.radar == "KBOX" + assert config.segmenter.threshold == 37.0 + assert "Ignoring user config file and CLI config overrides" in out + + +def test_init_runtime_config_existing_run_missing_saved_runtime_config(tmp_path): + """Existing run_id without runtime_config_.json should fail loudly.""" + run_id = "2026MAR23-0206-KBOX" + registry = RepositoryRegistry.get_instance(tmp_path) + registry.register_radar("KBOX") + registry.register_run(run_id=run_id, radar="KBOX", mode="historical") + + with pytest.raises(FileNotFoundError, match="Saved runtime config not found"): + init_runtime_config(_args(tmp_path, run_id=run_id)) + + +def test_init_runtime_config_auto_generates_formatted_run_id(tmp_path): + """Auto-generated run_id follows YYYYMONDD-HHMM-RADAR format.""" + config = init_runtime_config(_args(tmp_path, run_id=None)) + assert re.match( + r"^\d{4}(JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)\d{2}-\d{4}-KBOX$", + config.run_id, + ) diff --git a/tests/configuration/test_rerun_cleanup.py b/tests/configuration/test_rerun_cleanup.py new file mode 100644 index 0000000..c4b3438 --- /dev/null +++ b/tests/configuration/test_rerun_cleanup.py @@ -0,0 +1,55 @@ +from argparse import Namespace + +from adapt.configuration.schemas.initialization import init_runtime_config + + +def _args(tmp_path, radar="KPOE", rerun=True): + return Namespace( + config=None, + radar=radar, + mode="realtime", + start_time=None, + end_time=None, + base_dir=str(tmp_path), + verbose=False, + run_id=None, + rerun=rerun, + max_runtime=1, + no_plot=True, + plot_interval=2.0, + show_plots=False, + ) + + +def test_rerun_cleanup_does_not_delete_user_files(tmp_path, capsys): + # User-owned file in base dir (must survive) + cfg = tmp_path / "config.yaml" + cfg.write_text("user config\n") + notes = tmp_path / "notes" + notes.mkdir() + (notes / "keep.txt").write_text("keep\n") + + radar = "KPOE" + # Program-created radar output dir (should be deleted) + radar_dir = tmp_path / radar + (radar_dir / "analysis").mkdir(parents=True) + (radar_dir / "analysis" / "KPOE_processing_tracker.db").write_text("db") + + # Program-created runtime config (should be deleted) + (tmp_path / f"runtime_config_2026APR04-0000-{radar}.json").write_text("{}") + + # Program-created legacy pipeline catalog (should be deleted) + catalog_dir = tmp_path / "catalog" + catalog_dir.mkdir() + (catalog_dir / f"2026APR04-0000-{radar}_pipeline_catalog.db").write_text("db") + + # Trigger init with rerun cleanup + init_runtime_config(_args(tmp_path, radar=radar, rerun=True)) + _ = capsys.readouterr() + + assert cfg.exists() + assert (notes / "keep.txt").exists() + + assert not radar_dir.exists() + assert not (tmp_path / f"runtime_config_2026APR04-0000-{radar}.json").exists() + assert not (catalog_dir / f"2026APR04-0000-{radar}_pipeline_catalog.db").exists() diff --git a/tests/configuration/test_setup_directories.py b/tests/configuration/test_setup_directories.py new file mode 100644 index 0000000..1b92439 --- /dev/null +++ b/tests/configuration/test_setup_directories.py @@ -0,0 +1,37 @@ +from pathlib import Path + +import pytest + +from adapt.configuration.schemas.directories import setup_output_directories + +pytestmark = pytest.mark.unit + + +def test_setup_output_directories_creates_all(tmp_path): + """Test that setup_output_directories creates base, catalog, and logs dirs.""" + dirs = setup_output_directories(tmp_path) + + # New structure: base, catalog, and logs at root level + expected = {"base", "catalog", "logs"} + + assert set(dirs.keys()) == expected + + for path in dirs.values(): + assert isinstance(path, Path) + assert path.exists() + assert path.is_dir() + + +def test_setup_output_directories_is_idempotent(tmp_path): + dirs1 = setup_output_directories(tmp_path) + dirs2 = setup_output_directories(tmp_path) + + assert dirs1 == dirs2 + + +def test_base_and_log_dirs_exist(tmp_path): + """Test that base and logs directories exist after setup.""" + dirs = setup_output_directories(tmp_path) + + assert dirs["base"].exists() + assert dirs["logs"].exists() diff --git a/tests/configuration/test_setup_directories_extended.py b/tests/configuration/test_setup_directories_extended.py new file mode 100644 index 0000000..3be5b2e --- /dev/null +++ b/tests/configuration/test_setup_directories_extended.py @@ -0,0 +1,62 @@ +"""Tests for setup_directories utility functions.""" + +from datetime import UTC, datetime + +from adapt.configuration.schemas.directories import ( + get_plot_path, + setup_output_directories, +) + + +def test_setup_output_directories_with_explicit_path(tmp_path): + dirs = setup_output_directories(tmp_path) + + assert dirs["base"] == tmp_path + assert dirs["logs"] == tmp_path / "logs" + + for key, path in dirs.items(): + assert path.exists(), f"{key} directory not created" + + +def test_setup_output_directories_creates_subdirs(tmp_path): + setup_output_directories(tmp_path) + + assert tmp_path.is_dir() + assert (tmp_path / "logs").is_dir() + + # Type-specific dirs are created dynamically under RADAR_ID/, not at root + assert not (tmp_path / "nexrad").exists() + assert not (tmp_path / "gridnc").exists() + assert not (tmp_path / "analysis").exists() + assert not (tmp_path / "plots").exists() + + +def test_setup_directories_expands_tilde(tmp_path): + dirs = setup_output_directories(tmp_path) + assert "~" not in str(dirs["base"]) + + +def test_setup_directories_resolves_relative_paths(tmp_path): + dirs = setup_output_directories(tmp_path) + assert dirs["base"].is_absolute() + + +def test_get_plot_path_reflectivity(tmp_path): + dirs = setup_output_directories(tmp_path) + scan_time = datetime(2024, 1, 15, 12, 30, tzinfo=UTC) + + plot_path = get_plot_path(dirs, radar="KMOB", plot_type="reflectivity", scan_time=scan_time) + + assert plot_path is not None + assert "20240115" in str(plot_path) + assert "KMOB" in str(plot_path) + + +def test_get_plot_path_cells(tmp_path): + dirs = setup_output_directories(tmp_path) + scan_time = datetime(2024, 1, 15, 12, 30, tzinfo=UTC) + + plot_path = get_plot_path(dirs, radar="KMOB", plot_type="cells", scan_time=scan_time) + + assert plot_path is not None + assert "KMOB" in str(plot_path) diff --git a/tests/configuration/test_userconfig_normalization.py b/tests/configuration/test_userconfig_normalization.py new file mode 100644 index 0000000..085cfaf --- /dev/null +++ b/tests/configuration/test_userconfig_normalization.py @@ -0,0 +1,27 @@ + +from adapt.configuration.schemas.user import UserConfig + + +def test_uppercase_keys_are_handled(): + raw = { + "MODE": "historical", + "RADAR_ID": "KHTX", + "THRESHOLD_DBZ": 40, + "BASE_DIR": "/tmp/adapt_out", + } + + user = UserConfig.model_validate(raw) + + assert user.mode == "historical" + assert user.radar == "KHTX" + assert isinstance(user.threshold, float) and user.threshold == 40.0 + assert user.base_dir == "/tmp/adapt_out" + + +def test_unknown_keys_are_ignored(): + raw = {"MODE": "realtime", "UNKNOWN_LEGACY": 12345} + user = UserConfig.model_validate(raw) + + assert user.mode == "realtime" + # Unknown key should not become an attribute nor raise + assert not hasattr(user, "UNKNOWN_LEGACY") From cf61805c928ff56143913be1f560fb0a2fa35d39 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 14:52:21 -0500 Subject: [PATCH 03/14] ADD:(tests) acquisition module downloader init, queue, realtime, historical, failures, availability --- tests/modules/acquisition/__init__.py | 0 .../test_downloader_availability.py | 105 ++++++++++++++++ .../acquisition/test_downloader_failures.py | 45 +++++++ .../acquisition/test_downloader_historical.py | 59 +++++++++ .../acquisition/test_downloader_init.py | 119 ++++++++++++++++++ .../test_downloader_integration.py | 27 ++++ .../acquisition/test_downloader_queue.py | 55 ++++++++ .../acquisition/test_downloader_realtime.py | 61 +++++++++ ...downloader_realtime_availability_window.py | 54 ++++++++ .../acquisition/test_downloader_run.py | 26 ++++ 10 files changed, 551 insertions(+) create mode 100644 tests/modules/acquisition/__init__.py create mode 100644 tests/modules/acquisition/test_downloader_availability.py create mode 100644 tests/modules/acquisition/test_downloader_failures.py create mode 100644 tests/modules/acquisition/test_downloader_historical.py create mode 100644 tests/modules/acquisition/test_downloader_init.py create mode 100644 tests/modules/acquisition/test_downloader_integration.py create mode 100644 tests/modules/acquisition/test_downloader_queue.py create mode 100644 tests/modules/acquisition/test_downloader_realtime.py create mode 100644 tests/modules/acquisition/test_downloader_realtime_availability_window.py create mode 100644 tests/modules/acquisition/test_downloader_run.py diff --git a/tests/modules/acquisition/__init__.py b/tests/modules/acquisition/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/acquisition/test_downloader_availability.py b/tests/modules/acquisition/test_downloader_availability.py new file mode 100644 index 0000000..bd8ca7e --- /dev/null +++ b/tests/modules/acquisition/test_downloader_availability.py @@ -0,0 +1,105 @@ +"""Test radar availability warning behavior. + +Verifies that availability warnings are only logged when the check explicitly +succeeds and finds no radar (not when the check fails/times out). +""" + +import logging +from datetime import UTC, datetime +from unittest.mock import MagicMock + +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + +pytestmark = pytest.mark.unit + + +def test_availability_check_warns_when_radar_explicitly_not_found(caplog, temp_dir): + """Warn when availability check succeeds but radar not found.""" + config = MagicMock() + config.downloader.radar = "KOHX" + config.downloader.poll_interval_sec = 1 + config.downloader.latest_files = 5 + config.downloader.latest_minutes = 60 + config.downloader.start_time = None + config.downloader.end_time = None + config.downloader.min_file_size = 1024 + + fake_conn = MagicMock() + fake_conn.get_avail_radars.return_value = ["KDFW", "KLBB"] # KOHX not in list + + downloader = AwsNexradDownloader( + config=config, + output_dir=temp_dir, + conn=fake_conn, + ) + + start = datetime(2025, 12, 18, tzinfo=UTC) + end = datetime(2025, 12, 18, tzinfo=UTC) + + with caplog.at_level(logging.WARNING): + downloader._check_radar_available(start, end) + + # Should warn because check succeeded but radar not found + assert any("Radar KOHX not found in AWS" in record.message for record in caplog.records) + + +def test_availability_check_does_not_warn_when_check_fails(caplog, temp_dir): + """Do NOT warn when availability check fails (exception or all failures).""" + config = MagicMock() + config.downloader.radar = "KOHX" + config.downloader.poll_interval_sec = 1 + config.downloader.latest_files = 5 + config.downloader.latest_minutes = 60 + config.downloader.start_time = None + config.downloader.end_time = None + config.downloader.min_file_size = 1024 + + fake_conn = MagicMock() + fake_conn.get_avail_radars.side_effect = Exception("AWS unavailable") + + downloader = AwsNexradDownloader( + config=config, + output_dir=temp_dir, + conn=fake_conn, + ) + + start = datetime(2025, 12, 18, tzinfo=UTC) + end = datetime(2025, 12, 18, tzinfo=UTC) + + with caplog.at_level(logging.WARNING): + downloader._check_radar_available(start, end) + + # Should NOT warn because check failed (exception) + assert not any("Radar KOHX not found in AWS" in record.message for record in caplog.records) + + +def test_availability_check_does_not_warn_when_radar_found(caplog, temp_dir): + """Do NOT warn when radar is found in availability check.""" + config = MagicMock() + config.downloader.radar = "KOHX" + config.downloader.poll_interval_sec = 1 + config.downloader.latest_files = 5 + config.downloader.latest_minutes = 60 + config.downloader.start_time = None + config.downloader.end_time = None + config.downloader.min_file_size = 1024 + + fake_conn = MagicMock() + fake_conn.get_avail_radars.return_value = ["KDFW", "KOHX", "KLBB"] # KOHX IS in list + + downloader = AwsNexradDownloader( + config=config, + output_dir=temp_dir, + conn=fake_conn, + ) + + start = datetime(2025, 12, 18, tzinfo=UTC) + end = datetime(2025, 12, 18, tzinfo=UTC) + + with caplog.at_level(logging.WARNING): + downloader._check_radar_available(start, end) + + # Should NOT warn because radar was found + assert not any("Radar KOHX not found in AWS" in record.message for record in caplog.records) diff --git a/tests/modules/acquisition/test_downloader_failures.py b/tests/modules/acquisition/test_downloader_failures.py new file mode 100644 index 0000000..82197e9 --- /dev/null +++ b/tests/modules/acquisition/test_downloader_failures.py @@ -0,0 +1,45 @@ +# tests/test_downloader_failures.py +from datetime import UTC, datetime + +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + +pytestmark = pytest.mark.unit + + +def test_download_failure_does_not_queue(tmp_path, fake_scan, make_config): + class FailingConn: + def get_avail_scans_in_range(self, *a): + return [fake_scan("bad", datetime.now(UTC))] + + def download(self, *a, **k): + class R: + def iter_success(self): return [] + return R() + + config = make_config() + d = AwsNexradDownloader( + config, + output_dir=tmp_path, + conn=FailingConn() + ) + + downloads = d._download_realtime() + assert downloads == [] + + +def test_fetch_scans_exception_returns_empty(tmp_path, make_config): + class ExplodingConn: + def get_avail_scans_in_range(self, *a): + raise RuntimeError("AWS down") + + config = make_config() + d = AwsNexradDownloader( + config, + output_dir=tmp_path, + conn=ExplodingConn() + ) + + scans = d._fetch_scans(datetime.now(UTC), datetime.now(UTC)) + assert scans == [] diff --git a/tests/modules/acquisition/test_downloader_historical.py b/tests/modules/acquisition/test_downloader_historical.py new file mode 100644 index 0000000..c87f50b --- /dev/null +++ b/tests/modules/acquisition/test_downloader_historical.py @@ -0,0 +1,59 @@ +# tests/test_downloader_historical.py +from datetime import UTC, datetime + +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + +pytestmark = pytest.mark.unit + + +def test_historical_mode_completes(tmp_path, fake_scan, fake_aws_conn, make_config): + scans = [ + fake_scan("h1", datetime(2024, 1, 1, tzinfo=UTC)), + fake_scan("h2", datetime(2024, 1, 1, 1, tzinfo=UTC)), + ] + + config = make_config( + start_time="2024-01-01T00:00:00Z", + end_time="2024-01-01T02:00:00Z", + ) + + d = AwsNexradDownloader( + config, + output_dir=tmp_path, + conn=fake_aws_conn(scans), + sleeper=lambda _: None, + ) + + downloads = d._download_task() + + assert d.is_historical_complete() + processed, expected = d.get_historical_progress() + assert expected == 2 + assert len(downloads) == 2 + +# Mock AWS completely. +def test_fetch_scans_filters_and_sorts(monkeypatch, tmp_path, make_config): + class FakeScan: + def __init__(self, key, scan_time): + self.key = key + self.scan_time = scan_time + + scans = [ + FakeScan("file2", datetime(2024,1,1,1, tzinfo=UTC)), + FakeScan("file1_MDM", datetime(2024,1,1,0, tzinfo=UTC)), + FakeScan("file0", datetime(2024,1,1,0, tzinfo=UTC)), + ] + + config = make_config() + d = AwsNexradDownloader(config, output_dir=tmp_path) + monkeypatch.setattr( + d.conn, + "get_avail_scans_in_range", + lambda *args, **kwargs: scans + ) + + result = d._fetch_scans(datetime.now(UTC), datetime.now(UTC)) + assert len(result) == 2 + assert result[0].key == "file0" diff --git a/tests/modules/acquisition/test_downloader_init.py b/tests/modules/acquisition/test_downloader_init.py new file mode 100644 index 0000000..0f23afd --- /dev/null +++ b/tests/modules/acquisition/test_downloader_init.py @@ -0,0 +1,119 @@ +# tests/test_downloader_init.py +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + +pytestmark = pytest.mark.unit + + +def test_init_custom_config(make_config, radar_output_dirs): + """Downloader initializes with custom config.""" + from adapt.configuration.schemas.user import UserDownloaderConfig + config = make_config( + downloader=UserDownloaderConfig(radar="KDIX", latest_files=5, latest_minutes=60) + ) + d = AwsNexradDownloader(config, radar_output_dirs["nexrad"]) + + assert d.config.downloader.radar == "KDIX" + assert d.config.downloader.latest_files == 5 + assert d.config.downloader.latest_minutes == 60 + + +def test_stop_sets_event(radar_config, radar_output_dirs): + """Stop event prevents downloader from polling.""" + d = AwsNexradDownloader(radar_config, radar_output_dirs["nexrad"]) + assert not d.stopped() + d.stop() + assert d.stopped() + + +def test_historical_mode_from_config(make_config, radar_output_dirs): + """Downloader detects historical mode from config.downloader.mode.""" + from adapt.configuration.schemas.user import UserDownloaderConfig + config = make_config( + downloader=UserDownloaderConfig( + start_time="2024-01-01T00:00:00Z", + end_time="2024-01-01T01:00:00Z", + ) + ) + d = AwsNexradDownloader(config, radar_output_dirs["nexrad"]) + # Mode is decided by schema, not by is_historical_mode() method + assert d.config.downloader.mode == "historical" + + d2 = AwsNexradDownloader(make_config(), radar_output_dirs["nexrad"]) + assert d2.config.downloader.mode == "realtime" + + +def test_parse_time_range(make_config, radar_output_dirs): + """Downloader parses time range correctly.""" + from adapt.configuration.schemas.user import UserDownloaderConfig + config = make_config( + downloader=UserDownloaderConfig( + start_time="2024-01-01T00:00:00Z", + end_time="2024-01-01T01:00:00Z", + ) + ) + d = AwsNexradDownloader(config, radar_output_dirs["nexrad"]) + + start, end = d._parse_time_range() + + assert start.tzinfo is not None + assert end > start + assert (end - start).total_seconds() == 3600 + + +def test_file_exists_rejects_small_files(tmp_path, radar_config, radar_output_dirs): + """Downloader rejects files below minimum size.""" + d = AwsNexradDownloader(radar_config, tmp_path) + + p = tmp_path / "tiny" + p.write_bytes(b"x") + + assert not d._file_exists(p) + + +def test_file_exists_true(tmp_path, radar_config, radar_output_dirs): + """Downloader accepts files above minimum size.""" + d = AwsNexradDownloader(radar_config, radar_output_dirs["nexrad"]) + p = tmp_path / "f" + p.write_bytes(b"x" * 2048) + assert d._file_exists(p) + + +from datetime import datetime # noqa: E402 + + +def test_get_local_path(make_config, radar_output_dirs): + """Downloader generates correct local file paths with new structure.""" + class FakeScan: + key = "foo/bar/testfile" + scan_time = datetime(2024, 1, 1) + + from adapt.configuration.schemas.user import UserDownloaderConfig + config = make_config(downloader=UserDownloaderConfig(radar="KDIX")) + # Use output_dirs for new path structure (RADAR_ID/nexrad/YYYYMMDD/) + d = AwsNexradDownloader(config, output_dirs=radar_output_dirs) + + path = d._get_local_path(FakeScan()) + assert "20240101" in str(path) + assert "KDIX" in str(path) + assert path.name == "testfile" + # New structure: base/KDIX/nexrad/20240101/testfile + assert "KDIX/nexrad/20240101" in str(path) or "KDIX\\nexrad\\20240101" in str(path) + + +def test_get_local_path_legacy(make_config, radar_output_dirs): + """Downloader generates correct local file paths with legacy output_dir.""" + class FakeScan: + key = "foo/bar/testfile" + scan_time = datetime(2024, 1, 1) + + from adapt.configuration.schemas.user import UserDownloaderConfig + config = make_config(downloader=UserDownloaderConfig(radar="KDIX")) + # Use legacy output_dir parameter + d = AwsNexradDownloader(config, output_dir=radar_output_dirs["nexrad"]) + + path = d._get_local_path(FakeScan()) + assert "20240101" in str(path) + assert "KDIX" in str(path) + assert path.name == "testfile" diff --git a/tests/modules/acquisition/test_downloader_integration.py b/tests/modules/acquisition/test_downloader_integration.py new file mode 100644 index 0000000..b96e22d --- /dev/null +++ b/tests/modules/acquisition/test_downloader_integration.py @@ -0,0 +1,27 @@ +from datetime import UTC, datetime, timedelta + +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + + +@pytest.mark.integration +def test_real_aws_listing(tmp_path, make_config): + """Test real AWS NEXRAD data listing. + + Uses a known radar ID (KMOB) to ensure we get real data from AWS. + Skips if no scans are available (expected during low-activity periods). + """ + config = make_config(radar_id="KHTX") # Use a known radar with consistent data + d = AwsNexradDownloader(config, output_dir=tmp_path) + end = datetime.now(UTC) + start = end - timedelta(minutes=60) + scans = d._fetch_scans(start, end) + + # Skip if no scans available (integration test depends on real AWS data availability) + if not scans: + pytest.skip( + "No NEXRAD scans available in AWS for the past 60 minutes " + "(expected during low-activity periods)" + ) + diff --git a/tests/modules/acquisition/test_downloader_queue.py b/tests/modules/acquisition/test_downloader_queue.py new file mode 100644 index 0000000..43036c0 --- /dev/null +++ b/tests/modules/acquisition/test_downloader_queue.py @@ -0,0 +1,55 @@ +# tests/test_downloader_queue.py +from datetime import UTC, datetime +from queue import Queue + +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + +pytestmark = pytest.mark.unit + + +def test_notify_queue_puts_item(tmp_path, make_config): + q = Queue() + config = make_config() + d = AwsNexradDownloader(config, output_dir=tmp_path, result_queue=q) + + path = tmp_path / "file1" + + d._notify_queue( + path=path, + scan_time=datetime.now(UTC), + is_new=True, + ) + + item = q.get_nowait() + + assert item["radar"] == d.radar + assert item["path"] == path + assert "scan_time" in item + assert "file_id" in item + + +def test_notify_queue_calls_tracker(tmp_path, fake_scan, make_config): + class FakeTracker: + def __init__(self): + self.registered = False + + def register_file(self, *a, **k): + self.registered = True + + def mark_stage_complete(self, *a, **k): + pass + + tracker = FakeTracker() + from queue import Queue + + q = Queue() + config = make_config() + d = AwsNexradDownloader(config, output_dir=tmp_path, result_queue=q, file_tracker=tracker) + + d._notify_queue( + path=tmp_path / "f", scan_time=fake_scan("x").scan_time, is_new=True + ) + + assert tracker.registered diff --git a/tests/modules/acquisition/test_downloader_realtime.py b/tests/modules/acquisition/test_downloader_realtime.py new file mode 100644 index 0000000..0b0efa7 --- /dev/null +++ b/tests/modules/acquisition/test_downloader_realtime.py @@ -0,0 +1,61 @@ +# tests/test_downloader_realtime.py +from datetime import UTC, datetime +from queue import Queue + +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + +pytestmark = pytest.mark.unit + + +def test_realtime_download_hybrid(tmp_path, fake_scan, fake_aws_conn, make_config): + now = datetime(2024, 1, 1, tzinfo=UTC) + + scans = [ + fake_scan("scan1", now), + fake_scan("scan2", now), + ] + + q = Queue() + + config = make_config( + radar_id="KDIX", + latest_n=2, + minutes=30, + ) + + d = AwsNexradDownloader( + config, + output_dir=tmp_path, + result_queue=q, + conn=fake_aws_conn(scans), + clock=lambda: now, + sleeper=lambda _: None, + ) + + downloads = d._download_realtime() + + assert len(downloads) == 2 + assert q.qsize() == 2 + + for path in downloads: + assert path.exists() + assert path.stat().st_size >= 1024 + + +def test_realtime_idempotent(tmp_path, fake_scan, fake_aws_conn, make_config): + scans = [fake_scan("same")] + + config = make_config() + d = AwsNexradDownloader( + config, + output_dir=tmp_path, + conn=fake_aws_conn(scans), + sleeper=lambda _: None, + ) + + d._download_realtime() + d._download_realtime() + + assert len(d._known_files) == 1 diff --git a/tests/modules/acquisition/test_downloader_realtime_availability_window.py b/tests/modules/acquisition/test_downloader_realtime_availability_window.py new file mode 100644 index 0000000..99adea5 --- /dev/null +++ b/tests/modules/acquisition/test_downloader_realtime_availability_window.py @@ -0,0 +1,54 @@ +"""Realtime downloader should check availability over the whole lookback window. + +This prevents spammy false warnings around UTC midnight when the lookback window +spans two dates (yesterday inventory has radar, today may not yet). +""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock + +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + +pytestmark = pytest.mark.unit + + +def test_realtime_availability_check_uses_start_and_end_dates(temp_dir): + config = MagicMock() + config.downloader.mode = "realtime" + config.downloader.radar = "KPOE" + config.downloader.poll_interval_sec = 1 + config.downloader.latest_files = 5 + config.downloader.latest_minutes = 60 + config.downloader.start_time = None + config.downloader.end_time = None + config.downloader.min_file_size = 1024 + + fake_conn = MagicMock() + fake_conn.get_avail_scans_in_range.return_value = [] + fake_conn.get_avail_radars.return_value = ["KPOE"] + + # 2026-04-05 00:10Z lookback spans previous date. + now = datetime(2026, 4, 5, 0, 10, 0, tzinfo=UTC) + + downloader = AwsNexradDownloader( + config=config, + output_dir=temp_dir, + conn=fake_conn, + clock=lambda: now, + ) + + captured = {} + + def _capture(start, end): + captured["start"] = start + captured["end"] = end + return AwsNexradDownloader._check_radar_available(downloader, start, end) + + downloader._check_radar_available = _capture + + downloader._download_realtime() + + assert captured["end"] == now + assert captured["start"] == now - timedelta(minutes=60) diff --git a/tests/modules/acquisition/test_downloader_run.py b/tests/modules/acquisition/test_downloader_run.py new file mode 100644 index 0000000..f02e8f2 --- /dev/null +++ b/tests/modules/acquisition/test_downloader_run.py @@ -0,0 +1,26 @@ +# tests/test_downloader_run.py +import pytest + +from adapt.modules.acquisition.module import AwsNexradDownloader + +pytestmark = pytest.mark.unit + + +def test_run_exits_after_historical_complete(tmp_path, fake_scan, fake_aws_conn, make_config): + scans = [fake_scan("one")] + + config = make_config( + start_time="2024-01-01T00:00:00Z", + end_time="2024-01-01T01:00:00Z", + ) + + d = AwsNexradDownloader( + config, + output_dir=tmp_path, + conn=fake_aws_conn(scans), + sleeper=lambda _: None, + ) + + d.run() + + assert d.is_historical_complete() From 36b9ef9849268d5cdf1bd1d4e5ab9ecc53b378ff Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 14:54:04 -0500 Subject: [PATCH 04/14] ADD:(tests) detection module segmenter init, threshold, filtering, morphology, contract, failures --- tests/modules/detection/__init__.py | 0 .../detection/test_segmenter_contract.py | 22 +++++++ .../detection/test_segmenter_failures.py | 32 ++++++++++ .../detection/test_segmenter_filtering.py | 54 ++++++++++++++++ .../modules/detection/test_segmenter_init.py | 45 +++++++++++++ .../detection/test_segmenter_morphology.py | 41 ++++++++++++ .../detection/test_segmenter_threshold.py | 64 +++++++++++++++++++ 7 files changed, 258 insertions(+) create mode 100644 tests/modules/detection/__init__.py create mode 100644 tests/modules/detection/test_segmenter_contract.py create mode 100644 tests/modules/detection/test_segmenter_failures.py create mode 100644 tests/modules/detection/test_segmenter_filtering.py create mode 100644 tests/modules/detection/test_segmenter_init.py create mode 100644 tests/modules/detection/test_segmenter_morphology.py create mode 100644 tests/modules/detection/test_segmenter_threshold.py diff --git a/tests/modules/detection/__init__.py b/tests/modules/detection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/detection/test_segmenter_contract.py b/tests/modules/detection/test_segmenter_contract.py new file mode 100644 index 0000000..01a2065 --- /dev/null +++ b/tests/modules/detection/test_segmenter_contract.py @@ -0,0 +1,22 @@ +"""Test RadarCellSegmenter output contract and data structure.""" + +import numpy as np +import pytest + +from adapt.modules.detection.module import RadarCellSegmenter + +pytestmark = pytest.mark.unit + + +def test_output_contract(simple_2d_ds, detection_module_config): + """Segmenter output has correct shape, dtype, and metadata.""" + seg = RadarCellSegmenter(detection_module_config) + + out = seg.segment(simple_2d_ds) + da = out["cell_labels"] + + assert da.dims == ("y", "x") + assert da.dtype == np.int32 + assert "threshold" in da.attrs + assert "z_level_m" in da.attrs + assert da.attrs["method"] == "threshold" diff --git a/tests/modules/detection/test_segmenter_failures.py b/tests/modules/detection/test_segmenter_failures.py new file mode 100644 index 0000000..fa9ad2c --- /dev/null +++ b/tests/modules/detection/test_segmenter_failures.py @@ -0,0 +1,32 @@ +"""Test RadarCellSegmenter error handling and edge cases.""" + +import numpy as np +import pytest +import xarray as xr + +from adapt.modules.detection.module import RadarCellSegmenter + +pytestmark = pytest.mark.unit + + +def test_missing_reflectivity_var(detection_module_config): + """Segmenter fails gracefully when reflectivity variable missing.""" + ds = xr.Dataset( + {"wrong_var": (("y", "x"), np.ones((3, 3)))} + ) + + seg = RadarCellSegmenter(detection_module_config) + with pytest.raises(KeyError): + seg.segment(ds) + + +def test_non_2d_data_fails(detection_module_config): + """Segmenter rejects 3D data (must be 2D slice).""" + ds = xr.Dataset( + {"reflectivity": (("z", "y", "x"), np.ones((2, 3, 3)))} + ) + + seg = RadarCellSegmenter(detection_module_config) + with pytest.raises(Exception): # noqa: B017 — ValueError or similar from segmenter + seg.segment(ds) + diff --git a/tests/modules/detection/test_segmenter_filtering.py b/tests/modules/detection/test_segmenter_filtering.py new file mode 100644 index 0000000..bf2a0cc --- /dev/null +++ b/tests/modules/detection/test_segmenter_filtering.py @@ -0,0 +1,54 @@ +"""Test RadarCellSegmenter filtering logic.""" + +import pytest + +from adapt.modules.detection.module import RadarCellSegmenter + +pytestmark = pytest.mark.unit + + +def test_min_cellsize_filter(two_cell_ds, make_detection_config): + """Small cells below min_cellsize threshold are filtered out.""" + config = make_detection_config(threshold=20, min_cellsize_gridpoint=4) + seg = RadarCellSegmenter(config) + + out = seg.segment(two_cell_ds) + labels = out["cell_labels"].values + + # Both cells meet min_size (4 pixels each), so expect both or merged + assert labels.max() >= 1 + + +def test_disable_size_filter(two_cell_ds, make_detection_config): + """All detected cells are retained when filter_by_size=False.""" + # threshold=20 will detect both cells (50 and 30 dBZ) + from adapt.configuration.schemas.user import UserSegmenterConfig + config = make_detection_config( + threshold=20.0, + segmenter=UserSegmenterConfig(filter_by_size=False) + ) + seg = RadarCellSegmenter(config) + + out = seg.segment(two_cell_ds) + labels = out["cell_labels"].values + + # Both cells should be detected + assert labels.max() == 2 + + +def test_relabeling_is_contiguous(two_cell_ds, make_detection_config): + """Cell labels are contiguous integers starting from 1.""" + from adapt.configuration.schemas.user import UserSegmenterConfig + config = make_detection_config( + threshold=20.0, + segmenter=UserSegmenterConfig(filter_by_size=False) + ) + seg = RadarCellSegmenter(config) + + labels = seg.segment(two_cell_ds)["cell_labels"].values + unique = sorted(set(labels.flatten()) - {0}) + + # Labels should be [1, 2] for two cells + assert unique == list(range(1, len(unique) + 1)) + + diff --git a/tests/modules/detection/test_segmenter_init.py b/tests/modules/detection/test_segmenter_init.py new file mode 100644 index 0000000..fee038f --- /dev/null +++ b/tests/modules/detection/test_segmenter_init.py @@ -0,0 +1,45 @@ +"""Test RadarCellSegmenter initialization with Pydantic configs.""" + +import pytest + +from adapt.configuration.schemas.param import ParamConfig +from adapt.configuration.schemas.resolve import resolve_config +from adapt.configuration.schemas.user import UserConfig +from adapt.modules.detection.module import RadarCellSegmenter + +pytestmark = pytest.mark.unit + + +def test_default_config(detection_module_config): + """Segmenter uses expert defaults when no user overrides provided.""" + seg = RadarCellSegmenter(detection_module_config) + assert seg.method == "threshold" + assert seg.threshold == 30.0 + assert seg.filter_by_size is True + + +def test_custom_config(make_detection_config): + """Segmenter respects user config overrides.""" + config = make_detection_config( + threshold=45, + min_cellsize_gridpoint=10, + # Note: filter_by_size not exposed in UserConfig yet, uses default + ) + + seg = RadarCellSegmenter(config) + assert seg.threshold == 45.0 + assert seg.min_gridpoints == 10 + + +def test_unknown_method_raises(): + """Invalid segmentation method fails at config validation time.""" + # Pydantic validation happens at model creation, not at runtime + # This test verifies the old behavior is no longer needed + # Invalid methods are caught by Literal["threshold"] in ParamConfig + + with pytest.raises(Exception): # noqa: B017 — ValidationError from Pydantic + # Try to create config with invalid method + param = ParamConfig() + user = UserConfig(segmentation_method="watershed") # Invalid + resolve_config(param, user, None) + diff --git a/tests/modules/detection/test_segmenter_morphology.py b/tests/modules/detection/test_segmenter_morphology.py new file mode 100644 index 0000000..9067f39 --- /dev/null +++ b/tests/modules/detection/test_segmenter_morphology.py @@ -0,0 +1,41 @@ +"""Test RadarCellSegmenter morphological operations.""" + +import pytest + +from adapt.modules.detection.module import RadarCellSegmenter + +pytestmark = pytest.mark.unit + + +def test_close_cells_without_closing(close_cells_ds, make_detection_config): + """Without morphological closing, nearby cells remain separate.""" + from adapt.configuration.schemas.user import UserSegmenterConfig + config = make_detection_config( + threshold=30, + segmenter=UserSegmenterConfig(filter_by_size=False) + ) + seg = RadarCellSegmenter(config) + + labels = seg.segment(close_cells_ds)["cell_labels"].values + + # Two cells separated by gap should remain separate + assert labels.max() == 2 + + +def test_close_cells_with_closing(close_cells_ds, make_detection_config): + """Closing fills the gap but maxtree still resolves two distinct intensity peaks.""" + from adapt.configuration.schemas.user import UserSegmenterConfig + config = make_detection_config( + threshold=30, + segmenter=UserSegmenterConfig(filter_by_size=False, closing_kernel=(2, 2)) + ) + seg = RadarCellSegmenter(config) + + labels = seg.segment(close_cells_ds)["cell_labels"].values + + # Closing merges the binary mask, but the two reflectivity peaks remain + # distinct (separated by a 0-value gap in the original field), so maxtree + # seeds two watershed regions and correctly labels them as two cells. + assert labels.max() == 2 + + diff --git a/tests/modules/detection/test_segmenter_threshold.py b/tests/modules/detection/test_segmenter_threshold.py new file mode 100644 index 0000000..70906f6 --- /dev/null +++ b/tests/modules/detection/test_segmenter_threshold.py @@ -0,0 +1,64 @@ +"""Test RadarCellSegmenter threshold-based segmentation.""" + +import numpy as np +import pytest + +from adapt.modules.detection.module import RadarCellSegmenter + +pytestmark = pytest.mark.unit + + +def test_threshold_filters_all(simple_2d_ds, make_detection_config): + """Threshold higher than max value results in no cells.""" + config = make_detection_config(threshold=50) # Higher than 40 in simple_2d_ds + seg = RadarCellSegmenter(config) + + out = seg.segment(simple_2d_ds) + + assert "cell_labels" in out + labels = out["cell_labels"].values + + # No cells should exist + assert labels.max() == 0 + assert np.count_nonzero(labels) == 0 + + +def test_threshold_creates_at_least_one_cell(simple_2d_ds, make_detection_config): + """Threshold below max value creates cells.""" + config = make_detection_config(threshold=30, min_cellsize_gridpoint=2) + seg = RadarCellSegmenter(config) + + out = seg.segment(simple_2d_ds) + + assert "cell_labels" in out + labels = out["cell_labels"].values + + assert labels.max() >= 1 + assert np.count_nonzero(labels) > 0 + + +def test_no_cells_below_threshold(empty_2d_ds, detection_module_config): + """Empty dataset (all zeros) produces no cells.""" + seg = RadarCellSegmenter(detection_module_config) + + out = seg.segment(empty_2d_ds) + labels = out["cell_labels"].values + + assert labels.max() == 0 + + +def test__multiple_cells(large_multi_cell_ds, make_detection_config): + """Multiple distinct cells are detected and labeled.""" + # Don't filter by size for this test + from adapt.configuration.schemas.user import UserSegmenterConfig + config = make_detection_config( + threshold=30, + segmenter=UserSegmenterConfig(filter_by_size=False) + ) + seg = RadarCellSegmenter(config) + + out = seg.segment(large_multi_cell_ds) + labels = out["cell_labels"].values + + # Expect four distinct cells + assert labels.max() == 4 From 366c4e149790b48c8d61267c549151f114944f76 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 14:54:34 -0500 Subject: [PATCH 05/14] ADD:(tests) ingest module, loader core & extended --- tests/modules/ingest/__init__.py | 0 tests/modules/ingest/test_loader_core.py | 27 ++++++++ tests/modules/ingest/test_loader_extended.py | 68 ++++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 tests/modules/ingest/__init__.py create mode 100644 tests/modules/ingest/test_loader_core.py create mode 100644 tests/modules/ingest/test_loader_extended.py diff --git a/tests/modules/ingest/__init__.py b/tests/modules/ingest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/ingest/test_loader_core.py b/tests/modules/ingest/test_loader_core.py new file mode 100644 index 0000000..93e6c3e --- /dev/null +++ b/tests/modules/ingest/test_loader_core.py @@ -0,0 +1,27 @@ +import pytest + +pytestmark = pytest.mark.unit +from adapt.modules.ingest.module import RadarDataLoader # noqa: E402 + +# Note: Legacy tests for None/incomplete dict configs removed. +# InternalConfig validation now prevents invalid configurations at creation time. + + +def test_read_missing_file_raises(ingest_module_config_from_radar): + """Loader raises FileNotFoundError for missing files.""" + loader = RadarDataLoader(ingest_module_config_from_radar) + with pytest.raises(FileNotFoundError, match="Radar file not found"): + loader.read("/does/not/exist") + + +def test_regrid_propagates_exception(monkeypatch, ingest_module_config_from_radar): + """Loader propagates regridding exceptions to the caller.""" + loader = RadarDataLoader(ingest_module_config_from_radar) + + def boom(*a, **k): + raise RuntimeError("fail") + + monkeypatch.setattr("pyart.map.grid_from_radars", boom) + + with pytest.raises(RuntimeError, match="fail"): + loader.regrid(object()) diff --git a/tests/modules/ingest/test_loader_extended.py b/tests/modules/ingest/test_loader_extended.py new file mode 100644 index 0000000..816fe87 --- /dev/null +++ b/tests/modules/ingest/test_loader_extended.py @@ -0,0 +1,68 @@ +"""Extended tests for RadarDataLoader functionality.""" + + +import pytest + +from adapt.modules.ingest.module import RadarDataLoader + +pytestmark = pytest.mark.unit + + +def test_loader_stores_grid_shape(make_ingest_config): + """Loader stores grid_shape from config.""" + config = make_ingest_config() + loader = RadarDataLoader(config) + assert loader.grid_shape is not None + assert len(loader.grid_shape) == 3 + + +def test_loader_with_custom_grid_shape(make_ingest_config): + """Loader respects custom grid_shape.""" + config = make_ingest_config(grid_shape=(10, 50, 50)) + loader = RadarDataLoader(config) + assert loader.grid_shape == (10, 50, 50) + + +def test_loader_with_custom_weighting_function(make_ingest_config): + """Loader respects weighting function config.""" + config = make_ingest_config(regridder={"weighting_function": "barnes"}) + loader = RadarDataLoader(config) + assert loader.weighting_function == "barnes" + + +def test_loader_with_custom_min_radius(make_ingest_config): + """Loader respects min_radius config.""" + config = make_ingest_config(regridder={"min_radius": 2000.0}) + loader = RadarDataLoader(config) + assert loader.min_radius == 2000.0 + + +def test_loader_with_custom_roi_func(make_ingest_config): + """Loader respects roi_func config.""" + config = make_ingest_config(regridder={"roi_func": "dist"}) + loader = RadarDataLoader(config) + assert loader.roi_func == "dist" + + +def test_loader_with_custom_grid_limits(make_ingest_config): + """Loader respects custom grid_limits.""" + config = make_ingest_config( + grid_limits=((0, 10000), (-50000, 50000), (-50000, 50000)) + ) + loader = RadarDataLoader(config) + assert loader.grid_limits[0] == (0, 10000) + + +def test_loader_initialization_succeeds(make_ingest_config): + """Loader can be created successfully.""" + config = make_ingest_config() + loader = RadarDataLoader(config) + assert loader is not None + + +def test_read_nonexistent_file_raises(tmp_path, make_ingest_config): + """Reading non-existent file raises FileNotFoundError.""" + config = make_ingest_config() + loader = RadarDataLoader(config) + with pytest.raises(FileNotFoundError, match="Radar file not found"): + loader.read(tmp_path / "nonexistent.nc") From 8ffc1077c52d2a43a34ff9ee345039488a5078df Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 14:55:09 -0500 Subject: [PATCH 06/14] ADD:(tests) analysis module cell analyzer init, extract, centroids, adjacency, optional fields, utils, validation --- tests/modules/analysis/__init__.py | 0 tests/modules/analysis/test_cell_adjacency.py | 76 +++++++++++++++++++ .../test_cell_analyzer_basic_extract.py | 33 ++++++++ .../analysis/test_cell_analyzer_centroids.py | 25 ++++++ .../analysis/test_cell_analyzer_init.py | 27 +++++++ .../test_cell_analyzer_optional_fields.py | 28 +++++++ .../analysis/test_cell_analyzer_utils.py | 25 ++++++ .../analysis/test_cell_analyzer_validation.py | 18 +++++ 8 files changed, 232 insertions(+) create mode 100644 tests/modules/analysis/__init__.py create mode 100644 tests/modules/analysis/test_cell_adjacency.py create mode 100644 tests/modules/analysis/test_cell_analyzer_basic_extract.py create mode 100644 tests/modules/analysis/test_cell_analyzer_centroids.py create mode 100644 tests/modules/analysis/test_cell_analyzer_init.py create mode 100644 tests/modules/analysis/test_cell_analyzer_optional_fields.py create mode 100644 tests/modules/analysis/test_cell_analyzer_utils.py create mode 100644 tests/modules/analysis/test_cell_analyzer_validation.py diff --git a/tests/modules/analysis/__init__.py b/tests/modules/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/analysis/test_cell_adjacency.py b/tests/modules/analysis/test_cell_adjacency.py new file mode 100644 index 0000000..d3d634a --- /dev/null +++ b/tests/modules/analysis/test_cell_adjacency.py @@ -0,0 +1,76 @@ +import dataclasses +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from adapt.configuration.schemas.materialization import materialize_module_configs +from adapt.configuration.schemas.param import ParamConfig +from adapt.configuration.schemas.resolve import resolve_config +from adapt.configuration.schemas.user import UserConfig +from adapt.modules.analysis.module import RadarCellAnalyzer + + +@pytest.fixture +def config(): + d = tempfile.mkdtemp() + try: + import shutil + param = ParamConfig() + user = UserConfig(base_dir=str(Path(d)), radar="TEST_RADAR") + internal = resolve_config(param, user, None) + return materialize_module_configs(internal)["analysis_config"] + finally: + shutil.rmtree(d, ignore_errors=True) + + +def _ds_with_labels(time, labels: np.ndarray) -> xr.Dataset: + H, W = labels.shape + ds = xr.Dataset( + { + "cell_labels": (["y", "x"], labels.astype(np.int32)), + "reflectivity": (["y", "x"], np.zeros((H, W), dtype=np.float32)), + }, + coords={ + "y": np.arange(H) * 1000.0, + "x": np.arange(W) * 1000.0, + }, + ) + return ds.assign_coords(time=time) + + +def test_extract_adjacency_simple_touch(config): + analyzer = RadarCellAnalyzer(config) + + labels = np.zeros((4, 4), dtype=np.int32) + labels[:, :2] = 1 + labels[:, 2:] = 2 + + ds = _ds_with_labels(np.datetime64("2024-01-01T00:00:00"), labels) + df = analyzer.extract_adjacency(ds) + + assert list(df.columns) == ["time", "cell_label_a", "cell_label_b", "touching_boundary_pixels"] + assert len(df) == 1 + assert int(df.iloc[0]["cell_label_a"]) == 1 + assert int(df.iloc[0]["cell_label_b"]) == 2 + # boundary between col=1 and col=2 has 4 touching edges (one per row) + assert int(df.iloc[0]["touching_boundary_pixels"]) == 4 + + +def test_extract_adjacency_threshold_filters(config): + # Override threshold to require >4 touches so pair is filtered out + cfg = dataclasses.replace(config, adjacency_min_touching=5) + analyzer = RadarCellAnalyzer(cfg) + + labels = np.zeros((4, 4), dtype=np.int32) + labels[:, :2] = 1 + labels[:, 2:] = 2 + + ds = _ds_with_labels(np.datetime64("2024-01-01T00:00:00"), labels) + df = analyzer.extract_adjacency(ds) + + assert isinstance(df, pd.DataFrame) + assert df.empty diff --git a/tests/modules/analysis/test_cell_analyzer_basic_extract.py b/tests/modules/analysis/test_cell_analyzer_basic_extract.py new file mode 100644 index 0000000..b5dddc0 --- /dev/null +++ b/tests/modules/analysis/test_cell_analyzer_basic_extract.py @@ -0,0 +1,33 @@ +import pandas as pd +import pytest + +pytestmark = pytest.mark.unit +from adapt.modules.analysis.module import RadarCellAnalyzer # noqa: E402 + + +def test_extract_single_cell(labeled_ds_with_extras, make_analysis_config): + """Analyzer extracts single cell statistics.""" + config = make_analysis_config() + analyzer = RadarCellAnalyzer(config) + + df = analyzer.extract(labeled_ds_with_extras) + + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + assert df.iloc[0]["cell_label"] == 1 + + +def test_extract_produces_required_columns(labeled_ds_with_extras, make_analysis_config): + """Analyzer produces required output columns.""" + config = make_analysis_config() + analyzer = RadarCellAnalyzer(config) + + df = analyzer.extract(labeled_ds_with_extras) + row = df.iloc[0] + + assert "cell_area_sqkm" in row + assert "area_40dbz_km2" in row + assert "cell_centroid_geom_x" in row + assert "cell_centroid_geom_y" in row + assert "radar_reflectivity_max" in row + assert "radar_differential_reflectivity_max" in row diff --git a/tests/modules/analysis/test_cell_analyzer_centroids.py b/tests/modules/analysis/test_cell_analyzer_centroids.py new file mode 100644 index 0000000..aad9167 --- /dev/null +++ b/tests/modules/analysis/test_cell_analyzer_centroids.py @@ -0,0 +1,25 @@ +def test_geometric_centroid_is_inside_cell(labeled_ds_with_extras, make_analysis_config): + + from adapt.modules.analysis.module import RadarCellAnalyzer + + + config = make_analysis_config() + analyzer = RadarCellAnalyzer(config) + df = analyzer.extract(labeled_ds_with_extras) + + row = df.iloc[0] + + assert 0 <= row["cell_centroid_geom_x"] < labeled_ds_with_extras.dims["x"] + assert 0 <= row["cell_centroid_geom_y"] < labeled_ds_with_extras.dims["y"] + + +def test_mass_centroid_exists(labeled_ds_with_extras, make_analysis_config): + from adapt.modules.analysis.module import RadarCellAnalyzer + + config = make_analysis_config() + analyzer = RadarCellAnalyzer(config) + df = analyzer.extract(labeled_ds_with_extras) + + assert "cell_centroid_mass_x" in df.columns + assert "cell_centroid_mass_y" in df.columns + diff --git a/tests/modules/analysis/test_cell_analyzer_init.py b/tests/modules/analysis/test_cell_analyzer_init.py new file mode 100644 index 0000000..0b1baa6 --- /dev/null +++ b/tests/modules/analysis/test_cell_analyzer_init.py @@ -0,0 +1,27 @@ +import pytest + +from adapt.modules.analysis.module import RadarCellAnalyzer + +pytestmark = pytest.mark.unit + + +def test_init_with_default_config(make_analysis_config): + """Analyzer initializes with default config.""" + config = make_analysis_config() + analyzer = RadarCellAnalyzer(config) + + assert analyzer.reflectivity_field == "reflectivity" + assert analyzer.max_projection_steps > 0 + + +def test_init_custom_config(make_analysis_config): + """Analyzer initializes with custom config.""" + from adapt.configuration.schemas.user import UserProjectorConfig + config = make_analysis_config( + reflectivity_var="dbz", + projector=UserProjectorConfig(max_projection_steps=2) + ) + analyzer = RadarCellAnalyzer(config) + + assert analyzer.reflectivity_field == "dbz" + assert analyzer.max_projection_steps == 2 \ No newline at end of file diff --git a/tests/modules/analysis/test_cell_analyzer_optional_fields.py b/tests/modules/analysis/test_cell_analyzer_optional_fields.py new file mode 100644 index 0000000..3b4cb08 --- /dev/null +++ b/tests/modules/analysis/test_cell_analyzer_optional_fields.py @@ -0,0 +1,28 @@ +import pytest + +pytestmark = pytest.mark.unit + + +def test_heading_statistics_optional(labeled_ds_with_extras, make_analysis_config): + """Heading statistics are included in extraction.""" + from adapt.modules.analysis.module import RadarCellAnalyzer + + config = make_analysis_config() + analyzer = RadarCellAnalyzer(config) + df = analyzer.extract(labeled_ds_with_extras) + + assert "cell_heading_x_mean" in df.columns + assert "cell_heading_y_mean" in df.columns + + +def test_projection_centroids_json_present(labeled_ds_with_extras, make_analysis_config): + """Projection centroids are included in extraction.""" + from adapt.modules.analysis.module import RadarCellAnalyzer + + config = make_analysis_config() + analyzer = RadarCellAnalyzer(config) + + df = analyzer.extract(labeled_ds_with_extras) + + assert "cell_projection_centroids_json" in df.columns + assert isinstance(df.iloc[0]["cell_projection_centroids_json"], str) diff --git a/tests/modules/analysis/test_cell_analyzer_utils.py b/tests/modules/analysis/test_cell_analyzer_utils.py new file mode 100644 index 0000000..a022316 --- /dev/null +++ b/tests/modules/analysis/test_cell_analyzer_utils.py @@ -0,0 +1,25 @@ +import numpy as np +import pytest + +pytestmark = pytest.mark.unit +from adapt.modules.analysis.module import RadarCellAnalyzer # noqa: E402 + + +def test_get_lat_lon_bounds(): + lat = np.ones((5, 5)) + lon = np.ones((5, 5)) + + lat_val, lon_val = RadarCellAnalyzer.get_lat_lon(100, 100, lat, lon) + + assert np.isnan(lat_val) + assert np.isnan(lon_val) + + +def test_pixel_area_computation(simple_2d_ds, make_analysis_config): + """Analyzer computes pixel area correctly.""" + config = make_analysis_config() + analyzer = RadarCellAnalyzer(config) + + area = analyzer._pixel_area_km2(simple_2d_ds) + + assert area > 0 diff --git a/tests/modules/analysis/test_cell_analyzer_validation.py b/tests/modules/analysis/test_cell_analyzer_validation.py new file mode 100644 index 0000000..abf3ed3 --- /dev/null +++ b/tests/modules/analysis/test_cell_analyzer_validation.py @@ -0,0 +1,18 @@ +import pytest + +pytestmark = pytest.mark.unit +from adapt.modules.analysis.module import RadarCellAnalyzer # noqa: E402 + + +def test_extract_requires_cell_labels(labeled_ds_with_extras, make_analysis_config): + """Analyzer works correctly when cell_labels variable is present. + + NOTE: This replaces the old defensive check test. After SRP refactoring, + the analyzer no longer validates input - it assumes the segmenter has + already added labels. Input validation is Pydantic's responsibility. + """ + config = make_analysis_config() + analyzer = RadarCellAnalyzer(config) + + # With labels present, extract should work + analyzer.extract(labeled_ds_with_extras) From 0bba7db9a398af5e10598ed6d5ea732736b0fb6d Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 14:55:53 -0500 Subject: [PATCH 07/14] ADD:(tests) projection module projector init, basic projection, configs, internal utils, validation --- tests/modules/projection/__init__.py | 0 .../modules/projection/test_projector_init.py | 21 ++++++++++ .../test_projector_internal_utils.py | 33 ++++++++++++++++ .../test_projector_projection_basic.py | 32 ++++++++++++++++ .../test_projector_projection_configs.py | 38 +++++++++++++++++++ .../projection/test_projector_validation.py | 27 +++++++++++++ 6 files changed, 151 insertions(+) create mode 100644 tests/modules/projection/__init__.py create mode 100644 tests/modules/projection/test_projector_init.py create mode 100644 tests/modules/projection/test_projector_internal_utils.py create mode 100644 tests/modules/projection/test_projector_projection_basic.py create mode 100644 tests/modules/projection/test_projector_projection_configs.py create mode 100644 tests/modules/projection/test_projector_validation.py diff --git a/tests/modules/projection/__init__.py b/tests/modules/projection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/projection/test_projector_init.py b/tests/modules/projection/test_projector_init.py new file mode 100644 index 0000000..11da86c --- /dev/null +++ b/tests/modules/projection/test_projector_init.py @@ -0,0 +1,21 @@ +import pytest + +from adapt.modules.projection.module import RadarCellProjector + +pytestmark = pytest.mark.unit + + +def test_init_stores_method(make_projection_config): + """Projector stores method from config.""" + config = make_projection_config() + proj = RadarCellProjector(config) + + assert proj.method == config.method + + +def test_init_stores_projection_steps(make_projection_config): + """Projector stores max_proj_steps from config.""" + config = make_projection_config() + proj = RadarCellProjector(config) + + assert proj.max_proj_steps == config.max_projection_steps diff --git a/tests/modules/projection/test_projector_internal_utils.py b/tests/modules/projection/test_projector_internal_utils.py new file mode 100644 index 0000000..894db7b --- /dev/null +++ b/tests/modules/projection/test_projector_internal_utils.py @@ -0,0 +1,33 @@ +import numpy as np +import pytest + +from adapt.modules.projection.module import RadarCellProjector + +pytestmark = pytest.mark.unit + + +def test_normalize_constant_field(make_projection_config): + """Projector normalizes constant fields correctly.""" + config = make_projection_config() + proj = RadarCellProjector(config) + + a = np.ones((4, 4), dtype=np.float32) * 10 + b = np.ones((4, 4), dtype=np.float32) * 10 + + a_n, b_n = proj._normalize(a, b) + + assert a_n.dtype == np.uint8 + assert b_n.dtype == np.uint8 + + +def test_fill_concave_hull_small_object_falls_back(make_projection_config): + """Projector falls back for small objects in concave hull fill.""" + config = make_projection_config() + proj = RadarCellProjector(config) + + mask = np.zeros((5, 5), dtype=bool) + mask[2, 2] = True + + filled = proj._fill_concave_hull(mask) + + assert filled.any() diff --git a/tests/modules/projection/test_projector_projection_basic.py b/tests/modules/projection/test_projector_projection_basic.py new file mode 100644 index 0000000..7f2e395 --- /dev/null +++ b/tests/modules/projection/test_projector_projection_basic.py @@ -0,0 +1,32 @@ +import pytest + +from adapt.modules.projection.module import RadarCellProjector + +pytestmark = pytest.mark.unit + + +def test_projection_adds_expected_variables(simple_labeled_ds_pair, make_projection_config): + """Projector adds expected projection variables.""" + from adapt.configuration.schemas.user import UserProjectorConfig + config = make_projection_config(projector=UserProjectorConfig(max_projection_steps=1)) + proj = RadarCellProjector(config) + + out = proj.project(simple_labeled_ds_pair) + + assert "cell_projections" in out + assert "heading_x" in out + assert "heading_y" in out + + +def test_projection_dimensions(simple_labeled_ds_pair, make_projection_config): + """Projection output has correct dimensions.""" + from adapt.configuration.schemas.user import UserProjectorConfig + config = make_projection_config(projector=UserProjectorConfig(max_projection_steps=2)) + proj = RadarCellProjector(config) + + out = proj.project(simple_labeled_ds_pair) + + proj_da = out["cell_projections"] + + assert proj_da.dims == ("frame_offset", "y", "x") + assert proj_da.shape[0] == 3 # 1 registration + 2 future diff --git a/tests/modules/projection/test_projector_projection_configs.py b/tests/modules/projection/test_projector_projection_configs.py new file mode 100644 index 0000000..41a5588 --- /dev/null +++ b/tests/modules/projection/test_projector_projection_configs.py @@ -0,0 +1,38 @@ +import pytest + +from adapt.modules.projection.module import RadarCellProjector + +pytestmark = pytest.mark.unit + +def test_max_projection_steps_is_capped(simple_labeled_ds_pair, make_projection_config): + """Projection steps are capped at default maximum.""" + from adapt.configuration.schemas.user import UserProjectorConfig + config = make_projection_config(projector=UserProjectorConfig(max_projection_steps=100)) + proj = RadarCellProjector(config) + + out = proj.project(simple_labeled_ds_pair) + + assert out["cell_projections"].shape[0] == 11 # 1 + 10 + + +def test_custom_flow_params_do_not_crash(simple_labeled_ds_pair, make_projection_config): + """Custom flow parameters are accepted without crashing.""" + from adapt.configuration.schemas.user import UserProjectorConfig + config = make_projection_config( + projector=UserProjectorConfig( + flow_params={ + "pyr_scale": 0.3, + "levels": 2, + "winsize": 5, + "iterations": 2, + "poly_n": 3, + "poly_sigma": 1.1, + "flags": 0, + } + ) + ) + proj = RadarCellProjector(config) + + out = proj.project(simple_labeled_ds_pair) + + assert "cell_projections" in out diff --git a/tests/modules/projection/test_projector_validation.py b/tests/modules/projection/test_projector_validation.py new file mode 100644 index 0000000..d547a58 --- /dev/null +++ b/tests/modules/projection/test_projector_validation.py @@ -0,0 +1,27 @@ +import pytest + +from adapt.modules.projection.module import RadarCellProjector + +pytestmark = pytest.mark.unit + + +def test_validate_requires_two_datasets(simple_labeled_ds_pair, make_projection_config): + """Projector requires at least two datasets.""" + config = make_projection_config() + proj = RadarCellProjector(config) + + with pytest.raises(ValueError): + proj.project(simple_labeled_ds_pair[:1]) + + +def test_projection_skipped_if_time_gap_too_large(simple_labeled_ds_pair, make_projection_config): + """Projection skipped when time gap exceeds max interval.""" + from adapt.configuration.schemas.user import UserProjectorConfig + config = make_projection_config(projector=UserProjectorConfig(max_time_interval_minutes=1)) + proj = RadarCellProjector(config) + + out = proj.project(simple_labeled_ds_pair) + + # No projections added + assert "cell_projections" not in out + From 22fbaa1853569b3df1028f1a60eca671d02c635b Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 15:05:53 -0500 Subject: [PATCH 08/14] ADD:(tests) tracking module UID determinism test --- tests/modules/tracking/__init__.py | 1 + .../tracking/test_track_id_deterministic.py | 83 +++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 tests/modules/tracking/__init__.py create mode 100644 tests/modules/tracking/test_track_id_deterministic.py diff --git a/tests/modules/tracking/__init__.py b/tests/modules/tracking/__init__.py new file mode 100644 index 0000000..4bfc1e5 --- /dev/null +++ b/tests/modules/tracking/__init__.py @@ -0,0 +1 @@ +"""Tests for tracking module.""" diff --git a/tests/modules/tracking/test_track_id_deterministic.py b/tests/modules/tracking/test_track_id_deterministic.py new file mode 100644 index 0000000..0adeb39 --- /dev/null +++ b/tests/modules/tracking/test_track_id_deterministic.py @@ -0,0 +1,83 @@ +import pytest + +from adapt.modules.tracking.module import _cell_uid_from_signature, _track_signature_from_birth + +pytestmark = pytest.mark.unit + + +def test_track_signature_format(): + sig = _track_signature_from_birth( + scan_start_time_epoch_s=1700000000.0, + centroid_lat_deg=35.04, + centroid_lon_deg=-97.02, + max_dbz=52.4, + max_zdr=1.2, + area40_km2=12.2, + time_step_s=10, + latlon_step_deg=0.1, + area_step_km2=5.0, + ) + assert sig.startswith("v1|") + parts = sig.split("|") + assert len(parts) == 7 + assert parts[0] == "v1" + + +def test_cell_uid_fixed_width_and_uppercase(): + sig = _track_signature_from_birth( + scan_start_time_epoch_s=1700000000.0, + centroid_lat_deg=35.01, + centroid_lon_deg=-97.01, + max_dbz=50.0, + max_zdr=0.3, + area40_km2=10.0, + time_step_s=10, + latlon_step_deg=0.1, + area_step_km2=5.0, + ) + pid = _cell_uid_from_signature(sig, width=10) + assert len(pid) == 10 + assert pid == pid.upper() + assert pid.isalnum() + + +def test_cell_uid_quantization_stability(): + sig_a = _track_signature_from_birth( + scan_start_time_epoch_s=1700000000.0, + centroid_lat_deg=35.01, + centroid_lon_deg=-97.01, + max_dbz=50.2, + max_zdr=0.34, + area40_km2=10.2, + time_step_s=10, + latlon_step_deg=0.1, + area_step_km2=5.0, + ) + sig_b = _track_signature_from_birth( + scan_start_time_epoch_s=1700000002.0, + centroid_lat_deg=35.04, + centroid_lon_deg=-97.04, + max_dbz=50.4, + max_zdr=0.34, + area40_km2=12.4, + time_step_s=10, + latlon_step_deg=0.1, + area_step_km2=5.0, + ) + pid_a = _cell_uid_from_signature(sig_a, width=10) + pid_b = _cell_uid_from_signature(sig_b, width=10) + assert pid_a == pid_b + + sig_c = _track_signature_from_birth( + scan_start_time_epoch_s=1700000011.0, + centroid_lat_deg=35.16, + centroid_lon_deg=-97.16, + max_dbz=51.6, + max_zdr=0.56, + area40_km2=20.2, + time_step_s=10, + latlon_step_deg=0.1, + area_step_km2=5.0, + ) + pid_c = _cell_uid_from_signature(sig_c, width=10) + assert pid_a != pid_c From a7ebd76410f26a8c80f2ba43d1e0cfd7fee6f664 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 15:06:40 -0500 Subject: [PATCH 09/14] ADD:(tests) persistence data repository, catalog schema, registry radar location --- tests/persistence/__init__.py | 1 + tests/persistence/test_data_repository.py | 635 ++++++++++++++++++ .../persistence/test_radar_catalog_schema.py | 20 + .../test_registry_radar_location.py | 28 + ...est_repository_registers_radar_location.py | 32 + 5 files changed, 716 insertions(+) create mode 100644 tests/persistence/__init__.py create mode 100644 tests/persistence/test_data_repository.py create mode 100644 tests/persistence/test_radar_catalog_schema.py create mode 100644 tests/persistence/test_registry_radar_location.py create mode 100644 tests/persistence/test_repository_registers_radar_location.py diff --git a/tests/persistence/__init__.py b/tests/persistence/__init__.py new file mode 100644 index 0000000..82325da --- /dev/null +++ b/tests/persistence/__init__.py @@ -0,0 +1 @@ +# Tests for adapt.core module diff --git a/tests/persistence/test_data_repository.py b/tests/persistence/test_data_repository.py new file mode 100644 index 0000000..639ccbd --- /dev/null +++ b/tests/persistence/test_data_repository.py @@ -0,0 +1,635 @@ +"""Tests for DataRepository artifact management.""" + +import json +import re +import shutil +import sqlite3 +import tempfile +from datetime import UTC, datetime +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from adapt.persistence import DataRepository, ProductType + +# ========================================================================= +# Fixtures +# ========================================================================= + + +@pytest.fixture +def temp_base_dir(): + """Create temporary base directory.""" + d = tempfile.mkdtemp() + yield Path(d) + shutil.rmtree(d, ignore_errors=True) + + +@pytest.fixture +def repository(temp_base_dir): + """Create DataRepository instance.""" + repo = DataRepository( + run_id="test1234", + base_dir=temp_base_dir, + radar="KDIX" + ) + yield repo + repo.close() + + +@pytest.fixture +def sample_dataset(): + """Create a sample xarray Dataset.""" + return xr.Dataset({ + 'reflectivity': xr.DataArray( + np.random.randn(10, 10).astype(np.float32), + dims=['y', 'x'], + coords={ + 'y': np.arange(10) * 1000.0, + 'x': np.arange(10) * 1000.0, + } + ) + }) + + +@pytest.fixture +def sample_dataframe(): + """Create a sample DataFrame.""" + return pd.DataFrame({ + 'cell_label': [1, 2, 3], + 'cell_area_sqkm': [100.0, 200.0, 150.0], + 'reflectivity_max': [45.5, 52.3, 48.1], + }) + + +# ========================================================================= +# Test: Catalog Initialization +# ========================================================================= + + +class TestCatalogInitialization: + """Test catalog database creation.""" + + def test_directory_structure_created(self, repository, temp_base_dir): + """All required directories should be created.""" + expected_dirs = [ + temp_base_dir / "KDIX" / "nexrad", + temp_base_dir / "KDIX" / "gridnc", + temp_base_dir / "KDIX" / "analysis", + temp_base_dir / "KDIX" / "plots", + temp_base_dir / "logs", + ] + for d in expected_dirs: + assert d.exists(), f"Directory not created: {d}" + + def test_radar_catalog_created(self, repository, temp_base_dir): + """RadarCatalog SQLite file should be created under the radar directory.""" + catalog_db = temp_base_dir / "KDIX" / "catalog.db" + assert catalog_db.exists() + + def test_catalog_items_table_exists(self, repository): + """RadarCatalog items table should exist.""" + conn = repository.catalog._get_connection() + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='items'" + ) + assert cursor.fetchone() is not None + + def test_run_registered_in_registry(self, repository): + """Run should be registered in the root RepositoryRegistry.""" + runs = repository.registry.list_runs() + assert not runs.empty + assert repository.run_id in runs["run_id"].values + + +# ========================================================================= +# Test: Artifact Registration +# ========================================================================= + + +class TestArtifactRegistration: + """Test artifact registration.""" + + def test_register_artifact(self, repository, temp_base_dir): + """Should register artifact and return ID.""" + file_path = temp_base_dir / "KDIX" / "nexrad" / "20260211" / "test.nc" + file_path.parent.mkdir(parents=True) + file_path.touch() + + artifact_id = repository.register_artifact( + product_type=ProductType.NEXRAD_RAW, + file_path=file_path, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test", + parent_ids=[], + metadata={"test": True} + ) + + assert len(artifact_id) == 16 + + def test_register_artifact_with_scan_time(self, repository, temp_base_dir): + """Should register artifact with scan_time.""" + file_path = temp_base_dir / "KDIX" / "test_file.db" + file_path.touch() + + artifact_id = repository.register_artifact( + product_type=ProductType.CELLS_DB, + file_path=file_path, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + assert len(artifact_id) == 16 + + def test_query_artifacts(self, repository, temp_base_dir): + """Should query registered artifacts.""" + file_path = temp_base_dir / "KDIX" / "test.nc" + file_path.touch() + + artifact_id = repository.register_artifact( + product_type=ProductType.GRIDDED_NC, + file_path=file_path, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + results = repository.query(product_type=ProductType.GRIDDED_NC) + assert len(results) == 1 + assert results[0]['artifact_id'] == artifact_id + + def test_query_by_time_range(self, repository, temp_base_dir): + """Should filter by time range.""" + file1 = temp_base_dir / "KDIX" / "file1.nc" + file2 = temp_base_dir / "KDIX" / "file2.nc" + file1.touch() + file2.touch() + + repository.register_artifact( + product_type=ProductType.GRIDDED_NC, + file_path=file1, + scan_time=datetime(2026, 2, 11, 10, 0, 0, tzinfo=UTC), + producer="test" + ) + repository.register_artifact( + product_type=ProductType.GRIDDED_NC, + file_path=file2, + scan_time=datetime(2026, 2, 11, 14, 0, 0, tzinfo=UTC), + producer="test" + ) + + results = repository.query( + product_type=ProductType.GRIDDED_NC, + time_range=( + datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + datetime(2026, 2, 11, 16, 0, 0, tzinfo=UTC) + ) + ) + assert len(results) == 1 + + def test_get_artifact(self, repository, temp_base_dir): + """Should retrieve artifact by ID with expected fields.""" + file_path = temp_base_dir / "KDIX" / "test.nc" + file_path.touch() + + artifact_id = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file_path, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="processor" + ) + + artifact = repository.get_artifact(artifact_id) + assert artifact is not None + assert artifact['product_type'] == ProductType.ANALYSIS_NC + assert artifact['producer'] == "processor" + + def test_get_artifact_file_path_is_absolute(self, repository, temp_base_dir): + """Returned artifact file_path should be absolute.""" + file_path = temp_base_dir / "KDIX" / "test.nc" + file_path.touch() + + artifact_id = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file_path, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + artifact = repository.get_artifact(artifact_id) + assert Path(artifact['file_path']).is_absolute() + + +# ========================================================================= +# Test: Write Operations +# ========================================================================= + + +class TestWriteOperations: + """Test atomic write operations.""" + + def test_write_netcdf(self, repository, sample_dataset): + """Should write NetCDF and register artifact.""" + artifact_id = repository.write_netcdf( + ds=sample_dataset, + product_type=ProductType.ANALYSIS_NC, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + artifact = repository.get_artifact(artifact_id) + assert artifact is not None + assert Path(artifact['file_path']).exists() + + filename = Path(artifact['file_path']).name + assert "test1234" in filename # run_id + assert "analysis" in filename + assert filename.endswith(".nc") + + def test_write_netcdf_gridded(self, repository, sample_dataset): + """Should write gridded NetCDF with correct path.""" + artifact_id = repository.write_netcdf( + ds=sample_dataset, + product_type=ProductType.GRIDDED_NC, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="loader" + ) + + artifact = repository.get_artifact(artifact_id) + file_path = Path(artifact['file_path']) + + assert "KDIX" in str(file_path) + assert "gridnc" in str(file_path) + assert "20260211" in str(file_path) + assert "gridded" in file_path.name + + def test_write_parquet(self, repository, sample_dataframe): + """Should write Parquet and register artifact.""" + artifact_id = repository.write_parquet( + df=sample_dataframe, + product_type=ProductType.CELLS_PARQUET, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + artifact = repository.get_artifact(artifact_id) + assert artifact is not None + assert Path(artifact['file_path']).exists() + + metadata = json.loads(artifact['metadata']) + assert metadata['row_count'] == 3 + + def test_get_or_create_cells_db(self, repository): + """Should create cells database.""" + artifact_id = repository.get_or_create_cells_db( + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="processor" + ) + + artifact = repository.get_artifact(artifact_id) + assert artifact is not None + assert artifact['product_type'] == ProductType.CELLS_DB + assert Path(artifact['file_path']).exists() + + def test_get_or_create_cells_db_reuse(self, repository): + """Should reuse existing cells database.""" + id1 = repository.get_or_create_cells_db( + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="processor" + ) + id2 = repository.get_or_create_cells_db( + scan_time=datetime(2026, 2, 11, 13, 0, 0, tzinfo=UTC), + producer="processor" + ) + + assert id1 == id2 + + def test_write_sqlite_table(self, repository, sample_dataframe): + """Should write DataFrame to SQLite table.""" + db_artifact_id = repository.get_or_create_cells_db( + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="processor" + ) + + repository.write_sqlite_table( + df=sample_dataframe, + table_name='cells', + artifact_id=db_artifact_id + ) + + artifact = repository.get_artifact(db_artifact_id) + with sqlite3.connect(artifact['file_path']) as conn: + df_read = pd.read_sql("SELECT * FROM cells", conn) + assert len(df_read) == 3 + + +# ========================================================================= +# Test: Data Access +# ========================================================================= + + +class TestDataAccess: + """Test data access operations.""" + + def test_open_dataset(self, repository, sample_dataset): + """Should open NetCDF as xarray Dataset.""" + artifact_id = repository.write_netcdf( + ds=sample_dataset, + product_type=ProductType.ANALYSIS_NC, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + opened_ds = repository.open_dataset(artifact_id) + assert 'reflectivity' in opened_ds.data_vars + opened_ds.close() + + def test_open_dataset_invalid_type(self, repository, temp_base_dir): + """Should raise error for non-NetCDF artifact.""" + file_path = temp_base_dir / "KDIX" / "test.db" + file_path.touch() + + artifact_id = repository.register_artifact( + product_type=ProductType.CELLS_DB, + file_path=file_path, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + with pytest.raises(ValueError, match="Cannot open as dataset"): + repository.open_dataset(artifact_id) + + def test_open_table_parquet(self, repository, sample_dataframe): + """Should open Parquet as DataFrame.""" + artifact_id = repository.write_parquet( + df=sample_dataframe, + product_type=ProductType.CELLS_PARQUET, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + opened_df = repository.open_table(artifact_id) + assert len(opened_df) == 3 + assert 'cell_label' in opened_df.columns + + def test_open_table_sqlite(self, repository, sample_dataframe): + """Should open SQLite table as DataFrame.""" + db_artifact_id = repository.get_or_create_cells_db( + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="processor" + ) + repository.write_sqlite_table( + df=sample_dataframe, + table_name='cells', + artifact_id=db_artifact_id + ) + + opened_df = repository.open_table(db_artifact_id, table_name='cells') + assert len(opened_df) == 3 + + def test_open_nonexistent_artifact(self, repository): + """Should raise error for nonexistent artifact.""" + with pytest.raises(ValueError, match="Artifact not found"): + repository.open_dataset("nonexistent") + + +# ========================================================================= +# Test: Lifecycle +# ========================================================================= + + +class TestLifecycle: + """Test repository lifecycle.""" + + def test_finalize_run(self, repository): + """Should mark run as complete in registry.""" + repository.finalize_run("completed") + + runs = repository.registry.list_runs() + row = runs[runs["run_id"] == repository.run_id] + assert not row.empty + assert row.iloc[0]["status"] == "completed" + + def test_context_manager(self, temp_base_dir): + """Should work as context manager.""" + with DataRepository( + run_id="ctx12345", + base_dir=temp_base_dir, + radar="KHTX" + ) as repo: + assert repo.run_id == "ctx12345" + + def test_generate_run_id(self): + """Should generate valid run IDs.""" + run_id = DataRepository.generate_run_id("KBOX") + assert re.match( + r"^\d{4}(JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)\d{2}-\d{4}-KBOX$", + run_id + ) + + +# ========================================================================= +# Test: Path Generation +# ========================================================================= + + +class TestPathGeneration: + """Test path generation methods.""" + + def test_generate_plot_path(self, repository): + """Should generate correct plot path.""" + path = repository.generate_plot_path( + plot_type="reflectivity", + scan_time=datetime(2026, 2, 11, 12, 30, 45, tzinfo=UTC) + ) + + assert "KDIX" in str(path) + assert "plots" in str(path) + assert "20260211" in str(path) + assert "reflectivity" in path.name + assert "123045" in path.name # HHMMSS + assert "test1234" in path.name # run_id + assert path.suffix == ".png" + + +# ========================================================================= +# Test: Get Latest (PlotConsumer API) +# ========================================================================= + + +class TestGetLatest: + """Test get_latest method for PlotConsumer polling.""" + + def test_get_latest_no_artifacts(self, repository): + """Should return None when no artifacts exist.""" + result = repository.get_latest(ProductType.ANALYSIS_NC) + assert result is None + + def test_get_latest_single_artifact(self, repository, temp_base_dir): + """Should return the only artifact.""" + file_path = temp_base_dir / "KDIX" / "test.nc" + file_path.touch() + + artifact_id = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file_path, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + result = repository.get_latest(ProductType.ANALYSIS_NC) + assert result is not None + assert result['artifact_id'] == artifact_id + + def test_get_latest_returns_most_recent(self, repository, temp_base_dir): + """Should return artifact with most recent scan_time.""" + file1 = temp_base_dir / "KDIX" / "file1.nc" + file2 = temp_base_dir / "KDIX" / "file2.nc" + file3 = temp_base_dir / "KDIX" / "file3.nc" + file1.touch() + file2.touch() + file3.touch() + + repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file1, + scan_time=datetime(2026, 2, 11, 10, 0, 0, tzinfo=UTC), + producer="test" + ) + latest_id = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file2, + scan_time=datetime(2026, 2, 11, 14, 0, 0, tzinfo=UTC), + producer="test" + ) + repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file3, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + result = repository.get_latest(ProductType.ANALYSIS_NC) + assert result['artifact_id'] == latest_id + + def test_get_latest_filters_by_product_type(self, repository, temp_base_dir): + """Should only return artifacts of requested type.""" + file1 = temp_base_dir / "KDIX" / "analysis.nc" + file2 = temp_base_dir / "KDIX" / "gridded.nc" + file1.touch() + file2.touch() + + analysis_id = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file1, + scan_time=datetime(2026, 2, 11, 10, 0, 0, tzinfo=UTC), + producer="test" + ) + repository.register_artifact( + product_type=ProductType.GRIDDED_NC, + file_path=file2, + scan_time=datetime(2026, 2, 11, 14, 0, 0, tzinfo=UTC), + producer="test" + ) + + result = repository.get_latest(ProductType.ANALYSIS_NC) + assert result['artifact_id'] == analysis_id + assert result['product_type'] == ProductType.ANALYSIS_NC + + +class TestGetAllSince: + """Test get_all_since method for catching up missed artifacts.""" + + def test_get_all_since_no_reference(self, repository, temp_base_dir): + """Should return all artifacts when no reference provided.""" + file1 = temp_base_dir / "KDIX" / "file1.nc" + file2 = temp_base_dir / "KDIX" / "file2.nc" + file1.touch() + file2.touch() + + repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file1, + scan_time=datetime(2026, 2, 11, 10, 0, 0, tzinfo=UTC), + producer="test" + ) + repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file2, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + results = repository.get_all_since(ProductType.ANALYSIS_NC) + assert len(results) == 2 + + def test_get_all_since_with_reference(self, repository, temp_base_dir): + """Should return only artifacts after reference.""" + file1 = temp_base_dir / "KDIX" / "file1.nc" + file2 = temp_base_dir / "KDIX" / "file2.nc" + file3 = temp_base_dir / "KDIX" / "file3.nc" + file1.touch() + file2.touch() + file3.touch() + + id1 = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file1, + scan_time=datetime(2026, 2, 11, 10, 0, 0, tzinfo=UTC), + producer="test" + ) + id2 = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file2, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + id3 = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file3, + scan_time=datetime(2026, 2, 11, 14, 0, 0, tzinfo=UTC), + producer="test" + ) + + results = repository.get_all_since(ProductType.ANALYSIS_NC, since_artifact_id=id1) + assert len(results) == 2 + result_ids = [r['artifact_id'] for r in results] + assert id2 in result_ids + assert id3 in result_ids + assert id1 not in result_ids + + def test_get_all_since_returns_chronological_order(self, repository, temp_base_dir): + """Should return artifacts in chronological order (oldest first).""" + file1 = temp_base_dir / "KDIX" / "file1.nc" + file2 = temp_base_dir / "KDIX" / "file2.nc" + file3 = temp_base_dir / "KDIX" / "file3.nc" + file1.touch() + file2.touch() + file3.touch() + + id1 = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file1, + scan_time=datetime(2026, 2, 11, 10, 0, 0, tzinfo=UTC), + producer="test" + ) + id2 = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file2, + scan_time=datetime(2026, 2, 11, 14, 0, 0, tzinfo=UTC), + producer="test" + ) + id3 = repository.register_artifact( + product_type=ProductType.ANALYSIS_NC, + file_path=file3, + scan_time=datetime(2026, 2, 11, 12, 0, 0, tzinfo=UTC), + producer="test" + ) + + results = repository.get_all_since(ProductType.ANALYSIS_NC) + # Should be in chronological order: id1 (10h), id3 (12h), id2 (14h) + assert results[0]['artifact_id'] == id1 + assert results[1]['artifact_id'] == id3 + assert results[2]['artifact_id'] == id2 diff --git a/tests/persistence/test_radar_catalog_schema.py b/tests/persistence/test_radar_catalog_schema.py new file mode 100644 index 0000000..61c47c4 --- /dev/null +++ b/tests/persistence/test_radar_catalog_schema.py @@ -0,0 +1,20 @@ +import sqlite3 + +from adapt.persistence.catalog import RadarCatalog + + +def test_radar_catalog_initializes_track_tables(tmp_path): + radar_dir = tmp_path / "KPOE" + radar_dir.mkdir() + + catalog = RadarCatalog(radar_dir) + conn = sqlite3.connect(catalog.db_path) + tables = { + row[0] + for row in conn.execute("SELECT name FROM sqlite_master WHERE type='table'") + } + conn.close() + catalog.close() + + assert {"items", "progress", "schemas", "scans", + "cells_by_scan", "cell_events", "cell_tracks"}.issubset(tables) diff --git a/tests/persistence/test_registry_radar_location.py b/tests/persistence/test_registry_radar_location.py new file mode 100644 index 0000000..517a491 --- /dev/null +++ b/tests/persistence/test_registry_radar_location.py @@ -0,0 +1,28 @@ +from adapt.persistence.registry import RepositoryRegistry + + +def test_registry_ensure_radar_location_populates_missing(tmp_path): + registry = RepositoryRegistry.get_instance(tmp_path) + registry.register_radar("KPOE", lat=None, lon=None) + + lat0, lon0 = registry.get_radar_location("KPOE") + assert lat0 is None + assert lon0 is None + + registry.ensure_radar_location("KPOE", lat=31.155277252197266, lon=-92.97611236572266) + + lat1, lon1 = registry.get_radar_location("KPOE") + assert lat1 == 31.155277252197266 + assert lon1 == -92.97611236572266 + + +def test_registry_ensure_radar_location_is_idempotent(tmp_path): + registry = RepositoryRegistry.get_instance(tmp_path) + registry.register_radar("KPOE", lat=31.0, lon=-92.0) + + registry.ensure_radar_location("KPOE", lat=31.155277252197266, lon=-92.97611236572266) + + lat, lon = registry.get_radar_location("KPOE") + assert lat == 31.0 + assert lon == -92.0 + diff --git a/tests/persistence/test_repository_registers_radar_location.py b/tests/persistence/test_repository_registers_radar_location.py new file mode 100644 index 0000000..9508b3c --- /dev/null +++ b/tests/persistence/test_repository_registers_radar_location.py @@ -0,0 +1,32 @@ +import adapt.persistence.repository as repo_mod +from adapt.persistence.registry import RepositoryRegistry +from adapt.persistence.repository import DataRepository + + +def test_repository_does_not_use_external_radar_location_lookup(tmp_path, monkeypatch): + def _should_not_be_called(_radar: str): + raise AssertionError("_lookup_radar_location_pyart should not be called") + + if hasattr(repo_mod, "_lookup_radar_location_pyart"): + monkeypatch.setattr(repo_mod, "_lookup_radar_location_pyart", _should_not_be_called) + repo = DataRepository(run_id="TESTRUN", base_dir=tmp_path, radar="KPOE", config=None) + radars = repo.registry.list_radars() + row = radars[radars["radar"] == "KPOE"].iloc[0] + assert row["location_lat"] is None + assert row["location_lon"] is None + + +def test_repository_does_not_overwrite_existing_radar_location(tmp_path, monkeypatch): + registry = RepositoryRegistry.get_instance(tmp_path) + registry.register_radar("KPOE", lat=9.0, lon=10.0) + + def _should_not_be_called(_radar: str): + raise AssertionError("_lookup_radar_location_pyart should not be called") + + if hasattr(repo_mod, "_lookup_radar_location_pyart"): + monkeypatch.setattr(repo_mod, "_lookup_radar_location_pyart", _should_not_be_called) + repo2 = DataRepository(run_id="TESTRUN2", base_dir=tmp_path, radar="KPOE", config=None) + radars = repo2.registry.list_radars() + row = radars[radars["radar"] == "KPOE"].iloc[0] + assert float(row["location_lat"]) == 9.0 + assert float(row["location_lon"]) == 10.0 From bf4439fd0b766d437f2598be7af38361aab4b02b Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 15:07:17 -0500 Subject: [PATCH 10/14] ADD:(tests) validation contract assert primitives and check bound wrappers --- tests/validation/__init__.py | 1 + tests/validation/test_check_functions.py | 305 +++++++++++++++++++++++ tests/validation/test_contracts.py | 261 +++++++++++++++++++ 3 files changed, 567 insertions(+) create mode 100644 tests/validation/__init__.py create mode 100644 tests/validation/test_check_functions.py create mode 100644 tests/validation/test_contracts.py diff --git a/tests/validation/__init__.py b/tests/validation/__init__.py new file mode 100644 index 0000000..24490cb --- /dev/null +++ b/tests/validation/__init__.py @@ -0,0 +1 @@ +"""Conftest for contracts tests.""" diff --git a/tests/validation/test_check_functions.py b/tests/validation/test_check_functions.py new file mode 100644 index 0000000..269bf83 --- /dev/null +++ b/tests/validation/test_check_functions.py @@ -0,0 +1,305 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Tests for the bound check_* contract wrappers. + +The assert_* primitives are tested in test_contracts.py. +These tests verify the check_* wrappers: pass on valid data, raise on invalid, +and preserve the same ContractViolation semantics. +""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from adapt.contracts import ( + ContractViolation, + check_cell_adjacency, + check_cell_events, + check_cell_stats, + check_grid_ds_2d, + check_projected_ds, + check_segmented_ds, + check_time_normalized, + check_tracked_cells, +) + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _valid_grid_ds(): + return xr.Dataset( + {"reflectivity": (("y", "x"), np.ones((4, 4), dtype=np.float32))}, + coords={"x": range(4), "y": range(4)}, + ) + + +def _valid_segmented_ds(): + labels = np.array([[0, 0, 1, 1], [0, 1, 1, 0]], dtype=np.int32) + return xr.Dataset( + {"cell_labels": (("y", "x"), labels)}, + coords={"x": range(4), "y": range(2)}, + ) + + +def _valid_projected_ds(): + return xr.Dataset( + { + "heading_x": (("y", "x"), np.ones((4, 4))), + "heading_y": (("y", "x"), np.zeros((4, 4))), + }, + coords={"x": range(4), "y": range(4)}, + ) + + +def _valid_cell_stats(): + return pd.DataFrame({ + "cell_label": [1, 2], + "cell_area_sqkm": [1.5, 2.5], + "time": pd.to_datetime(["2025-01-01", "2025-01-01"]), + "time_volume_start": ["2025-01-01T00:00:00+00:00"] * 2, + "cell_centroid_mass_lat": [35.0, 35.1], + "cell_centroid_mass_lon": [-97.0, -97.1], + "radar_reflectivity_max": [45.0, 50.0], + "radar_differential_reflectivity_max": [1.0, 1.5], + "area_40dbz_km2": [1.0, 2.0], + }) + + +def _valid_adjacency(): + return pd.DataFrame({ + "time": pd.to_datetime(["2025-01-01"]), + "cell_label_a": [1], + "cell_label_b": [2], + "touching_boundary_pixels": [5], + }) + + +def _valid_tracked_cells(): + return pd.DataFrame({ + "time": pd.to_datetime(["2025-01-01"]), + "cell_label": [1], + "cell_uid": ["abc-001"], + "area": [4.0], + "centroid_x": [2.5], + "centroid_y": [2.5], + "mean_reflectivity": [40.0], + "max_reflectivity": [45.0], + "core_area": [2.0], + }) + + +def _valid_cell_events(): + return pd.DataFrame({ + "time": pd.to_datetime(["2025-01-01"]), + "event_type": ["INITIATION"], + "source_cell_uid": [None], + "target_cell_uid": ["abc-001"], + "source_cell_label": [None], + "target_cell_label": [1], + "cost": [0.0], + "is_dominant": [True], + "event_group_id": [1], + }) + + +# --------------------------------------------------------------------------- +# check_grid_ds_2d +# --------------------------------------------------------------------------- + +class TestCheckGridDs2d: + def test_passes_on_valid_ds(self): + check_grid_ds_2d(_valid_grid_ds()) # must not raise + + def test_fails_on_missing_x(self): + ds = xr.Dataset( + {"reflectivity": (("y", "x"), np.ones((4, 4)))}, + coords={"y": range(4)}, + ) + with pytest.raises(ContractViolation): + check_grid_ds_2d(ds) + + def test_fails_on_missing_reflectivity(self): + ds = xr.Dataset(coords={"x": range(4), "y": range(4)}) + with pytest.raises(ContractViolation): + check_grid_ds_2d(ds) + + +# --------------------------------------------------------------------------- +# check_segmented_ds +# --------------------------------------------------------------------------- + +class TestCheckSegmentedDs: + def test_passes_on_valid_ds(self): + check_segmented_ds(_valid_segmented_ds()) + + def test_fails_on_float_labels(self): + labels = np.ones((4, 4), dtype=np.float32) + ds = xr.Dataset( + {"cell_labels": (("y", "x"), labels)}, + coords={"x": range(4), "y": range(4)}, + ) + with pytest.raises(ContractViolation): + check_segmented_ds(ds) + + def test_fails_on_missing_labels_var(self): + ds = xr.Dataset(coords={"x": range(4), "y": range(4)}) + with pytest.raises(ContractViolation): + check_segmented_ds(ds) + + +# --------------------------------------------------------------------------- +# check_projected_ds +# --------------------------------------------------------------------------- + +class TestCheckProjectedDs: + def test_passes_on_valid_ds(self): + check_projected_ds(_valid_projected_ds()) + + def test_fails_on_missing_heading_x(self): + ds = xr.Dataset( + {"heading_y": (("y", "x"), np.zeros((4, 4)))}, + coords={"x": range(4), "y": range(4)}, + ) + with pytest.raises(ContractViolation): + check_projected_ds(ds) + + +# --------------------------------------------------------------------------- +# check_cell_stats +# --------------------------------------------------------------------------- + +class TestCheckCellStats: + def test_passes_on_valid_df(self): + check_cell_stats(_valid_cell_stats()) + + def test_passes_on_empty_df(self): + empty = pd.DataFrame(columns=[ + "cell_label", "cell_area_sqkm", "time", "time_volume_start", + "cell_centroid_mass_lat", "cell_centroid_mass_lon", + "radar_reflectivity_max", "radar_differential_reflectivity_max", + "area_40dbz_km2", + ]) + check_cell_stats(empty) + + def test_fails_on_missing_column(self): + df = _valid_cell_stats().drop(columns=["cell_label"]) + with pytest.raises(ContractViolation): + check_cell_stats(df) + + def test_fails_on_zero_cell_label(self): + df = _valid_cell_stats().copy() + df.loc[0, "cell_label"] = 0 + with pytest.raises(ContractViolation): + check_cell_stats(df) + + +# --------------------------------------------------------------------------- +# check_cell_adjacency +# --------------------------------------------------------------------------- + +class TestCheckCellAdjacency: + def test_passes_on_valid_df(self): + check_cell_adjacency(_valid_adjacency()) + + def test_passes_on_empty_df(self): + empty = pd.DataFrame(columns=[ + "time", "cell_label_a", "cell_label_b", "touching_boundary_pixels" + ]) + check_cell_adjacency(empty) + + def test_fails_on_missing_column(self): + df = _valid_adjacency().drop(columns=["touching_boundary_pixels"]) + with pytest.raises(ContractViolation): + check_cell_adjacency(df) + + def test_fails_on_wrong_label_order(self): + df = _valid_adjacency().copy() + df["cell_label_a"], df["cell_label_b"] = df["cell_label_b"], df["cell_label_a"] + with pytest.raises(ContractViolation): + check_cell_adjacency(df) + + +# --------------------------------------------------------------------------- +# check_tracked_cells +# --------------------------------------------------------------------------- + +class TestCheckTrackedCells: + def test_passes_on_valid_df(self): + check_tracked_cells(_valid_tracked_cells()) + + def test_passes_on_empty_df(self): + # check_tracked_cells skips validation on empty frames + check_tracked_cells(pd.DataFrame()) + + def test_fails_on_missing_column(self): + df = _valid_tracked_cells().drop(columns=["cell_uid"]) + with pytest.raises(ContractViolation): + check_tracked_cells(df) + + def test_fails_on_zero_cell_label(self): + df = _valid_tracked_cells().copy() + df["cell_label"] = 0 + with pytest.raises(ContractViolation): + check_tracked_cells(df) + + def test_fails_on_null_uid(self): + df = _valid_tracked_cells().copy() + df["cell_uid"] = None + with pytest.raises(ContractViolation): + check_tracked_cells(df) + + +# --------------------------------------------------------------------------- +# check_cell_events +# --------------------------------------------------------------------------- + +class TestCheckCellEvents: + def test_passes_on_valid_df(self): + check_cell_events(_valid_cell_events()) + + def test_passes_on_empty_df(self): + check_cell_events(pd.DataFrame()) + + def test_fails_on_missing_column(self): + df = _valid_cell_events().drop(columns=["event_type"]) + with pytest.raises(ContractViolation): + check_cell_events(df) + + def test_fails_on_invalid_event_type(self): + df = _valid_cell_events().copy() + df["event_type"] = "UNKNOWN" + with pytest.raises(ContractViolation): + check_cell_events(df) + + +# --------------------------------------------------------------------------- +# check_time_normalized +# --------------------------------------------------------------------------- + +class TestCheckTimeNormalized: + def test_passes_with_numpy_datetime64_coord(self): + ds = xr.Dataset(coords={"time": np.datetime64("2025-01-01T12:00:00")}) + check_time_normalized(ds) # must not raise + + def test_passes_with_no_time_coord_when_attr_present(self): + ds = xr.Dataset(attrs={"time": "2025-01-01"}) + check_time_normalized(ds) + + def test_fails_without_time(self): + ds = xr.Dataset() + with pytest.raises(ContractViolation): + check_time_normalized(ds) + + def test_check_is_same_as_assert(self): + """check_time_normalized delegates to assert_time_normalized.""" + from adapt.contracts import assert_time_normalized + ds = xr.Dataset(coords={"time": np.datetime64("2025-01-01T12:00:00")}) + # Both should pass without raising + assert_time_normalized(ds) + check_time_normalized(ds) diff --git a/tests/validation/test_contracts.py b/tests/validation/test_contracts.py new file mode 100644 index 0000000..45ba00e --- /dev/null +++ b/tests/validation/test_contracts.py @@ -0,0 +1,261 @@ +"""Tests for pipeline contracts. + +These tests verify that contracts are enforced at stage boundaries. +They test contract violations directly, without defensive logic downstream. +""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +pytestmark = pytest.mark.unit + +from adapt.contracts import ( # noqa: E402 + ContractViolation, + assert_analysis_output, + assert_gridded, + assert_projected, + assert_segmented, +) + + +class TestGridContract: + """Test grid stage contract.""" + + def test_grid_contract_passes_with_valid_dataset(self): + """Grid contract passes when x, y, and reflectivity exist.""" + ds = xr.Dataset( + {"reflectivity": (("y", "x"), np.ones((4, 4)))}, + coords={"x": range(4), "y": range(4)}, + ) + # Should not raise + assert_gridded(ds, "reflectivity") + + def test_grid_contract_fails_without_x(self): + """Grid contract fails when x coordinate is missing.""" + ds = xr.Dataset( + {"reflectivity": (("y", "x"), np.ones((4, 4)))}, + coords={"y": range(4)}, + ) + with pytest.raises(ContractViolation, match="missing 'x'"): + assert_gridded(ds, "reflectivity") + + def test_grid_contract_fails_without_y(self): + """Grid contract fails when y coordinate is missing.""" + ds = xr.Dataset( + {"reflectivity": (("y", "x"), np.ones((4, 4)))}, + coords={"x": range(4)}, + ) + with pytest.raises(ContractViolation, match="missing 'y'"): + assert_gridded(ds, "reflectivity") + + def test_grid_contract_fails_without_reflectivity(self): + """Grid contract fails when reflectivity variable is missing.""" + ds = xr.Dataset( + coords={"x": range(4), "y": range(4)}, + ) + with pytest.raises(ContractViolation, match="missing 'reflectivity'"): + assert_gridded(ds, "reflectivity") + + def test_grid_contract_fails_with_wrong_dims(self): + """Grid contract fails when reflectivity is 3D instead of 2D.""" + ds = xr.Dataset( + {"reflectivity": (("z", "y", "x"), np.ones((2, 4, 4)))}, + coords={"x": range(4), "y": range(4), "z": range(2)}, + ) + with pytest.raises(ContractViolation, match="3 dims"): + assert_gridded(ds, "reflectivity") + + +class TestSegmentationContract: + """Test segmentation stage contract.""" + + def test_segmentation_contract_passes_with_valid_labels(self): + """Segmentation contract passes with integer labels.""" + labels = np.array([[0, 0, 1, 1], [0, 1, 1, 0]], dtype=np.int32) + ds = xr.Dataset( + {"cell_labels": (("y", "x"), labels)}, + coords={"x": range(4), "y": range(2)}, + ) + # Should not raise + assert_segmented(ds, "cell_labels") + + def test_segmentation_contract_fails_without_labels(self): + """Segmentation contract fails when labels variable is missing.""" + ds = xr.Dataset(coords={"x": range(4), "y": range(2)}) + with pytest.raises(ContractViolation, match="not found"): + assert_segmented(ds, "cell_labels") + + def test_segmentation_contract_fails_with_float_labels(self): + """Segmentation contract fails when labels are float instead of integer.""" + labels = np.array([[0, 0, 1, 1], [0, 1, 1, 0]], dtype=np.float32) + ds = xr.Dataset( + {"cell_labels": (("y", "x"), labels)}, + coords={"x": range(4), "y": range(2)}, + ) + with pytest.raises(ContractViolation, match="dtype"): + assert_segmented(ds, "cell_labels") + + def test_segmentation_contract_fails_with_negative_labels(self): + """Segmentation contract fails when labels contain negative values.""" + labels = np.array([[0, -1, 1, 1], [0, 1, 1, 0]], dtype=np.int32) + ds = xr.Dataset( + {"cell_labels": (("y", "x"), labels)}, + coords={"x": range(4), "y": range(2)}, + ) + with pytest.raises(ContractViolation, match="negative"): + assert_segmented(ds, "cell_labels") + + def test_segmentation_contract_fails_with_wrong_dims(self): + """Segmentation contract fails when labels are 3D instead of 2D.""" + labels = np.array([[[0, 1], [1, 0]]], dtype=np.int32) + ds = xr.Dataset( + {"cell_labels": (("z", "y", "x"), labels)}, + coords={"x": range(2), "y": range(2), "z": range(1)}, + ) + with pytest.raises(ContractViolation, match="3 dims"): + assert_segmented(ds, "cell_labels") + + +class TestProjectionContract: + """Test projection stage contract.""" + + def test_projection_contract_passes_with_valid_flow(self): + """Projection contract passes when flow fields exist.""" + ds = xr.Dataset( + { + "heading_x": (("y", "x"), np.ones((4, 4))), + "heading_y": (("y", "x"), np.zeros((4, 4))), + }, + coords={"x": range(4), "y": range(4)}, + ) + # Should not raise + assert_projected(ds, max_steps=5) + + def test_projection_contract_fails_without_flow_u(self): + """Projection contract fails when heading_x is missing.""" + ds = xr.Dataset( + {"heading_y": (("y", "x"), np.zeros((4, 4)))}, + coords={"x": range(4), "y": range(4)}, + ) + with pytest.raises(ContractViolation, match="heading_x"): + assert_projected(ds) + + def test_projection_contract_fails_without_flow_v(self): + """Projection contract fails when heading_y is missing.""" + ds = xr.Dataset( + {"heading_x": (("y", "x"), np.ones((4, 4)))}, + coords={"x": range(4), "y": range(4)}, + ) + with pytest.raises(ContractViolation, match="heading_y"): + assert_projected(ds) + + def test_projection_contract_fails_with_too_many_steps(self): + """Projection contract fails when projections exceed max_steps.""" + projections = np.zeros((7, 4, 4), dtype=np.int32) # 7 steps > 5 max + ds = xr.Dataset( + { + "heading_x": (("y", "x"), np.ones((4, 4))), + "heading_y": (("y", "x"), np.zeros((4, 4))), + "cell_projections": (("frame_offset", "y", "x"), projections), + }, + coords={"x": range(4), "y": range(4)}, + ) + with pytest.raises(ContractViolation, match="expected 6"): + assert_projected(ds, max_steps=5) + + +class TestAnalysisContract: + """Test analysis stage contract.""" + + def test_analysis_contract_passes_with_valid_dataframe(self): + """Analysis contract passes with valid output DataFrame.""" + df = pd.DataFrame({ + "cell_label": [1, 2], + "cell_area_sqkm": [1.5, 2.5], + "time": pd.to_datetime(["2025-01-01", "2025-01-01"]), + "time_volume_start": ["2025-01-01T00:00:00+00:00", "2025-01-01T00:00:00+00:00"], + "cell_centroid_mass_lat": [35.0, 35.1], + "cell_centroid_mass_lon": [-97.0, -97.1], + "radar_reflectivity_max": [45.0, 50.0], + "radar_differential_reflectivity_max": [1.0, 1.5], + "area_40dbz_km2": [1.0, 2.0], + }) + # Should not raise + assert_analysis_output(df) + + def test_analysis_contract_passes_with_empty_dataframe(self): + """Analysis contract passes with empty DataFrame (no cells detected).""" + df = pd.DataFrame({ + "cell_label": [], + "cell_area_sqkm": [], + "time": [], + "time_volume_start": [], + "cell_centroid_mass_lat": [], + "cell_centroid_mass_lon": [], + "radar_reflectivity_max": [], + "radar_differential_reflectivity_max": [], + "area_40dbz_km2": [], + }) + # Should not raise + assert_analysis_output(df) + + def test_analysis_contract_fails_with_missing_cell_label(self): + """Analysis contract fails when cell_label column is missing.""" + df = pd.DataFrame({ + "cell_area_sqkm": [1.5], + "time": pd.to_datetime(["2025-01-01"]), + }) + with pytest.raises(ContractViolation, match="cell_label"): + assert_analysis_output(df) + + def test_analysis_contract_fails_with_zero_cell_label(self): + """Analysis contract fails when cell_label is 0 or negative.""" + df = pd.DataFrame({ + "cell_label": [0, 1], + "cell_area_sqkm": [1.5, 2.5], + "time": pd.to_datetime(["2025-01-01", "2025-01-01"]), + "time_volume_start": ["2025-01-01T00:00:00+00:00", "2025-01-01T00:00:00+00:00"], + "cell_centroid_mass_lat": [35.0, 35.1], + "cell_centroid_mass_lon": [-97.0, -97.1], + "radar_reflectivity_max": [45.0, 50.0], + "radar_differential_reflectivity_max": [1.0, 1.5], + "area_40dbz_km2": [1.0, 2.0], + }) + with pytest.raises(ContractViolation, match="cell_label must be > 0"): + assert_analysis_output(df) + + def test_analysis_contract_fails_with_insufficient_rows(self): + """Analysis contract fails when row count below minimum.""" + df = pd.DataFrame({ + "cell_label": [1], + "cell_area_sqkm": [1.5], + "time": pd.to_datetime(["2025-01-01"]), + "time_volume_start": ["2025-01-01T00:00:00+00:00"], + "cell_centroid_mass_lat": [35.0], + "cell_centroid_mass_lon": [-97.0], + "radar_reflectivity_max": [45.0], + "radar_differential_reflectivity_max": [1.0], + "area_40dbz_km2": [1.0], + }) + with pytest.raises(ContractViolation, match="expected >= 5"): + assert_analysis_output(df, min_expected_rows=5) + + +class TestContractViolationException: + """Test ContractViolation exception type and semantics.""" + + def test_contract_violation_is_runtime_error(self): + """ContractViolation is a RuntimeError subclass.""" + assert issubclass(ContractViolation, RuntimeError) + + def test_contract_violation_has_clear_message(self): + """ContractViolation carries clear error message.""" + with pytest.raises(ContractViolation) as exc_info: + ds = xr.Dataset() + assert_segmented(ds, "cell_labels") + + msg = str(exc_info.value) + assert "contract violated" in msg.lower() + assert "cell_labels" in msg From 69715d33087efa0163d647e8feeace07619b1d1a Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 15:08:11 -0500 Subject: [PATCH 11/14] ADD:(tests) time normalization and edge cases (detection, tracker, contracts, executor) --- tests/unit/__init__.py | 0 tests/unit/test_edge_cases.py | 283 ++++++++++++++++++++++++++++++++++ tests/unit/test_utils_time.py | 69 +++++++++ 3 files changed, 352 insertions(+) create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_edge_cases.py create mode 100644 tests/unit/test_utils_time.py diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_edge_cases.py b/tests/unit/test_edge_cases.py new file mode 100644 index 0000000..5e5ddbf --- /dev/null +++ b/tests/unit/test_edge_cases.py @@ -0,0 +1,283 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Edge cases, chaos, and adversarial tests. + +Tests in this file probe boundary conditions, extreme inputs, and unusual +but valid combinations. All use synthetic data — no IO, no real files. +""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from adapt.contracts import ContractViolation + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _refl_ds(data: np.ndarray) -> xr.Dataset: + H, W = data.shape + return xr.Dataset( + {"reflectivity": (("y", "x"), data.astype(np.float32))}, + coords={"y": np.arange(H) * 1000.0, "x": np.arange(W) * 1000.0}, + attrs={"z_level_m": 2000}, + ) + + +# --------------------------------------------------------------------------- +# Detection edge cases +# --------------------------------------------------------------------------- + +class TestDetectionEdgeCases: + def test_single_pixel_cell_filtered_by_min_size(self, make_detection_config): + """A single-pixel cell below min_cellsize_gridpoint is discarded.""" + from adapt.modules.detection.module import RadarCellSegmenter + data = np.zeros((8, 8), dtype=np.float32) + data[4, 4] = 50.0 # one pixel at 50 dBZ + ds = _refl_ds(data) + config = make_detection_config(threshold=35, min_cellsize_gridpoint=2) + result = RadarCellSegmenter(config).segment(ds) + assert result["cell_labels"].values.max() == 0 # filtered out + + def test_entire_domain_above_threshold_is_one_cell(self, make_detection_config): + """A single contiguous blob above threshold produces exactly one cell label.""" + from adapt.configuration.schemas.user import UserSegmenterConfig + from adapt.modules.detection.module import RadarCellSegmenter + data = np.zeros((8, 8), dtype=np.float32) + data[2:6, 2:6] = 50.0 # 4×4 = 16 pixels, clearly above any min_size + ds = _refl_ds(data) + config = make_detection_config( + threshold=35, + min_cellsize_gridpoint=1, + segmenter=UserSegmenterConfig(filter_by_size=False), + ) + result = RadarCellSegmenter(config).segment(ds) + labels = result["cell_labels"].values + assert labels.max() == 1 # exactly one connected component + + def test_extreme_reflectivity_values_do_not_crash(self, make_detection_config): + """75 dBZ (extreme hail) is handled without error.""" + from adapt.modules.detection.module import RadarCellSegmenter + data = np.zeros((10, 10), dtype=np.float32) + data[2:6, 2:6] = 75.0 # 4×4 = 16 pixels — well above any min_size threshold + ds = _refl_ds(data) + config = make_detection_config(threshold=40) + result = RadarCellSegmenter(config).segment(ds) + assert result["cell_labels"].values.max() >= 1 + + def test_all_nan_reflectivity_returns_no_cells(self, make_detection_config): + """NaN reflectivity should not produce any labelled cells.""" + from adapt.modules.detection.module import RadarCellSegmenter + data = np.full((6, 6), np.nan, dtype=np.float32) + ds = _refl_ds(data) + config = make_detection_config(threshold=40) + result = RadarCellSegmenter(config).segment(ds) + # NaN is below threshold — no cells + assert result["cell_labels"].values.max() == 0 + + +# --------------------------------------------------------------------------- +# Tracker edge cases +# --------------------------------------------------------------------------- + +class TestTrackerEdgeCases: + @pytest.fixture + def tracker(self, tracking_module_config): + from adapt.modules.tracking.module import RadarCellTracker + return RadarCellTracker(tracking_module_config) + + def _make_ds(self, labels, time): + H, W = labels.shape + refl = np.where(labels > 0, 45.0, 0.0).astype(np.float32) + proj = np.stack([labels.astype(np.int32)], axis=0) + ds = xr.Dataset( + { + "cell_labels": (["y", "x"], labels.astype(np.int32)), + "reflectivity": (["y", "x"], refl), + "cell_projections": (["frame_offset", "y", "x"], proj), + "heading_x": (["y", "x"], np.zeros_like(labels, dtype=np.float32)), + "heading_y": (["y", "x"], np.zeros_like(labels, dtype=np.float32)), + }, + coords={"y": np.arange(H) * 1000.0, "x": np.arange(W) * 1000.0, + "frame_offset": [0], "time": time}, + ) + return ds + + def _stats(self, label, time, cx=3.0, cy=3.0, area=4.0): + return pd.DataFrame([{ + "time": time, "time_volume_start": time, + "cell_label": label, "cell_area_sqkm": area, "area_40dbz_km2": area, + "cell_centroid_geom_x": cx, "cell_centroid_geom_y": cy, + "cell_centroid_mass_lat": 35.0, "cell_centroid_mass_lon": -97.0, + "radar_reflectivity_mean": 42.0, "radar_reflectivity_max": 50.0, + "radar_differential_reflectivity_max": 1.5, + }]) + + def test_tracker_uid_is_deterministic(self, tracking_module_config): + """Same input frames produce the same cell_uid on every run.""" + from adapt.modules.tracking.module import RadarCellTracker + t1 = np.datetime64("2025-01-01T12:00:00") + labels = np.zeros((6, 6), dtype=np.int32) + labels[2:4, 2:4] = 1 + ds1 = self._make_ds(labels, t1) + stats1 = self._stats(1, t1) + + tracker_a = RadarCellTracker(tracking_module_config) + tracker_b = RadarCellTracker(tracking_module_config) + + tracked_a, _ = tracker_a.track(ds1, stats1) + tracked_b, _ = tracker_b.track(ds1, stats1) + assert tracked_a.iloc[0]["cell_uid"] == tracked_b.iloc[0]["cell_uid"] + + def test_tracker_handles_large_time_gap(self, tracker): + """Two frames 60 min apart are processed without error.""" + t1 = np.datetime64("2025-01-01T11:00:00") + t2 = np.datetime64("2025-01-01T12:00:00") # 60 min gap + labels = np.zeros((6, 6), dtype=np.int32) + labels[2:4, 2:4] = 1 + ds1 = self._make_ds(labels, t1) + ds2 = self._make_ds(labels, t2) + stats1 = self._stats(1, t1) + stats2 = self._stats(1, t2) + tracked1, events1 = tracker.track(ds1, stats1) + tracked2, events2 = tracker.track(ds2, stats2) + # Both frames produce valid tracked output — gap handling does not crash + assert not tracked1.empty + assert not tracked2.empty + # Either the track continues with the same uid, or a new track is initiated + assert "cell_uid" in tracked1.columns + assert "cell_uid" in tracked2.columns + + +# --------------------------------------------------------------------------- +# Contract adversarial tests +# --------------------------------------------------------------------------- + +class TestContractAdversarial: + def test_contract_rejects_non_integer_labels(self): + """Float labels must raise ContractViolation, not silently pass.""" + from adapt.contracts import check_segmented_ds + labels = np.array([[0.0, 0.5, 1.0, 1.0]], dtype=np.float32) + ds = xr.Dataset( + {"cell_labels": (("y", "x"), labels)}, + coords={"x": range(4), "y": range(1)}, + ) + with pytest.raises(ContractViolation): + check_segmented_ds(ds) + + def test_contract_rejects_negative_cell_labels(self): + """Negative label values must raise ContractViolation.""" + from adapt.contracts import assert_segmented + labels = np.array([[-1, 0, 1]], dtype=np.int32) + ds = xr.Dataset( + {"cell_labels": (("y", "x"), labels)}, + coords={"x": range(3), "y": range(1)}, + ) + with pytest.raises(ContractViolation): + assert_segmented(ds, "cell_labels") + + def test_contract_rejects_3d_grid(self): + """3D reflectivity field must raise — contract expects 2D.""" + from adapt.contracts import assert_gridded + data = np.ones((3, 4, 4), dtype=np.float32) + ds = xr.Dataset( + {"reflectivity": (("z", "y", "x"), data)}, + coords={"x": range(4), "y": range(4), "z": range(3)}, + ) + with pytest.raises(ContractViolation): + assert_gridded(ds, "reflectivity") + + def test_contract_rejects_invalid_event_type(self): + """Unknown event_type string in cell events must raise.""" + from adapt.contracts import check_cell_events + df = pd.DataFrame({ + "time": pd.to_datetime(["2025-01-01"]), + "event_type": ["ALIEN_STORM"], # not a valid type + "source_cell_uid": [None], + "target_cell_uid": ["x"], + "source_cell_label": [None], + "target_cell_label": [1], + "cost": [0.0], + "is_dominant": [True], + "event_group_id": [1], + }) + with pytest.raises(ContractViolation): + check_cell_events(df) + + def test_tracked_cells_rejects_null_uid(self): + """Null cell_uid in tracked_cells must raise ContractViolation.""" + from adapt.contracts import check_tracked_cells + df = pd.DataFrame({ + "time": pd.to_datetime(["2025-01-01"]), + "cell_label": [1], + "cell_uid": [None], + "area": [4.0], + "centroid_x": [2.5], + "centroid_y": [2.5], + "mean_reflectivity": [40.0], + "max_reflectivity": [45.0], + "core_area": [2.0], + }) + with pytest.raises(ContractViolation): + check_tracked_cells(df) + + +# --------------------------------------------------------------------------- +# Execution graph adversarial tests +# --------------------------------------------------------------------------- + +class TestExecutorAdversarial: + def test_executor_extra_keys_in_initial_context_are_ignored(self): + """Additional keys not declared as inputs are silently passed through.""" + from adapt.execution.graph.builder import GraphBuilder + from adapt.execution.graph.executor import GraphExecutor + from adapt.modules.base import BaseModule + + class Sink(BaseModule): + name = "sink" + inputs = ["x"] + outputs = [] + def run(self, ctx): return {} + + nodes = GraphBuilder([Sink()]).build() + result = GraphExecutor(nodes).run({"x": 1, "unexpected": 99}) + assert result["unexpected"] == 99 + + def test_builder_with_zero_modules_returns_empty_list(self): + """Building a graph with no modules produces an empty node list.""" + from adapt.execution.graph.builder import GraphBuilder + nodes = GraphBuilder([]).build() + assert nodes == [] + + def test_executor_with_empty_graph_returns_context_unchanged(self): + """Running an executor with no nodes returns the initial context.""" + from adapt.execution.graph.executor import GraphExecutor + result = GraphExecutor([]).run({"key": "value"}) + assert result == {"key": "value"} + + +# --------------------------------------------------------------------------- +# Utils time adversarial tests +# --------------------------------------------------------------------------- + +class TestUtilsTimeAdversarial: + def test_numpy_datetime64_scalar_unwraps_to_date(self): + """np.datetime64 scalar is unwrapped to a Python date via .item().""" + from datetime import date + + from adapt.utils.time import normalize_time_scalar + result = normalize_time_scalar(np.datetime64("2025-01-01")) + assert isinstance(result, date) + + def test_object_without_item_passthrough(self): + """Non-numpy objects without .item() are returned unchanged.""" + from adapt.utils.time import normalize_time_scalar + obj = {"a": 1} # dict — no .item() + result = normalize_time_scalar(obj) + assert result is obj diff --git a/tests/unit/test_utils_time.py b/tests/unit/test_utils_time.py new file mode 100644 index 0000000..7da1662 --- /dev/null +++ b/tests/unit/test_utils_time.py @@ -0,0 +1,69 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""Unit tests for adapt.utils.time.normalize_time_scalar. + +All tests use synthetic numpy/Python scalars — no radar files, no IO. +""" + +from datetime import UTC, datetime + +import numpy as np +import pytest + +from adapt.utils.time import normalize_time_scalar + +pytestmark = pytest.mark.unit + + +class TestNormalizeTimeScalar: + def test_python_datetime_passthrough(self): + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC) + result = normalize_time_scalar(dt) + assert result == dt + + def test_numpy_datetime64_unwrapped_to_python(self): + val = np.datetime64("2025-06-15T08:00:00") + result = normalize_time_scalar(val) + # numpy datetime64.item() returns a Python datetime + assert isinstance(result, datetime) + + def test_single_element_array_unwrapped(self): + arr = np.array(["2025-03-01T06:00:00"], dtype="datetime64[s]") + result = normalize_time_scalar(arr) + assert not isinstance(result, np.ndarray) + + def test_size_1_array_with_ndim_gt_1(self): + arr = np.array([["2025-01-01T00:00:00"]], dtype="datetime64[s]") + result = normalize_time_scalar(arr) + assert not isinstance(result, np.ndarray) + + def test_nat_returns_none(self): + """NaT.item() returns None; normalize_time_scalar propagates that.""" + result = normalize_time_scalar(np.datetime64("NaT")) + assert result is None + + def test_plain_integer_passthrough(self): + result = normalize_time_scalar(42) + assert result == 42 + + def test_cftime_converted_to_datetime(self): + """cftime objects are converted to Python datetime with UTC timezone.""" + pytest.importorskip("cftime") + import cftime + cf = cftime.DatetimeGregorian(2025, 6, 15, 10, 30, 0) + result = normalize_time_scalar(cf) + assert isinstance(result, datetime) + assert result.year == 2025 + assert result.month == 6 + assert result.day == 15 + assert result.hour == 10 + assert result.minute == 30 + + def test_numpy_scalar_without_item_method(self): + """Objects without .item() are returned as-is.""" + class FakeScalar: + pass + obj = FakeScalar() + result = normalize_time_scalar(obj) + assert result is obj From 4013f1b76b2a8ac50ddbbd364f0d7d37380cd175 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 15:09:15 -0500 Subject: [PATCH 12/14] ADD:(tests) story: detection and tracking --- tests/stories/__init__.py | 0 tests/stories/test_scientist_stories.py | 188 ++++++++++++++++++++++++ 2 files changed, 188 insertions(+) create mode 100644 tests/stories/__init__.py create mode 100644 tests/stories/test_scientist_stories.py diff --git a/tests/stories/__init__.py b/tests/stories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/stories/test_scientist_stories.py b/tests/stories/test_scientist_stories.py new file mode 100644 index 0000000..020eb49 --- /dev/null +++ b/tests/stories/test_scientist_stories.py @@ -0,0 +1,188 @@ +# Copyright © 2026, UChicago Argonne, LLC +# See LICENSE for terms and disclaimer. + +"""User story tests — scientist perspective. + +These tests describe end-user scientific outcomes rather than implementation +details. All use synthetic numpy/xarray data; no real NEXRAD files, no IO. + +Pattern: Given / When / Then, one scenario per test. +""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +pytestmark = pytest.mark.unit + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +def _reflectivity_ds(values_2d: np.ndarray) -> xr.Dataset: + H, W = values_2d.shape + return xr.Dataset( + {"reflectivity": (("y", "x"), values_2d.astype(np.float32))}, + coords={"y": np.arange(H) * 1000.0, "x": np.arange(W) * 1000.0}, + attrs={"z_level_m": 2000}, + ) + + +def _labeled_ds(labels: np.ndarray, time=None) -> xr.Dataset: + """Dataset with cell_labels, reflectivity, projections, headings.""" + H, W = labels.shape + refl = np.zeros((H, W), dtype=np.float32) + refl[labels > 0] = 45.0 + projections = np.stack([labels.astype(np.int32)], axis=0) + ds = xr.Dataset( + { + "cell_labels": (["y", "x"], labels.astype(np.int32)), + "reflectivity": (["y", "x"], refl), + "cell_projections": (["frame_offset", "y", "x"], projections), + "heading_x": (["y", "x"], np.zeros((H, W), dtype=np.float32)), + "heading_y": (["y", "x"], np.zeros((H, W), dtype=np.float32)), + }, + coords={ + "y": np.arange(H) * 1000.0, + "x": np.arange(W) * 1000.0, + "frame_offset": [0], + }, + ) + if time is None: + time = np.datetime64("2025-06-15T12:00:00") + return ds.assign_coords(time=time) + + +def _cell_stats_row(label: int, time, *, cx=5.0, cy=5.0, area=4.0) -> dict: + return { + "time": time, + "time_volume_start": time, + "cell_label": label, + "cell_area_sqkm": area, + "area_40dbz_km2": area, + "cell_centroid_geom_x": cx, + "cell_centroid_geom_y": cy, + "cell_centroid_mass_lat": 35.0, + "cell_centroid_mass_lon": -97.0, + "radar_reflectivity_mean": 42.0, + "radar_reflectivity_max": 50.0, + "radar_differential_reflectivity_max": 1.5, + } + + +# --------------------------------------------------------------------------- +# Detection stories +# --------------------------------------------------------------------------- + +class TestScientistCanDetectCells: + def test_user_can_detect_cells_from_threshold(self, make_detection_config): + """Given: 2D data with a clear 40-dBZ cluster. + When: segmenter runs with threshold=35. + Then: at least one labeled cell appears at the cluster location. + """ + from adapt.modules.detection.module import RadarCellSegmenter + refl = np.zeros((10, 10), dtype=np.float32) + refl[4:7, 4:7] = 45.0 # 9-pixel cluster at centre + ds = _reflectivity_ds(refl) + config = make_detection_config(threshold=35, min_cellsize_gridpoint=4) + result = RadarCellSegmenter(config).segment(ds) + labels = result["cell_labels"].values + assert labels[5, 5] > 0, "Centre of cluster should be labelled" + + def test_user_sees_no_cells_when_storm_below_threshold(self, detection_module_config): + """Given: all reflectivity below detection threshold. + When: segmenter runs. + Then: output has no cells (all labels == 0). + """ + from adapt.modules.detection.module import RadarCellSegmenter + refl = np.full((8, 8), 20.0, dtype=np.float32) # threshold is 40 dBZ by default + ds = _reflectivity_ds(refl) + result = RadarCellSegmenter(detection_module_config).segment(ds) + assert result["cell_labels"].values.max() == 0 + + def test_user_can_separate_two_distinct_storms(self, make_detection_config): + """Given: two separated storm cores. + When: segmenter runs. + Then: exactly two cell labels appear. + """ + from adapt.configuration.schemas.user import UserSegmenterConfig + from adapt.modules.detection.module import RadarCellSegmenter + refl = np.zeros((12, 12), dtype=np.float32) + refl[2:4, 1:4] = 48.0 # storm A + refl[8:10, 8:11] = 46.0 # storm B + ds = _reflectivity_ds(refl) + config = make_detection_config( + threshold=40, + segmenter=UserSegmenterConfig(filter_by_size=False), + ) + result = RadarCellSegmenter(config).segment(ds) + assert result["cell_labels"].values.max() == 2 + + +# --------------------------------------------------------------------------- +# Tracking stories +# --------------------------------------------------------------------------- + +class TestScientistCanTrackStorms: + @pytest.fixture + def tracker(self, tracking_module_config): + from adapt.modules.tracking.module import RadarCellTracker + return RadarCellTracker(tracking_module_config) + + def test_user_can_track_a_persistent_storm(self, tracker): + """Given: one cell that persists at the same location across two frames. + When: tracker runs on both frames. + Then: the same cell_uid appears in both tracked outputs. + """ + t1 = np.datetime64("2025-01-01T12:00:00") + t2 = np.datetime64("2025-01-01T12:05:00") + labels = np.zeros((8, 8), dtype=np.int32) + labels[3:5, 3:5] = 1 # stationary 2×2 cell + ds1 = _labeled_ds(labels, t1) + ds2 = _labeled_ds(labels, t2) + stats1 = pd.DataFrame([_cell_stats_row(1, t1, cx=3.5, cy=3.5)]) + stats2 = pd.DataFrame([_cell_stats_row(1, t2, cx=3.5, cy=3.5)]) + tracked1, events1 = tracker.track(ds1, stats1) + tracked2, events2 = tracker.track(ds2, stats2) + assert tracked1.iloc[0]["cell_uid"] == tracked2.iloc[0]["cell_uid"] + assert events1["event_type"].iloc[0] == "INITIATION" + assert events2["event_type"].iloc[0] == "CONTINUE" + + def test_user_sees_empty_output_when_no_storm(self, tracker): + """Given: no cells in any frame. + When: tracker runs. + Then: tracked_cells and events are empty DataFrames, no exception raised. + """ + t1 = np.datetime64("2025-01-01T12:00:00") + t2 = np.datetime64("2025-01-01T12:05:00") + empty_labels = np.zeros((6, 6), dtype=np.int32) + ds1 = _labeled_ds(empty_labels, t1) + ds2 = _labeled_ds(empty_labels, t2) + stats_empty = pd.DataFrame(columns=[ + "time", "time_volume_start", "cell_label", "cell_area_sqkm", + "area_40dbz_km2", "cell_centroid_geom_x", "cell_centroid_geom_y", + "cell_centroid_mass_lat", "cell_centroid_mass_lon", + "radar_reflectivity_mean", "radar_reflectivity_max", + "radar_differential_reflectivity_max", + ]) + tracked1, events1 = tracker.track(ds1, stats_empty) + tracked2, events2 = tracker.track(ds2, stats_empty) + assert tracked1.empty + assert tracked2.empty + + def test_user_can_identify_storm_initiation(self, tracker): + """Given: a cell appears for the first time. + When: tracker runs. + Then: events contain an INITIATION event with a non-null cell_uid. + """ + t1 = np.datetime64("2025-01-01T12:00:00") + labels = np.zeros((6, 6), dtype=np.int32) + labels[2:4, 2:4] = 1 + ds1 = _labeled_ds(labels, t1) + stats1 = pd.DataFrame([_cell_stats_row(1, t1)]) + tracked, events = tracker.track(ds1, stats1) + initiations = events[events["event_type"] == "INITIATION"] + assert len(initiations) == 1 + assert initiations.iloc[0]["target_cell_uid"] is not None From 83fc20c832036b4ebd9cdd2a34cdf3464a3a0fa3 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 16:06:16 -0500 Subject: [PATCH 13/14] FIX: use contextlib.suppress in time contract (ruff SIM105) --- src/adapt/contracts/time.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/adapt/contracts/time.py b/src/adapt/contracts/time.py index e5aae92..607c1ee 100644 --- a/src/adapt/contracts/time.py +++ b/src/adapt/contracts/time.py @@ -10,6 +10,8 @@ returning a dataset to the context. """ +import contextlib + import numpy as np import xarray as xr @@ -41,10 +43,8 @@ def assert_time_normalized(ds: xr.Dataset) -> None: tv = raw.flat[0] if isinstance(raw, np.ndarray) and raw.ndim > 0 else raw # unwrap numpy scalar wrapper if needed if hasattr(tv, "item"): - try: + with contextlib.suppress(Exception): tv = tv.item() - except Exception: - pass module = getattr(type(tv), "__module__", "") require( not module.startswith("cftime"), From d476648e5c83ae9abb6a776f04bd4006df40fcd8 Mon Sep 17 00:00:00 2001 From: Bhupendra Raut Date: Thu, 14 May 2026 16:09:19 -0500 Subject: [PATCH 14/14] ADD: utils --- src/adapt/utils/__init__.py | 0 src/adapt/utils/time.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) create mode 100644 src/adapt/utils/__init__.py create mode 100644 src/adapt/utils/time.py diff --git a/src/adapt/utils/__init__.py b/src/adapt/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/adapt/utils/time.py b/src/adapt/utils/time.py new file mode 100644 index 0000000..c893a51 --- /dev/null +++ b/src/adapt/utils/time.py @@ -0,0 +1,33 @@ +"""Time normalization helpers shared across ADAPT modules.""" + +import contextlib +from datetime import UTC, datetime + +import numpy as np + + +def normalize_time_scalar(time_val): + """Normalize xarray/cftime/numpy time representations to a scalar.""" + tv = time_val + while isinstance(tv, np.ndarray) and tv.size == 1: + tv = tv.reshape(-1)[0] + if isinstance(tv, np.ndarray): + tv = tv.reshape(-1)[0] + + if hasattr(tv, "item"): + with contextlib.suppress(TypeError, ValueError): + tv = tv.item() + + if getattr(type(tv), "__module__", "").startswith("cftime"): + tv = datetime( + int(tv.year), + int(tv.month), + int(tv.day), + int(tv.hour), + int(tv.minute), + int(tv.second), + int(getattr(tv, "microsecond", 0) or 0), + tzinfo=UTC, + ) + + return tv