4040class WriterAsyncIO :
4141 _loop : asyncio .AbstractEventLoop
4242 _reconnector : "WriterAsyncIOReconnector"
43- _lock : asyncio .Lock
4443 _closed : bool
4544
4645 @property
4746 def last_seqno (self ) -> int :
4847 raise NotImplementedError ()
4948
5049 def __init__ (self , driver : SupportedDriverType , settings : PublicWriterSettings ):
51- self ._lock = asyncio .Lock ()
5250 self ._loop = asyncio .get_running_loop ()
5351 self ._closed = False
5452 self ._reconnector = WriterAsyncIOReconnector (
@@ -68,10 +66,10 @@ def __del__(self):
6866 self ._loop .call_soon (self .close )
6967
7068 async def close (self ):
71- async with self ._lock :
72- if self . _closed :
73- return
74- self ._closed = True
69+ if self ._closed :
70+ return
71+
72+ self ._closed = True
7573
7674 await self ._reconnector .close ()
7775
@@ -159,70 +157,92 @@ async def wait_init(self) -> PublicWriterInitInfo:
159157
160158
161159class WriterAsyncIOReconnector :
160+ _closed : bool
162161 _credentials : Union [ydb .Credentials , None ]
163162 _driver : ydb .aio .Driver
164163 _update_token_interval : int
165164 _token_get_function : TokenGetterFuncType
166165 _init_message : StreamWriteMessage .InitRequest
167- _new_messages : asyncio .Queue
168166 _init_info : asyncio .Future
169167 _stream_connected : asyncio .Event
170168 _settings : WriterSettings
171169
172- _lock : asyncio .Lock
173170 _last_known_seq_no : int
174171 _messages : Deque [InternalMessage ]
175172 _messages_future : Deque [asyncio .Future ]
176- _stop_reason : Optional [Exception ]
173+ _new_messages : asyncio .Queue
174+ _stop_reason : asyncio .Future
177175 _background_tasks : List [asyncio .Task ]
178176
179177 def __init__ (self , driver : SupportedDriverType , settings : WriterSettings ):
178+ self ._closed = False
180179 self ._driver = driver
181180 self ._credentials = driver ._credentials
182181 self ._init_message = settings .create_init_request ()
183- self ._new_messages = asyncio .Queue ()
184182 self ._init_info = asyncio .Future ()
185183 self ._stream_connected = asyncio .Event ()
186184 self ._settings = settings
187185
188- self ._lock = asyncio .Lock ()
189186 self ._last_known_seq_no = 0
190187 self ._messages = deque ()
191188 self ._messages_future = deque ()
192- self ._stop_reason = None
189+ self ._new_messages = asyncio .Queue ()
190+ self ._stop_reason = asyncio .Future ()
193191 self ._background_tasks = [
194192 asyncio .create_task (self ._connection_loop (), name = "connection_loop" )
195193 ]
196194
197195 async def close (self ):
198- await self ._check_stop ()
199- await self ._stop (TopicWriterStopped ())
196+ if self ._closed :
197+ return
198+
199+ self ._closed = True
200+
201+ self ._stop (TopicWriterStopped ())
202+
203+ background_tasks = self ._background_tasks
204+
205+ for task in background_tasks :
206+ task .cancel ()
207+
208+ await asyncio .wait (self ._background_tasks )
200209
201210 async def wait_init (self ) -> PublicWriterInitInfo :
202- return await self ._init_info
211+ done , _ = await asyncio .wait (
212+ [self ._init_info , self ._stop_reason ], return_when = asyncio .FIRST_COMPLETED
213+ )
214+ res = done .pop () # type: asyncio.Future
215+ res_val = res .result ()
216+
217+ if isinstance (res_val , BaseException ):
218+ raise res_val
219+
220+ return res_val
221+
222+ async def wait_stop (self ) -> Exception :
223+ return await self ._stop_reason
203224
204225 async def write_with_ack (
205226 self , messages : List [PublicMessage ]
206227 ) -> List [asyncio .Future ]:
207228 # todo check internal buffer limit
208- await self ._check_stop ()
229+ self ._check_stop ()
209230
210231 if self ._settings .auto_seqno :
211232 await self .wait_init ()
212233
213- async with self ._lock :
214- internal_messages = self ._prepare_internal_messages_locked (messages )
215- messages_future = [asyncio .Future () for _ in internal_messages ]
234+ internal_messages = self ._prepare_internal_messages (messages )
235+ messages_future = [asyncio .Future () for _ in internal_messages ]
216236
217- self ._messages .extend (internal_messages )
218- self ._messages_future .extend (messages_future )
237+ self ._messages .extend (internal_messages )
238+ self ._messages_future .extend (messages_future )
219239
220240 for m in internal_messages :
221241 self ._new_messages .put_nowait (m )
222242
223243 return messages_future
224244
225- def _prepare_internal_messages_locked (self , messages : List [PublicMessage ]):
245+ def _prepare_internal_messages (self , messages : List [PublicMessage ]):
226246 if self ._settings .auto_created_at :
227247 now = datetime .datetime .now ()
228248 else :
@@ -263,10 +283,9 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]):
263283
264284 return res
265285
266- async def _check_stop (self ):
267- async with self ._lock :
268- if self ._stop_reason is not None :
269- raise self ._stop_reason
286+ def _check_stop (self ):
287+ if self ._stop_reason .done ():
288+ raise self ._stop_reason .result ()
270289
271290 async def _connection_loop (self ):
272291 retry_settings = RetrySettings () # todo
@@ -275,23 +294,16 @@ async def _connection_loop(self):
275294 attempt = 0 # todo calc and reset
276295 pending = []
277296
278- async def on_stop (e ):
279- for t in pending :
280- self ._background_tasks .append (t )
281- pending .clear ()
282- await self ._stop (e )
283-
284297 # noinspection PyBroadException
285298 try :
286299 stream_writer = await WriterAsyncIOStream .create (
287300 self ._driver , self ._init_message , self ._get_token
288301 )
289302 try :
290- async with self ._lock :
291- self ._last_known_seq_no = stream_writer .last_seqno
292- self ._init_info .set_result (
293- PublicWriterInitInfo (last_seqno = stream_writer .last_seqno )
294- )
303+ self ._last_known_seq_no = stream_writer .last_seqno
304+ self ._init_info .set_result (
305+ PublicWriterInitInfo (last_seqno = stream_writer .last_seqno )
306+ )
295307 except asyncio .InvalidStateError :
296308 pass
297309
@@ -316,13 +328,13 @@ async def on_stop(e):
316328
317329 err_info = check_retriable_error (err , retry_settings , attempt )
318330 if not err_info .is_retriable :
319- await on_stop (err )
331+ self . _stop (err )
320332 return
321333
322334 await asyncio .sleep (err_info .sleep_timeout_seconds )
323335
324- except Exception as e :
325- await on_stop ( e )
336+ except ( asyncio . CancelledError , Exception ) as err :
337+ self . _stop ( err )
326338 return
327339 finally :
328340 if len (pending ) > 0 :
@@ -333,11 +345,11 @@ async def on_stop(e):
333345 async def _read_loop (self , writer : "WriterAsyncIOStream" ):
334346 while True :
335347 resp = await writer .receive ()
336- async with self ._lock :
337- for ack in resp .acks :
338- self ._handle_receive_ack_need_lock (ack )
339348
340- def _handle_receive_ack_need_lock (self , ack ):
349+ for ack in resp .acks :
350+ self ._handle_receive_ack (ack )
351+
352+ def _handle_receive_ack (self , ack ):
341353 current_message = self ._messages .popleft ()
342354 message_future = self ._messages_future .popleft ()
343355 if current_message .seq_no != ack .seq_no :
@@ -351,8 +363,7 @@ def _handle_receive_ack_need_lock(self, ack):
351363
352364 async def _send_loop (self , writer : "WriterAsyncIOStream" ):
353365 try :
354- async with self ._lock :
355- messages = list (self ._messages )
366+ messages = list (self ._messages )
356367
357368 last_seq_no = 0
358369 for m in messages :
@@ -364,24 +375,18 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"):
364375 if m .seq_no > last_seq_no :
365376 writer .write ([m ])
366377 except Exception as e :
367- await self ._stop (e )
378+ self ._stop (e )
368379 finally :
369380 pass
370381
371- async def _stop (self , reason : Exception ):
382+ def _stop (self , reason : Exception ):
372383 if reason is None :
373384 raise Exception ("writer stop reason can not be None" )
374385
375- async with self ._lock :
376- if self ._stop_reason is not None :
377- return
378- self ._stop_reason = reason
379- background_tasks = self ._background_tasks
380-
381- for task in background_tasks :
382- task .cancel ()
386+ if self ._stop_reason .done ():
387+ return
383388
384- await asyncio . wait ( self ._background_tasks )
389+ self ._stop_reason . set_result ( reason )
385390
386391 def _get_token (self ) -> str :
387392 raise NotImplementedError ()
0 commit comments