Skip to content

Commit 39ea7b2

Browse files
committed
Implement the PickBestCodec combinator
1 parent 4b4588c commit 39ea7b2

File tree

6 files changed

+304
-1
lines changed

6 files changed

+304
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The following combinators, implementing the `CodecCombinatorMixin` are provided:
1212

1313
- `CodecStack`: a stack of codecs
1414
- `FramedCodecStack`: a stack of codecs that is framed with array data type and shape information
15+
- `PickBestCodec`: pick the best codec to encode the data
1516

1617
[`numcodecs`]: https://numcodecs.readthedocs.io/en/stable/
1718

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The following combinators, implementing the [`CodecCombinatorMixin`][numcodecs_c
1212

1313
- [`CodecStack`][numcodecs_combinators.stack.CodecStack]: a stack of codecs
1414
- [`FramedCodecStack`][numcodecs_combinators.framed.FramedCodecStack]: a stack of codecs that is framed with array data type and shape information
15+
- [`PickBestCodec`][numcodecs_combinators.best.PickBestCodec]: pick the best codec to encode the data
1516

1617
## Funding
1718

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ optional-dependencies.xarray = [ "xarray>=2024.06", "dask>=2024.6" ]
2121
dev = ["mypy~=1.14", "pytest~=8.3"]
2222

2323
[project.entry-points."numcodecs.codecs"]
24-
"combinators.stack" = "numcodecs_combinators.stack:CodecStack"
24+
"combinators.best" = "numcodecs_combinators.best:PickBestCodec"
2525
"combinators.framed" = "numcodecs_combinators.framed:FramedCodecStack"
26+
"combinators.stack" = "numcodecs_combinators.stack:CodecStack"
2627

2728
[tool.setuptools.packages.find]
2829
where = ["src"]

src/numcodecs_combinators/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
- [`CodecStack`][numcodecs_combinators.stack.CodecStack]: a stack of codecs
99
- [`FramedCodecStack`][numcodecs_combinators.framed.FramedCodecStack]: a stack
1010
of codecs that is framed with array data type and shape information
11+
- [`PickBestCodec`][numcodecs_combinators.best.PickBestCodec]: pick the best
12+
codec to encode the data
1113
"""
1214

1315
__all__ = ["map_codec"]
@@ -18,6 +20,7 @@
1820
from numcodecs.abc import Codec
1921

2022
from . import abc as abc
23+
from . import best as best
2124
from . import framed as framed
2225
from . import stack as stack
2326

src/numcodecs_combinators/best.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
This module defines the [`PickBestCodec`][numcodecs_combinators.best.PickBestCodec] class, which picks the codec that encoded the data best.
3+
"""
4+
5+
__all__ = ["PickBestCodec"]
6+
7+
from io import BytesIO
8+
from typing import Callable, Optional
9+
10+
import numcodecs
11+
import numcodecs.compat
12+
import numcodecs.registry
13+
import numpy as np
14+
import varint
15+
from numcodecs.abc import Codec
16+
from typing_extensions import Buffer, Self # MSPV 3.12
17+
18+
from .abc import CodecCombinatorMixin
19+
20+
21+
class PickBestCodec(Codec, CodecCombinatorMixin, tuple[Codec]):
22+
"""
23+
A codec that tries encoding with all combined codecs and then picks the one with the fewest bytes.
24+
25+
The inner codecs must all encode to 1D byte arrays. To use a codec not
26+
encoding to bytes with this combinator, you can wrap it using
27+
[`FramedCodecStack(codec)`][numcodecs_combinators.framed.FramedCodecStack]
28+
combinator.
29+
30+
This combinator uses the ULEB128 variable length integer encoding to encode
31+
the index of the codec that was chosen to encode and uses this index as a
32+
header before the encoded bytes. The header index is only included if this
33+
combinator wraps at least two codecs. If this combinator wraps zero codecs,
34+
it passes the original data through unchanged.
35+
"""
36+
37+
__slots__ = ()
38+
39+
codec_id: str = "combinators.best" # type: ignore
40+
41+
def __init__(self, *args: dict | Codec):
42+
pass
43+
44+
def __new__(cls, *args: dict | Codec) -> Self:
45+
return super(PickBestCodec, cls).__new__(
46+
cls,
47+
tuple(
48+
codec
49+
if isinstance(codec, Codec)
50+
else numcodecs.registry.get_codec(codec)
51+
for codec in args
52+
),
53+
)
54+
55+
def encode(self, buf: Buffer) -> bytes:
56+
"""Encode the data in `buf`.
57+
58+
Parameters
59+
----------
60+
buf : Buffer
61+
Data to be encoded. May be any object supporting the new-style
62+
buffer protocol.
63+
64+
Returns
65+
-------
66+
enc : bytes
67+
Encoded and data as a bytestring.
68+
"""
69+
70+
if len(self) == 0:
71+
return buf
72+
73+
data = numcodecs.compat.ensure_ndarray(buf)
74+
75+
best_size = np.inf
76+
best_index = None
77+
best_encoded = None
78+
79+
for i, codec in enumerate(self):
80+
encoded = numcodecs.compat.ensure_ndarray(codec.encode(np.copy(data)))
81+
assert encoded.dtype == np.dtype("uint8"), (
82+
f"codec best[{i}] must encode to bytes"
83+
)
84+
assert encoded.ndim <= 1, f"codec best[{i}] must encode to 1D bytes"
85+
86+
if encoded.nbytes < best_size:
87+
best_size = encoded.nbytes
88+
best_index = i
89+
best_encoded = encoded
90+
91+
encoded_index = varint.encode(best_index)
92+
encoded_bytes = numcodecs.compat.ensure_bytes(best_encoded)
93+
94+
if len(self) == 1:
95+
return encoded_bytes
96+
97+
return encoded_index + encoded_bytes
98+
99+
def decode(self, buf: Buffer, out: Optional[Buffer] = None) -> Buffer:
100+
"""Decode the data in `buf`.
101+
102+
Parameters
103+
----------
104+
buf : Buffer
105+
Encoded data. Must be an object representing a bytestring, e.g.
106+
[`bytes`][bytes] or a 1D array of [`np.uint8`][numpy.uint8]s etc.
107+
out : Buffer, optional
108+
Writeable buffer to store decoded data. N.B. if provided, this buffer must
109+
be exactly the right size to store the decoded data.
110+
111+
Returns
112+
-------
113+
dec : Buffer
114+
Decoded data. May be any object supporting the new-style
115+
buffer protocol.
116+
"""
117+
118+
if len(self) == 0:
119+
return numcodecs.compat.ndarray_copy(buf, out)
120+
121+
b = numcodecs.compat.ensure_bytes(buf)
122+
b_io = BytesIO(b)
123+
124+
if len(self) == 1:
125+
best_index = 0
126+
else:
127+
best_index = varint.decode_stream(b_io)
128+
129+
return self[best_index].decode(b_io.read(), out=out)
130+
131+
def get_config(self) -> dict:
132+
"""
133+
Returns the configuration of the best codec combinator.
134+
135+
[`numcodecs.registry.get_codec(config)`][numcodecs.registry.get_codec]
136+
can be used to reconstruct this combinator from the returned config.
137+
138+
Returns
139+
-------
140+
config : dict
141+
Configuration of the best codec combinator.
142+
"""
143+
144+
return dict(
145+
id=type(self).codec_id,
146+
codecs=tuple(codec.get_config() for codec in self),
147+
)
148+
149+
@classmethod
150+
def from_config(cls, config: dict) -> Self:
151+
"""
152+
Instantiate the best codec combinator from a configuration [`dict`][dict].
153+
154+
Parameters
155+
----------
156+
config : dict
157+
Configuration of the best codec combinator.
158+
159+
Returns
160+
-------
161+
best : PickBestCodec
162+
Instantiated best codec combinator.
163+
"""
164+
165+
return cls(*config["codecs"])
166+
167+
def __repr__(self) -> str:
168+
repr = ", ".join(f"{codec!r}" for codec in self)
169+
170+
return f"{type(self).__name__}({repr})"
171+
172+
def map(self, mapper: Callable[[Codec], Codec]) -> "PickBestCodec":
173+
"""
174+
Apply the `mapper` to all codecs that are in this combinator.
175+
In the returned combinator, each codec is replaced by its mapped codec.
176+
177+
The `mapper` should recursively apply itself to any inner codecs that
178+
also implement the [`CodecCombinatorMixin`][numcodecs_combinators.abc.CodecCombinatorMixin]
179+
mixin.
180+
181+
To automatically handle the recursive application as a caller, you can
182+
use
183+
```python
184+
numcodecs_combinators.map_codec(best, mapper)
185+
```
186+
instead.
187+
188+
Parameters
189+
----------
190+
mapper : Callable[[Codec], Codec]
191+
The callable that should be applied to each codec to map over this
192+
best codec combinator.
193+
194+
Returns
195+
-------
196+
mapped : PickBestCodec
197+
The mapped best codec combinator.
198+
"""
199+
200+
return PickBestCodec(*map(mapper, self))
201+
202+
def __add__(self, other) -> "PickBestCodec":
203+
return PickBestCodec(*tuple.__add__(self, other))
204+
205+
def __mul__(self, other) -> "PickBestCodec":
206+
return PickBestCodec(*tuple.__mul__(self, other))
207+
208+
def __rmul__(self, other) -> "PickBestCodec":
209+
return PickBestCodec(*tuple.__rmul__(self, other))
210+
211+
212+
numcodecs.registry.register_codec(PickBestCodec)

tests/test_best.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import numcodecs
2+
import numcodecs.compat
3+
import numpy as np
4+
5+
import numcodecs_combinators
6+
from numcodecs_combinators.best import PickBestCodec
7+
from numcodecs_combinators.framed import FramedCodecStack
8+
9+
10+
def assert_config_roundtrip(codec: numcodecs.abc.Codec):
11+
config = codec.get_config()
12+
codec2 = numcodecs.get_codec(config)
13+
assert codec2 == codec
14+
15+
16+
def test_init_config():
17+
best = PickBestCodec()
18+
assert len(best) == 0
19+
assert_config_roundtrip(best)
20+
21+
best = PickBestCodec(dict(id="zlib", level=9))
22+
assert len(best) == 1
23+
assert_config_roundtrip(best)
24+
25+
best = PickBestCodec(dict(id="zlib", level=9), numcodecs.CRC32())
26+
assert len(best) == 2
27+
assert_config_roundtrip(best)
28+
29+
30+
def test_encode_decode():
31+
for best in [
32+
PickBestCodec(),
33+
PickBestCodec(dict(id="combinators.framed", codecs=[dict(id="zlib", level=9)])),
34+
PickBestCodec(
35+
FramedCodecStack(numcodecs.Zlib(level=9)),
36+
FramedCodecStack(numcodecs.CRC32()),
37+
),
38+
PickBestCodec(
39+
FramedCodecStack(numcodecs.Zlib(level=9)),
40+
FramedCodecStack(numcodecs.CRC32()),
41+
FramedCodecStack(numcodecs.Zstd(level=20)),
42+
),
43+
]:
44+
for data in [
45+
np.zeros(shape=(0,)),
46+
np.array(3),
47+
np.array([97, 98, 99], dtype=np.uint8),
48+
np.linspace(1, 100, 100).reshape(10, 10),
49+
np.linspace(1, 100, 100).reshape(10, 10).byteswap(),
50+
]:
51+
encoded = best.encode(data)
52+
if len(best) > 0:
53+
assert isinstance(encoded, bytes)
54+
decoded = best.decode(encoded)
55+
print(best)
56+
assert np.all(decoded == data)
57+
58+
59+
def test_map():
60+
best = PickBestCodec(numcodecs.Zlib(level=9), numcodecs.CRC32())
61+
62+
mapped = numcodecs_combinators.map_codec(best, lambda c: c)
63+
assert mapped == best
64+
65+
mapped = numcodecs_combinators.map_codec(best, lambda c: PickBestCodec(c))
66+
assert mapped == PickBestCodec(
67+
PickBestCodec(
68+
PickBestCodec(numcodecs.Zlib(level=9)),
69+
PickBestCodec(numcodecs.CRC32()),
70+
)
71+
)
72+
73+
mapped = numcodecs_combinators.map_codec(mapped, lambda c: PickBestCodec(c))
74+
assert mapped == PickBestCodec(
75+
PickBestCodec(
76+
PickBestCodec(
77+
PickBestCodec(
78+
PickBestCodec(
79+
PickBestCodec(PickBestCodec(numcodecs.Zlib(level=9)))
80+
),
81+
PickBestCodec(PickBestCodec(PickBestCodec(numcodecs.CRC32()))),
82+
)
83+
)
84+
)
85+
)

0 commit comments

Comments
 (0)