Skip to content

Commit bece3a7

Browse files
authored
Merge pull request #151 Simplify writer
2 parents 889d147 + a4146ff commit bece3a7

File tree

1 file changed

+62
-57
lines changed

1 file changed

+62
-57
lines changed

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 62 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,13 @@
4040
class WriterAsyncIO:
4141
_loop: asyncio.AbstractEventLoop
4242
_reconnector: "WriterAsyncIOReconnector"
43-
_lock: asyncio.Lock
4443
_closed: bool
4544

4645
@property
4746
def last_seqno(self) -> int:
4847
raise NotImplementedError()
4948

5049
def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings):
51-
self._lock = asyncio.Lock()
5250
self._loop = asyncio.get_running_loop()
5351
self._closed = False
5452
self._reconnector = WriterAsyncIOReconnector(
@@ -68,10 +66,10 @@ def __del__(self):
6866
self._loop.call_soon(self.close)
6967

7068
async def close(self):
71-
async with self._lock:
72-
if self._closed:
73-
return
74-
self._closed = True
69+
if self._closed:
70+
return
71+
72+
self._closed = True
7573

7674
await self._reconnector.close()
7775

@@ -159,70 +157,92 @@ async def wait_init(self) -> PublicWriterInitInfo:
159157

160158

161159
class WriterAsyncIOReconnector:
160+
_closed: bool
162161
_credentials: Union[ydb.Credentials, None]
163162
_driver: ydb.aio.Driver
164163
_update_token_interval: int
165164
_token_get_function: TokenGetterFuncType
166165
_init_message: StreamWriteMessage.InitRequest
167-
_new_messages: asyncio.Queue
168166
_init_info: asyncio.Future
169167
_stream_connected: asyncio.Event
170168
_settings: WriterSettings
171169

172-
_lock: asyncio.Lock
173170
_last_known_seq_no: int
174171
_messages: Deque[InternalMessage]
175172
_messages_future: Deque[asyncio.Future]
176-
_stop_reason: Optional[Exception]
173+
_new_messages: asyncio.Queue
174+
_stop_reason: asyncio.Future
177175
_background_tasks: List[asyncio.Task]
178176

179177
def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
178+
self._closed = False
180179
self._driver = driver
181180
self._credentials = driver._credentials
182181
self._init_message = settings.create_init_request()
183-
self._new_messages = asyncio.Queue()
184182
self._init_info = asyncio.Future()
185183
self._stream_connected = asyncio.Event()
186184
self._settings = settings
187185

188-
self._lock = asyncio.Lock()
189186
self._last_known_seq_no = 0
190187
self._messages = deque()
191188
self._messages_future = deque()
192-
self._stop_reason = None
189+
self._new_messages = asyncio.Queue()
190+
self._stop_reason = asyncio.Future()
193191
self._background_tasks = [
194192
asyncio.create_task(self._connection_loop(), name="connection_loop")
195193
]
196194

197195
async def close(self):
198-
await self._check_stop()
199-
await self._stop(TopicWriterStopped())
196+
if self._closed:
197+
return
198+
199+
self._closed = True
200+
201+
self._stop(TopicWriterStopped())
202+
203+
background_tasks = self._background_tasks
204+
205+
for task in background_tasks:
206+
task.cancel()
207+
208+
await asyncio.wait(self._background_tasks)
200209

201210
async def wait_init(self) -> PublicWriterInitInfo:
202-
return await self._init_info
211+
done, _ = await asyncio.wait(
212+
[self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED
213+
)
214+
res = done.pop() # type: asyncio.Future
215+
res_val = res.result()
216+
217+
if isinstance(res_val, BaseException):
218+
raise res_val
219+
220+
return res_val
221+
222+
async def wait_stop(self) -> Exception:
223+
return await self._stop_reason
203224

204225
async def write_with_ack(
205226
self, messages: List[PublicMessage]
206227
) -> List[asyncio.Future]:
207228
# todo check internal buffer limit
208-
await self._check_stop()
229+
self._check_stop()
209230

210231
if self._settings.auto_seqno:
211232
await self.wait_init()
212233

213-
async with self._lock:
214-
internal_messages = self._prepare_internal_messages_locked(messages)
215-
messages_future = [asyncio.Future() for _ in internal_messages]
234+
internal_messages = self._prepare_internal_messages(messages)
235+
messages_future = [asyncio.Future() for _ in internal_messages]
216236

217-
self._messages.extend(internal_messages)
218-
self._messages_future.extend(messages_future)
237+
self._messages.extend(internal_messages)
238+
self._messages_future.extend(messages_future)
219239

220240
for m in internal_messages:
221241
self._new_messages.put_nowait(m)
222242

223243
return messages_future
224244

225-
def _prepare_internal_messages_locked(self, messages: List[PublicMessage]):
245+
def _prepare_internal_messages(self, messages: List[PublicMessage]):
226246
if self._settings.auto_created_at:
227247
now = datetime.datetime.now()
228248
else:
@@ -263,10 +283,9 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]):
263283

