Skip to content

Commit c4ba6db

Browse files
author
Niru Maheswaranathan
authored
Merge pull request #27 from nirum/cleanup
- harden palette plotting, respect confusion-matrix opts, and keep generator inputs working across plot helpers - scope image colorbars to their axes, safeguard signal normalization/stable rank, and flush stopwatch timing output - document updated behaviours, satisfy pyrefly typing, and broaden tests (including ridgeline color handling and ellipse coverage)
2 parents 88ebce3 + 5fa8edf commit c4ba6db

12 files changed

Lines changed: 325 additions & 55 deletions

File tree

AGENTS.md

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
# Developer Instructions
1+
# Repository Guidelines
22

3-
## Code Style
4-
- Run `uv run ruff check .` before committing to ensure all Python code passes linting.
5-
- Write clear docstrings for all public functions and classes.
6-
- Use relative imports within the `jetplot` package.
3+
## Project Structure & Module Organization
4+
Source lives under `src/jetplot/`, split into focused modules such as `colors.py`, `plots.py`, and `style.py`. Add new utilities in cohesive files and keep imports relative (e.g. `from .colors import Palette`). Tests reside in `tests/` and should mirror the module layout; prefer descriptive folders like `tests/test_plots.py` to match the feature under test. Documentation content is maintained in `docs/` and rendered via MkDocs, while build artefacts land in `build/` and `dist/`.
75

8-
## Testing
9-
- Run `uv run pytest --cov=jetplot --cov-report=term` before committing to ensure all tests pass before submitting a PR.
10-
- Run `uv run pyrefly check` before committing to ensure all pyrefly type checking passes.
6+
## Build, Test, and Development Commands
7+
Run `uv run ruff check .` to ensure linting passes and style expectations are met. Use `uv run pytest --cov=jetplot --cov-report=term` for the full suite with coverage feedback, and `uv run pyrefly check` to validate typing across the package. When iterating on documentation, launch `uv run mkdocs serve` for a live preview. Regenerate distributions with `uv build` once changes are ready to publish.
118

12-
## PR Guidelines
13-
- Your pull request description must contain a **Summary** section explaining the changes.
14-
- Include a **Testing** section describing the commands used to run lint and tests along with their results.
9+
## Coding Style & Naming Conventions
10+
Follow PEP 8 defaults: four-space indentation, snake_case for functions and module-level variables, and CapWords for classes. Exported constants stay upper-case with underscores. Provide clear docstrings for every public function or class that describes inputs, return values, and side effects. Keep modules small, favor pure functions, and rely on the shared helpers already defined in `style.py` and `chart_utils.py` instead of duplicating logic.
11+
12+
## Testing Guidelines
13+
Write pytest-based tests alongside new functionality, naming files `test_<feature>.py` and individual tests `test_<behavior>`. Prefer parametrization to cover edge cases concisely. Aim to maintain or raise the coverage reported by the standard coverage command; unexpected drops should block merges. Include regression tests when fixing bugs so future refactors stay guarded.
14+
15+
## Commit & Pull Request Guidelines
16+
Commits use short, imperative summaries (e.g. `Add palette cycler helper`). Break large efforts into logical commits that pass linting and tests independently. Pull requests must include **Summary** and **Testing** sections outlining what changed and the exact commands run (`uv run ruff check .`, `uv run pytest --cov=jetplot --cov-report=term`, `uv run pyrefly check`). Link relevant issues and add screenshots when UI-facing artifacts such as documentation pages change.
17+
18+
## Documentation Tips
19+
Reference existing examples in `docs/` when adding guides, and keep code snippets synced with the APIs under `src/jetplot/`. Rebuild the site locally with `uv run mkdocs serve` after edits to verify navigation, formatting, and cross-links before submitting your changes.

justfile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
default: test
1+
default: format lint test typecheck
22

33
build:
44
uv build
@@ -12,6 +12,9 @@ docs:
1212
format:
1313
uv run ruff format
1414

15+
lint:
16+
uv run ruff check
17+
1518
typecheck:
1619
uv run pyrefly check
1720

