Skip to content

Commit 1f7e25c

Browse files
author
Niru Maheswaranathan
authored
Merge pull request #25 from nirum/p5ztb0-codex/add-missing-docstrings
2 parents fa0c10c + f671e6d commit 1f7e25c

6 files changed

Lines changed: 99 additions & 28 deletions

File tree

src/jetplot/chart_utils.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
"""Plotting utils."""
22

3-
from collections.abc import Callable
43
from functools import partial, wraps
4+
from typing import Any, Literal
55

66
import numpy as np
77
from matplotlib import pyplot as plt
8+
from matplotlib.axes import Axes
89

910
__all__ = [
1011
"noticks",
@@ -114,14 +115,25 @@ def nospines(left=False, bottom=False, top=True, right=True, **kwargs):
114115
return ax
115116

116117

117-
def get_bounds(axis, ax=None):
118-
if ax is None:
119-
ax = plt.gca()
118+
def get_bounds(axis: Literal["x", "y"], ax: Axes | None = None) -> tuple[float, float]:
119+
"""Return the axis spine bounds for the given axis.
120120
121+
Parameters
122+
----------
123+
axis : str
124+
Axis to inspect, either ``"x"`` or ``"y"``.
125+
ax : matplotlib.axes.Axes | None, optional
126+
Axes object to inspect. If ``None``, the current axes are used.
121127
122-
Result = tuple[Callable[[], list[float]], Callable[[], list[str]], Callable[[], tuple[float, float]], str]
128+
Returns
129+
-------
130+
tuple[float, float]
131+
Lower and upper bounds of the axis spine.
132+
"""
133+
if ax is None:
134+
ax = plt.gca()
123135

124-
axis_map: dict[str, Result] = {
136+
axis_map: dict[str, Any] = {
125137
"x": (ax.get_xticks, ax.get_xticklabels, ax.get_xlim, "bottom"),
126138
"y": (ax.get_yticks, ax.get_yticklabels, ax.get_ylim, "left"),
127139
}
@@ -187,14 +199,20 @@ def identity(x):
187199

188200

189201
@axwrapper
190-
def yclamp(y0=None, y1=None, dt=None, **kwargs):
202+
def yclamp(
203+
y0: float | None = None,
204+
y1: float | None = None,
205+
dt: float | None = None,
206+
**kwargs,
207+
) -> Axes:
208+
"""Clamp the y-axis to evenly spaced tick marks."""
191209
ax = kwargs["ax"]
192210

193211
lims = ax.get_ylim()
194212
y0 = lims[0] if y0 is None else y0
195213
y1 = lims[1] if y1 is None else y1
196214

197-
ticks: list[float] = ax.get_yticks() # pyrefly: ignore
215+
ticks: list[float] = ax.get_yticks() # pyrefly: ignore
198216
dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt)
199217

200218
new_ticks = np.arange(dt * np.floor(y0 / dt), dt * (np.ceil(y1 / dt) + 1), dt)
@@ -206,14 +224,20 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs):
206224

207225

208226
@axwrapper
209-
def xclamp(x0=None, x1=None, dt=None, **kwargs):
227+
def xclamp(
228+
x0: float | None = None,
229+
x1: float | None = None,
230+
dt: float | None = None,
231+
**kwargs,
232+
) -> Axes:
233+
"""Clamp the x-axis to evenly spaced tick marks."""
210234
ax = kwargs["ax"]
211235

212236
lims = ax.get_xlim()
213237
x0 = lims[0] if x0 is None else x0
214238
x1 = lims[1] if x1 is None else x1
215239

216-
ticks: list[float] = ax.get_xticks() # pyrefly: ignore
240+
ticks: list[float] = ax.get_xticks() # pyrefly: ignore
217241
dt = float(np.mean(np.diff(ticks))) if dt is None else float(dt)
218242

219243
new_ticks = np.arange(dt * np.floor(x0 / dt), dt * (np.ceil(x1 / dt) + 1), dt)

src/jetplot/colors.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
import numpy as np
44
from matplotlib import cm
55
from matplotlib import pyplot as plt
6+
from matplotlib.axes import Axes
67
from matplotlib.colors import LinearSegmentedColormap, to_hex
8+
from matplotlib.figure import Figure
79
from matplotlib.typing import ColorType
10+
from numpy.typing import NDArray
811

