Skip to content

Commit db9ffaf

Browse files
committed
feat: initial to_host stream
1 parent ce2e654 commit db9ffaf

5 files changed

Lines changed: 158 additions & 75 deletions

File tree

src/duron/context.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515

1616
from duron.ops import Barrier, FnCall, StreamCreate
17-
from duron.stream import PeekStream, ResumableStream
17+
from duron.stream import LogStream, ResumableStream
1818

1919
if TYPE_CHECKING:
2020
from collections.abc import AsyncGenerator, Callable, Coroutine
@@ -29,7 +29,6 @@
2929
_S = TypeVar("_S")
3030
_P = ParamSpec("_P")
3131

32-
3332
_context: ContextVar[Context | None] = ContextVar("duron_context", default=None)
3433

3534

@@ -110,7 +109,9 @@ async def barrier(self) -> int:
110109
await self._loop.create_op(Barrier()),
111110
)
112111

113-
async def create_stream(self, observer: Observer[_T] | None) -> StreamHandle[_T]:
112+
async def create_stream_handle(
113+
self, observer: Observer[_T] | None
114+
) -> StreamHandle[_T]:
114115
if asyncio.get_event_loop() is not self._loop:
115116
raise RuntimeError("Context time can only be used in the context loop")
116117
o = cast("Observer[object]", observer) if observer else None
@@ -119,19 +120,19 @@ async def create_stream(self, observer: Observer[_T] | None) -> StreamHandle[_T]
119120
await self._loop.create_op(StreamCreate(observer=o)),
120121
)
121122