src/jetplot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Jetplot is a set of useful utility functions for scientific python."""
22

3-
__version__ = "0.6.5"
3+
__version__ = "0.6.6"
44

55
from . import colors as c # noqa: F401
66
from .chart_utils import *

src/jetplot/colors.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,19 @@ def cmap(self) -> LinearSegmentedColormap:
3030

3131
def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, list[Axes]]:
3232
"""Visualize the colors in the palette."""
33+
if not self:
34+
raise ValueError("Palette has no colors to plot.")
35+
3336
fig, axs = plt.subplots(1, len(self), figsize=figsize)
34-
for c, ax in zip(self, axs, strict=True): # pyrefly: ignore
35-
ax.set_facecolor(c)
37+
axs_array = np.atleast_1d(axs)
38+
axes_list = [cast(Axes, ax) for ax in axs_array.flat]
39+
40+
for c, ax in zip(self, axes_list, strict=True):
41+
ax.set_facecolor(c) # pyrefly: ignore
3642
ax.set_aspect("equal")
3743
noticks(ax=ax)
3844

39-
return fig, cast(list[Axes], axs)
45+
return fig, axes_list
4046

4147

4248
def cubehelix(

src/jetplot/images.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, cast
66

77
import numpy as np
8-
from matplotlib import pyplot as plt
98
from matplotlib.axes import Axes
109
from matplotlib.image import AxesImage
1110
from matplotlib.ticker import FixedLocator
@@ -31,12 +30,21 @@ def img(
3130
"""Visualize a matrix as an image.
3231
3332
Args:
34-
img: array_like, The array to visualize.
35-
mode: string, One of 'div' for a diverging image, 'seq' for
36-
sequential, 'cov' for covariance matrices, or 'corr' for
37-
correlation matrices (default: 'div').
38-
cmap: string, Colormap to use.
39-
aspect: string, Either 'equal' or 'auto'
33+
data: Array to visualize.
34+
mode: One of ``"div"``, ``"seq"``, ``"cov"``, or ``"corr"``.
35+
cmap: Matplotlib colormap name. Mode defaults are used when ``None``.
36+
aspect: Either ``"equal"`` or ``"auto"``.
37+
vmin: Lower bound for normalization.
38+
vmax: Upper bound for normalization.
39+
cbar: Whether to draw a colorbar attached to the provided axes.
40+
interpolation: Interpolation strategy passed to ``imshow``.
41+
42+
Raises:
43+
ValueError: If ``mode`` is not recognized.
44+
45+
Notes:
46+
When ``cbar`` is ``True``, the colorbar is added to the supplied axes/figure
47+
so multi-axes layouts keep their layout intact.
4048
"""
4149
# work with a copy of the original image data
4250
img = np.squeeze(data.copy())
@@ -68,16 +76,17 @@ def img(
6876
raise ValueError("Unrecognized mode: '" + mode + "'")
6977

7078
# make the image
71-
im = kwargs["ax"].imshow(
79+
fig, ax = kwargs["fig"], kwargs["ax"]
80+
im = ax.imshow(
7281
img, cmap=cmap, interpolation=interpolation, vmin=vmin, vmax=vmax, aspect=aspect
7382
)
7483

7584
# colorbar
7685
if cbar:
77-
plt.colorbar(im)
86+
fig.colorbar(im, ax=ax)
7887

7988
# clear ticks
80-
noticks(ax=kwargs["ax"])
89+
noticks(ax=ax)
8190

8291
return im
8392

@@ -131,24 +140,60 @@ def cmat(
131140
vmax: float = 1.0,
132141
**kwargs: Any,
133142
) -> tuple[AxesImage, Axes]:
134-
"""Plot confusion matrix."""
143+
"""Plot a confusion matrix with optional annotations.
144+
145+
Args:
146+
arr: Square matrix of scores in [0, 1].
147+
labels: Optional axis labels. Must match matrix dimensions.
148+
annot: Whether to draw text annotations for each cell.
149+
cmap: Colormap used for the heatmap.
150+
cbar: Whether to include a colorbar.
151+
fmt: Format string applied to annotation labels.
152+
dark_color: Text color used when ``value <= theta``.
153+
light_color: Text color used when ``value > theta``.
154+
grid_color: Grid line color.
155+
theta: Threshold for choosing between ``dark_color`` and ``light_color``.
156+
label_fontsize: Tick label font size.
157+
fontsize: Annotation font size.
158+
vmin: Lower bound for normalization.
159+
vmax: Upper bound for normalization.
160+
161+
Raises:
162+
ValueError: If labels are provided but do not match the matrix dimensions.
163+
"""
135164
num_rows, num_cols = arr.shape
136165

166+
label_list: list[str] | None = None
167+
if labels is not None:
168+
label_list = list(labels)
169+
if len(label_list) != num_cols or num_rows != num_cols:
170+
raise ValueError(
171+
"Labels must match confusion matrix dimensions and matrix must be square."
172+
)
173+
137174
ax = kwargs.pop("ax")
138175
cb = imv(arr, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar=cbar)
139176

140177
xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy")
141178

142-
for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore
143-
color = dark_color if (value <= theta) else light_color
144-
label = f"{{:{fmt}}}".format(value)
145-
ax.text(x, y, label, ha="center", va="center", color=color, fontsize=fontsize)
146-
147-
if labels is not None:
179+
if annot:
180+
for x, y, value in zip( # pyrefly: ignore
181+
xs.flat, # pyrefly: ignore
182+
ys.flat,
183+
arr.flat,
184+
strict=True, # pyrefly: ignore
185+
):
186+
color = dark_color if (value <= theta) else light_color
187+
label = f"{{:{fmt}}}".format(value)
188+
ax.text(
189+
x, y, label, ha="center", va="center", color=color, fontsize=fontsize
190+
)
191+
192+
if label_list is not None:
148193
ax.set_xticks(np.arange(num_cols))
149-
ax.set_xticklabels(labels, rotation=90, fontsize=label_fontsize)
194+
ax.set_xticklabels(label_list, rotation=90, fontsize=label_fontsize)
150195
ax.set_yticks(np.arange(num_rows))
151-
ax.set_yticklabels(labels, fontsize=label_fontsize)
196+
ax.set_yticklabels(label_list, fontsize=label_fontsize)
152197

153198
ax.xaxis.set_minor_locator(FixedLocator((np.arange(num_cols) - 0.5).tolist()))
154199

src/jetplot/plots.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,19 @@ def errorplot(
163163
"""Plot a line with error bars."""
164164
ax = kwargs["ax"]
165165

166-
if np.isscalar(yerr) or len(yerr) == len(y): # pyrefly: ignore
166+
if np.isscalar(yerr):
167167
ymin = y - yerr # pyrefly: ignore
168168
ymax = y + yerr # pyrefly: ignore
169-
elif len(yerr) == 2:
169+
elif isinstance(yerr, tuple):
170+
if len(yerr) != 2:
171+
raise ValueError("Invalid yerr tuple length: ", yerr)
170172
ymin, ymax = yerr # pyrefly: ignore
171173
else:
172-
raise ValueError("Invalid yerr value: ", yerr)
174+
yerr_array = np.asarray(yerr)
175+
if yerr_array.shape != y.shape:
176+
raise ValueError("Invalid yerr value: ", yerr)
177+
ymin = y - yerr_array
178+
ymax = y + yerr_array
173179

174180
if method == "line":
175181
ax.plot(x, y, fmt, color=color, linewidth=4, clip_on=clip_on)
@@ -295,11 +301,27 @@ def waterfall(
295301
ew: float = 2.0,
296302
**kwargs: Any,
297303
) -> None:
298-
"""Waterfall plot."""
304+
"""Waterfall plot for stacked sequences.
305+
306+
Args:
307+
x: Common x-axis samples shared by every series.
308+
ys: Iterable of y-series. Generators are supported and are consumed once.
309+
dy: Vertical scaling applied to each successive series.
310+
pad: Offset applied so the outline sits slightly above the fill.
311+
color: Fill color for each series.
312+
ec: Edge color for the outline.
313+
ew: Edge line width.
314+
315+
Raises:
316+
ValueError: If ``ys`` yields no series.
317+
"""
299318
ax = kwargs["ax"]
300-
total = cast(int, len(ys))
319+
ys_list = list(ys)
320+
if not ys_list:
321+
raise ValueError("ys must contain at least one series.")
322+
total = len(ys_list)
301323

302-
for index, y in enumerate(ys):
324+
for index, y in enumerate(ys_list):
303325
zorder = total - index
304326
y = y * dy + index
305327
ax.plot(x, y + pad, color=ec, clip_on=False, lw=ew, zorder=zorder)
@@ -318,16 +340,40 @@ def ridgeline(
318340
ymax: float = 0.6,
319341
**kwargs: Any,
320342
) -> tuple[Figure, list[Axes]]:
321-
"""Stacked density plots reminiscent of a ridgeline plot."""
343+
"""Stacked density plots reminiscent of a ridgeline plot.
344+
345+
Args:
346+
t: Grid used when evaluating the kernel density estimate.
347+
xs: Iterable of 1-D samples. Accepts generators and consumes them once.
348+
colors: Iterable of colors. Must provide at least as many entries as ``xs``.
349+
edgecolor: Line color used for the outline.
350+
ymax: Upper y-limit for each subplot.
351+
352+
Raises:
353+
ValueError: If ``xs`` is empty or ``colors`` provides too few values.
354+
"""
322355
fig = kwargs["fig"]
356+
xs_list = list(xs)
357+
colors_iter = iter(colors)
358+
359+
if not xs_list:
360+
raise ValueError("xs must contain at least one series.")
361+
323362
axs = []
324363

325-
for k, (x, c) in enumerate(zip(xs, colors, strict=False)):
326-
ax = fig.add_subplot(cast(int, len(xs)), 1, k + 1)
364+
for k, x in enumerate(xs_list):
365+
try:
366+
palette_color = next(colors_iter)
367+
except StopIteration as exc:
368+
raise ValueError(
369+
"colors must provide at least as many items as xs."
370+
) from exc
371+
372+
ax = fig.add_subplot(len(xs_list), 1, k + 1)
327373
y = gaussian_kde(x).evaluate(t)
328-
ax.fill_between(t, y, color=c, clip_on=False)
374+
ax.fill_between(t, y, color=palette_color, clip_on=False)
329375
ax.plot(t, y, color=edgecolor, clip_on=False)
330-
ax.axhline(0.0, lw=2, color=c, clip_on=False)
376+
ax.axhline(0.0, lw=2, color=palette_color, clip_on=False)
331377

332378
ax.set_xlim(t[0], t[-1])
333379
ax.set_xticks([])
@@ -378,7 +424,8 @@ def ellipse(
378424
-------
379425
matplotlib.patches.Ellipse
380426
"""
381-
ax = cast(Axes, kwargs.get("ax"))
427+
ax = cast(Axes, kwargs.pop("ax", None))
428+
kwargs.pop("fig", None)
382429

