@@ -391,6 +391,99 @@ async def add_to_buffer(value: T) -> T:
391391 self .enable_acks = stream_enable_acks
392392 self ._processors .remove (add_to_buffer )
393393
394+ async def take_events (
395+ self , max_ : int , within : Seconds
396+ ) -> AsyncIterable [Sequence [EventT ]]:
397+ """Buffer n events at a time and yield a list of buffered events.
398+ Arguments:
399+ max_: Max number of messages to receive. When more than this
400+ number of messages are received within the specified number of
401+ seconds then we flush the buffer immediately.
402+ within: Timeout for when we give up waiting for another value,
403+ and process the values we have.
404+ Warning: If there's no timeout (i.e. `timeout=None`),
405+ the agent is likely to stall and block buffered events for an
406+ unreasonable length of time(!).
407+ """
408+ buffer : List [T_co ] = []
409+ events : List [EventT ] = []
410+ buffer_add = buffer .append
411+ event_add = events .append
412+ buffer_size = buffer .__len__
413+ buffer_full = asyncio .Event ()
414+ buffer_consumed = asyncio .Event ()
415+ timeout = want_seconds (within ) if within else None
416+ stream_enable_acks : bool = self .enable_acks
417+
418+ buffer_consuming : Optional [asyncio .Future ] = None
419+
420+ channel_it = aiter (self .channel )
421+
422+ # We add this processor to populate the buffer, and the stream
423+ # is passively consumed in the background (enable_passive below).
424+ async def add_to_buffer (value : T ) -> T :
425+ try :
426+ # buffer_consuming is set when consuming buffer after timeout.
427+ nonlocal buffer_consuming
428+ if buffer_consuming is not None :
429+ try :
430+ await buffer_consuming
431+ finally :
432+ buffer_consuming = None
433+ buffer_add (cast (T_co , value ))
434+ event = self .current_event
435+ if event is None :
436+ raise RuntimeError ("Take buffer found current_event is None" )
437+ event_add (event )
438+ if buffer_size () >= max_ :
439+ # signal that the buffer is full and should be emptied.
440+ buffer_full .set ()
441+ # strict wait for buffer to be consumed after buffer full.
442+ # If max is 1000, we are not allowed to return 1001 values.
443+ buffer_consumed .clear ()
444+ await self .wait (buffer_consumed )
445+ except CancelledError : # pragma: no cover
446+ raise
447+ except Exception as exc :
448+ self .log .exception ("Error adding to take buffer: %r" , exc )
449+ await self .crash (exc )
450+ return value
451+
452+ # Disable acks to ensure this method acks manually
453+ # events only after they are consumed by the user
454+ self .enable_acks = False
455+
456+ self .add_processor (add_to_buffer )
457+ self ._enable_passive (cast (ChannelT , channel_it ))
458+ try :
459+ while not self .should_stop :
460+ # wait until buffer full, or timeout
461+ await self .wait_for_stopped (buffer_full , timeout = timeout )
462+ if buffer :
463+ # make sure background thread does not add new items to
464+ # buffer while we read.
465+ buffer_consuming = self .loop .create_future ()
466+ try :
467+ yield list (events )
468+ finally :
469+ buffer .clear ()
470+ for event in events :
471+ await self .ack (event )
472+ events .clear ()
473+ # allow writing to buffer again
474+ notify (buffer_consuming )
475+ buffer_full .clear ()
476+ buffer_consumed .set ()
477+ else : # pragma: no cover
478+ pass
479+ else : # pragma: no cover
480+ pass
481+
482+ finally :
483+ # Restore last behaviour of "enable_acks"
484+ self .enable_acks = stream_enable_acks
485+ self ._processors .remove (add_to_buffer )
486+
394487 async def take_with_timestamp (
395488 self , max_ : int , within : Seconds , timestamp_field_name : str
396489 ) -> AsyncIterable [Sequence [T_co ]]:
0 commit comments