Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lars/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .confusion_matrix import plot_confusion_matrix, calculate_cohen_kappa # noqa: F401
from .image_grid import plot_label_images # noqa: F401
91 changes: 91 additions & 0 deletions lars/util/image_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import math
import random

import matplotlib.pyplot as plt
import matplotlib.image as mpimg


def plot_label_images(
df,
hand_label,
llm_label,
n,
output_path,
label_col="label",
pred_col="llm_label",
seed=None,
):
"""
Randomly sample images where hand label and LLM label match the given values,
plot them in a single grid figure, and save to a PNG file.

Parameters
----------
df : pd.DataFrame
DataFrame with at minimum columns for file paths, hand labels, and LLM labels.
hand_label : str
The human-annotated label to filter on.
llm_label : str
The LLM-generated label to filter on.
n : int
Number of images to sample and display.
output_path : str
File path for the saved PNG (e.g. "output/grid.png").
label_col : str
Column name for hand labels. Default is "label".
pred_col : str
Column name for LLM labels. Default is "llm_label".
seed : int or None
Random seed for reproducibility. Default is None.

Raises
------
ValueError
If fewer matching images exist than requested.
"""
mask = (df[label_col].str.lower() == hand_label.lower()) & (
df[pred_col].str.lower() == llm_label.lower()
)
subset = df[mask]

if len(subset) == 0:
raise ValueError(
f"No images found where {label_col}='{hand_label}' and {pred_col}='{llm_label}'."
)
if len(subset) < n:
raise ValueError(
f"Requested {n} images but only {len(subset)} match the given labels."
)

rng = random.Random(seed)
sampled = subset.sample(n=n, random_state=seed if seed is not None else rng.randint(0, 2**31))

ncols = math.ceil(math.sqrt(n))
nrows = math.ceil(n / ncols)

fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))

# Normalise axes to always be a flat list
if n == 1:
axes = [axes]
else:
axes = list(axes.flat)

for ax, (idx, row) in zip(axes, sampled.iterrows()):
img = mpimg.imread(row["file_path"])
ax.imshow(img)
ax.set_title(str(idx), fontsize=8)
ax.axis("off")

# Hide any unused subplot panels
for ax in axes[n:]:
ax.set_visible(False)

fig.suptitle(
f"Hand label: '{hand_label}' | LLM label: '{llm_label}' | n={n}",
fontsize=12,
y=1.01,
)
fig.tight_layout()
fig.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close(fig)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ classifiers = [
]
dependencies = ["xradar", "scikit-learn", "python-dotenv", "aiohttp", "asksageclient", "pip_system_certs", "requests"]

[tool.setuptools.packages.find]
where = ["."]
include = ["lars"]
exclude = ["notebooks*", "example*"]

