Skip to content

Commit d40202b

Browse files
committed
mathutils/image items: code refactoring
1 parent 93d687b commit d40202b

File tree

5 files changed

+119
-68
lines changed

5 files changed

+119
-68
lines changed

plotpy/coords.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,32 @@
77
Plot coordinates
88
----------------
99
10+
Overview
11+
^^^^^^^^
12+
1013
The :mod:`plotpy.coords` module provides functions to convert coordinates
1114
between canvas and axes coordinates systems.
1215
16+
The following functions are available:
17+
18+
* :py:func:`.canvas_to_axes`
19+
* :py:func:`.axes_to_canvas`
20+
* :py:func:`.pixelround`
21+
1322
Reference
14-
~~~~~~~~~
23+
^^^^^^^^^
1524
1625
.. autofunction:: canvas_to_axes
1726
.. autofunction:: axes_to_canvas
27+
.. autofunction:: pixelround
1828
"""
1929

2030
from __future__ import annotations
2131

2232
from typing import TYPE_CHECKING
2333

34+
import numpy as np
35+
2436
if TYPE_CHECKING: # pragma: no cover
2537
from qtpy.QtCore import QPointF
2638
from qwt import QwtPlot, QwtPlotItem
@@ -59,3 +71,23 @@ def axes_to_canvas(item: QwtPlotItem, x: float, y: float) -> tuple[float, float]
5971
plot: QwtPlot = item.plot()
6072
ax, ay = item.xAxis(), item.yAxis()
6173
return plot.transform(ax, x), plot.transform(ay, y)
74+
75+
76+
def pixelround(x: float, corner: str | None = None) -> int:
77+
"""Get pixel index from pixel coordinate
78+
79+
Args:
80+
x: Pixel coordinate
81+
corner: None (not a corner), 'TL' (top-left corner),
82+
'BR' (bottom-right corner)
83+
84+
Returns:
85+
int: Pixel index
86+
"""
87+
assert corner is None or corner in ("TL", "BR")
88+
if corner is None:
89+
return np.floor(x)
90+
elif corner == "BR":
91+
return np.ceil(x)
92+
elif corner == "TL":
93+
return np.floor(x)

plotpy/items/image/base.py

Lines changed: 5 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from plotpy.config import _
2929
from plotpy.constants import LUTAlpha
30+
from plotpy.coords import pixelround
3031
from plotpy.interfaces import (
3132
IBaseImageItem,
3233
IBasePlotItem,
@@ -42,6 +43,7 @@
4243
)
4344
from plotpy.items.shape.rectangle import RectangleShape
4445
from plotpy.lutrange import lut_range_threshold
46+
from plotpy.mathutils.arrayfuncs import get_nan_range
4547
from plotpy.mathutils.colormap import FULLRANGE, get_cmap, get_cmap_name
4648
from plotpy.styles.image import RawImageParam
4749

@@ -60,60 +62,6 @@
6062
LUT_MAX = float(LUT_SIZE - 1)
6163

6264

63-
def _nanmin(data: np.ndarray) -> float:
64-
"""Return minimum value of data, ignoring NaNs
65-
66-
Args:
67-
data: Data array
68-
69-
Returns:
70-
float: Minimum value of data, ignoring NaNs
71-
"""
72-
if isinstance(data, np.ma.MaskedArray):
73-
data = data.data
74-
if data.dtype.name in ("float32", "float64", "float128"):
75-
return np.nanmin(data)
76-
else:
77-
return data.min()
78-
79-
80-
def _nanmax(data: np.ndarray) -> float:
81-
"""Return maximum value of data, ignoring NaNs
82-
83-
Args:
84-
data: Data array
85-
86-
Returns:
87-
float: Maximum value of data, ignoring NaNs
88-
"""
89-
if isinstance(data, np.ma.MaskedArray):
90-
data = data.data
91-
if data.dtype.name in ("float32", "float64", "float128"):
92-
return np.nanmax(data)
93-
else:
94-
return data.max()
95-
96-
97-
def pixelround(x: float, corner: str | None = None) -> int:
98-
"""Get pixel index from pixel coordinate
99-
100-
Args:
101-
x: Pixel coordinate
102-
corner: None (not a corner), 'TL' (top-left corner),
103-
'BR' (bottom-right corner)
104-
105-
Returns:
106-
int: Pixel index
107-
"""
108-
assert corner is None or corner in ("TL", "BR")
109-
if corner is None:
110-
return np.floor(x)
111-
elif corner == "BR":
112-
return np.ceil(x)
113-
elif corner == "TL":
114-
return np.floor(x)
115-
116-
11765
class BaseImageItem(QwtPlotItem):
11866
"""Base class for image items
11967
@@ -391,7 +339,7 @@ def set_data(
391339
if lut_range is not None:
392340
_min, _max = lut_range
393341
else:
394-
_min, _max = _nanmin(data), _nanmax(data)
342+
_min, _max = get_nan_range(data)
395343

396344
self.data = data
397345
self.histogram_cache = None
@@ -592,7 +540,7 @@ def get_lut_range_full(self) -> tuple[float, float]:
592540
Returns:
593541
tuple[float, float]: Lut range, tuple(min, max)
594542
"""
595-
return _nanmin(self.data), _nanmax(self.data)
543+
return get_nan_range(self.data)
596544

597545
def get_lut_range_max(self) -> tuple[float, float]:
598546
"""Get maximum range for this dataset
@@ -1019,8 +967,7 @@ def get_histogram(self, nbins: int) -> tuple[np.ndarray, np.ndarray]:
1019967
res = np.histogram(self.data[~np.isnan(self.data)], nbins)
1020968
else:
1021969
# TODO: _histogram is faster, but caching is buggy in this version
1022-
_min = _nanmin(self.data)
1023-
_max = _nanmax(self.data)
970+
_min, _max = get_nan_range(self.data)
1024971
if self.data.dtype in (np.float64, np.float32):
1025972
bins = np.unique(
1026973
np.array(

plotpy/items/image/image_items.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from plotpy import io
1414
from plotpy.config import _
1515
from plotpy.constants import LUTAlpha
16-
from plotpy.coords import canvas_to_axes
16+
from plotpy.coords import canvas_to_axes, pixelround
1717
from plotpy.interfaces import (
1818
IBaseImageItem,
1919
IBasePlotItem,
@@ -26,9 +26,8 @@
2626
ITrackableItemType,
2727
IVoiImageItemType,
2828
)
29-
from plotpy.items.image.base import RawImageItem, pixelround
29+
from plotpy.items.image.base import RawImageItem
3030
from plotpy.items.image.filter import XYImageFilterItem, to_bins
31-
from plotpy.mathutils.geometry import colvector
3231
from plotpy.styles.image import ImageParam, RGBImageParam, XYImageParam
3332

3433
if TYPE_CHECKING: # pragma: no cover

plotpy/items/image/misc.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
ITrackableItemType,
2626
IVoiImageItemType,
2727
)
28-
from plotpy.items.image.base import BaseImageItem, RawImageItem, _nanmax, _nanmin
28+
from plotpy.items.image.base import BaseImageItem, RawImageItem
2929
from plotpy.items.image.transform import TrImageItem
30+
from plotpy.mathutils.arrayfuncs import get_nan_range
3031
from plotpy.styles import Histogram2DParam, ImageParam, QuadGridParam
3132

3233
try:
@@ -136,7 +137,7 @@ def set_data(
136137
if lut_range is not None:
137138
_min, _max = lut_range
138139
else:
139-
_min, _max = _nanmin(data), _nanmax(data)
140+
_min, _max = get_nan_range(data)
140141

141142
self.data = data
142143
self.histogram_cache = None
@@ -330,8 +331,7 @@ def draw_image(
330331
else:
331332
self.data[self.data_tmp == 0.0] = np.nan
332333
if self.histparam.auto_lut:
333-
nmin = _nanmin(self.data)
334-
nmax = _nanmax(self.data)
334+
nmin, nmax = get_nan_range(self.data)
335335
self.set_lut_range([nmin, nmax])
336336
self.plot().update_colormap_axis(self)
337337
src_rect = (0, 0, self.nx_bins, self.ny_bins)
@@ -417,8 +417,7 @@ def get_histogram(self, nbins: int) -> tuple[np.ndarray, np.ndarray]:
417417
"""
418418
if self.data is None:
419419
return [0], [0, 1]
420-
_min = _nanmin(self.data)
421-
_max = _nanmax(self.data)
420+
_min, _max = get_nan_range(self.data)
422421
if self.data.dtype in (np.float64, np.float32):
423422
bins = np.unique(
424423
np.array(np.linspace(_min, _max, nbins + 1), dtype=self.data.dtype)

plotpy/mathutils/arrayfuncs.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Array functions
5+
---------------
6+
7+
Overview
8+
^^^^^^^^
9+
10+
The :py:mod:`.arrayfuncs` module provides miscellaneous array functions.
11+
12+
The following functions are available:
13+
14+
* :py:func:`.get_nan_min`
15+
* :py:func:`.get_nan_max`
16+
* :py:func:`.get_nan_range`
17+
18+
Reference
19+
^^^^^^^^^
20+
21+
.. autofunction:: get_nan_min
22+
.. autofunction:: get_nan_max
23+
.. autofunction:: get_nan_range
24+
"""
25+
26+
from __future__ import annotations
27+
28+
import numpy as np
29+
30+
31+
def get_nan_min(data: np.ndarray | np.ma.MaskedArray) -> float:
32+
"""Return minimum value of data, ignoring NaNs
33+
34+
Args:
35+
data: Data array (or masked array)
36+
37+
Returns:
38+
float: Minimum value of data, ignoring NaNs
39+
"""
40+
if isinstance(data, np.ma.MaskedArray):
41+
data = data.data
42+
if data.dtype.name in ("float32", "float64", "float128"):
43+
return np.nanmin(data)
44+
else:
45+
return data.min()
46+
47+
48+
def get_nan_max(data: np.ndarray | np.ma.MaskedArray) -> float:
49+
"""Return maximum value of data, ignoring NaNs
50+
51+
Args:
52+
data: Data array (or masked array)
53+
54+
Returns:
55+
float: Maximum value of data, ignoring NaNs
56+
"""
57+
if isinstance(data, np.ma.MaskedArray):
58+
data = data.data
59+
if data.dtype.name in ("float32", "float64", "float128"):
60+
return np.nanmax(data)
61+
else:
62+
return data.max()
63+
64+
65+
def get_nan_range(data: np.ndarray | np.ma.MaskedArray) -> tuple[float, float]:
66+
"""Return range of data, i.e. (min, max), ignoring NaNs
67+
68+
Args:
69+
data: Data array (or masked array)
70+
71+
Returns:
72+
tuple: Minimum and maximum value of data, ignoring NaNs
73+
"""
74+
return get_nan_min(data), get_nan_max(data)

0 commit comments

Comments
 (0)