Skip to content

Commit 584e393

Browse files
authored
add support reader codecs (#214)
* add support reader codecs * fix writer codec bug
1 parent 047bc1a commit 584e393

File tree

8 files changed

+331
-14
lines changed

8 files changed

+331
-14
lines changed

tests/topics/test_topic_reader.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22

3+
import ydb
4+
35

46
@pytest.mark.asyncio
57
class TestTopicReaderAsyncIO:
@@ -25,6 +27,36 @@ async def test_read_and_commit_message(
2527

2628
await reader.close()
2729

30+
async def test_read_compressed_messages(self, driver, topic_path, topic_consumer):
31+
async with driver.topic_client.writer(
32+
topic_path, codec=ydb.TopicCodec.GZIP
33+
) as writer:
34+
await writer.write("123")
35+
36+
async with driver.topic_client.reader(topic_consumer, topic_path) as reader:
37+
batch = await reader.receive_batch()
38+
assert batch.messages[0].data.decode() == "123"
39+
40+
async def test_read_custom_encoded(self, driver, topic_path, topic_consumer):
41+
codec = 10001
42+
43+
def encode(b: bytes):
44+
return bytes(reversed(b))
45+
46+
def decode(b: bytes):
47+
return bytes(reversed(b))
48+
49+
async with driver.topic_client.writer(
50+
topic_path, codec=codec, encoders={codec: encode}
51+
) as writer:
52+
await writer.write("123")
53+
54+
async with driver.topic_client.reader(
55+
topic_consumer, topic_path, decoders={codec: decode}
56+
) as reader:
57+
batch = await reader.receive_batch()
58+
assert batch.messages[0].data.decode() == "123"
59+
2860

2961
class TestTopicReaderSync:
3062
def test_read_message(
@@ -45,3 +77,33 @@ def test_read_and_commit_message(
4577
reader = driver_sync.topic_client.reader(topic_consumer, topic_path)
4678
batch2 = reader.receive_batch()
4779
assert batch.messages[0] != batch2.messages[0]
80+
81+
def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer):
82+
with driver_sync.topic_client.writer(
83+
topic_path, codec=ydb.TopicCodec.GZIP
84+
) as writer:
85+
writer.write("123")
86+
87+
with driver_sync.topic_client.reader(topic_consumer, topic_path) as reader:
88+
batch = reader.receive_batch()
89+
assert batch.messages[0].data.decode() == "123"
90+
91+
def test_read_custom_encoded(self, driver_sync, topic_path, topic_consumer):
92+
codec = 10001
93+
94+
def encode(b: bytes):
95+
return bytes(reversed(b))
96+
97+
def decode(b: bytes):
98+
return bytes(reversed(b))
99+
100+
with driver_sync.topic_client.writer(
101+
topic_path, codec=codec, encoders={codec: encode}
102+
) as writer:
103+
writer.write("123")
104+
105+
with driver_sync.topic_client.reader(
106+
topic_consumer, topic_path, decoders={codec: decode}
107+
) as reader:
108+
batch = reader.receive_batch()
109+
assert batch.messages[0].data.decode() == "123"

ydb/_topic_reader/datatypes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import datetime
1010
from typing import Mapping, Union, Any, List, Dict, Deque, Optional
1111

12-
from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange
12+
from ydb._grpc.grpcwrapper.ydb_topic import OffsetsRange, Codec
1313
from ydb._topic_reader import topic_reader_asyncio
1414

1515

@@ -168,6 +168,7 @@ class PublicBatch(ICommittable, ISessionAlive):
168168
messages: List[PublicMessage]
169169
_partition_session: PartitionSession
170170
_bytes_size: int
171+
_codec: Codec
171172

172173
def _commit_get_partition_session(self) -> PartitionSession:
173174
return self.messages[0]._commit_get_partition_session()

ydb/_topic_reader/topic_reader.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import concurrent.futures
12
import enum
23
import datetime
34
from dataclasses import dataclass
45
from typing import (
56
Union,
67
Optional,
78
List,
9+
Mapping,
10+
Callable,
811
)
912

1013
from ..table import RetrySettings
@@ -27,6 +30,13 @@ class PublicReaderSettings:
2730
consumer: str
2831
topic: str
2932
buffer_size_bytes: int = 50 * 1024 * 1024
33+
34+
decoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None
35+
"""decoders: map[codec_code] func(encoded_bytes)->decoded_bytes"""
36+
37+
# decoder_executor, must be set for handle non raw messages
38+
decoder_executor: Optional[concurrent.futures.Executor] = None
39+
3040
# on_commit: Callable[["Events.OnCommit"], None] = None
3141
# on_get_partition_start_offset: Callable[
3242
# ["Events.OnPartitionGetStartOffsetRequest"],
@@ -35,7 +45,6 @@ class PublicReaderSettings:
3545
# on_partition_session_start: Callable[["StubEvent"], None] = None
3646
# on_partition_session_stop: Callable[["StubEvent"], None] = None
3747
# on_partition_session_close: Callable[["StubEvent"], None] = None # todo?
38-
# decoder: Union[Mapping[int, Callable[[bytes], bytes]], None] = None
3948
# deserializer: Union[Callable[[bytes], Any], None] = None
4049
# one_attempt_connection_timeout: Union[float, None] = 1
4150
# connection_timeout: Union[float, None] = None

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import concurrent.futures
5+
import gzip
46
import typing
57
from asyncio import Task
68
from collections import deque
@@ -17,14 +19,18 @@
1719
SupportedDriverType,
1820
GrpcWrapperAsyncIO,
1921
)
20-
from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage
22+
from .._grpc.grpcwrapper.ydb_topic import StreamReadMessage, Codec
2123
from .._errors import check_retriable_error
2224

2325

2426
class TopicReaderError(YdbError):
2527
pass
2628

2729

30+
class TopicReaderUnexpectedCodec(YdbError):
31+
pass
32+
33+
2834
class TopicReaderCommitToExpiredPartition(TopicReaderError):
2935
"""
3036
Commit message when partition read session are dropped.
@@ -57,10 +63,10 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
5763
self._reconnector = ReaderReconnector(driver, settings)
5864

5965
async def __aenter__(self):
60-
raise NotImplementedError()
66+
return self
6167

6268
async def __aexit__(self, exc_type, exc_val, exc_tb):
63-
raise NotImplementedError()
69+
await self.close()
6470

6571
def __del__(self):
6672
if not self._closed:
@@ -259,6 +265,7 @@ def _set_first_error(self, err: issues.Error):
259265
class ReaderStream:
260266
_static_id_counter = AtomicCounter()
261267

268+
_loop: asyncio.AbstractEventLoop
262269
_id: int
263270
_reader_reconnector_id: int
264271
_session_id: str
@@ -267,6 +274,15 @@ class ReaderStream:
267274
_background_tasks: Set[asyncio.Task]
268275
_partition_sessions: Dict[int, datatypes.PartitionSession]
269276
_buffer_size_bytes: int # use for init request, then for debug purposes only
277+
_decode_executor: concurrent.futures.Executor
278+
_decoders: Dict[
279+
int, typing.Callable[[bytes], bytes]
280+
] # dict[codec_code] func(encoded_bytes)->decoded_bytes
281+
282+
if typing.TYPE_CHECKING:
283+
_batches_to_decode: asyncio.Queue[datatypes.PublicBatch]
284+
else:
285+
_batches_to_decode: asyncio.Queue
270286

271287
_state_changed: asyncio.Event
272288
_closed: bool
@@ -276,6 +292,7 @@ class ReaderStream:
276292
def __init__(
277293
self, reader_reconnector_id: int, settings: topic_reader.PublicReaderSettings
278294
):
295+
self._loop = asyncio.get_running_loop()
279296
self._id = ReaderStream._static_id_counter.inc_and_get()
280297
self._reader_reconnector_id = reader_reconnector_id
281298
self._session_id = "not initialized"
@@ -284,10 +301,16 @@ def __init__(
284301
self._background_tasks = set()
285302
self._partition_sessions = dict()
286303
self._buffer_size_bytes = settings.buffer_size_bytes
304+
self._decode_executor = settings.decoder_executor
305+
306+
self._decoders = {Codec.CODEC_GZIP: gzip.decompress}
307+
if settings.decoders:
308+
self._decoders.update(settings.decoders)
287309

288310
self._state_changed = asyncio.Event()
289311
self._closed = False
290312
self._first_error = asyncio.get_running_loop().create_future()
313+
self._batches_to_decode = asyncio.Queue()
291314
self._message_batches = deque()
292315

293316
@staticmethod
@@ -324,8 +347,10 @@ async def _start(
324347
"Unexpected message after InitRequest: %s", init_response
325348
)
326349

327-
read_messages_task = asyncio.create_task(self._read_messages_loop(stream))
328-
self._background_tasks.add(read_messages_task)
350+
self._background_tasks.add(
351+
asyncio.create_task(self._read_messages_loop(stream))
352+
)
353+
self._background_tasks.add(asyncio.create_task(self._decode_batches_loop()))
329354

330355
async def wait_error(self):
331356
raise await self._first_error
@@ -486,10 +511,12 @@ def _on_partition_session_stop(
486511
)
487512

488513
def _on_read_response(self, message: StreamReadMessage.ReadResponse):
489-
batches = self._read_response_to_batches(message)
490-
self._message_batches.extend(batches)
491514
self._buffer_consume_bytes(message.bytes_size)
492515

516+
batches = self._read_response_to_batches(message)
517+
for batch in batches:
518+
self._batches_to_decode.put_nowait(batch)
519+
493520
def _on_commit_response(self, message: StreamReadMessage.CommitOffsetResponse):
494521
for partition_offset in message.partitions_committed_offsets:
495522
session = self._partition_sessions.get(
@@ -561,12 +588,44 @@ def _read_response_to_batches(
561588
messages=messages,
562589
_partition_session=partition_session,
563590
_bytes_size=bytes_per_batch,
591+
_codec=Codec(server_batch.codec),
564592
)
565593
batches.append(batch)
566594

567595
batches[-1]._bytes_size += additional_bytes_to_last_batch
568596
return batches
569597

598+
async def _decode_batches_loop(self):
599+
while True:
600+
batch = await self._batches_to_decode.get()
601+
await self._decode_batch_inplace(batch)
602+
self._message_batches.append(batch)
603+
self._state_changed.set()
604+
605+
async def _decode_batch_inplace(self, batch):
606+
if batch._codec == Codec.CODEC_RAW:
607+
return
608+
609+
try:
610+
decode_func = self._decoders[batch._codec]
611+
except KeyError:
612+
raise TopicReaderUnexpectedCodec(
613+
"Receive message with unexpected codec: %s" % batch._codec
614+
)
615+
616+
decode_data_futures = []
617+
for message in batch.messages:
618+
future = self._loop.run_in_executor(
619+
self._decode_executor, decode_func, message.data
620+
)
621+
decode_data_futures.append(future)
622+
623+
decoded_data = await asyncio.gather(*decode_data_futures)
624+
for index, message in enumerate(batch.messages):
625+
message.data = decoded_data[index]
626+
627+
batch._codec = Codec.CODEC_RAW
628+
570629
def _set_first_error(self, err: YdbError):
571630
try:
572631
self._first_error.set_result(err)

0 commit comments

Comments
 (0)