[project.optional-dependencies]
dev = ["pytest>=6.0", "pytest-asyncio>=0.21", "black", "flake8", "openai", "xradar", "python-dotenv",
"scikit-learn", "cmweather", "torchvision", "torch", "aiohttp", "matplotlib", "pandas",
Expand Down
181 changes: 181 additions & 0 deletions tests/test_image_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import pytest
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

matplotlib.use("Agg")


# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

def _make_png(path):
"""Save a tiny 4x4 solid-colour PNG to *path* and return the path."""
fig, ax = plt.subplots(figsize=(0.1, 0.1))
ax.imshow(np.zeros((4, 4, 3), dtype=np.uint8))
ax.axis("off")
fig.savefig(path)
plt.close(fig)
return path


@pytest.fixture()
def image_df(tmp_path):
"""DataFrame with 8 labelled images backed by real PNG files."""
rows = [
# hand_label, llm_label, filename
("stratiform", "stratiform", "img_00.png"),
("stratiform", "stratiform", "img_01.png"),
("stratiform", "stratiform", "img_02.png"),
("stratiform", "convective", "img_03.png"),
("convective", "convective", "img_04.png"),
("convective", "convective", "img_05.png"),
("convective", "stratiform", "img_06.png"),
("no precipitation", "no precipitation", "img_07.png"),
]
records = []
for hand, llm, fname in rows:
path = _make_png(tmp_path / fname)
records.append({"label": hand, "llm_label": llm, "file_path": str(path)})
return pd.DataFrame(records)


@pytest.fixture(autouse=True)
def close_figures():
yield
plt.close("all")


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------

def test_output_file_is_created(image_df, tmp_path):
from lars.util.image_grid import plot_label_images

out = tmp_path / "grid.png"
plot_label_images(image_df, "stratiform", "stratiform", n=2, output_path=str(out), seed=0)
assert out.exists()


def test_output_is_valid_png(image_df, tmp_path):
from lars.util.image_grid import plot_label_images

out = tmp_path / "grid.png"
plot_label_images(image_df, "stratiform", "stratiform", n=2, output_path=str(out), seed=0)
with open(out, "rb") as f:
header = f.read(8)
assert header == b"\x89PNG\r\n\x1a\n"


def test_correct_number_of_visible_axes(image_df, tmp_path, monkeypatch):
from lars.util import image_grid

captured = {}

original_savefig = matplotlib.figure.Figure.savefig

def mock_savefig(self, *args, **kwargs):
captured["fig"] = self
original_savefig(self, *args, **kwargs)

monkeypatch.setattr(matplotlib.figure.Figure, "savefig", mock_savefig)

out = tmp_path / "grid.png"
plot_label_images = image_grid.plot_label_images
plot_label_images(image_df, "stratiform", "stratiform", n=3, output_path=str(out), seed=0)

fig = captured["fig"]
visible = [ax for ax in fig.axes if ax.get_visible()]
assert len(visible) == 3


def test_suptitle_contains_labels(image_df, tmp_path, monkeypatch):
from lars.util import image_grid

captured = {}

original_savefig = matplotlib.figure.Figure.savefig

def mock_savefig(self, *args, **kwargs):
captured["fig"] = self
original_savefig(self, *args, **kwargs)

monkeypatch.setattr(matplotlib.figure.Figure, "savefig", mock_savefig)

out = tmp_path / "grid.png"
image_grid.plot_label_images(
image_df, "stratiform", "stratiform", n=2, output_path=str(out), seed=0
)

title = captured["fig"].texts[0].get_text()
assert "stratiform" in title.lower()


def test_raises_when_no_matching_rows(image_df, tmp_path):
from lars.util.image_grid import plot_label_images

with pytest.raises(ValueError, match="No images found"):
plot_label_images(
image_df, "anvil", "anvil", n=1, output_path=str(tmp_path / "out.png")
)


def test_raises_when_not_enough_images(image_df, tmp_path):
from lars.util.image_grid import plot_label_images

# Only 3 stratiform/stratiform images exist
with pytest.raises(ValueError, match="only 3 match"):
plot_label_images(
image_df, "stratiform", "stratiform", n=10,
output_path=str(tmp_path / "out.png"),
)


def test_case_insensitive_matching(image_df, tmp_path):
from lars.util.image_grid import plot_label_images

out = tmp_path / "grid.png"
plot_label_images(image_df, "Stratiform", "STRATIFORM", n=2, output_path=str(out), seed=0)
assert out.exists()


def test_custom_column_names(tmp_path):
from lars.util.image_grid import plot_label_images

rows = []
for i in range(3):
path = _make_png(tmp_path / f"custom_{i}.png")
rows.append({"true": "stratiform", "pred": "stratiform", "file_path": str(path)})
df = pd.DataFrame(rows)

out = tmp_path / "grid.png"
plot_label_images(
df, "stratiform", "stratiform", n=2, output_path=str(out),
label_col="true", pred_col="pred", seed=0,
)
assert out.exists()


def test_seed_reproducibility(image_df, tmp_path):
from lars.util.image_grid import plot_label_images

out1 = tmp_path / "g1.png"
out2 = tmp_path / "g2.png"
plot_label_images(image_df, "stratiform", "stratiform", n=2, output_path=str(out1), seed=42)
plot_label_images(image_df, "stratiform", "stratiform", n=2, output_path=str(out2), seed=42)

import hashlib
h1 = hashlib.md5(out1.read_bytes()).hexdigest()
h2 = hashlib.md5(out2.read_bytes()).hexdigest()
assert h1 == h2


def test_n_equals_one(image_df, tmp_path):
from lars.util.image_grid import plot_label_images

out = tmp_path / "single.png"
plot_label_images(image_df, "no precipitation", "no precipitation", n=1, output_path=str(out), seed=0)
assert out.exists()
Loading