diff --git a/lars/preprocessing/radar_preprocessing.py b/lars/preprocessing/radar_preprocessing.py index 5e3e5b8..af167a5 100644 --- a/lars/preprocessing/radar_preprocessing.py +++ b/lars/preprocessing/radar_preprocessing.py @@ -1,9 +1,10 @@ import xradar as xd import matplotlib.pyplot as plt import glob +import numpy as np import os import pandas as pd -import cmweather # noqa +import cmweather # noqa def preprocess_radar_data(file_path, output_path, date=None, @@ -66,7 +67,8 @@ def preprocess_radar_data(file_path, output_path, date=None, radar = radar.xradar.georeference() if 'sweep_0' in radar: sweep = radar['sweep_0'] - if sweep["sweep_mode"] == 'ppi' or sweep["sweep_mode"] == 'sector': + sweep_mode = str(sweep["sweep_mode"].values).split('\x00')[0].strip() + if sweep_mode in ('ppi', 'sector', 'azimuth_surveillance'): fig = plt.figure(figsize=(size_px/dpi, size_px/dpi)) ax = plt.axes() sweep["corrected_reflectivity"].where( @@ -74,11 +76,14 @@ def preprocess_radar_data(file_path, output_path, date=None, ax=ax, add_colorbar=False, **kwargs) - min_ref = sweep["corrected_reflectivity"].where( - sweep["corrected_reflectivity"] > min_ref).values.min() - max_ref = sweep["corrected_reflectivity"].where( - sweep["corrected_reflectivity"] > min_ref).values.max() - + masked = sweep["corrected_reflectivity"].where( + sweep["corrected_reflectivity"] > min_ref).values + ref_min = np.nanmin(masked) + ref_max = np.nanmax(masked) + ax.axis('off') + ax.set_title('') + ax.set_ylabel('') + ax.set_xlabel('') ax.set_xlim(x_bounds) ax.set_ylim(y_bounds) @@ -91,7 +96,7 @@ def preprocess_radar_data(file_path, output_path, date=None, os.path.basename(file).replace('.nc', '.png')), dpi=dpi, bbox_inches='tight', pad_inches=0) plt.close(fig) - out_df.loc[len(out_df)] = [file_name, time_str, label, min_ref, max_ref] + out_df.loc[len(out_df)] = [file_name, time_str, label, ref_min, ref_max] else: print(f"Sweep mode is not PPI or sector scan in {file}, skipping.") diff --git a/pyproject.toml b/pyproject.toml index be50e70..21fc9fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = ["xradar", "scikit-learn", "python-dotenv", "aiohttp", "asksagecl [project.optional-dependencies] dev = ["pytest>=6.0", "pytest-asyncio>=0.21", "black", "flake8", "openai", "xradar", "python-dotenv", "scikit-learn", "cmweather", "torchvision", "torch", "aiohttp", "matplotlib", "pandas", - "asksageclient", "pip_system_certs", "requests"] + "asksageclient", "pip_system_certs", "requests", "open-radar-data"] [project.urls] Homepage = "https://github.com/rcjackson/lars" diff --git a/tests/data/baseline/preprocessing/gucxprecipradarcmacppiS2.c1.20220314.021559.png b/tests/data/baseline/preprocessing/gucxprecipradarcmacppiS2.c1.20220314.021559.png new file mode 100644 index 0000000..1f26e19 Binary files /dev/null and b/tests/data/baseline/preprocessing/gucxprecipradarcmacppiS2.c1.20220314.021559.png differ diff --git a/tests/data/baseline/preprocessing/gucxprecipradarcmacppiS2.c1.20220314.024239.png b/tests/data/baseline/preprocessing/gucxprecipradarcmacppiS2.c1.20220314.024239.png new file mode 100644 index 0000000..898e2d5 Binary files /dev/null and b/tests/data/baseline/preprocessing/gucxprecipradarcmacppiS2.c1.20220314.024239.png differ diff --git a/tests/data/baseline/preprocessing/gucxprecipradarcmacppiS2.c1.20220314.025840.png b/tests/data/baseline/preprocessing/gucxprecipradarcmacppiS2.c1.20220314.025840.png new file mode 100644 index 0000000..45f11dc Binary files /dev/null and b/tests/data/baseline/preprocessing/gucxprecipradarcmacppiS2.c1.20220314.025840.png differ diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..dfb7bc6 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,25 @@ +import sys + +# The parent tests/conftest.py mocks these at import time. Remove the mocks +# so integration tests can use the real implementations. +_MOCKED = ["xradar", "cmweather", "asksageclient", "pip_system_certs"] +for _key in list(sys.modules): + if any(_key == m or _key.startswith(m + ".") for m in _MOCKED): + del sys.modules[_key] + +# Evict any cached lars imports so they re-link against the real deps. +for _key in list(sys.modules): + if _key == "lars" or _key.startswith("lars."): + del sys.modules[_key] + +import matplotlib +matplotlib.use("Agg") + + +def pytest_addoption(parser): + parser.addoption( + "--generate-baseline", + action="store_true", + default=False, + help="Write baseline images instead of comparing against them.", + ) diff --git a/tests/integration/test_radar_preprocessing.py b/tests/integration/test_radar_preprocessing.py new file mode 100644 index 0000000..8ce92fa --- /dev/null +++ b/tests/integration/test_radar_preprocessing.py @@ -0,0 +1,144 @@ +""" +Integration tests for lars.preprocessing.preprocess_radar_data. + +Downloads three GUC XPRECIP CMAC PPI radar files from open-radar-data, +runs the full preprocessing workflow, validates the returned DataFrame, +and compares each generated PNG against a stored baseline image. + +Generating baselines (first-time setup or after intentional changes): + pytest tests/integration/ --generate-baseline + +Running the tests normally: + pytest tests/integration/ +""" + +import os +import shutil + +import matplotlib.image as mpimg +import numpy as np +import pytest + +open_radar_data = pytest.importorskip("open_radar_data") +xradar = pytest.importorskip("xradar") + +RADAR_FILES = [ + "gucxprecipradarcmacppiS2.c1.20220314.021559.nc", + "gucxprecipradarcmacppiS2.c1.20220314.024239.nc", + "gucxprecipradarcmacppiS2.c1.20220314.025840.nc", +] + +BASELINE_DIR = os.path.join( + os.path.dirname(__file__), "..", "data", "baseline", "preprocessing" +) + +# Pixel-value tolerance for image comparison (values are float in [0, 1]). +IMAGE_TOLERANCE = 5 / 255 + + +@pytest.fixture(scope="module") +def radar_data_dir(tmp_path_factory): + """Download the three test radar files into an isolated temp directory.""" + from open_radar_data import DATASETS + + tmp_dir = tmp_path_factory.mktemp("radar_data") + for fname in RADAR_FILES: + src = DATASETS.fetch(fname) + shutil.copy(src, tmp_dir / fname) + return str(tmp_dir) + + +@pytest.fixture(scope="module") +def preprocessing_output(tmp_path_factory, radar_data_dir): + """Run preprocessing once and share the output across all tests.""" + from lars.preprocessing import preprocess_radar_data + + out_dir = str(tmp_path_factory.mktemp("preprocessing_output")) + label_df = preprocess_radar_data(radar_data_dir, out_dir) + return out_dir, label_df + + +# --------------------------------------------------------------------------- +# DataFrame tests +# --------------------------------------------------------------------------- + + +def test_dataframe_row_count(preprocessing_output): + _, label_df = preprocessing_output + assert len(label_df) == 3 + + +def test_dataframe_columns(preprocessing_output): + _, label_df = preprocessing_output + assert set(label_df.columns) == {"file_path", "label", "ref_min", "ref_max"} + + +def test_labels_are_unknown(preprocessing_output): + _, label_df = preprocessing_output + assert (label_df["label"] == "UNKNOWN").all() + + +def test_reflectivity_bounds(preprocessing_output): + _, label_df = preprocessing_output + assert (label_df["ref_min"] <= label_df["ref_max"]).all() + + +def test_timestamps_are_on_correct_date(preprocessing_output): + _, label_df = preprocessing_output + assert all("2022-03-14" in str(idx) for idx in label_df.index) + + +def test_index_is_sorted(preprocessing_output): + _, label_df = preprocessing_output + assert label_df.index.is_monotonic_increasing + + +# --------------------------------------------------------------------------- +# Image file tests +# --------------------------------------------------------------------------- + + +def test_png_files_created(preprocessing_output): + out_dir, _ = preprocessing_output + for fname in RADAR_FILES: + assert os.path.exists(os.path.join(out_dir, fname.replace(".nc", ".png"))) + + +# --------------------------------------------------------------------------- +# Baseline image-comparison tests +# --------------------------------------------------------------------------- + + +def _compare_to_baseline(generated_path, baseline_path, tolerance): + generated = mpimg.imread(generated_path).astype(np.float32) + baseline = mpimg.imread(baseline_path).astype(np.float32) + + assert generated.shape == baseline.shape, ( + f"Shape mismatch: generated {generated.shape} vs baseline {baseline.shape}" + ) + max_diff = np.max(np.abs(generated - baseline)) + assert max_diff <= tolerance, ( + f"Max pixel difference {max_diff:.4f} exceeds tolerance {tolerance:.4f} " + f"({os.path.basename(generated_path)})" + ) + + +@pytest.mark.parametrize("fname", RADAR_FILES) +def test_image_matches_baseline(request, preprocessing_output, fname): + out_dir, _ = preprocessing_output + png_name = fname.replace(".nc", ".png") + generated_path = os.path.join(out_dir, png_name) + baseline_path = os.path.join(BASELINE_DIR, png_name) + + if request.config.getoption("--generate-baseline"): + os.makedirs(BASELINE_DIR, exist_ok=True) + shutil.copy(generated_path, baseline_path) + pytest.skip(f"Baseline written to {baseline_path}") + + if not os.path.exists(baseline_path): + pytest.skip( + f"No baseline found at {baseline_path}. " + "Run with --generate-baseline to create it." + ) + + _compare_to_baseline(generated_path, baseline_path, IMAGE_TOLERANCE)