Skip to content

Commit 72caaca

Browse files
committed
wip feat: stream merge
1 parent c02187c commit 72caaca

2 files changed

Lines changed: 64 additions & 14 deletions

File tree

src/duron/stream.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from enum import Enum, auto
99
from typing import (
1010
TYPE_CHECKING,
11+
Any,
1112
Concatenate,
1213
Generic,
1314
ParamSpec,
1415
Protocol,
1516
TypeVar,
1617
cast,
1718
final,
19+
overload,
1820
)
1921

2022
from typing_extensions import override
@@ -31,6 +33,7 @@
3133
_In = TypeVar("_In", contravariant=True)
3234
_Out = TypeVar("_Out", covariant=True)
3335
_T = TypeVar("_T")
36+
_SentialObject = object()
3437

3538

3639
class EndOfStream(Exception):
@@ -128,6 +131,25 @@ async def close(self) -> None:
128131
await self._stop()
129132
self._state = StreamState.CLOSED
130133

134+
@overload
135+
async def reduce(self, fn: Callable[[_T, _Out], _T], initial: _T, /) -> _T: ...
136+
@overload
137+
async def reduce(self, fn: Callable[[_Out, _Out], _Out], /) -> _Out: ...
138+
async def reduce(
139+
self, fn: Callable[[Any, _Out], Any], initial: Any = _SentialObject, /
140+
) -> Any:
141+
if initial is _SentialObject:
142+
await self.start()
143+
acc = await self._get()
144+
async for v in self:
145+
acc = fn(acc, v)
146+
return acc
147+
else:
148+
acc = initial
149+
async for v in self:
150+
acc = fn(acc, v)
151+
return acc
152+
131153
@final
132154
def get_nowait(self) -> _Out:
133155
self._ensure(StreamState.STARTED)
@@ -159,13 +181,16 @@ async def _start(self) -> None: ...
159181
async def _stop(self) -> None: ...
160182

161183
def map(self, fn: Callable[[_Out], _T]) -> Stream[_T]:
162-
self._ensure(StreamState.CONSUMED)
163184
return _MapStream(self, fn)
164185

165186
def fork(self) -> tuple[Stream[_Out], Stream[_Out]]:
166-
self._ensure(StreamState.CONSUMED)
167187
return _IntoBuffer(self).fork()
168188

189+
def merge(self, other: Stream[_Out]) -> Stream[_Out]:
190+
a = self if isinstance(self, _BufferedStream) else _IntoBuffer(self)
191+
b = other if isinstance(other, _BufferedStream) else _IntoBuffer(other)
192+
return a.merge(b)
193+
169194

170195
@final
171196
class _AsyncIterableStream(Generic[_Out], Stream[_Out]):
@@ -196,6 +221,7 @@ def _get_nowait(self) -> _Out:
196221
class _MapStream(Generic[_In, _T], Stream[_T]):
197222
def __init__(self, source: Stream[_In], fn: Callable[[_In], _T]) -> None:
198223
super().__init__()
224+
source._ensure(StreamState.CONSUMED)
199225
self._source = source
200226
self._fn = fn
201227

@@ -219,12 +245,13 @@ class _Sentinel:
219245

220246

221247
class _BufferedStream(Generic[_T], Stream[_T]):
222-
def __init__(self, parent: _BufferedStream[_T] | None = None) -> None:
248+
def __init__(self, parents: list[_BufferedStream[_T]] | None = None) -> None:
223249
super().__init__()
224250
self.__buffer: deque[_T | _Sentinel] = deque()
225251
self.__subscribers: list[_BufferedStream[_T]] = []
226-
self.__parent = parent
252+
self.__parent = parents
227253
self.__event = asyncio.Event()
254+
self.__closed = 1
228255

229256
@override
230257
async def _get(self) -> _T:
@@ -251,6 +278,9 @@ def _send(self, value: _T):
251278
s._send(value)
252279

