Skip to content

Commit 70fb3a4

Browse files
committed
feat: initial stream
1 parent db43f89 commit 70fb3a4

8 files changed

Lines changed: 489 additions & 27 deletions

File tree

src/duron/context.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from contextvars import ContextVar
45
from random import Random
5-
from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast, final
6+
from typing import TYPE_CHECKING, Concatenate, ParamSpec, TypeVar, cast, final
67

78
from typing_extensions import overload
89

9-
from duron.ops import FnCall
10+
from duron.ops import FnCall, StreamCreate
11+
from duron.stream import StreamTask
1012

1113
if TYPE_CHECKING:
12-
from collections.abc import Awaitable, Callable
14+
from collections.abc import AsyncGenerator, Awaitable, Callable
1315
from contextvars import Token
1416
from types import TracebackType
1517

1618
from duron.event_loop import EventLoop
1719
from duron.fn import Fn
20+
from duron.stream import Observer, RawStream
1821

1922
_T = TypeVar("_T")
23+
_S = TypeVar("_S")
2024
_P = ParamSpec("_P")
2125

26+
2227
_context: ContextVar[Context | None] = ContextVar("duron_context", default=None)
2328

2429

@@ -74,6 +79,8 @@ async def run(
7479
*args: _P.args,
7580
**kwargs: _P.kwargs,
7681
) -> _T:
82+
if asyncio.get_event_loop() is not self._loop:
83+
raise RuntimeError("Context time can only be used in the context loop")
7784
type_info = self._task.codec.inspect_function(fn)
7885
return cast(
7986
"_T",
@@ -87,11 +94,46 @@ async def run(
8794
),
8895
)
8996

97+
def run_stream(
98+
self,
99+
initial: _T,
100+
reducer: Callable[[_T, _S], _T],
101+
fn: Callable[Concatenate[_T, _P], AsyncGenerator[_S, _T]],
102+
/,
103+
*args: _P.args,
104+
**kwargs: _P.kwargs,
105+
) -> StreamTask[_S, _T]:
106+
if asyncio.get_event_loop() is not self._loop:
107+
raise RuntimeError("Context time can only be used in the context loop")
108+
return StreamTask(
109+
self._loop,
110+
initial,
111+
reducer,
112+
fn,
113+
*args,
114+
**kwargs,
115+
)
116+
117+
async def create_stream(self, observer: Observer[_T] | None) -> RawStream[_T]:
118+
if asyncio.get_event_loop() is not self._loop:
119+
raise RuntimeError("Context time can only be used in the context loop")
120+
o = cast("Observer[object]", observer) if observer else None
121+
return cast(
122+
"RawStream[_T]",
123+
await self._loop.create_op(StreamCreate(observer=o)),
124+
)
125+
90126
def time(self) -> float:
127+
if asyncio.get_event_loop() is not self._loop:
128+
raise RuntimeError("Context time can only be used in the context loop")
91129
return self._loop.time()
92130

93131
def time_ns(self) -> int:
132+
if asyncio.get_event_loop() is not self._loop:
133+
raise RuntimeError("Context time can only be used in the context loop")
94134
return self._loop.time_ns()
95135

96136
def random(self) -> Random:
137+
if asyncio.get_event_loop() is not self._loop:
138+
raise RuntimeError("Context random can only be used in the context loop")
97139
return Random(self._loop.generate_op_id())

src/duron/event_loop.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import contextvars
66
import heapq
77
import logging
8+
import os
89
from asyncio import AbstractEventLoop, Handle, Task, TimerHandle, events
910
from collections import deque
1011
from dataclasses import dataclass
@@ -47,7 +48,7 @@ class OpFuture(asyncio.Future[object]):
4748
id: bytes
4849
params: object
4950

50-
def __init__(self, id: bytes, params: object, loop: EventLoop) -> None:
51+
def __init__(self, id: bytes, params: object, loop: AbstractEventLoop) -> None:
5152
super().__init__(loop=loop)
5253
self.id = id
5354
self.params = params
@@ -76,9 +77,10 @@ class _TaskCtx:
7677

7778

7879
class EventLoop(AbstractEventLoop):
79-
def __init__(self, seed: bytes) -> None:
80+
def __init__(self, ambient: asyncio.AbstractEventLoop, seed: bytes) -> None:
8081
self._ready: deque[Handle] = deque()
8182
self._debug: bool = False
83+
self._ambient: asyncio.AbstractEventLoop = ambient
8284
self._exc_handler: (
8385
Callable[[AbstractEventLoop, dict[str, object]], object] | None
8486
) = None
@@ -95,6 +97,9 @@ def generate_op_id(self) -> bytes:
9597
ctx.seq += 1
9698
return _mix_id(ctx.parent_id, ctx.seq - 1)
9799

