@@ -69,6 +69,9 @@ class UnknownGrpcMessageError(issues.Error):
6969 pass
7070
7171
72+ _stop_grpc_connection_marker = object ()
73+
74+
7275class QueueToIteratorAsyncIO :
7376 __slots__ = ("_queue" ,)
7477
@@ -79,10 +82,10 @@ def __aiter__(self):
7982 return self
8083
8184 async def __anext__ (self ):
82- try :
83- return await self ._queue .get ()
84- except asyncio .QueueEmpty :
85+ item = await self ._queue .get ()
86+ if item is _stop_grpc_connection_marker :
8587 raise StopAsyncIteration ()
88+ return item
8689
8790
8891class AsyncQueueToSyncIteratorAsyncIO :
@@ -100,13 +103,10 @@ def __iter__(self):
100103 return self
101104
102105 def __next__ (self ):
103- try :
104- res = asyncio .run_coroutine_threadsafe (
105- self ._queue .get (), self ._loop
106- ).result ()
107- return res
108- except asyncio .QueueEmpty :
106+ item = asyncio .run_coroutine_threadsafe (self ._queue .get (), self ._loop ).result ()
107+ if item is _stop_grpc_connection_marker :
109108 raise StopIteration ()
109+ return item
110110
111111
112112class SyncIteratorToAsyncIterator :
@@ -133,6 +133,10 @@ async def receive(self) -> Any:
133133 def write (self , wrap_message : IToProto ):
134134 ...
135135
136+ @abc .abstractmethod
137+ def close (self ):
138+ ...
139+
136140
137141SupportedDriverType = Union [ydb .Driver , ydb .aio .Driver ]
138142
@@ -142,11 +146,15 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
142146 from_server_grpc : AsyncIterator
143147 convert_server_grpc_to_wrapper : Callable [[Any ], Any ]
144148 _connection_state : str
149+ _stream_call : Optional [
150+ Union [grpc .aio .StreamStreamCall , "grpc._channel._MultiThreadedRendezvous" ]
151+ ]
145152
146153 def __init__ (self , convert_server_grpc_to_wrapper ):
147154 self .from_client_grpc = asyncio .Queue ()
148155 self .convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
149156 self ._connection_state = "new"
157+ self ._stream_call = None
150158
151159 async def start (self , driver : SupportedDriverType , stub , method ):
152160 if asyncio .iscoroutinefunction (driver .__call__ ):
@@ -155,13 +163,19 @@ async def start(self, driver: SupportedDriverType, stub, method):
155163 await self ._start_sync_driver (driver , stub , method )
156164 self ._connection_state = "started"
157165
166+ def close (self ):
167+ self .from_client_grpc .put_nowait (_stop_grpc_connection_marker )
168+ if self ._stream_call :
169+ self ._stream_call .cancel ()
170+
158171 async def _start_asyncio_driver (self , driver : ydb .aio .Driver , stub , method ):
159172 requests_iterator = QueueToIteratorAsyncIO (self .from_client_grpc )
160173 stream_call = await driver (
161174 requests_iterator ,
162175 stub ,
163176 method ,
164177 )
178+ self ._stream_call = stream_call
165179 self .from_server_grpc = stream_call .__aiter__ ()
166180
167181 async def _start_sync_driver (self , driver : ydb .Driver , stub , method ):
@@ -172,6 +186,7 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
172186 stub ,
173187 method ,
174188 )
189+ self ._stream_call = stream_call
175190 self .from_server_grpc = SyncIteratorToAsyncIterator (stream_call .__iter__ ())
176191
177192 async def receive (self ) -> Any :
0 commit comments