Skip to content

Commit 0084cbf

Browse files
committed
feat wip stream
1 parent db43f89 commit 0084cbf

7 files changed

Lines changed: 464 additions & 21 deletions

File tree

src/duron/context.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from contextvars import ContextVar
45
from random import Random
5-
from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast, final
6+
from typing import TYPE_CHECKING, Concatenate, ParamSpec, TypeVar, cast, final
67

78
from typing_extensions import overload
89

9-
from duron.ops import FnCall
10+
from duron.ops import FnCall, StreamCreate
11+
from duron.stream import StreamTask
1012

1113
if TYPE_CHECKING:
12-
from collections.abc import Awaitable, Callable
14+
from collections.abc import AsyncGenerator, Awaitable, Callable
1315
from contextvars import Token
1416
from types import TracebackType
1517

1618
from duron.event_loop import EventLoop
1719
from duron.fn import Fn
20+
from duron.stream import Observer, Stream
1821

1922
_T = TypeVar("_T")
23+
_S = TypeVar("_S")
2024
_P = ParamSpec("_P")
2125

26+
2227
_context: ContextVar[Context | None] = ContextVar("duron_context", default=None)
2328

2429

@@ -74,6 +79,8 @@ async def run(
7479
*args: _P.args,
7580
**kwargs: _P.kwargs,
7681
) -> _T:
82+
if asyncio.get_event_loop() is not self._loop:
83+
raise RuntimeError("Context time can only be used in the context loop")
7784
type_info = self._task.codec.inspect_function(fn)
7885
return cast(
7986
"_T",
@@ -87,11 +94,46 @@ async def run(
8794
),
8895
)
8996

97+
def run_stream(
98+
self,
99+
initial: _T,
100+
reducer: Callable[[_T, _S], _T],
101+
fn: Callable[Concatenate[_T, _P], AsyncGenerator[_S, _T]],
102+
/,
103+
*args: _P.args,
104+
**kwargs: _P.kwargs,
105+
) -> StreamTask[_S, _T]:
106+
if asyncio.get_event_loop() is not self._loop:
107+
raise RuntimeError("Context time can only be used in the context loop")
108+
return StreamTask(
109+
self._loop,
110+
initial,
111+
reducer,
112+
fn,
113+
*args,
114+
**kwargs,
115+
)
116+
117+
async def create_stream(self, observer: Observer[_T] | None) -> Stream[_T]:
118+
if asyncio.get_event_loop() is not self._loop:
119+
raise RuntimeError("Context time can only be used in the context loop")
120+
o = cast("Observer[object]", observer) if observer else None
121+
return cast(
122+
"Stream[_T]",
123+
await self._loop.create_op(StreamCreate(observer=o)),
124+
)
125+
90126
def time(self) -> float:
127+
if asyncio.get_event_loop() is not self._loop:
128+
raise RuntimeError("Context time can only be used in the context loop")
91129
return self._loop.time()
92130

93131
def time_ns(self) -> int:
132+
if asyncio.get_event_loop() is not self._loop:
133+
raise RuntimeError("Context time can only be used in the context loop")
94134
return self._loop.time_ns()
95135

96136
def random(self) -> Random:
137+
if asyncio.get_event_loop() is not self._loop:
138+
raise RuntimeError("Context random can only be used in the context loop")
97139
return Random(self._loop.generate_op_id())

src/duron/event_loop.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import contextvars
66
import heapq
77
import logging
8+
import os
89
from asyncio import AbstractEventLoop, Handle, Task, TimerHandle, events
910
from collections import deque
1011
from dataclasses import dataclass
@@ -47,7 +48,7 @@ class OpFuture(asyncio.Future[object]):
4748
id: bytes
4849
params: object
4950

50-
def __init__(self, id: bytes, params: object, loop: EventLoop) -> None:
51+
def __init__(self, id: bytes, params: object, loop: AbstractEventLoop) -> None:
5152
super().__init__(loop=loop)
5253
self.id = id
5354
self.params = params
@@ -212,11 +213,36 @@ def poll_completion(self, task: Future[_T]) -> WaitSet | None:
212213
finally:
213214
events._set_running_loop(old)
214215

