Skip to content

Commit dd1a199

Browse files
committed
Mark chunked arrays in encode_decode_data_array
1 parent 39ea7b2 commit dd1a199

File tree

3 files changed

+75
-6
lines changed

3 files changed

+75
-6
lines changed

src/numcodecs_combinators/abc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ class CodecCombinatorMixin(ABC):
1515
Mixin class for combinators over [`Codec`][numcodecs.abc.Codec]s.
1616
"""
1717

18+
__slots__ = ()
19+
1820
@abstractmethod
1921
def map(self, mapper: Callable[[Codec], Codec]) -> Codec:
2022
"""

src/numcodecs_combinators/stack.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def encode_decode(self, buf: Buffer) -> Buffer:
140140
buffer protocol.
141141
"""
142142

143+
chunked = getattr(buf, "chunked", False)
144+
143145
encoded = np.asarray(
144146
numcodecs.compat.ensure_contiguous_ndarray_like(buf, flatten=False)
145147
)
@@ -149,16 +151,23 @@ def encode_decode(self, buf: Buffer) -> Buffer:
149151
silhouettes.append((encoded.shape, encoded.dtype))
150152
encoded = np.asarray(
151153
numcodecs.compat.ensure_contiguous_ndarray_like(
152-
codec.encode((encoded)), flatten=False
154+
codec.encode(_MaybeChunkedNdArray(encoded) if chunked else encoded),
155+
flatten=False,
153156
)
154157
)
155158

156-
decoded = encoded
159+
decoded = encoded.view(np.ndarray)
157160

158161
for codec in reversed(self):
159162
shape, dtype = silhouettes.pop()
160163
out = np.empty(shape=shape, dtype=dtype)
161-
decoded = codec.decode(decoded, out).view(dtype).reshape(shape)
164+
decoded = (
165+
codec.decode(decoded, _MaybeChunkedNdArray(out) if chunked else out)
166+
.view(dtype)
167+
.reshape(shape)
168+
)
169+
170+
decoded = decoded.view(np.ndarray)
162171

163172
if isinstance(decoded, type(buf)):
164173
return decoded
@@ -167,7 +176,8 @@ def encode_decode(self, buf: Buffer) -> Buffer:
167176

168177
def encode_decode_data_array(self, da: "xr.DataArray") -> "xr.DataArray":
169178
"""
170-
Encode, then decode each chunk (independently) in the data array `da`.
179+
Encode, then decode the data array `da`. If `da` is chunked, each chunk
180+
is encoded and decoded *independently*.
171181
172182
Since each chunk is encoded *independently*, this method may cause
173183
chunk boundary artifacts. Do *not* use this method if the codec
@@ -195,6 +205,8 @@ def encode_decode_data_array(self, da: "xr.DataArray") -> "xr.DataArray":
195205

196206
import xarray as xr
197207

208+
chunked = da.chunks is not None
209+
198210
def encode_decode_data_array_single_chunk(
199211
da: xr.DataArray,
200212
) -> xr.DataArray:
@@ -205,9 +217,11 @@ def encode_decode_data_array_single_chunk(
205217
return da.copy(deep=False).chunk(single_chunk)
206218

207219
# eagerly compute the input chunk and encode and decode it
208-
decoded = self.encode_decode(da.values) # type: ignore
220+
decoded = self.encode_decode(_MaybeChunkedNdArray(da.values, chunked)) # type: ignore
209221

210-
return da.copy(deep=False, data=decoded).chunk(single_chunk)
222+
return da.copy(deep=False, data=np.array(decoded).view(np.ndarray)).chunk(
223+
single_chunk
224+
)
211225

212226
return xr.map_blocks(encode_decode_data_array_single_chunk, da)
213227

@@ -293,3 +307,22 @@ def __rmul__(self, other) -> "CodecStack":
293307

294308

295309
numcodecs.registry.register_codec(CodecStack)
310+
311+
312+
class _MaybeChunkedNdArray(np.ndarray):
313+
__slots__ = ("_chunked",)
314+
_chunked: bool
315+
316+
def __new__(cls, array, chunked: bool = True):
317+
obj = np.asarray(array).view(cls)
318+
obj._chunked = chunked
319+
return obj
320+
321+
def __array_finalize__(self, obj):
322+
if obj is None:
323+
return
324+
self._chunked = getattr(obj, "chunked", True)
325+
326+
@property
327+
def chunked(self) -> bool:
328+
return self._chunked

tests/test_stack.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numcodecs
22
import numpy as np
33
import xarray as xr
4+
from numcodecs.abc import Codec
45

56
import numcodecs_combinators
67
from numcodecs_combinators.stack import CodecStack
@@ -51,6 +52,39 @@ def test_encode_decode():
5152
assert encoded_decoded.equals(xr.DataArray([1.0, 2.0, 3.0]))
5253

5354

55+
def test_chunked_encode_decode():
56+
class CheckChunkedCodec(Codec):
57+
__slots__ = ("is_chunked",)
58+
is_chunked: bool
59+
60+
def __init__(self, is_chunked: bool):
61+
self.is_chunked = is_chunked
62+
63+
def encode(self, buf):
64+
assert getattr(buf, "chunked", False) == self.is_chunked
65+
return buf
66+
67+
def decode(self, buf, out=None):
68+
assert getattr(buf, "chunked", False) is False
69+
assert getattr(out, "chunked", False) == self.is_chunked
70+
return numcodecs.compat.ndarray_copy(buf, out)
71+
72+
stack = CodecStack(CheckChunkedCodec(False))
73+
74+
encoded_decoded = stack.encode_decode(np.array([1.0, 2.0, 3.0]))
75+
assert np.all(encoded_decoded == np.array([1.0, 2.0, 3.0]))
76+
77+
encoded_decoded = stack.encode_decode_data_array(xr.DataArray([1.0, 2.0, 3.0]))
78+
assert encoded_decoded.equals(xr.DataArray([1.0, 2.0, 3.0]))
79+
80+
stack = CodecStack(CheckChunkedCodec(True))
81+
82+
encoded_decoded = stack.encode_decode_data_array(
83+
xr.DataArray([1.0, 2.0, 3.0]).chunk(1)
84+
)
85+
assert encoded_decoded.equals(xr.DataArray([1.0, 2.0, 3.0]))
86+
87+
5488
def test_map():
5589
stack = CodecStack(numcodecs.Zlib(level=9), numcodecs.CRC32())
5690

0 commit comments

Comments
 (0)