@@ -391,6 +391,99 @@ async def add_to_buffer(value: T) -> T:
391391 self .enable_acks = stream_enable_acks
392392 self ._processors .remove (add_to_buffer )
393393
394+ async def take_events (self , max_ : int ,
395+ within : Seconds ) -> AsyncIterable [Sequence [EventT ]]:
396+ """Buffer n events at a time and yield a list of buffered events.
397+ Arguments:
398+ max_: Max number of messages to receive. When more than this
399+ number of messages are received within the specified number of
400+ seconds then we flush the buffer immediately.
401+ within: Timeout for when we give up waiting for another value,
402+ and process the values we have.
403+ Warning: If there's no timeout (i.e. `timeout=None`),
404+ the agent is likely to stall and block buffered events for an
405+ unreasonable length of time(!).
406+ """
407+ buffer : List [T_co ] = []
408+ events : List [EventT ] = []
409+ buffer_add = buffer .append
410+ event_add = events .append
411+ buffer_size = buffer .__len__
412+ buffer_full = asyncio .Event ()
413+ buffer_consumed = asyncio .Event ()
414+ timeout = want_seconds (within ) if within else None
415+ stream_enable_acks : bool = self .enable_acks
416+
417+ buffer_consuming : Optional [asyncio .Future ] = None
418+
419+ channel_it = aiter (self .channel )
420+
421+ # We add this processor to populate the buffer, and the stream
422+ # is passively consumed in the background (enable_passive below).
423+ async def add_to_buffer (value : T ) -> T :
424+ try :
425+ # buffer_consuming is set when consuming buffer after timeout.
426+ nonlocal buffer_consuming
427+ if buffer_consuming is not None :
428+ try :
429+ await buffer_consuming
430+ finally :
431+ buffer_consuming = None
432+ buffer_add (cast (T_co , value ))
433+ event = self .current_event
434+ if event is None :
435+ raise RuntimeError (
436+ 'Take buffer found current_event is None' )
437+ event_add (event )
438+ if buffer_size () >= max_ :
439+ # signal that the buffer is full and should be emptied.
440+ buffer_full .set ()
441+ # strict wait for buffer to be consumed after buffer full.
442+ # If max is 1000, we are not allowed to return 1001 values.
443+ buffer_consumed .clear ()
444+ await self .wait (buffer_consumed )
445+ except CancelledError : # pragma: no cover
446+ raise
447+ except Exception as exc :
448+ self .log .exception ('Error adding to take buffer: %r' , exc )
449+ await self .crash (exc )
450+ return value
451+
452+ # Disable acks to ensure this method acks manually
453+ # events only after they are consumed by the user
454+ self .enable_acks = False
455+
456+ self .add_processor (add_to_buffer )
457+ self ._enable_passive (cast (ChannelT , channel_it ))
458+ try :
459+ while not self .should_stop :
460+ # wait until buffer full, or timeout
461+ await self .wait_for_stopped (buffer_full , timeout = timeout )
462+ if buffer :
463+ # make sure background thread does not add new items to
464+ # buffer while we read.
465+ buffer_consuming = self .loop .create_future ()
466+ try :
467+ yield list (events )
468+ finally :
469+ buffer .clear ()
470+ for event in events :
471+ await self .ack (event )
472+ events .clear ()
473+ # allow writing to buffer again
474+ notify (buffer_consuming )
475+ buffer_full .clear ()
476+ buffer_consumed .set ()
477+ else : # pragma: no cover
478+ pass
479+ else : # pragma: no cover
480+ pass
481+
482+ finally :
483+ # Restore last behaviour of "enable_acks"
484+ self .enable_acks = stream_enable_acks
485+ self ._processors .remove (add_to_buffer )
486+
394487 async def take_with_timestamp (
395488 self , max_ : int , within : Seconds , timestamp_field_name : str
396489 ) -> AsyncIterable [Sequence [T_co ]]:
0 commit comments