Skip to content

Commit c8db459

Browse files
authored
Merge pull request #164 from ydb-platform/topic-writer-flush
topic writer add flush
2 parents e683c57 + 1df4618 commit c8db459

File tree

5 files changed

+73
-44
lines changed

5 files changed

+73
-44
lines changed

examples/topic/writer_async_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ async def send_message_without_block_if_internal_buffer_is_full(
6565
return False
6666

6767

68-
def send_messages_with_manual_seqno(writer: ydb.TopicWriter):
68+
async def send_messages_with_manual_seqno(writer: ydb.TopicWriter):
6969
await writer.write(ydb.TopicWriterMessage("mess")) # send text
7070

7171

tests/topics/test_topic_writer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ async def test_send_message(self, driver: ydb.aio.Driver, topic_path):
99
writer = driver.topic_client.topic_writer(
1010
topic_path, producer_and_message_group_id="test"
1111
)
12-
writer.write(ydb.TopicWriterMessage(data="123".encode()))
12+
await writer.write(ydb.TopicWriterMessage(data="123".encode()))
13+
await writer.close()
1314

1415
async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path):
1516
async with driver.topic_client.topic_writer(
@@ -28,3 +29,24 @@ async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path):
2829
) as writer2:
2930
init_info = await writer2.wait_init()
3031
assert init_info.last_seqno == 5
32+
33+
async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path):
34+
async with driver.topic_client.topic_writer(
35+
topic_path,
36+
producer_and_message_group_id="test",
37+
auto_seqno=False,
38+
) as writer:
39+
last_seqno = 0
40+
for i in range(10):
41+
last_seqno = i + 1
42+
await writer.write(
43+
ydb.TopicWriterMessage(data=f"msg-{i}", seqno=last_seqno)
44+
)
45+
46+
async with driver.topic_client.topic_writer(
47+
topic_path,
48+
producer_and_message_group_id="test",
49+
get_last_seqno=True,
50+
) as writer:
51+
init_info = await writer.wait_init()
52+
assert init_info.last_seqno == last_seqno

ydb/_topic_writer/topic_writer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def async_flush(self):
104104
"""
105105
raise NotImplementedError()
106106

107-
def flush(self, timeout: Union[float, None] = None) -> concurrent.futures.Future:
107+
def flush(self, timeout: Optional[float] = None) -> concurrent.futures.Future:
108108
"""
109109
Force send all messages from internal buffer and wait acks from server for all
110110
messages.
@@ -122,7 +122,7 @@ def async_wait_init(self) -> concurrent.futures.Future:
122122
"""
123123
raise NotImplementedError()
124124

125-
def wait_init(self, timeout: Union[float, None] = None):
125+
def wait_init(self, timeout: Optional[float] = None):
126126
"""
127127
Wait until underling connection established
128128
@@ -141,15 +141,15 @@ class PublicWriterSettings:
141141
session_metadata: Optional[Dict[str, str]] = None
142142
encoders: Union[Mapping[int, Callable[[bytes], bytes]], None] = None
143143
serializer: Union[Callable[[Any], bytes], None] = None
144-
send_buffer_count: Union[int, None] = 10000
145-
send_buffer_bytes: Union[int, None] = 100 * 1024 * 1024
144+
send_buffer_count: Optional[int] = 10000
145+
send_buffer_bytes: Optional[int] = 100 * 1024 * 1024
146146
partition_id: Optional[int] = None
147-
codec: Union[int, None] = None
147+
codec: Optional[int] = None
148148
codec_autoselect: bool = True
149149
auto_seqno: bool = True
150150
auto_created_at: bool = True
151151
get_last_seqno: bool = False
152-
retry_policy: Union["RetryPolicy", None] = None
152+
retry_policy: Optional["RetryPolicy"] = None
153153
update_token_interval: Union[int, float] = 3600
154154

155155

@@ -251,7 +251,7 @@ def to_message_data(self) -> StreamWriteMessage.WriteRequest.MessageData:
251251

252252

253253
class MessageSendResult:
254-
offset: Union[None, int]
254+
offset: Optional[int]
255255
write_status: "MessageWriteStatus"
256256