215-
def create_op(self, params: object) -> OpFuture:
216-
id = self.generate_op_id()
217-
s = OpFuture(id, params, self)
218-
self._ops[id] = s
219-
return s
216+
def create_op(
217+
self, params: object, /, *, loop: AbstractEventLoop | None = None
218+
) -> OpFuture:
219+
if loop is not None and loop is not self:
220+
id = os.urandom(12)
221+
host_fut: asyncio.Future[object] = OpFuture(id, params, loop)
222+
op_fut = OpFuture(id, params, self)
223+
self._ops[id] = op_fut
224+
225+
def op_to_host(f: asyncio.Future[object]):
226+
if f.cancelled():
227+
_ = loop.call_soon(host_fut.cancel)
228+
elif e := f.exception():
229+
_ = loop.call_soon(host_fut.set_exception, e)
230+
else:
231+
_ = loop.call_soon(host_fut.set_result, f.result())
232+
233+
def host_to_op(f: asyncio.Future[object]):
234+
if f.cancelled() and not op_fut.done():
235+
_ = self.call_soon(op_fut.cancel)
236+
237+
op_fut.add_done_callback(op_to_host)
238+
host_fut.add_done_callback(host_to_op)
239+
self._event.set()
240+
return host_fut
241+
else:
242+
id = self.generate_op_id()
243+
op_fut = OpFuture(id, params, self)
244+
self._ops[id] = op_fut
245+
return op_fut
220246

