|
5 | 5 | __all__ = ["TokenizeCodec"] |
6 | 6 |
|
7 | 7 | from io import BytesIO |
8 | | -from sys import byteorder |
9 | 8 |
|
10 | 9 | import numcodecs.compat |
11 | 10 | import numcodecs.registry |
@@ -83,32 +82,28 @@ def encode( |
83 | 82 | else: |
84 | 83 | utype = a.dtype |
85 | 84 |
|
| 85 | + assert (dtype.itemsize % utype.itemsize) == 0 |
| 86 | + |
86 | 87 | # insert padding to align with itemsize |
87 | 88 | message.append( |
88 | 89 | b"\0" * (utype.itemsize - (sum(len(m) for m in message) % utype.itemsize)) |
89 | 90 | ) |
90 | 91 |
|
91 | 92 | # ensure that the table keys are encoded in little endian binary |
92 | 93 | table_keys_array = unique[argsort] |
93 | | - table_keys_byteorder = table_keys_array.dtype.byteorder |
94 | | - table_keys_byteorder = ( |
95 | | - table_keys_byteorder |
96 | | - if table_keys_byteorder in ("<", ">") |
97 | | - else ("<" if (byteorder == "little") else ">") |
| 94 | + message.append( |
| 95 | + table_keys_array.astype(table_keys_array.dtype.newbyteorder("<")).tobytes() |
98 | 96 | ) |
99 | | - if table_keys_byteorder != "<": |
100 | | - table_keys_array = table_keys_array.byteswap() |
101 | | - message.append(table_keys_array.tobytes()) |
102 | 97 |
|
103 | 98 | indices = argsortinv[inverse].astype(utype) |
104 | | - if table_keys_byteorder != "<": |
105 | | - indices = indices.byteswap() |
106 | | - message.append(indices.tobytes()) |
| 99 | + message.append(indices.astype(indices.dtype.newbyteorder("<")).tobytes()) |
107 | 100 |
|
108 | 101 | encoded_bytes = b"".join(message) |
109 | 102 |
|
110 | 103 | encoded: np.ndarray[tuple[int], np.dtype[np.unsignedinteger]] = np.frombuffer( |
111 | | - encoded_bytes, dtype=utype, count=len(encoded_bytes) // utype.itemsize |
| 104 | + encoded_bytes, |
| 105 | + dtype=utype.newbyteorder("<"), |
| 106 | + count=len(encoded_bytes) // utype.itemsize, |
112 | 107 | ) |
113 | 108 |
|
114 | 109 | return encoded # type: ignore |
@@ -168,24 +163,16 @@ def decode( |
168 | 163 | dtype=_dtype_bits(dtype).newbyteorder("<"), |
169 | 164 | count=table_len, |
170 | 165 | ) |
171 | | - dtype_bits_byteorder = _dtype_bits(dtype).byteorder |
172 | | - dtype_bits_byteorder = ( |
173 | | - dtype_bits_byteorder |
174 | | - if dtype_bits_byteorder in ("<", ">") |
175 | | - else ("<" if (byteorder == "little") else ">") |
176 | | - ) |
177 | | - if dtype_bits_byteorder != "<": |
178 | | - table_keys = table_keys.byteswap() |
179 | 166 |
|
180 | 167 | indices = np.frombuffer( |
181 | 168 | b_io.read(), |
182 | 169 | dtype=utype.newbyteorder("<"), |
183 | 170 | count=np.prod(shape, dtype=np.uintp), |
184 | 171 | ) |
185 | | - if dtype_bits_byteorder != "<": |
186 | | - indices = indices.byteswap() |
187 | 172 |
|
188 | | - decoded = table_keys[indices].view(dtype).reshape(shape) |
| 173 | + decoded = ( |
| 174 | + table_keys[indices].astype(_dtype_bits(dtype)).view(dtype).reshape(shape) |
| 175 | + ) |
189 | 176 |
|
190 | 177 | return numcodecs.compat.ndarray_copy(decoded, out) # type: ignore |
191 | 178 |
|
|
0 commit comments