257257

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,14 @@ async def write_with_ack(
8686
8787
For wait with timeout use asyncio.wait_for.
8888
"""
89-
if isinstance(messages, PublicMessage):
90-
futures = await self._reconnector.write_with_ack([messages])
91-
return await futures[0]
92-
if isinstance(messages, list):
93-
for m in messages:
94-
if not isinstance(m, PublicMessage):
95-
raise NotImplementedError()
89+
futures = await self.write_with_ack_future(messages, *args)
90+
if not isinstance(futures, list):
91+
futures = [futures]
9692

97-
futures = await self._reconnector.write_with_ack(messages)
98-
await asyncio.wait(futures)
93+
await asyncio.wait(futures)
94+
results = [f.result() for f in futures]
9995

100-
results = [f.result() for f in futures]
101-
return results
102-
103-
raise NotImplementedError()
96+
return results if isinstance(messages, list) else results[0]
10497

10598
async def write_with_ack_future(
10699
self,
@@ -145,7 +138,7 @@ async def flush(self):
145138
146139
For wait with timeout use asyncio.wait_for.
147140
"""
148-
raise NotImplementedError()
141+
return await self._reconnector.flush()
149142

150143
async def wait_init(self) -> PublicWriterInitInfo:
151144
"""
@@ -192,10 +185,13 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
192185
asyncio.create_task(self._connection_loop(), name="connection_loop")
193186
]
194187

195-
async def close(self):
188+
async def close(self, flush: bool = True):
196189
if self._closed:
197190
return
198191

192+
if flush:
193+
await self.flush()
194+
199195
self._closed = True
200196
self._stop(TopicWriterStopped())
201197

@@ -396,6 +392,14 @@ def _stop(self, reason: Exception):
396392
def _get_token(self) -> str:
397393
raise NotImplementedError()
398394

395+
async def flush(self):
396+
self._check_stop()
397+
if not self._messages_future:
398+
return
399+
400+
# wait last message
401+
await asyncio.wait((self._messages_future[-1],))
402+
399403

400404
class WriterAsyncIOStream:
401405
# todo slots

ydb/_topic_writer/topic_writer_asyncio_test.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,20 @@ def default_write_statistic(
239239
topic_quota_wait_time=datetime.timedelta(milliseconds=5),
240240
)
241241

242+
def make_default_ack_message(self, seq_no=1) -> StreamWriteMessage.WriteResponse:
243+
return StreamWriteMessage.WriteResponse(
244+
partition_id=1,
245+
acks=[
246+
StreamWriteMessage.WriteResponse.WriteAck(
247+
seq_no=seq_no,
248+
message_write_status=StreamWriteMessage.WriteResponse.WriteAck.StatusWritten(
249+
offset=1
250+
),
251+
)
252+
],
253+
write_statistics=self.default_write_statistic,
254+
)
255+
242256
@pytest.fixture
243257
async def reconnector(
244258
self, default_driver, default_settings
@@ -275,20 +289,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error(
275289
assert [InternalMessage(message2)] == messages
276290

277291
# ack first message
278-
stream_writer.from_server.put_nowait(
279-
StreamWriteMessage.WriteResponse(
280-
partition_id=1,
281-
acks=[
282-
StreamWriteMessage.WriteResponse.WriteAck(
283-
seq_no=1,
284-
message_write_status=StreamWriteMessage.WriteResponse.WriteAck.StatusWritten(
285-
offset=1
286-
),
287-
)
288-
],
289-
write_statistics=default_write_statistic,
290-
)
291-
)
292+
stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1))
292293

293294
stream_writer.from_server.put_nowait(issues.Overloaded("test"))
294295

@@ -297,6 +298,8 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error(
297298

298299
expected_messages = [InternalMessage(message2)]
299300
assert second_sent_msg == expected_messages
301+
302+
second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2))
300303
await reconnector.close()
301304

302305
async def test_stop_on_unexpected_exception(
@@ -323,7 +326,7 @@ async def wait_stop():
323326
await asyncio.wait_for(wait_stop(), 1)
324327

325328
with pytest.raises(TestException):
326-
await reconnector.close()
329+
await reconnector.close(flush=False)
327330

328331
async def test_wait_init(self, default_driver, default_settings, get_stream_writer):
329332
init_seqno = 100
@@ -350,7 +353,7 @@ async def test_wait_init(self, default_driver, default_settings, get_stream_writ
350353
info = await reconnector.wait_init()
351354
assert info == expected_init_info
352355

353-
await reconnector.close()
356+
await reconnector.close(flush=False)
354357

355358
async def test_write_message(
356359
self, reconnector: WriterAsyncIOReconnector, get_stream_writer
@@ -365,7 +368,7 @@ async def test_write_message(
365368
sent_messages = await asyncio.wait_for(stream_writer.from_client.get(), 1)
366369
assert sent_messages == [InternalMessage(message)]
367370

368-
await reconnector.close()
371+
await reconnector.close(flush=False)
369372

370373
async def test_auto_seq_no(
371374
self, default_driver, default_settings, get_stream_writer
@@ -399,7 +402,7 @@ async def test_auto_seq_no(
399402
[PublicMessage(seqno=last_seq_no + 3, data="123")]
400403
)
401404

402-
await reconnector.close()
405+
await reconnector.close(flush=False)
403406

404407
async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector):
405408
await reconnector.write_with_ack([PublicMessage(seqno=10, data="123")])
@@ -412,7 +415,7 @@ async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector):
412415

413416
await reconnector.write_with_ack([PublicMessage(seqno=11, data="123")])
414417

415-
await reconnector.close()
418+
await reconnector.close(flush=False)
416419

417420
@freezegun.freeze_time("2022-01-13 20:50:00", tz_offset=0)
418421
async def test_auto_created_at(
@@ -431,7 +434,7 @@ async def test_auto_created_at(
431434
assert [
432435
InternalMessage(PublicMessage(seqno=4, data="123", created_at=now))
433436
] == sent
434-
await reconnector.close()
437+
await reconnector.close(flush=False)
435438

436439

437440
@pytest.mark.asyncio

0 commit comments

Comments
 (0)