Skip to content

Commit 190dd9f

Browse files
committed
close grpc streams when stream reader/writer closed
1 parent 6307e59 commit 190dd9f

File tree

6 files changed

+59
-9
lines changed

6 files changed

+59
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Close grpc streams while closing readers/writers
12
* Add control plane operations for topic api: create, drop
23

34
## 3.0.1b4 ##

ydb/_grpc/grpcwrapper/common_utils.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ def __aiter__(self):
7979
return self
8080

8181
async def __anext__(self):
82-
try:
83-
return await self._queue.get()
84-
except asyncio.QueueEmpty:
82+
item = await self._queue.get()
83+
if item is None:
8584
raise StopAsyncIteration()
85+
return item
8686

8787

8888
class AsyncQueueToSyncIteratorAsyncIO:
@@ -100,13 +100,10 @@ def __iter__(self):
100100
return self
101101

102102
def __next__(self):
103-
try:
104-
res = asyncio.run_coroutine_threadsafe(
105-
self._queue.get(), self._loop
106-
).result()
107-
return res
108-
except asyncio.QueueEmpty:
103+
item = asyncio.run_coroutine_threadsafe(self._queue.get(), self._loop).result()
104+
if item is None:
109105
raise StopIteration()
106+
return item
110107

111108

112109
class SyncIteratorToAsyncIterator:
@@ -133,6 +130,10 @@ async def receive(self) -> Any:
133130
def write(self, wrap_message: IToProto):
134131
...
135132

133+
@abc.abstractmethod
134+
def close(self):
135+
...
136+
136137

137138
SupportedDriverType = Union[ydb.Driver, ydb.aio.Driver]
138139

@@ -142,11 +143,15 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
142143
from_server_grpc: AsyncIterator
143144
convert_server_grpc_to_wrapper: Callable[[Any], Any]
144145
_connection_state: str
146+
_stream_call: Optional[
147+
Union[grpc.aio.StreamStreamCall, "grpc._channel._MultiThreadedRendezvous"]
148+
]
145149

146150
def __init__(self, convert_server_grpc_to_wrapper):
147151
self.from_client_grpc = asyncio.Queue()
148152
self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
149153
self._connection_state = "new"
154+
self._stream_call = None
150155

151156
async def start(self, driver: SupportedDriverType, stub, method):
152157
if asyncio.iscoroutinefunction(driver.__call__):
@@ -155,13 +160,18 @@ async def start(self, driver: SupportedDriverType, stub, method):
155160
await self._start_sync_driver(driver, stub, method)
156161
self._connection_state = "started"
157162

163+
def close(self):
164+
self.from_client_grpc.put_nowait(None)
165+
self._stream_call.cancel()
166+
158167
async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
159168
requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc)
160169
stream_call = await driver(
161170
requests_iterator,
162171
stub,
163172
method,
164173
)
174+
self._stream_call = stream_call
165175
self.from_server_grpc = stream_call.__aiter__()
166176

167177
async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
@@ -172,6 +182,7 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
172182
stub,
173183
method,
174184
)
185+
self._stream_call = stream_call
175186
self.from_server_grpc = SyncIteratorToAsyncIterator(stream_call.__iter__())
176187

177188
async def receive(self) -> Any:

ydb/_topic_common/test_helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,36 @@
88
class StreamMock(IGrpcWrapperAsyncIO):
99
from_server: asyncio.Queue
1010
from_client: asyncio.Queue
11+
_closed: bool
1112

1213
def __init__(self):
1314
self.from_server = asyncio.Queue()
1415
self.from_client = asyncio.Queue()
16+
self._closed = False
1517

1618
async def receive(self) -> typing.Any:
19+
if self._closed:
20+
raise Exception("read from closed StreamMock")
21+
1722
item = await self.from_server.get()
23+
if item is None:
24+
raise StopAsyncIteration()
1825
if isinstance(item, Exception):
1926
raise item
2027
return item
2128

2229
def write(self, wrap_message: IToProto):
30+
if self._closed:
31+
raise Exception("write to closed StreamMock")
2332
self.from_client.put_nowait(wrap_message)
2433

34+
def close(self):
35+
if self._closed:
36+
return
37+
38+
self._closed = True
39+
self.from_server.put_nowait(None)
40+
2541

2642
async def wait_condition(f: typing.Callable[[], bool], timeout=1):
2743
start = time.monotonic()

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ async def close(self):
496496
self._closed = True
497497
self._set_first_error(TopicReaderStreamClosedError())
498498
self._state_changed.set()
499+
self._stream.close()
499500

500501
for task in self._background_tasks:
501502
task.cancel()

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ async def _connection_loop(self):
322322
done, pending = await asyncio.wait(
323323
[send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED
324324
)
325+
stream_writer.close()
325326
done.pop().result()
326327
except issues.Error as err:
327328
# todo log error
@@ -417,6 +418,9 @@ def __init__(
417418
):
418419
self._token_getter = token_getter
419420

421+
def close(self):
422+
self._stream.close()
423+
420424
@staticmethod
421425
async def create(
422426
driver: SupportedDriverType,

ydb/_topic_writer/topic_writer_asyncio_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,20 +158,37 @@ class StreamWriterMock:
158158
from_client: asyncio.Queue
159159
from_server: asyncio.Queue
160160

161+
_closed: bool
162+
161163
def __init__(self):
162164
self.last_seqno = 0
163165
self.from_server = asyncio.Queue()
164166
self.from_client = asyncio.Queue()
167+
self._closed = False
165168

166169
def write(self, messages: typing.List[InternalMessage]):
170+
if self._closed:
171+
raise Exception("write to closed StreamWriterMock")
172+
167173
self.from_client.put_nowait(messages)
168174

169175
async def receive(self) -> StreamWriteMessage.WriteResponse:
176+
if self._closed:
177+
raise Exception("read from closed StreamWriterMock")
178+
170179
item = await self.from_server.get()
171180
if isinstance(item, Exception):
172181
raise item
173182
return item
174183

184+
def close(self):
185+
if self._closed:
186+
return
187+
188+
self.from_server.put_nowait(
189+
Exception("waited message while StreamWriterMock closed")
190+
)
191+
175192
@pytest.fixture(autouse=True)
176193
async def stream_writer_double_queue(self, monkeypatch):
177194
class DoubleQueueWriters:

0 commit comments

Comments
 (0)