11"""Plotting utils."""
22
3- from collections .abc import Callable
43from functools import partial , wraps
4+ from typing import Any , Literal
55
66import numpy as np
77from matplotlib import pyplot as plt
8+ from matplotlib .axes import Axes
89
910__all__ = [
1011 "noticks" ,
@@ -114,14 +115,25 @@ def nospines(left=False, bottom=False, top=True, right=True, **kwargs):
114115 return ax
115116
116117
117- def get_bounds (axis , ax = None ):
118- if ax is None :
119- ax = plt .gca ()
118+ def get_bounds (axis : Literal ["x" , "y" ], ax : Axes | None = None ) -> tuple [float , float ]:
119+ """Return the axis spine bounds for the given axis.
120120
121+ Parameters
122+ ----------
123+ axis : str
124+ Axis to inspect, either ``"x"`` or ``"y"``.
125+ ax : matplotlib.axes.Axes | None, optional
126+ Axes object to inspect. If ``None``, the current axes are used.
121127
122- Result = tuple [Callable [[], list [float ]], Callable [[], list [str ]], Callable [[], tuple [float , float ]], str ]
128+ Returns
129+ -------
130+ tuple[float, float]
131+ Lower and upper bounds of the axis spine.
132+ """
133+ if ax is None :
134+ ax = plt .gca ()
123135
124- axis_map : dict [str , Result ] = {
136+ axis_map : dict [str , Any ] = {
125137 "x" : (ax .get_xticks , ax .get_xticklabels , ax .get_xlim , "bottom" ),
126138 "y" : (ax .get_yticks , ax .get_yticklabels , ax .get_ylim , "left" ),
127139 }
@@ -187,14 +199,20 @@ def identity(x):
187199
188200
189201@axwrapper
190- def yclamp (y0 = None , y1 = None , dt = None , ** kwargs ):
202+ def yclamp (
203+ y0 : float | None = None ,
204+ y1 : float | None = None ,
205+ dt : float | None = None ,
206+ ** kwargs ,
207+ ) -> Axes :
208+ """Clamp the y-axis to evenly spaced tick marks."""
191209 ax = kwargs ["ax" ]
192210
193211 lims = ax .get_ylim ()
194212 y0 = lims [0 ] if y0 is None else y0
195213 y1 = lims [1 ] if y1 is None else y1
196214
197- ticks : list [float ] = ax .get_yticks () # pyrefly: ignore
215+ ticks : list [float ] = ax .get_yticks () # pyrefly: ignore
198216 dt = float (np .mean (np .diff (ticks ))) if dt is None else float (dt )
199217
200218 new_ticks = np .arange (dt * np .floor (y0 / dt ), dt * (np .ceil (y1 / dt ) + 1 ), dt )
@@ -206,14 +224,20 @@ def yclamp(y0=None, y1=None, dt=None, **kwargs):
206224
207225
208226@axwrapper
209- def xclamp (x0 = None , x1 = None , dt = None , ** kwargs ):
227+ def xclamp (
228+ x0 : float | None = None ,
229+ x1 : float | None = None ,
230+ dt : float | None = None ,
231+ ** kwargs ,
232+ ) -> Axes :
233+ """Clamp the x-axis to evenly spaced tick marks."""
210234 ax = kwargs ["ax" ]
211235
212236 lims = ax .get_xlim ()
213237 x0 = lims [0 ] if x0 is None else x0
214238 x1 = lims [1 ] if x1 is None else x1
215239
216- ticks : list [float ] = ax .get_xticks () # pyrefly: ignore
240+ ticks : list [float ] = ax .get_xticks () # pyrefly: ignore
217241 dt = float (np .mean (np .diff (ticks ))) if dt is None else float (dt )
218242
219243 new_ticks = np .arange (dt * np .floor (x0 / dt ), dt * (np .ceil (x1 / dt ) + 1 ), dt )
0 commit comments