11from __future__ import annotations
22
33import asyncio
4+ import concurrent .futures
5+ import gzip
46import typing
57from asyncio import Task
68from collections import deque
1719 SupportedDriverType ,
1820 GrpcWrapperAsyncIO ,
1921)
20- from .._grpc .grpcwrapper .ydb_topic import StreamReadMessage
22+ from .._grpc .grpcwrapper .ydb_topic import StreamReadMessage , Codec
2123from .._errors import check_retriable_error
2224
2325
2426class TopicReaderError (YdbError ):
2527 pass
2628
2729
30+ class TopicReaderUnexpectedCodec (YdbError ):
31+ pass
32+
33+
2834class TopicReaderCommitToExpiredPartition (TopicReaderError ):
2935 """
3036 Commit message when partition read session are dropped.
@@ -57,10 +63,10 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
5763 self ._reconnector = ReaderReconnector (driver , settings )
5864
5965 async def __aenter__ (self ):
60- raise NotImplementedError ()
66+ return self
6167
6268 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
63- raise NotImplementedError ()
69+ await self . close ()
6470
6571 def __del__ (self ):
6672 if not self ._closed :
@@ -259,6 +265,7 @@ def _set_first_error(self, err: issues.Error):
259265class ReaderStream :
260266 _static_id_counter = AtomicCounter ()
261267
268+ _loop : asyncio .AbstractEventLoop
262269 _id : int
263270 _reader_reconnector_id : int
264271 _session_id : str
@@ -267,6 +274,15 @@ class ReaderStream:
267274 _background_tasks : Set [asyncio .Task ]
268275 _partition_sessions : Dict [int , datatypes .PartitionSession ]
269276 _buffer_size_bytes : int # use for init request, then for debug purposes only
277+ _decode_executor : concurrent .futures .Executor
278+ _decoders : Dict [
279+ int , typing .Callable [[bytes ], bytes ]
280+ ] # dict[codec_code] func(encoded_bytes)->decoded_bytes
281+
282+ if typing .TYPE_CHECKING :
283+ _batches_to_decode : asyncio .Queue [datatypes .PublicBatch ]
284+ else :
285+ _batches_to_decode : asyncio .Queue
270286
271287 _state_changed : asyncio .Event
272288 _closed : bool
@@ -276,6 +292,7 @@ class ReaderStream:
276292 def __init__ (
277293 self , reader_reconnector_id : int , settings : topic_reader .PublicReaderSettings
278294 ):
295+ self ._loop = asyncio .get_running_loop ()
279296 self ._id = ReaderStream ._static_id_counter .inc_and_get ()
280297 self ._reader_reconnector_id = reader_reconnector_id
281298 self ._session_id = "not initialized"
@@ -284,10 +301,16 @@ def __init__(
284301 self ._background_tasks = set ()
285302 self ._partition_sessions = dict ()
286303 self ._buffer_size_bytes = settings .buffer_size_bytes
304+ self ._decode_executor = settings .decoder_executor
305+
306+ self ._decoders = {Codec .CODEC_GZIP : gzip .decompress }
307+ if settings .decoders :
308+ self ._decoders .update (settings .decoders )
287309
288310 self ._state_changed = asyncio .Event ()
289311 self ._closed = False
290312 self ._first_error = asyncio .get_running_loop ().create_future ()
313+ self ._batches_to_decode = asyncio .Queue ()
291314 self ._message_batches = deque ()
292315
293316 @staticmethod
@@ -324,8 +347,10 @@ async def _start(
324347 "Unexpected message after InitRequest: %s" , init_response
325348 )
326349
327- read_messages_task = asyncio .create_task (self ._read_messages_loop (stream ))
328- self ._background_tasks .add (read_messages_task )
350+ self ._background_tasks .add (
351+ asyncio .create_task (self ._read_messages_loop (stream ))
352+ )
353+ self ._background_tasks .add (asyncio .create_task (self ._decode_batches_loop ()))
329354
330355 async def wait_error (self ):
331356 raise await self ._first_error
@@ -486,10 +511,12 @@ def _on_partition_session_stop(
486511 )
487512
488513 def _on_read_response (self , message : StreamReadMessage .ReadResponse ):
489- batches = self ._read_response_to_batches (message )
490- self ._message_batches .extend (batches )
491514 self ._buffer_consume_bytes (message .bytes_size )
492515
516+ batches = self ._read_response_to_batches (message )
517+ for batch in batches :
518+ self ._batches_to_decode .put_nowait (batch )
519+
493520 def _on_commit_response (self , message : StreamReadMessage .CommitOffsetResponse ):
494521 for partition_offset in message .partitions_committed_offsets :
495522 session = self ._partition_sessions .get (
@@ -561,12 +588,44 @@ def _read_response_to_batches(
561588 messages = messages ,
562589 _partition_session = partition_session ,
563590 _bytes_size = bytes_per_batch ,
591+ _codec = Codec (server_batch .codec ),
564592 )
565593 batches .append (batch )
566594
567595 batches [- 1 ]._bytes_size += additional_bytes_to_last_batch
568596 return batches
569597
598+ async def _decode_batches_loop (self ):
599+ while True :
600+ batch = await self ._batches_to_decode .get ()
601+ await self ._decode_batch_inplace (batch )
602+ self ._message_batches .append (batch )
603+ self ._state_changed .set ()
604+
605+ async def _decode_batch_inplace (self , batch ):
606+ if batch ._codec == Codec .CODEC_RAW :
607+ return
608+
609+ try :
610+ decode_func = self ._decoders [batch ._codec ]
611+ except KeyError :
612+ raise TopicReaderUnexpectedCodec (
613+ "Receive message with unexpected codec: %s" % batch ._codec
614+ )
615+
616+ decode_data_futures = []
617+ for message in batch .messages :
618+ future = self ._loop .run_in_executor (
619+ self ._decode_executor , decode_func , message .data
620+ )
621+ decode_data_futures .append (future )
622+
623+ decoded_data = await asyncio .gather (* decode_data_futures )
624+ for index , message in enumerate (batch .messages ):
625+ message .data = decoded_data [index ]
626+
627+ batch ._codec = Codec .CODEC_RAW
628+
570629 def _set_first_error (self , err : YdbError ):
571630 try :
572631 self ._first_error .set_result (err )
0 commit comments