264284
return res
265285

266-
async def _check_stop(self):
267-
async with self._lock:
268-
if self._stop_reason is not None:
269-
raise self._stop_reason
286+
def _check_stop(self):
287+
if self._stop_reason.done():
288+
raise self._stop_reason.result()
270289

271290
async def _connection_loop(self):
272291
retry_settings = RetrySettings() # todo
@@ -275,23 +294,16 @@ async def _connection_loop(self):
275294
attempt = 0 # todo calc and reset
276295
pending = []
277296

278-
async def on_stop(e):
279-
for t in pending:
280-
self._background_tasks.append(t)
281-
pending.clear()
282-
await self._stop(e)
283-
284297
# noinspection PyBroadException
285298
try:
286299
stream_writer = await WriterAsyncIOStream.create(
287300
self._driver, self._init_message, self._get_token
288301
)
289302
try:
290-
async with self._lock:
291-
self._last_known_seq_no = stream_writer.last_seqno
292-
self._init_info.set_result(
293-
PublicWriterInitInfo(last_seqno=stream_writer.last_seqno)
294-
)
303+
self._last_known_seq_no = stream_writer.last_seqno
304+
self._init_info.set_result(
305+
PublicWriterInitInfo(last_seqno=stream_writer.last_seqno)
306+
)
295307
except asyncio.InvalidStateError:
296308
pass
297309

@@ -316,13 +328,13 @@ async def on_stop(e):
316328

317329
err_info = check_retriable_error(err, retry_settings, attempt)
318330
if not err_info.is_retriable:
319-
await on_stop(err)
331+
self._stop(err)
320332
return
321333

322334
await asyncio.sleep(err_info.sleep_timeout_seconds)
323335

324-
except Exception as e:
325-
await on_stop(e)
336+
except (asyncio.CancelledError, Exception) as err:
337+
self._stop(err)
326338
return
327339
finally:
328340
if len(pending) > 0:
@@ -333,11 +345,11 @@ async def on_stop(e):
333345
async def _read_loop(self, writer: "WriterAsyncIOStream"):
334346
while True:
335347
resp = await writer.receive()
336-
async with self._lock:
337-
for ack in resp.acks:
338-
self._handle_receive_ack_need_lock(ack)
339348

340-
def _handle_receive_ack_need_lock(self, ack):
349+
for ack in resp.acks:
350+
self._handle_receive_ack(ack)
351+
352+
def _handle_receive_ack(self, ack):
341353
current_message = self._messages.popleft()
342354
message_future = self._messages_future.popleft()
343355
if current_message.seq_no != ack.seq_no:
@@ -351,8 +363,7 @@ def _handle_receive_ack_need_lock(self, ack):
351363

352364
async def _send_loop(self, writer: "WriterAsyncIOStream"):
353365
try:
354-
async with self._lock:
355-
messages = list(self._messages)
366+
messages = list(self._messages)
356367

357368
last_seq_no = 0
358369
for m in messages:
@@ -364,24 +375,18 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"):
364375
if m.seq_no > last_seq_no:
365376
writer.write([m])
366377
except Exception as e:
367-
await self._stop(e)
378+
self._stop(e)
368379
finally:
369380
pass
370381

371-
async def _stop(self, reason: Exception):
382+
def _stop(self, reason: Exception):
372383
if reason is None:
373384
raise Exception("writer stop reason can not be None")
374385

375-
async with self._lock:
376-
if self._stop_reason is not None:
377-
return
378-
self._stop_reason = reason
379-
background_tasks = self._background_tasks
380-
381-
for task in background_tasks:
382-
task.cancel()
386+
if self._stop_reason.done():
387+
return
383388

384-
await asyncio.wait(self._background_tasks)
389+
self._stop_reason.set_result(reason)
385390

386391
def _get_token(self) -> str:
387392
raise NotImplementedError()

0 commit comments

Comments
 (0)