122-
async def create_peek_stream(self) -> tuple[PeekStream[_T], StreamHandle[_T]]:
123+
async def create_stream(self) -> tuple[Stream[_T], StreamHandle[_T]]:
123124
if asyncio.get_event_loop() is not self._loop:
124125
raise RuntimeError("Context time can only be used in the context loop")
125-
ps: PeekStream[_T] = PeekStream(self._loop)
126-
return (
127-
ps,
128-
cast(
129-
"StreamHandle[_T]",
130-
await self._loop.create_op(
131-
StreamCreate(observer=cast("Observer[object]", ps))
132-
),
126+
log: LogStream[_T] = LogStream(self._loop)
127+
hdl = cast(
128+
"StreamHandle[_T]",
129+
await self._loop.create_op(
130+
StreamCreate(
131+
observer=cast("Observer[object]", cast("Observer[_T]", log))
132+
)
133133
),
134134
)
135+
return (log, hdl)
135136

136137
def time(self) -> float:
137138
if asyncio.get_event_loop() is not self._loop:

src/duron/event_loop.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,26 @@ def poll_completion(self, task: Future[_T]) -> WaitSet | None:
224224
events._set_running_loop(old)
225225

226226
def create_op(
227-
self, params: object, /, *, loop: AbstractEventLoop | None = None
227+
self,
228+
params: object,
229+
/,
228230
) -> asyncio.Future[object]:
229231
self._event.set()
230-
if loop is not None and loop is not self:
231-
id = os.urandom(12)
232-
op_fut = OpFuture(id, params, self)
233-
self._ops[id] = op_fut
234-
return wrap_future(op_fut, loop=loop)
235-
else:
236-
id = self.generate_op_id()
237-
op_fut = OpFuture(id, params, self)
238-
self._ops[id] = op_fut
239-
return op_fut
232+
id = self.generate_op_id()
233+
op_fut = OpFuture(id, params, self)
234+
self._ops[id] = op_fut
235+
return op_fut
236+
237+
def create_host_op(
238+
self,
239+
params: object,
240+
/,
241+
) -> asyncio.Future[object]:
242+
self._event.set()
243+
id = os.urandom(12)
244+
op_fut = OpFuture(id, params, self)
245+
self._ops[id] = op_fut
246+
return wrap_future(op_fut, loop=self._host)
240247

241248
@overload
242249
def post_completion(

src/duron/stream.py

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from typing_extensions import override
2121

22+
from duron.event_loop import wrap_future
2223
from duron.ops import Barrier, FnCall, StreamClose, StreamCreate, StreamEmit
2324

2425
if TYPE_CHECKING:
@@ -57,16 +58,24 @@ def __init__(self, id: str, loop: EventLoop) -> None:
5758
self._event = asyncio.Event()
5859

5960
async def send(self, value: _In, /) -> None:
60-
_ = await self._loop.create_op(
61-
StreamEmit(stream_id=self._stream_id, value=value),
62-
loop=self._target_loop,
63-
)
61+
if self._target_loop is None:
62+
_ = await self._loop.create_op(
63+
StreamEmit(stream_id=self._stream_id, value=value),
64+
)
65+
else:
66+
_ = await self._loop.create_host_op(
67+
StreamEmit(stream_id=self._stream_id, value=value),
68+
)
6469

6570
async def close(self, error: BaseException | None = None, /) -> None:
66-
_ = await self._loop.create_op(
67-
StreamClose(stream_id=self._stream_id, exception=error),
68-
loop=self._target_loop,
69-
)
71+
if self._target_loop is None:
72+
_ = await self._loop.create_op(
73+
StreamClose(stream_id=self._stream_id, exception=error),
74+
)
75+
else:
76+
_ = await self._loop.create_host_op(
77+
StreamClose(stream_id=self._stream_id, exception=error),
78+
)
7079
self._event.set()
7180

7281
async def wait(self) -> None:
@@ -138,6 +147,12 @@ async def _start(self) -> None: ...
138147
@abstractmethod
139148
async def _close(self) -> None: ...
140149

150+
def peek(self) -> AsyncGenerator[_Out]:
151+
raise NotImplementedError("peek is not supported for this stream")
152+
153+
async def to_host(self) -> Stream[_Out]:
154+
raise NotImplementedError("to_host is not supported for this stream")
155+
141156
def map(self, fn: Callable[[_Out], _T]) -> Stream[_T]:
142157
return _Map(self, fn)
143158

@@ -187,6 +202,16 @@ async def _start(self) -> None:
187202
async def _close(self) -> None:
188203
await self._source.close()
189204

205+
@override
206+
async def peek(self) -> AsyncGenerator[_T]:
207+
async for v in self._source.peek():
208+
yield self._fn(v)
209+
210+
@override
211+
async def to_host(self) -> Stream[_T]:
212+
self._source = await self._source.to_host()
213+
return self
214+
190215

191216
@dataclass(slots=True)
192217
class _Sentinel:
@@ -225,6 +250,14 @@ def _send_close(self, offset: int, exc: BaseException | None = None):
225250
for s in self.__subscribers:
226251
s._send_close(offset, exc)
227252

253+
async def _peek(self, offset: int) -> AsyncGenerator[_T]:
254+
while self.__buffer and self.__buffer[0][0] <= offset:
255+
_, item = self.__buffer.popleft()
256+
if isinstance(item, _Sentinel):
257+
raise item.exception
258+
else:
259+
yield item
260+
228261
@override
229262
async def _start(self) -> None:
230263
if self.__parent:
@@ -250,6 +283,10 @@ def tee(self) -> tuple[Stream[_T], Stream[_T]]:
250283
parent.__pending_subscribers += 1
251284
return self, a
252285

286+
@override
287+
async def to_host(self) -> Stream[_T]:
288+
return self
289+
253290

254291
@final
255292
class _IntoBuffer(Generic[_T], _BufferedStream[_T]):
@@ -281,9 +318,14 @@ async def _close(self) -> None:
281318
self._task = None
282319
await super()._close()
283320

321+
@override
322+
async def to_host(self) -> Stream[_T]:
323+
self._stream = await self._stream.to_host()
324+
return self
325+
284326

285327
@final
286-
class ResumableStream(Generic[_In, _T], _BufferedStream[_T], Observer[_In]):
328+
class ResumableStream(Generic[_In, _T], _BufferedStream[_T]):
287329
def __init__(
288330
self,
289331
loop: EventLoop,
@@ -300,19 +342,18 @@ def __init__(
300342
self._closed: bool | BaseException = False
301343
self._current: _T = initial
302344
self._op = self._loop.create_op(
303-
StreamCreate(observer=cast("Observer[object]", self))
345+
StreamCreate(observer=cast("Observer[object]", cast("Observer[_In]", self)))
304346
)
305347
self._fn = fn
306348
self._args = args
307349
self._kwargs = kwargs
308350
self._task: asyncio.Future[object] | None = None
351+
self._target_loop: asyncio.AbstractEventLoop | None = None
309352

310-
@override
311353
def on_next(self, offset: int, val: _In):
312354
self._current = self._reducer(self._current, val)
313355
self._send(offset, self._current)
314356

315-
@override
316357
def on_close(self, offset: int, exc: BaseException | None):
317358
self._closed = True if exc is None else exc
318359
self._send_close(offset, exc)
@@ -323,7 +364,7 @@ async def _start(self) -> None:
323364

324365
async def worker():
325366
if self._closed is True:
326-
raise StopAsyncIteration
367+
return
327368

328369
gen = self._fn(self._current, *self._args, **self._kwargs)
329370
try:
@@ -339,8 +380,11 @@ async def worker():
339380
finally:
340381
await gen.aclose()
341382

342-
self._task = self._loop.create_op(
343-
FnCall(callable=worker, args=(), kwargs={}, return_type=None)
383+
self._task = wrap_future(
384+
self._loop.create_op(
385+
FnCall(callable=worker, args=(), kwargs={}, return_type=None),
386+
),
387+
loop=self._target_loop,
344388
)
345389

346390
@override
@@ -352,27 +396,30 @@ async def _close(self) -> None:
352396
self._task = None
353397
await super()._close()
354398

399+
@override
400+
async def to_host(self) -> Stream[_T]:
401+
self._target_loop = self._loop.host_loop()
402+
await self.start()
403+
return self
404+
355405

356406
@final
357-
class PeekStream(Generic[_T], Observer[_T]):
407+
class LogStream(Generic[_T], _BufferedStream[_T]):
358408
def __init__(self, loop: EventLoop) -> None:
359409
super().__init__()
360410
self._loop = loop
361-
self._data: deque[tuple[int, _T | _Sentinel]] = deque()
362411

363-
@override
364412
def on_next(self, _offset: int, val: _T):
365-
self._data.append((_offset, val))
413+
self._send(_offset, val)
366414

367-
@override
368415
def on_close(self, _offset: int, exc: BaseException | None):
369-
self._data.append((_offset, _Sentinel(exc or EndOfStream)))
416+
self._send_close(_offset, exc)
370417

418+
@override
371419
async def peek(self) -> AsyncGenerator[_T]:
420+
assert asyncio.get_event_loop() is self._loop, (
421+
"peek can only be used in the context loop"
422+
)
372423
offset = cast("int", await self._loop.create_op(Barrier()))
373-
while self._data and self._data[0][0] <= offset:
374-
_off, item = self._data.popleft()
375-
if isinstance(item, _Sentinel):
376-
raise item.exception
377-
else:
378-
yield item
424+
async for value in self._peek(offset):
425+
yield value

src/duron/task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,8 @@ def done(f: OpFuture) -> None:
395395
pass
396396
elif task_info := self._tasks.get(sid, None):
397397
task, _ = task_info
398-
_ = task.get_loop().call_soon(task.cancel)
398+
if not task.done():
399+
_ = task.get_loop().call_soon(task.cancel)
399400

400401
fut.add_done_callback(done)
401402
sid = _encode_id(id)

0 commit comments

Comments
 (0)