912
from .chart_utils import noticks
1013

@@ -16,15 +19,18 @@ class Palette(list[ColorType]):
1619

1720
@property
1821
def hex(self):
22+
"""Return the palette colors as hexadecimal strings."""
1923
return Palette([to_hex(rgb) for rgb in self])
2024

2125
@property
22-
def cmap(self):
26+
def cmap(self) -> LinearSegmentedColormap:
27+
"""Return the palette as a Matplotlib colormap."""
2328
return LinearSegmentedColormap.from_list("", self)
2429

25-
def plot(self, figsize=(5, 1)):
30+
def plot(self, figsize: tuple[int, int] = (5, 1)) -> tuple[Figure, NDArray[Axes]]:
31+
"""Visualize the colors in the palette."""
2632
fig, axs = plt.subplots(1, len(self), figsize=figsize)
27-
for c, ax in zip(self, axs, strict=True): # pyrefly: ignore
33+
for c, ax in zip(self, axs, strict=True): # pyrefly: ignore
2834
ax.set_facecolor(c)
2935
ax.set_aspect("equal")
3036
noticks(ax=ax)
@@ -54,7 +60,13 @@ def cubehelix(
5460
return Palette(colors)
5561

5662

57-
def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0):
63+
def cmap_colors(
64+
cmap: str,
65+
n: int,
66+
vmin: float = 0.0,
67+
vmax: float = 1.0,
68+
) -> Palette:
69+
"""Extract ``n`` colors from a Matplotlib colormap."""
5870
return Palette(getattr(cm, cmap)(np.linspace(vmin, vmax, n)))
5971

6072

@@ -371,6 +383,8 @@ def cmap_colors(cmap: str, n: int, vmin: float = 0.0, vmax: float = 1.0):
371383

372384

