Skip to content

Commit 6d762b2

Browse files
committed
Added convert_array_to_standard_type
1 parent a2fd484 commit 6d762b2

File tree

5 files changed

+55
-6
lines changed

5 files changed

+55
-6
lines changed

cdl/core/io/conv.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,49 @@ def data_to_xy(data: np.ndarray) -> list[np.ndarray]:
4040
if len(data) == 4:
4141
dx, dy = data[2:]
4242
return x, y, dx, dy
43+
44+
45+
def convert_array_to_standard_type(array: np.ndarray) -> np.ndarray:
46+
"""Convert an integer array to a standard type
47+
(int8, int16, int32, uint8, uint16, uint32).
48+
49+
Ignores floating point arrays.
50+
51+
Args:
52+
array: array to convert
53+
54+
Raises:
55+
ValueError: if array is not of integer type
56+
57+
Returns:
58+
Converted array
59+
"""
60+
# Determine the kind and size of the data type
61+
kind = array.dtype.kind
62+
itemsize = array.dtype.itemsize
63+
64+
if kind in ["f", "c"]: # 'f' for floating point, 'c' for complex
65+
return array
66+
67+
if kind == "b":
68+
# Convert to uint8 if it is not already
69+
if array.dtype != np.uint8:
70+
return array.astype(np.uint8)
71+
return array
72+
73+
if kind in ["i", "u"]: # 'i' for signed integers, 'u' for unsigned integers
74+
if itemsize == 1: # 8-bit
75+
new_type = np.dtype(f"{kind}1").newbyteorder("=")
76+
elif itemsize == 2: # 16-bit
77+
new_type = np.dtype(f"{kind}2").newbyteorder("=")
78+
elif itemsize == 4: # 32-bit
79+
new_type = np.dtype(f"{kind}4").newbyteorder("=")
80+
else:
81+
raise ValueError("Unsupported item size for integer type")
82+
83+
# Convert to the new type if it is different from the current type
84+
if array.dtype != new_type:
85+
return array.astype(new_type)
86+
return array
87+
88+
raise ValueError("Unsupported data type")

cdl/core/io/h5/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import numpy as np
2020

2121
from cdl.config import Conf
22-
from cdl.core.io.conv import data_to_xy
22+
from cdl.core.io.conv import convert_array_to_standard_type, data_to_xy
2323
from cdl.utils.misc import to_string
2424

2525

@@ -118,6 +118,7 @@ def set_signal_data(self, obj):
118118
data = self.data
119119
if data.dtype not in (float, np.complex128):
120120
data = np.array(data, dtype=float)
121+
data = convert_array_to_standard_type(data)
121122
if len(data.shape) == 1:
122123
obj.set_xydata(np.arange(data.size), data)
123124
else:
@@ -131,7 +132,7 @@ def set_image_data(self, obj):
131132
self.uint32_wng = data.max() > np.iinfo(np.int32).max
132133
clipped_data = data.clip(0, np.iinfo(np.int32).max)
133134
data = np.array(clipped_data, dtype=np.int32)
134-
obj.data = data
135+
obj.data = convert_array_to_standard_type(data)
135136

136137

137138
class H5Importer:

cdl/core/io/image/formats.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from cdl.config import _
1717
from cdl.core.io.base import FormatInfo
18+
from cdl.core.io.conv import convert_array_to_standard_type
1819
from cdl.core.io.image import funcs
1920
from cdl.core.io.image.base import ImageFormatBase
2021
from cdl.core.model.image import ImageObj
@@ -67,7 +68,7 @@ class NumPyImageFormat(ImageFormatBase):
6768
@staticmethod
6869
def read_data(filename: str) -> np.ndarray:
6970
"""Read data and return it"""
70-
return np.load(filename)
71+
return convert_array_to_standard_type(np.load(filename))
7172

7273
def write(self, filename: str, obj: ImageObj) -> None:
7374
"""Write data to file"""

cdl/core/io/signal/formats.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from cdl.config import _
1313
from cdl.core.io.base import FormatInfo
14+
from cdl.core.io.conv import convert_array_to_standard_type
1415
from cdl.core.io.signal import funcs
1516
from cdl.core.io.signal.base import SignalFormatBase
1617
from cdl.core.model.signal import SignalObj
@@ -36,7 +37,7 @@ def read_xydata(self, filename: str, obj: SignalObj) -> np.ndarray:
3637
Returns:
3738
np.ndarray: xydata
3839
"""
39-
return np.load(filename)
40+
return convert_array_to_standard_type(np.load(filename))
4041

4142
def write(self, filename: str, obj: SignalObj) -> None:
4243
"""Write data to file
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
guidata:cb7baedddd8db3f67fb914d751f5a4c39e6545ee
2-
plotpy:34883d5c7d23262ec6b5f1d7bd8ea99858ad0dbf
1+
guidata:6189db04e502cc0a63453221df916e873b90fa0b
2+
plotpy:c73d5efc5a5101a745d73cca8b99419a1fe85d31

0 commit comments

Comments
 (0)