From 68452c250fb04ebc380b2ee5a9591b0fda3f6603 Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Thu, 23 Apr 2026 15:58:13 -0500 Subject: [PATCH 1/3] FIX: Pyproject.toml explictly includes only lars --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 21fc9fe..e440e37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,11 @@ classifiers = [ ] dependencies = ["xradar", "scikit-learn", "python-dotenv", "aiohttp", "asksageclient", "pip_system_certs", "requests"] +[tool.setuptools.packages.find] +where = ["."] +include = ["lars"] +exclude = ["notebooks*", "example*"] + [project.optional-dependencies] dev = ["pytest>=6.0", "pytest-asyncio>=0.21", "black", "flake8", "openai", "xradar", "python-dotenv", "scikit-learn", "cmweather", "torchvision", "torch", "aiohttp", "matplotlib", "pandas", From 620700deea5a1a6dd2888a7029512203fa90e440 Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Mon, 27 Apr 2026 16:19:26 -0500 Subject: [PATCH 2/3] ADD: plot_label_images utility for random image grid by label pair Adds lars.util.plot_label_images, which filters a labeled DataFrame by hand label and LLM label, randomly samples n images, and saves them as a single grid PNG. Co-Authored-By: Claude Sonnet 4.6 --- lars/util/__init__.py | 1 + lars/util/image_grid.py | 91 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 lars/util/image_grid.py diff --git a/lars/util/__init__.py b/lars/util/__init__.py index a4898db..fa0864c 100644 --- a/lars/util/__init__.py +++ b/lars/util/__init__.py @@ -1 +1,2 @@ from .confusion_matrix import plot_confusion_matrix, calculate_cohen_kappa # noqa: F401 +from .image_grid import plot_label_images # noqa: F401 diff --git a/lars/util/image_grid.py b/lars/util/image_grid.py new file mode 100644 index 0000000..f7e950b --- /dev/null +++ b/lars/util/image_grid.py @@ -0,0 +1,91 @@ +import math +import random + +import matplotlib.pyplot as plt +import matplotlib.image as mpimg + + +def plot_label_images( + df, + hand_label, + llm_label, + n, + output_path, + label_col="label", + pred_col="llm_label", + seed=None, +): + """ + Randomly sample images where hand label and LLM label match the given values, + plot them in a single grid figure, and save to a PNG file. + + Parameters + ---------- + df : pd.DataFrame + DataFrame with at minimum columns for file paths, hand labels, and LLM labels. + hand_label : str + The human-annotated label to filter on. + llm_label : str + The LLM-generated label to filter on. + n : int + Number of images to sample and display. + output_path : str + File path for the saved PNG (e.g. "output/grid.png"). + label_col : str + Column name for hand labels. Default is "label". + pred_col : str + Column name for LLM labels. Default is "llm_label". + seed : int or None + Random seed for reproducibility. Default is None. + + Raises + ------ + ValueError + If fewer matching images exist than requested. + """ + mask = (df[label_col].str.lower() == hand_label.lower()) & ( + df[pred_col].str.lower() == llm_label.lower() + ) + subset = df[mask] + + if len(subset) == 0: + raise ValueError( + f"No images found where {label_col}='{hand_label}' and {pred_col}='{llm_label}'." + ) + if len(subset) < n: + raise ValueError( + f"Requested {n} images but only {len(subset)} match the given labels." + ) + + rng = random.Random(seed) + sampled = subset.sample(n=n, random_state=seed if seed is not None else rng.randint(0, 2**31)) + + ncols = math.ceil(math.sqrt(n)) + nrows = math.ceil(n / ncols) + + fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows)) + + # Normalise axes to always be a flat list + if n == 1: + axes = [axes] + else: + axes = list(axes.flat) + + for ax, (idx, row) in zip(axes, sampled.iterrows()): + img = mpimg.imread(row["file_path"]) + ax.imshow(img) + ax.set_title(str(idx), fontsize=8) + ax.axis("off") + + # Hide any unused subplot panels + for ax in axes[n:]: + ax.set_visible(False) + + fig.suptitle( + f"Hand label: '{hand_label}' | LLM label: '{llm_label}' | n={n}", + fontsize=12, + y=1.01, + ) + fig.tight_layout() + fig.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) From 55aa8f234e31f12fadb78f0ace6370c1e2c02ca9 Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Mon, 27 Apr 2026 16:22:45 -0500 Subject: [PATCH 3/3] ADD: Unit tests for plot_label_images using mock dataset Tests cover output file creation, valid PNG output, correct grid layout, suptitle content, error handling, case-insensitive matching, custom column names, seed reproducibility, and the n=1 edge case. Co-Authored-By: Claude Sonnet 4.6 --- tests/test_image_grid.py | 181 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 tests/test_image_grid.py diff --git a/tests/test_image_grid.py b/tests/test_image_grid.py new file mode 100644 index 0000000..65f3da6 --- /dev/null +++ b/tests/test_image_grid.py @@ -0,0 +1,181 @@ +import pytest +import numpy as np +import pandas as pd +import matplotlib +import matplotlib.pyplot as plt + +matplotlib.use("Agg") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_png(path): + """Save a tiny 4x4 solid-colour PNG to *path* and return the path.""" + fig, ax = plt.subplots(figsize=(0.1, 0.1)) + ax.imshow(np.zeros((4, 4, 3), dtype=np.uint8)) + ax.axis("off") + fig.savefig(path) + plt.close(fig) + return path + + +@pytest.fixture() +def image_df(tmp_path): + """DataFrame with 8 labelled images backed by real PNG files.""" + rows = [ + # hand_label, llm_label, filename + ("stratiform", "stratiform", "img_00.png"), + ("stratiform", "stratiform", "img_01.png"), + ("stratiform", "stratiform", "img_02.png"), + ("stratiform", "convective", "img_03.png"), + ("convective", "convective", "img_04.png"), + ("convective", "convective", "img_05.png"), + ("convective", "stratiform", "img_06.png"), + ("no precipitation", "no precipitation", "img_07.png"), + ] + records = [] + for hand, llm, fname in rows: + path = _make_png(tmp_path / fname) + records.append({"label": hand, "llm_label": llm, "file_path": str(path)}) + return pd.DataFrame(records) + + +@pytest.fixture(autouse=True) +def close_figures(): + yield + plt.close("all") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_output_file_is_created(image_df, tmp_path): + from lars.util.image_grid import plot_label_images + + out = tmp_path / "grid.png" + plot_label_images(image_df, "stratiform", "stratiform", n=2, output_path=str(out), seed=0) + assert out.exists() + + +def test_output_is_valid_png(image_df, tmp_path): + from lars.util.image_grid import plot_label_images + + out = tmp_path / "grid.png" + plot_label_images(image_df, "stratiform", "stratiform", n=2, output_path=str(out), seed=0) + with open(out, "rb") as f: + header = f.read(8) + assert header == b"\x89PNG\r\n\x1a\n" + + +def test_correct_number_of_visible_axes(image_df, tmp_path, monkeypatch): + from lars.util import image_grid + + captured = {} + + original_savefig = matplotlib.figure.Figure.savefig + + def mock_savefig(self, *args, **kwargs): + captured["fig"] = self + original_savefig(self, *args, **kwargs) + + monkeypatch.setattr(matplotlib.figure.Figure, "savefig", mock_savefig) + + out = tmp_path / "grid.png" + plot_label_images = image_grid.plot_label_images + plot_label_images(image_df, "stratiform", "stratiform", n=3, output_path=str(out), seed=0) + + fig = captured["fig"] + visible = [ax for ax in fig.axes if ax.get_visible()] + assert len(visible) == 3 + + +def test_suptitle_contains_labels(image_df, tmp_path, monkeypatch): + from lars.util import image_grid + + captured = {} + + original_savefig = matplotlib.figure.Figure.savefig + + def mock_savefig(self, *args, **kwargs): + captured["fig"] = self + original_savefig(self, *args, **kwargs) + + monkeypatch.setattr(matplotlib.figure.Figure, "savefig", mock_savefig) + + out = tmp_path / "grid.png" + image_grid.plot_label_images( + image_df, "stratiform", "stratiform", n=2, output_path=str(out), seed=0 + ) + + title = captured["fig"].texts[0].get_text() + assert "stratiform" in title.lower() + + +def test_raises_when_no_matching_rows(image_df, tmp_path): + from lars.util.image_grid import plot_label_images + + with pytest.raises(ValueError, match="No images found"): + plot_label_images( + image_df, "anvil", "anvil", n=1, output_path=str(tmp_path / "out.png") + ) + + +def test_raises_when_not_enough_images(image_df, tmp_path): + from lars.util.image_grid import plot_label_images + + # Only 3 stratiform/stratiform images exist + with pytest.raises(ValueError, match="only 3 match"): + plot_label_images( + image_df, "stratiform", "stratiform", n=10, + output_path=str(tmp_path / "out.png"), + ) + + +def test_case_insensitive_matching(image_df, tmp_path): + from lars.util.image_grid import plot_label_images + + out = tmp_path / "grid.png" + plot_label_images(image_df, "Stratiform", "STRATIFORM", n=2, output_path=str(out), seed=0) + assert out.exists() + + +def test_custom_column_names(tmp_path): + from lars.util.image_grid import plot_label_images + + rows = [] + for i in range(3): + path = _make_png(tmp_path / f"custom_{i}.png") + rows.append({"true": "stratiform", "pred": "stratiform", "file_path": str(path)}) + df = pd.DataFrame(rows) + + out = tmp_path / "grid.png" + plot_label_images( + df, "stratiform", "stratiform", n=2, output_path=str(out), + label_col="true", pred_col="pred", seed=0, + ) + assert out.exists() + + +def test_seed_reproducibility(image_df, tmp_path): + from lars.util.image_grid import plot_label_images + + out1 = tmp_path / "g1.png" + out2 = tmp_path / "g2.png" + plot_label_images(image_df, "stratiform", "stratiform", n=2, output_path=str(out1), seed=42) + plot_label_images(image_df, "stratiform", "stratiform", n=2, output_path=str(out2), seed=42) + + import hashlib + h1 = hashlib.md5(out1.read_bytes()).hexdigest() + h2 = hashlib.md5(out2.read_bytes()).hexdigest() + assert h1 == h2 + + +def test_n_equals_one(image_df, tmp_path): + from lars.util.image_grid import plot_label_images + + out = tmp_path / "single.png" + plot_label_images(image_df, "no precipitation", "no precipitation", n=1, output_path=str(out), seed=0) + assert out.exists()