373385
def rainbow(k: int) -> Palette:
386+
"""Return a palette of distinct colors from several base palettes."""
387+
374388
_colors = (
375389
blue,
376390
orange,

src/jetplot/images.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Image visualization tools."""
22

3+
from collections.abc import Callable
34
from functools import partial
45

56
import numpy as np
@@ -79,7 +80,15 @@ def img(
7980

8081

8182
@plotwrapper
82-
def fsurface(func, xrng=None, yrng=None, n=100, nargs=2, **kwargs):
83+
def fsurface(
84+
func: Callable[..., np.ndarray],
85+
xrng: tuple[float, float] | None = None,
86+
yrng: tuple[float, float] | None = None,
87+
n: int = 100,
88+
nargs: int = 2,
89+
**kwargs,
90+
) -> None:
91+
"""Plot a 2‑D function as a filled surface."""
8392
xrng = (-1, 1) if xrng is None else xrng
8493
yrng = xrng if yrng is None else yrng
8594

@@ -127,7 +136,7 @@ def cmat(
127136

128137
xs, ys = np.meshgrid(np.arange(num_cols), np.arange(num_rows), indexing="xy")
129138

130-
for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore
139+
for x, y, value in zip(xs.flat, ys.flat, arr.flat, strict=True): # pyrefly: ignore
131140
color = dark_color if (value <= theta) else light_color
132141
annot = f"{{:{fmt}}}".format(value)
133142
ax.text(x, y, annot, ha="center", va="center", color=color, fontsize=fontsize)

src/jetplot/plots.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from matplotlib.patches import Ellipse
55
from matplotlib.transforms import Affine2D
66
from matplotlib.typing import ColorType
7+
from matplotlib.figure import Figure
8+
from matplotlib.axes import Axes
9+
from collections.abc import Sequence
710
from numpy.typing import NDArray
811
from scipy.stats import gaussian_kde
912
from sklearn.covariance import EmpiricalCovariance, MinCovDet
@@ -35,7 +38,8 @@ def violinplot(
3538
showmeans=False,
3639
showquartiles=True,
3740
**kwargs,
38-
):
41+
) -> Axes:
42+
"""Violin plot with customizable elements."""
3943
_ = kwargs.pop("fig")
4044
ax = kwargs.pop("ax")
4145

@@ -86,6 +90,8 @@ def violinplot(
8690
zorder=20,
8791
)
8892

93+
return ax
94+
8995

9096
@plotwrapper
9197
def hist(*args, histtype="stepfilled", alpha=0.85, density=True, **kwargs):
@@ -247,11 +253,17 @@ def bar(
247253

248254

249255
@plotwrapper
250-
def lines(x, lines=None, cmap="viridis", **kwargs):
256+
def lines(
257+
x: NDArray[np.floating] | NDArray[np.integer],
258+
lines: list[NDArray[np.floating]] | None = None,
259+
cmap: str = "viridis",
260+
**kwargs,
261+
) -> Axes:
262+
"""Plot multiple lines using a color map."""
251263
ax = kwargs["ax"]
252264

253265
if lines is None:
254-
lines = list(x)
266+
lines = list(x) # pyrefly: ignore
255267
x = np.arange(len(lines[0]))
256268

257269
else:
@@ -261,6 +273,8 @@ def lines(x, lines=None, cmap="viridis", **kwargs):
261273
for line, color in zip(lines, colors, strict=False):
262274
ax.plot(x, line, color=color)
263275

276+
return ax
277+
264278

265279
@plotwrapper
266280
def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **kwargs):
@@ -279,7 +293,15 @@ def waterfall(x, ys, dy=1.0, pad=0.1, color="#444444", ec="#cccccc", ew=2.0, **k
279293

280294

281295
@figwrapper
282-
def ridgeline(t, xs, colors, edgecolor="#ffffff", ymax=0.6, **kwargs):
296+
def ridgeline(
297+
t: NDArray[np.floating],
298+
xs: Sequence[NDArray[np.floating]],
299+
colors: Sequence[ColorType],
300+
edgecolor: ColorType = "#ffffff",
301+
ymax: float = 0.6,
302+
**kwargs,
303+
) -> tuple[Figure, list[Axes]]:
304+
"""Stacked density plots reminiscent of a ridgeline plot."""
283305
fig = kwargs["fig"]
284306
axs = []
285307

src/jetplot/style.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def set_defaults(
140140

141141
def available_fonts() -> list[str]:
142142
"""Returns a list of available fonts."""
143-
return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore
143+
return sorted(set([f.name for f in fm.fontManager.ttflist])) # pyrefly: ignore
144144

145145

146146
def install_fonts(filepath: str):
@@ -150,7 +150,7 @@ def install_fonts(filepath: str):
150150
font_files = fm.findSystemFonts(fontpaths=[filepath])
151151

152152
for font_file in font_files:
153-
fm.fontManager.addfont(font_file) # pyrefly: ignore
153+
fm.fontManager.addfont(font_file) # pyrefly: ignore
154154

155155
new_fonts = set(available_fonts()) - original_fonts
156156
if new_fonts:

src/jetplot/timepiece.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,30 @@
99

1010

1111
class Stopwatch:
12-
def __init__(self, name=""):
12+
"""Simple timer utility for measuring code execution time."""
13+
14+
def __init__(self, name: str = "") -> None:
1315
self.name = name
1416
self.start = time.perf_counter()
1517
self.absolute_start = time.perf_counter()
1618

17-
def __str__(self):
19+
def __str__(self) -> str:
1820
return "\u231a Stopwatch for: " + self.name
1921

2022
@property
21-
def elapsed(self):
23+
def elapsed(self) -> float:
2224
current = time.perf_counter()
2325
elapsed = current - self.start
2426
self.start = time.perf_counter()
2527
return elapsed
2628

27-
def checkpoint(self, name=""):
29+
def checkpoint(self, name: str = "") -> None:
2830
print(f"{self.name} {name} took {hrtime(self.elapsed)}".strip())
2931

30-
def __enter__(self):
32+
def __enter__(self) -> "Stopwatch":
3133
return self
3234

33-
def __exit__(self, *_):
35+
def __exit__(self, *_: object) -> None:
3436
total = hrtime(time.perf_counter() - self.absolute_start)
3537
print(f"{self.name} Finished! \u2714\nTotal elapsed time: {total}")
3638

0 commit comments

Comments
 (0)