88from enum import Enum , auto
99from typing import (
1010 TYPE_CHECKING ,
11+ Any ,
1112 Concatenate ,
1213 Generic ,
1314 ParamSpec ,
1415 Protocol ,
1516 TypeVar ,
1617 cast ,
1718 final ,
19+ overload ,
1820)
1921
2022from typing_extensions import override
3133_In = TypeVar ("_In" , contravariant = True )
3234_Out = TypeVar ("_Out" , covariant = True )
3335_T = TypeVar ("_T" )
36+ _SentialObject = object ()
3437
3538
3639class EndOfStream (Exception ):
@@ -128,6 +131,25 @@ async def close(self) -> None:
128131 await self ._stop ()
129132 self ._state = StreamState .CLOSED
130133
134+ @overload
135+ async def reduce (self , fn : Callable [[_T , _Out ], _T ], initial : _T , / ) -> _T : ...
136+ @overload
137+ async def reduce (self , fn : Callable [[_Out , _Out ], _Out ], / ) -> _Out : ...
138+ async def reduce (
139+ self , fn : Callable [[Any , _Out ], Any ], initial : Any = _SentialObject , /
140+ ) -> Any :
141+ if initial is _SentialObject :
142+ await self .start ()
143+ acc = await self ._get ()
144+ async for v in self :
145+ acc = fn (acc , v )
146+ return acc
147+ else :
148+ acc = initial
149+ async for v in self :
150+ acc = fn (acc , v )
151+ return acc
152+
131153 @final
132154 def get_nowait (self ) -> _Out :
133155 self ._ensure (StreamState .STARTED )
@@ -159,13 +181,16 @@ async def _start(self) -> None: ...
159181 async def _stop (self ) -> None : ...
160182
161183 def map (self , fn : Callable [[_Out ], _T ]) -> Stream [_T ]:
162- self ._ensure (StreamState .CONSUMED )
163184 return _MapStream (self , fn )
164185
165186 def fork (self ) -> tuple [Stream [_Out ], Stream [_Out ]]:
166- self ._ensure (StreamState .CONSUMED )
167187 return _IntoBuffer (self ).fork ()
168188
189+ def merge (self , other : Stream [_Out ]) -> Stream [_Out ]:
190+ a = self if isinstance (self , _BufferedStream ) else _IntoBuffer (self )
191+ b = other if isinstance (other , _BufferedStream ) else _IntoBuffer (other )
192+ return a .merge (b )
193+
169194
170195@final
171196class _AsyncIterableStream (Generic [_Out ], Stream [_Out ]):
@@ -196,6 +221,7 @@ def _get_nowait(self) -> _Out:
196221class _MapStream (Generic [_In , _T ], Stream [_T ]):
197222 def __init__ (self , source : Stream [_In ], fn : Callable [[_In ], _T ]) -> None :
198223 super ().__init__ ()
224+ source ._ensure (StreamState .CONSUMED )
199225 self ._source = source
200226 self ._fn = fn
201227
@@ -219,12 +245,13 @@ class _Sentinel:
219245
220246
221247class _BufferedStream (Generic [_T ], Stream [_T ]):
222- def __init__ (self , parent : _BufferedStream [_T ] | None = None ) -> None :
248+ def __init__ (self , parents : list [ _BufferedStream [_T ] ] | None = None ) -> None :
223249 super ().__init__ ()
224250 self .__buffer : deque [_T | _Sentinel ] = deque ()
225251 self .__subscribers : list [_BufferedStream [_T ]] = []
226- self .__parent = parent
252+ self .__parent = parents
227253 self .__event = asyncio .Event ()
254+ self .__closed = 1
228255
229256 @override
230257 async def _get (self ) -> _T :
@@ -251,6 +278,9 @@ def _send(self, value: _T):
251278 s ._send (value )
252279
253280 def _send_close (self , exc : BaseException | None = None ):
281+ self .__closed -= 1
282+ if self .__closed != 0 :
283+ return
254284 self .__buffer .append (_Sentinel (exc or EndOfStream ))
255285 self .__event .set ()
256286 for s in self .__subscribers :
@@ -260,20 +290,39 @@ def _send_close(self, exc: BaseException | None = None):
260290 async def _start (self ) -> None :
261291 await super ()._start ()
262292 if self .__parent :
263- await self .__parent .start ()
293+ for p in self .__parent :
294+ await p .start ()
264295
265296 @override
266297 def fork (self ) -> tuple [Stream [_T ], Stream [_T ]]:
267- parent = self .__parent or self
298+ parent = self .__parent or [ self ]
268299 clone : _BufferedStream [_T ] = _BufferedStream (parent )
269- parent .__subscribers .append (clone )
300+ for p in parent :
301+ p .__subscribers .append (clone )
270302 return self , clone
271303
304+ @override
305+ def merge (self , other : Stream [_T ]) -> Stream [_T ]:
306+ if not isinstance (other , _BufferedStream ):
307+ return super ().merge (other )
308+
309+ self ._ensure (StreamState .CONSUMED )
310+ other ._ensure (StreamState .CONSUMED )
311+ parents : list [_BufferedStream [_T ]] = []
312+ parents .extend (self .__parent or [self ])
313+ parents .extend (other .__parent or [other ])
314+ clone : _BufferedStream [_T ] = _BufferedStream (parents )
315+ for p in parents :
316+ p .__subscribers .append (clone )
317+ clone .__closed += len (parents ) - 1
318+ return clone
319+
272320
273321@final
274322class _IntoBuffer (Generic [_T ], _BufferedStream [_T ]):
275323 def __init__ (self , stream : Stream [_T ]) -> None :
276324 super ().__init__ ()
325+ stream ._ensure (StreamState .CONSUMED )
277326 self ._stream = stream
278327 self ._task : asyncio .Task [None ] | None = None
279328
@@ -292,7 +341,7 @@ async def pump() -> None:
292341 @override
293342 async def _stop (self ) -> None :
294343 if self ._task :
295- self ._task .cancel ()
344+ _ = self ._task .cancel ()
296345 with contextlib .suppress (asyncio .CancelledError ):
297346 await self ._task
298347 self ._task = None
@@ -326,12 +375,12 @@ def __init__(
326375 @override
327376 def on_next (self , val : _In ):
328377 self ._current = self ._reducer (self ._current , val )
329- self ._send (self ._current ) # pyright: ignore[reportPrivateUsage]
378+ self ._send (self ._current )
330379
331380 @override
332381 def on_close (self , exc : BaseException | None ):
333382 self ._closed = True if exc is None else exc
334- self ._send_close (exc ) # pyright: ignore[reportPrivateUsage]
383+ self ._send_close (exc )
335384
336385 @override
337386 async def _start (self ) -> None :
0 commit comments