253280
def _send_close(self, exc: BaseException | None = None):
281+
self.__closed -= 1
282+
if self.__closed != 0:
283+
return
254284
self.__buffer.append(_Sentinel(exc or EndOfStream))
255285
self.__event.set()
256286
for s in self.__subscribers:
@@ -260,20 +290,39 @@ def _send_close(self, exc: BaseException | None = None):
260290
async def _start(self) -> None:
261291
await super()._start()
262292
if self.__parent:
263-
await self.__parent.start()
293+
for p in self.__parent:
294+
await p.start()
264295

265296
@override
266297
def fork(self) -> tuple[Stream[_T], Stream[_T]]:
267-
parent = self.__parent or self
298+
parent = self.__parent or [self]
268299
clone: _BufferedStream[_T] = _BufferedStream(parent)
269-
parent.__subscribers.append(clone)
300+
for p in parent:
301+
p.__subscribers.append(clone)
270302
return self, clone
271303

304+
@override
305+
def merge(self, other: Stream[_T]) -> Stream[_T]:
306+
if not isinstance(other, _BufferedStream):
307+
return super().merge(other)
308+
309+
self._ensure(StreamState.CONSUMED)
310+
other._ensure(StreamState.CONSUMED)
311+
parents: list[_BufferedStream[_T]] = []
312+
parents.extend(self.__parent or [self])
313+
parents.extend(other.__parent or [other])
314+
clone: _BufferedStream[_T] = _BufferedStream(parents)
315+
for p in parents:
316+
p.__subscribers.append(clone)
317+
clone.__closed += len(parents) - 1
318+
return clone
319+
272320

273321
@final
274322
class _IntoBuffer(Generic[_T], _BufferedStream[_T]):
275323
def __init__(self, stream: Stream[_T]) -> None:
276324
super().__init__()
325+
stream._ensure(StreamState.CONSUMED)
277326
self._stream = stream
278327
self._task: asyncio.Task[None] | None = None
279328

@@ -292,7 +341,7 @@ async def pump() -> None:
292341
@override
293342
async def _stop(self) -> None:
294343
if self._task:
295-
self._task.cancel()
344+
_ = self._task.cancel()
296345
with contextlib.suppress(asyncio.CancelledError):
297346
await self._task
298347
self._task = None
@@ -326,12 +375,12 @@ def __init__(
326375
@override
327376
def on_next(self, val: _In):
328377
self._current = self._reducer(self._current, val)
329-
self._send(self._current) # pyright: ignore[reportPrivateUsage]
378+
self._send(self._current)
330379

331380
@override
332381
def on_close(self, exc: BaseException | None):
333382
self._closed = True if exc is None else exc
334-
self._send_close(exc) # pyright: ignore[reportPrivateUsage]
383+
self._send_close(exc)
335384

336385
@override
337386
async def _start(self) -> None:

tests/test_stream.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,15 @@ async def f(s: str) -> AsyncGenerator[str, str]:
134134
@pytest.mark.asyncio
135135
async def test_stream_generator():
136136
@fn()
137-
async def activity(_ctx: Context) -> None:
137+
async def activity(ctx: Context) -> None:
138138
async def f() -> AsyncGenerator[int]:
139139
for i in range(100):
140140
yield i
141-
await asyncio.sleep(0)
141+
await asyncio.sleep(ctx.random().randint(1, 10) * 0.001)
142142

143143
stream = Stream[int].from_iterator(f())
144-
m = stream.map(lambda x: x * 2)
144+
m = stream.map(lambda x: x + 2000)
145+
s3 = Stream[int].from_iterator(f())
145146

146147
with pytest.raises(RuntimeError):
147148
# cannot iterate twice
@@ -155,7 +156,7 @@ async def f() -> AsyncGenerator[int]:
155156
# def add(a: int, b: int) -> int:
156157
# return a + b
157158

158-
assert await d.collect() == list(range(0, 200, 2))
159+
print(await (s3.merge(d)).collect())
159160

160161
log = MemoryLogStorage()
161162
async with activity.create_task(log) as t:

0 commit comments

Comments
 (0)