4040class WriterAsyncIO :
4141 _loop : asyncio .AbstractEventLoop
4242 _reconnector : "WriterAsyncIOReconnector"
43- _lock : asyncio .Lock
4443 _closed : bool
4544
4645 @property
4746 def last_seqno (self ) -> int :
4847 raise NotImplementedError ()
4948
5049 def __init__ (self , driver : SupportedDriverType , settings : PublicWriterSettings ):
51- self ._lock = asyncio .Lock ()
5250 self ._loop = asyncio .get_running_loop ()
5351 self ._closed = False
5452 self ._reconnector = WriterAsyncIOReconnector (
@@ -68,10 +66,10 @@ def __del__(self):
6866 self ._loop .call_soon (self .close )
6967
7068 async def close (self ):
71- async with self ._lock :
72- if self . _closed :
73- return
74- self ._closed = True
69+ if self ._closed :
70+ return
71+
72+ self ._closed = True
7573
7674 await self ._reconnector .close ()
7775
@@ -164,65 +162,81 @@ class WriterAsyncIOReconnector:
164162 _update_token_interval : int
165163 _token_get_function : TokenGetterFuncType
166164 _init_message : StreamWriteMessage .InitRequest
167- _new_messages : asyncio .Queue
168165 _init_info : asyncio .Future
169166 _stream_connected : asyncio .Event
170167 _settings : WriterSettings
171168
172- _lock : asyncio .Lock
173169 _last_known_seq_no : int
174170 _messages : Deque [InternalMessage ]
175171 _messages_future : Deque [asyncio .Future ]
176- _stop_reason : Optional [Exception ]
172+ _new_messages : asyncio .Queue
173+ _stop_reason : asyncio .Future
177174 _background_tasks : List [asyncio .Task ]
178175
179176 def __init__ (self , driver : SupportedDriverType , settings : WriterSettings ):
180177 self ._driver = driver
181178 self ._credentials = driver ._credentials
182179 self ._init_message = settings .create_init_request ()
183- self ._new_messages = asyncio .Queue ()
184180 self ._init_info = asyncio .Future ()
185181 self ._stream_connected = asyncio .Event ()
186182 self ._settings = settings
187183
188- self ._lock = asyncio .Lock ()
189184 self ._last_known_seq_no = 0
190185 self ._messages = deque ()
191186 self ._messages_future = deque ()
192- self ._stop_reason = None
187+ self ._new_messages = asyncio .Queue ()
188+ self ._stop_reason = asyncio .Future ()
193189 self ._background_tasks = [
194190 asyncio .create_task (self ._connection_loop (), name = "connection_loop" )
195191 ]
196192
197193 async def close (self ):
198- await self ._check_stop ()
199- await self ._stop (TopicWriterStopped ())
194+ self ._check_stop ()
195+ self ._stop (TopicWriterStopped ())
196+
197+ background_tasks = self ._background_tasks
198+
199+ for task in background_tasks :
200+ task .cancel ()
201+
202+ await asyncio .wait (self ._background_tasks )
200203
201204 async def wait_init (self ) -> PublicWriterInitInfo :
202- return await self ._init_info
205+ done , _ = await asyncio .wait (
206+ [self ._init_info , self ._stop_reason ], return_when = asyncio .FIRST_COMPLETED
207+ )
208+ res = done .pop () # type: asyncio.Future
209+ res_val = res .result ()
210+
211+ if isinstance (res_val , Exception ):
212+ raise res_val
213+
214+ return res_val
215+
216+ async def wait_stop (self ) -> Exception :
217+ return await self ._stop_reason
203218
204219 async def write_with_ack (
205220 self , messages : List [PublicMessage ]
206221 ) -> List [asyncio .Future ]:
207222 # todo check internal buffer limit
208- await self ._check_stop ()
223+ self ._check_stop ()
209224
210225 if self ._settings .auto_seqno :
211226 await self .wait_init ()
212227
213- async with self ._lock :
214- internal_messages = self ._prepare_internal_messages_locked (messages )
215- messages_future = [asyncio .Future () for _ in internal_messages ]
228+ internal_messages = self ._prepare_internal_messages (messages )
229+ messages_future = [asyncio .Future () for _ in internal_messages ]
216230
217- self ._messages .extend (internal_messages )
218- self ._messages_future .extend (messages_future )
231+ self ._messages .extend (internal_messages )
232+ self ._messages_future .extend (messages_future )
219233
220234 for m in internal_messages :
221235 self ._new_messages .put_nowait (m )
222236
223237 return messages_future
224238
225- def _prepare_internal_messages_locked (self , messages : List [PublicMessage ]):
239+ def _prepare_internal_messages (self , messages : List [PublicMessage ]):
226240 if self ._settings .auto_created_at :
227241 now = datetime .datetime .now ()
228242 else :
@@ -263,10 +277,9 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]):
263277
264278 return res
265279
266- async def _check_stop (self ):
267- async with self ._lock :
268- if self ._stop_reason is not None :
269- raise self ._stop_reason
280+ def _check_stop (self ):
281+ if self ._stop_reason .done ():
282+ raise self ._stop_reason .result ()
270283
271284 async def _connection_loop (self ):
272285 retry_settings = RetrySettings () # todo
@@ -275,23 +288,16 @@ async def _connection_loop(self):
275288 attempt = 0 # todo calc and reset
276289 pending = []
277290
278- async def on_stop (e ):
279- for t in pending :
280- self ._background_tasks .append (t )
281- pending .clear ()
282- await self ._stop (e )
283-
284291 # noinspection PyBroadException
285292 try :
286293 stream_writer = await WriterAsyncIOStream .create (
287294 self ._driver , self ._init_message , self ._get_token
288295 )
289296 try :
290- async with self ._lock :
291- self ._last_known_seq_no = stream_writer .last_seqno
292- self ._init_info .set_result (
293- PublicWriterInitInfo (last_seqno = stream_writer .last_seqno )
294- )
297+ self ._last_known_seq_no = stream_writer .last_seqno
298+ self ._init_info .set_result (
299+ PublicWriterInitInfo (last_seqno = stream_writer .last_seqno )
300+ )
295301 except asyncio .InvalidStateError :
296302 pass
297303
@@ -316,13 +322,13 @@ async def on_stop(e):
316322
317323 err_info = check_retriable_error (err , retry_settings , attempt )
318324 if not err_info .is_retriable :
319- await on_stop (err )
325+ self . _stop (err )
320326 return
321327
322328 await asyncio .sleep (err_info .sleep_timeout_seconds )
323329
324- except Exception as e :
325- await on_stop ( e )
330+ except ( asyncio . CancelledError , Exception ) as err :
331+ self . _stop ( err )
326332 return
327333 finally :
328334 if len (pending ) > 0 :
@@ -333,11 +339,11 @@ async def on_stop(e):
333339 async def _read_loop (self , writer : "WriterAsyncIOStream" ):
334340 while True :
335341 resp = await writer .receive ()
336- async with self ._lock :
337- for ack in resp .acks :
338- self ._handle_receive_ack_need_lock (ack )
339342
340- def _handle_receive_ack_need_lock (self , ack ):
343+ for ack in resp .acks :
344+ self ._handle_receive_ack (ack )
345+
346+ def _handle_receive_ack (self , ack ):
341347 current_message = self ._messages .popleft ()
342348 message_future = self ._messages_future .popleft ()
343349 if current_message .seq_no != ack .seq_no :
@@ -351,8 +357,7 @@ def _handle_receive_ack_need_lock(self, ack):
351357
352358 async def _send_loop (self , writer : "WriterAsyncIOStream" ):
353359 try :
354- async with self ._lock :
355- messages = list (self ._messages )
360+ messages = list (self ._messages )
356361
357362 last_seq_no = 0
358363 for m in messages :
@@ -364,24 +369,18 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"):
364369 if m .seq_no > last_seq_no :
365370 writer .write ([m ])
366371 except Exception as e :
367- await self ._stop (e )
372+ self ._stop (e )
368373 finally :
369374 pass
370375
371- async def _stop (self , reason : Exception ):
376+ def _stop (self , reason : Exception ):
372377 if reason is None :
373378 raise Exception ("writer stop reason can not be None" )
374379
375- async with self ._lock :
376- if self ._stop_reason is not None :
377- return
378- self ._stop_reason = reason
379- background_tasks = self ._background_tasks
380-
381- for task in background_tasks :
382- task .cancel ()
380+ if self ._stop_reason .done ():
381+ return
383382
384- await asyncio . wait ( self ._background_tasks )
383+ self ._stop_reason . set_result ( reason )
385384
386385 def _get_token (self ) -> str :
387386 raise NotImplementedError ()
0 commit comments