221247
@overload
222248
def post_completion(

src/duron/log.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ class StreamCreateEntry(_BaseEntry):
4646
class StreamEmitEntry(_BaseEntry):
4747
type: Literal["stream/emit"]
4848
stream_id: str
49-
value: NotRequired[JSONValue]
50-
state: NotRequired[JSONValue]
49+
value: JSONValue
5150

5251

5352
class StreamCompleteEntry(_BaseEntry):

src/duron/ops.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
if TYPE_CHECKING:
77
from collections.abc import Awaitable, Callable, Coroutine
88

9+
from duron.stream import Observer
10+
911

1012
@dataclass(slots=True)
1113
class FnCall:
@@ -21,4 +23,21 @@ class TaskRun:
2123
return_type: type | None = None
2224

2325

24-
Op = FnCall | TaskRun
26+
@dataclass(slots=True)
27+
class StreamCreate:
28+
observer: Observer[object] | None
29+
30+
31+
@dataclass(slots=True)
32+
class StreamEmit:
33+
stream_id: str
34+
value: object
35+
36+
37+
@dataclass(slots=True)
38+
class StreamClose:
39+
stream_id: str
40+
exception: BaseException | None
41+
42+
43+
Op = FnCall | StreamCreate | StreamEmit | StreamClose | TaskRun

src/duron/stream.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from typing import (
5+
TYPE_CHECKING,
6+
Concatenate,
7+
Generic,
8+
ParamSpec,
9+
Protocol,
10+
TypeVar,
11+
cast,
12+
final,
13+
)
14+
15+
from typing_extensions import override
16+
17+
from duron.ops import FnCall, StreamClose, StreamCreate, StreamEmit
18+
19+
if TYPE_CHECKING:
20+
from collections.abc import AsyncGenerator, Callable
21+
22+
from duron.event_loop import EventLoop
23+
24+
_P = ParamSpec("_P")
25+
26+
_In = TypeVar("_In", contravariant=True)
27+
_Out = TypeVar("_Out", covariant=True)
28+
29+
30+
class Observer(Generic[_In], Protocol):
31+
def on_next(self, value: _In, /) -> None: ...
32+
def on_close(self, error: BaseException | None, /) -> None: ...
33+
34+
35+
@final
36+
class Stream(Generic[_In]):
37+
def __init__(self, id: str, loop: EventLoop) -> None:
38+
self._stream_id = id
39+
self._loop = loop
40+
self._event = asyncio.Event()
41+
42+
async def send(self, value: _In, /) -> None:
43+
_ = await self._loop.create_op(
44+
StreamEmit(stream_id=self._stream_id, value=value),
45+
loop=asyncio.get_running_loop(),
46+
)
47+
48+
async def close(self, error: BaseException | None = None, /) -> None:
49+
_ = await self._loop.create_op(
50+
StreamClose(stream_id=self._stream_id, exception=error),
51+
loop=asyncio.get_running_loop(),
52+
)
53+
self._event.set()
54+
55+
async def wait(self) -> None:
56+
_ = await self._event.wait()
57+
58+
59+
@final
60+
class _StreamObserver(Generic[_In, _Out], Observer[_In]):
61+
def __init__(self, initial: _Out, reducer: Callable[[_Out, _In], _Out]):
62+
self.current = initial
63+
self.enable = True
64+
self._reducer = reducer
65+
self.data: list[_Out] = []
66+
self.closed: bool | BaseException = False
67+
68+
@override
69+
def on_next(self, val: _In):
70+
if self.enable:
71+
self.current = self._reducer(self.current, val)
72+
self.data.append(self.current)
73+
74+
@override
75+
def on_close(self, exc: BaseException | None):
76+
self.closed = True if exc is None else exc
77+
78+
79+
@final
80+
class StreamTask(Generic[_In, _Out]):
81+
def __init__(
82+
self,
83+
loop: EventLoop,
84+
initial: _Out,
85+
reducer: Callable[[_Out, _In], _Out],
86+
fn: Callable[Concatenate[_Out, _P], AsyncGenerator[_In, _Out]],
87+
/,
88+
*args: _P.args,
89+
**kwargs: _P.kwargs,
90+
) -> None:
91+
self._loop = loop
92+
self._reducer = reducer
93+
self._fn = fn
94+
self._args = args
95+
self._kwargs = kwargs
96+
self._obs = _StreamObserver(initial, self._reducer)
97+
self._op = self._loop.create_op(
98+
StreamCreate(observer=cast("Observer[object]", self._obs))
99+
)
100+
self._queue: asyncio.Queue[tuple[_Out] | None | BaseException] = (
101+
self._setup_stream()
102+
)
103+
104+
def __aiter__(self) -> StreamTask[_In, _Out]:
105+
return self
106+
107+
async def __anext__(self) -> _Out:
108+
item = await self._queue.get()
109+
if item is None:
110+
raise StopAsyncIteration
111+
if isinstance(item, BaseException):
112+
raise item
113+
return item[0]
114+
115+
async def discard(self) -> None:
116+
async for _ in self:
117+
...
118+
119+
def _setup_stream(self) -> asyncio.Queue[tuple[_Out] | None | BaseException]:
120+
queue: asyncio.Queue[tuple[_Out] | None | BaseException] = asyncio.Queue()
121+
122+
async def worker():
123+
stream = cast("Stream[_In]", await self._op)
124+
try:
125+
state = self._obs.current
126+
self._obs.enable = False
127+
128+
for d in self._obs.data:
129+
await queue.put((d,))
130+
if self._obs.closed is True:
131+
return
132+
elif isinstance(self._obs.closed, BaseException):
133+
raise self._obs.closed
134+
135+
gen = self._fn(state, *self._args, **self._kwargs)
136+
state_partial = await gen.__anext__()
137+
138+
while True:
139+
state = self._reducer(state, state_partial)
140+
await stream.send(state_partial)
141+
await queue.put((state,))
142+
state_partial = await gen.asend(state)
143+
except StopAsyncIteration as _e:
144+
await stream.close()
145+
except BaseException as e:
146+
await queue.put(e)
147+
raise
148+
finally:
149+
await queue.put(None)
150+
151+
_ = self._loop.create_op(
152+
FnCall(callable=worker, args=(), kwargs={}, return_type=None)
153+
)
154+
return queue

0 commit comments

Comments
 (0)