Skip to content

Commit 94ff38f

Browse files
Add take_events() function (#392)
* Add take_events() * lint streams.py Co-authored-by: William Barnhart <william.barnhart@he360.com> Co-authored-by: William Barnhart <williambbarnhart@gmail.com>
1 parent ecee01b commit 94ff38f

1 file changed

Lines changed: 93 additions & 0 deletions

File tree

faust/streams.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,99 @@ async def add_to_buffer(value: T) -> T:
391391
self.enable_acks = stream_enable_acks
392392
self._processors.remove(add_to_buffer)
393393

394+
async def take_events(
395+
self, max_: int, within: Seconds
396+
) -> AsyncIterable[Sequence[EventT]]:
397+
"""Buffer n events at a time and yield a list of buffered events.
398+
Arguments:
399+
max_: Max number of messages to receive. When more than this
400+
number of messages are received within the specified number of
401+
seconds then we flush the buffer immediately.
402+
within: Timeout for when we give up waiting for another value,
403+
and process the values we have.
404+
Warning: If there's no timeout (i.e. `timeout=None`),
405+
the agent is likely to stall and block buffered events for an
406+
unreasonable length of time(!).
407+
"""
408+
buffer: List[T_co] = []
409+
events: List[EventT] = []
410+
buffer_add = buffer.append
411+
event_add = events.append
412+
buffer_size = buffer.__len__
413+
buffer_full = asyncio.Event()
414+
buffer_consumed = asyncio.Event()
415+
timeout = want_seconds(within) if within else None
416+
stream_enable_acks: bool = self.enable_acks
417+
418+
buffer_consuming: Optional[asyncio.Future] = None
419+
420+
channel_it = aiter(self.channel)
421+
422+
# We add this processor to populate the buffer, and the stream
423+
# is passively consumed in the background (enable_passive below).
424+
async def add_to_buffer(value: T) -> T:
425+
try:
426+
# buffer_consuming is set when consuming buffer after timeout.
427+
nonlocal buffer_consuming
428+
if buffer_consuming is not None:
429+
try:
430+
await buffer_consuming
431+
finally:
432+
buffer_consuming = None
433+
buffer_add(cast(T_co, value))
434+
event = self.current_event
435+
if event is None:
436+
raise RuntimeError("Take buffer found current_event is None")
437+
event_add(event)
438+
if buffer_size() >= max_:
439+
# signal that the buffer is full and should be emptied.
440+
buffer_full.set()
441+
# strict wait for buffer to be consumed after buffer full.
442+
# If max is 1000, we are not allowed to return 1001 values.
443+
buffer_consumed.clear()
444+
await self.wait(buffer_consumed)
445+
except CancelledError: # pragma: no cover
446+
raise
447+
except Exception as exc:
448+
self.log.exception("Error adding to take buffer: %r", exc)
449+
await self.crash(exc)
450+
return value
451+
452+
# Disable acks to ensure this method acks manually
453+
# events only after they are consumed by the user
454+
self.enable_acks = False
455+
456+
self.add_processor(add_to_buffer)
457+
self._enable_passive(cast(ChannelT, channel_it))
458+
try:
459+
while not self.should_stop:
460+
# wait until buffer full, or timeout
461+
await self.wait_for_stopped(buffer_full, timeout=timeout)
462+
if buffer:
463+
# make sure background thread does not add new items to
464+
# buffer while we read.
465+
buffer_consuming = self.loop.create_future()
466+
try:
467+
yield list(events)
468+
finally:
469+
buffer.clear()
470+
for event in events:
471+
await self.ack(event)
472+
events.clear()
473+
# allow writing to buffer again
474+
notify(buffer_consuming)
475+
buffer_full.clear()
476+
buffer_consumed.set()
477+
else: # pragma: no cover
478+
pass
479+
else: # pragma: no cover
480+
pass
481+
482+
finally:
483+
# Restore last behaviour of "enable_acks"
484+
self.enable_acks = stream_enable_acks
485+
self._processors.remove(add_to_buffer)
486+
394487
async def take_with_timestamp(
395488
self, max_: int, within: Seconds, timestamp_field_name: str
396489
) -> AsyncIterable[Sequence[T_co]]:

0 commit comments

Comments
 (0)