383430
if x.size != y.size:
384431
raise ValueError("x and y must be the same size")
@@ -419,4 +466,4 @@ def ellipse(
419466
)
420467

421468
ellipse.set_transform(transform + ax.transData) # pyrefly: ignore
422-
return ax.add_patch(ellipse)
469+
return cast(Ellipse, ax.add_patch(ellipse))

src/jetplot/signals.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,21 @@ def smooth(x: ArrayLike, sigma: float = 1.0, axis: int = 0) -> NDArray[np.floati
2323
Returns:
2424
xs: array_like, A smoothed version of the input signal
2525
"""
26-
return gaussian_filter1d(x, sigma, axis=axis)
26+
arr = np.asarray(x)
27+
return gaussian_filter1d(arr, sigma, axis=axis)
2728

2829

2930
def stable_rank(X: NDArray[np.floating[Any]]) -> float:
30-
"""Computes the stable rank of a matrix"""
31-
assert X.ndim == 2, "X must be a matrix"
31+
"""Compute the stable rank of a matrix.
32+
33+
Args:
34+
X: Two-dimensional array representing a matrix.
35+
36+
Raises:
37+
ValueError: If ``X`` is not two-dimensional.
38+
"""
39+
if X.ndim != 2:
40+
raise ValueError("X must be a matrix")
3241

3342
# pyrefly: ignore
3443
svals_sq = np.linalg.svd(X, compute_uv=False, full_matrices=False) ** 2
@@ -98,6 +107,22 @@ def normalize(
98107
norm: Function that computes the norm (Default: np.linalg.norm).
99108
100109
Returns:
101-
Xn: Arrays that have been normalized using to the given function.
110+
Normalized array with the same shape as ``X``.
111+
112+
Notes:
113+
Any vectors whose norm is zero remain zero after normalization instead of
114+
producing NaNs or infinities.
102115
"""
103-
return np.asarray(X) / norm(X, axis=axis, keepdims=True)
116+
arr = np.asarray(X, dtype=float)
117+
denom = norm(arr, axis=axis, keepdims=True)
118+
zero_mask = denom == 0
119+
120+
# Avoid divide-by-zero warnings and keep zeros in place by dividing only where safe.
121+
safe_denom = np.where(zero_mask, 1.0, denom)
122+
normalized = np.zeros_like(arr, dtype=float)
123+
np.divide(arr, safe_denom, out=normalized, where=~zero_mask)
124+
125+
if np.any(zero_mask):
126+
normalized = np.where(zero_mask, 0.0, normalized)
127+
128+
return normalized

src/jetplot/timepiece.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ def elapsed(self) -> float:
2929
return elapsed
3030

3131
def checkpoint(self, name: str = "") -> None:
32-
print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip())
32+
print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip(), flush=True)
3333

3434
def __enter__(self) -> "Stopwatch":
3535
return self
3636

3737
def __exit__(self, *_: object) -> None:
3838
total = hrtime(time.perf_counter() - self.absolute_start)
39-
print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}")
39+
print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}", flush=True)
4040

4141

4242
def hrtime(t: float) -> str:

0 commit comments

Comments
 (0)