@@ -79,10 +79,10 @@ def __aiter__(self):
7979 return self
8080
8181 async def __anext__ (self ):
82- try :
83- return await self ._queue .get ()
84- except asyncio .QueueEmpty :
82+ item = await self ._queue .get ()
83+ if item is None :
8584 raise StopAsyncIteration ()
85+ return item
8686
8787
8888class AsyncQueueToSyncIteratorAsyncIO :
@@ -100,13 +100,10 @@ def __iter__(self):
100100 return self
101101
102102 def __next__ (self ):
103- try :
104- res = asyncio .run_coroutine_threadsafe (
105- self ._queue .get (), self ._loop
106- ).result ()
107- return res
108- except asyncio .QueueEmpty :
103+ item = asyncio .run_coroutine_threadsafe (self ._queue .get (), self ._loop ).result ()
104+ if item is None :
109105 raise StopIteration ()
106+ return item
110107
111108
112109class SyncIteratorToAsyncIterator :
@@ -133,6 +130,10 @@ async def receive(self) -> Any:
133130 def write (self , wrap_message : IToProto ):
134131 ...
135132
133+ @abc .abstractmethod
134+ def close (self ):
135+ ...
136+
136137
137138SupportedDriverType = Union [ydb .Driver , ydb .aio .Driver ]
138139
@@ -142,11 +143,15 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
142143 from_server_grpc : AsyncIterator
143144 convert_server_grpc_to_wrapper : Callable [[Any ], Any ]
144145 _connection_state : str
146+ _stream_call : Optional [
147+ Union [grpc .aio .StreamStreamCall , "grpc._channel._MultiThreadedRendezvous" ]
148+ ]
145149
146150 def __init__ (self , convert_server_grpc_to_wrapper ):
147151 self .from_client_grpc = asyncio .Queue ()
148152 self .convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
149153 self ._connection_state = "new"
154+ self ._stream_call = None
150155
151156 async def start (self , driver : SupportedDriverType , stub , method ):
152157 if asyncio .iscoroutinefunction (driver .__call__ ):
@@ -155,13 +160,18 @@ async def start(self, driver: SupportedDriverType, stub, method):
155160 await self ._start_sync_driver (driver , stub , method )
156161 self ._connection_state = "started"
157162
163+ def close (self ):
164+ self .from_client_grpc .put_nowait (None )
165+ self ._stream_call .cancel ()
166+
158167 async def _start_asyncio_driver (self , driver : ydb .aio .Driver , stub , method ):
159168 requests_iterator = QueueToIteratorAsyncIO (self .from_client_grpc )
160169 stream_call = await driver (
161170 requests_iterator ,
162171 stub ,
163172 method ,
164173 )
174+ self ._stream_call = stream_call
165175 self .from_server_grpc = stream_call .__aiter__ ()
166176
167177 async def _start_sync_driver (self , driver : ydb .Driver , stub , method ):
@@ -172,6 +182,7 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
172182 stub ,
173183 method ,
174184 )
185+ self ._stream_call = stream_call
175186 self .from_server_grpc = SyncIteratorToAsyncIterator (stream_call .__iter__ ())
176187
177188 async def receive (self ) -> Any :
0 commit comments