Skip to content

Commit 9da0026

Browse files
committed
add stop marker
fix close path
1 parent 190dd9f commit 9da0026

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

tests/topics/test_topic_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ async def test_read_message(
99
reader = driver.topic_client.topic_reader(topic_consumer, topic_path)
1010

1111
assert await reader.receive_batch() is not None
12+
await reader.close()

ydb/_grpc/grpcwrapper/common_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class UnknownGrpcMessageError(issues.Error):
6969
pass
7070

7171

72+
_stop_grpc_connection_marker = object()
73+
74+
7275
class QueueToIteratorAsyncIO:
7376
__slots__ = ("_queue",)
7477

@@ -80,7 +83,7 @@ def __aiter__(self):
8083

8184
async def __anext__(self):
8285
item = await self._queue.get()
83-
if item is None:
86+
if item is _stop_grpc_connection_marker:
8487
raise StopAsyncIteration()
8588
return item
8689

@@ -101,7 +104,7 @@ def __iter__(self):
101104

102105
def __next__(self):
103106
item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result()
104-
if item is None:
107+
if item is _stop_grpc_connection_marker:
105108
raise StopIteration()
106109
return item
107110

@@ -161,8 +164,9 @@ async def start(self, driver: SupportedDriverType, stub, method):
161164
self._connection_state = "started"
162165

163166
def close(self):
164-
self.from_client_grpc.put_nowait(None)
165-
self._stream_call.cancel()
167+
self.from_client_grpc.put_nowait(_stop_grpc_connection_marker)
168+
if self._stream_call:
169+
self._stream_call.cancel()
166170

167171
async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
168172
requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc)

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ async def _connection_loop(self):
296296
pending = []
297297

298298
# noinspection PyBroadException
299+
stream_writer = None
299300
try:
300301
stream_writer = await WriterAsyncIOStream.create(
301302
self._driver, self._init_message, self._get_token
@@ -339,6 +340,8 @@ async def _connection_loop(self):
339340
self._stop(err)
340341
return
341342
finally:
343+
if stream_writer:
344+
stream_writer.close()
342345
if len(pending) > 0:
343346
for task in pending:
344347
task.cancel()

0 commit comments

Comments
 (0)