100+
def ambient_loop(self) -> asyncio.AbstractEventLoop:
101+
return self._ambient
102+
98103
@override
99104
def call_soon(
100105
self,
@@ -212,11 +217,36 @@ def poll_completion(self, task: Future[_T]) -> WaitSet | None:
212217
finally:
213218
events._set_running_loop(old)
214219

215-
def create_op(self, params: object) -> OpFuture:
216-
id = self.generate_op_id()
217-
s = OpFuture(id, params, self)
218-
self._ops[id] = s
219-
return s
220+
def create_op(
221+
self, params: object, /, *, loop: AbstractEventLoop | None = None
222+
) -> OpFuture:
223+
if loop is not None and loop is not self:
224+
id = os.urandom(12)
225+
host_fut: asyncio.Future[object] = OpFuture(id, params, loop)
226+
op_fut = OpFuture(id, params, self)
227+
self._ops[id] = op_fut
228+
229+
def op_to_host(f: asyncio.Future[object]):
230+
if f.cancelled():
231+
_ = loop.call_soon(host_fut.cancel)
232+
elif e := f.exception():
233+
_ = loop.call_soon(host_fut.set_exception, e)
234+
else:
235+
_ = loop.call_soon(host_fut.set_result, f.result())
236+
237+
def host_to_op(f: asyncio.Future[object]):
238+
if f.cancelled() and not op_fut.done():
239+
_ = self.call_soon(op_fut.cancel)
240+
241+
op_fut.add_done_callback(op_to_host)
242+
host_fut.add_done_callback(host_to_op)
243+
self._event.set()
244+
return host_fut
245+
else:
246+
id = self.generate_op_id()
247+
op_fut = OpFuture(id, params, self)
248+
self._ops[id] = op_fut
249+
return op_fut
220250

221251
@overload
222252
def post_completion(
@@ -316,8 +346,11 @@ def _mix_id(a: bytes, b: int) -> bytes:
316346
return blake2b(b.to_bytes(4, "little", signed=True) + a, digest_size=12).digest()
317347

318348

319-
def create_loop(seed: bytes) -> EventLoop:
320-
return EventLoop(seed) # type: ignore[abstract]
349+
def create_loop(
350+
parent_loop: asyncio.AbstractEventLoop,
351+
seed: bytes,
352+
) -> EventLoop:
353+
return EventLoop(parent_loop, seed) # type: ignore[abstract]
321354

322355

323356
def create_op(params: object) -> OpFuture:

src/duron/log.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ class StreamCreateEntry(_BaseEntry):
4646
class StreamEmitEntry(_BaseEntry):
4747
type: Literal["stream/emit"]
4848
stream_id: str
49-
value: NotRequired[JSONValue]
50-
state: NotRequired[JSONValue]
49+
value: JSONValue
5150

5251

5352
class StreamCompleteEntry(_BaseEntry):

src/duron/ops.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
if TYPE_CHECKING:
77
from collections.abc import Awaitable, Callable, Coroutine
88

9+
from duron.stream import Observer
10+
911

1012
@dataclass(slots=True)
1113
class FnCall:
@@ -21,4 +23,21 @@ class TaskRun:
2123
return_type: type | None = None
2224

2325

24-
Op = FnCall | TaskRun
26+
@dataclass(slots=True)
27+
class StreamCreate:
28+
observer: Observer[object] | None
29+
30+
31+
@dataclass(slots=True)
32+
class StreamEmit:
33+
stream_id: str
34+
value: object
35+
36+
37+
@dataclass(slots=True)
38+
class StreamClose:
39+
stream_id: str
40+
exception: BaseException | None
41+
42+
43+
Op = FnCall | StreamCreate | StreamEmit | StreamClose | TaskRun

src/duron/stream.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import (
5+
TYPE_CHECKING,
6+
Concatenate,
7+
Generic,
8+
ParamSpec,
9+
Protocol,
10+
TypeVar,
11+
cast,
12+
final,
13+
)
14+
15+
from typing_extensions import override
16+
17+
from duron.ops import FnCall, StreamClose, StreamCreate, StreamEmit
18+
19+
if TYPE_CHECKING:
20+
from collections.abc import AsyncGenerator, Callable
21+
22+
from duron.event_loop import EventLoop
23+
24+
_P = ParamSpec("_P")
25+
26+
_In = TypeVar("_In", contravariant=True)
27+
_Out = TypeVar("_Out", covariant=True)
28+
29+
30+
class Observer(Generic[_In], Protocol):
31+
def on_next(self, value: _In, /) -> None: ...
32+
def on_close(self, error: BaseException | None, /) -> None: ...
33+
34+
35+
class AmbientRawStream(Protocol[_In]):
36+
async def send(self, value: _In, /) -> None: ...
37+
38+
async def close(self, error: BaseException | None = None, /) -> None: ...
39+
40+
41+
@final
42+
class RawStream(Generic[_In]):
43+
def __init__(self, id: str, loop: EventLoop) -> None:
44+
self._stream_id = id
45+
self._loop = loop
46+
self._target_loop: asyncio.AbstractEventLoop | None = None
47+
self._event = asyncio.Event()
48+
49+
async def send(self, value: _In, /) -> None:
50+
_ = await self._loop.create_op(
51+
StreamEmit(stream_id=self._stream_id, value=value),
52+
loop=self._target_loop,
53+
)
54+
55+
async def close(self, error: BaseException | None = None, /) -> None:
56+
_ = await self._loop.create_op(
57+
StreamClose(stream_id=self._stream_id, exception=error),
58+
loop=self._target_loop,
59+
)
60+
self._event.set()
61+
62+
async def wait(self) -> None:
63+
_ = await self._event.wait()
64+
65+
def to_ambient(self) -> AmbientRawStream[_In]:
66+
s: RawStream[_In] = RawStream(self._stream_id, self._loop)
67+
s._target_loop = self._loop.ambient_loop()
68+
return s
69+
70+
71+
@final
72+
class _StreamObserver(Generic[_In, _Out], Observer[_In]):
73+
def __init__(self, initial: _Out, reducer: Callable[[_Out, _In], _Out]):
74+
self.current = initial
75+
self.enable = True
76+
self._reducer = reducer
77+
self.data: list[_Out] = []
78+
self.closed: bool | BaseException = False
79+
80+
@override
81+
def on_next(self, val: _In):
82+
if self.enable:
83+
self.current = self._reducer(self.current, val)
84+
self.data.append(self.current)
85+
86+
@override
87+
def on_close(self, exc: BaseException | None):
88+
self.closed = True if exc is None else exc
89+
90+
91+
@final
92+
class StreamTask(Generic[_In, _Out]):
93+
def __init__(
94+
self,
95+
loop: EventLoop,
96+
initial: _Out,
97+
reducer: Callable[[_Out, _In], _Out],
98+
fn: Callable[Concatenate[_Out, _P], AsyncGenerator[_In, _Out]],
99+
/,
100+
*args: _P.args,
101+
**kwargs: _P.kwargs,
102+
) -> None:
103+
self._loop = loop
104+
self._reducer = reducer
105+
self._fn = fn
106+
self._args = args
107+
self._kwargs = kwargs
108+
self._obs = _StreamObserver(initial, self._reducer)
109+
self._op = self._loop.create_op(
110+
StreamCreate(observer=cast("Observer[object]", self._obs))
111+
)
112+
self._queue: asyncio.Queue[tuple[_Out] | None | BaseException] = (
113+
self._setup_stream()
114+
)
115+
116+
def __aiter__(self) -> StreamTask[_In, _Out]:
117+
return self
118+
119+
async def __anext__(self) -> _Out:
120+
item = await self._queue.get()
121+
if item is None:
122+
raise StopAsyncIteration
123+
if isinstance(item, BaseException):
124+
raise item
125+
return item[0]
126+
127+
async def discard(self) -> None:
128+
async for _ in self:
129+
...
130+
131+
def _setup_stream(self) -> asyncio.Queue[tuple[_Out] | None | BaseException]:
132+
queue: asyncio.Queue[tuple[_Out] | None | BaseException] = asyncio.Queue()
133+
134+
async def worker():
135+
stream = cast("RawStream[_In]", await self._op).to_ambient()
136+
try:
137+
state = self._obs.current
138+
self._obs.enable = False
139+
140+
for d in self._obs.data:
141+
await queue.put((d,))
142+
if self._obs.closed is True:
143+
return
144+
elif isinstance(self._obs.closed, BaseException):
145+
raise self._obs.closed
146+
147+
gen = self._fn(state, *self._args, **self._kwargs)
148+
state_partial = await gen.__anext__()
149+
150+
while True:
151+
state = self._reducer(state, state_partial)
152+
await stream.send(state_partial)
153+
await queue.put((state,))
154+
state_partial = await gen.asend(state)
155+
except StopAsyncIteration as _e:
156+
await stream.close()
157+
except BaseException as e:
158+
await queue.put(e)
159+
raise
160+
finally:
161+
await queue.put(None)
162+
163+
_ = self._loop.create_op(
164+
FnCall(callable=worker, args=(), kwargs={}, return_type=None)
165+
)
166+
return queue

0 commit comments

Comments
 (0)