Skip to content

Commit 2578d55

Browse files
committed
feat: initial event loop
1 parent d3f47b0 commit 2578d55

7 files changed

Lines changed: 376 additions & 14 deletions

File tree

flake.nix

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@
3030
default = pkgs.mkShell {
3131
buildInputs = with pkgs; [
3232
gnumake
33-
(python313.withPackages (
34-
p: with p; [
35-
uv
36-
]
37-
))
33+
uv
3834
];
35+
env = {
36+
UV_PYTHON = pkgs.python310.interpreter;
37+
};
3938
};
4039
};
4140

main.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ readme = "README.md"
66
requires-python = ">=3.10"
77
dependencies = [
88
"pydantic>=2.11.7",
9+
"typing-extensions>=4.15.0",
910
]
1011

1112
[build-system]

src/duron/event_loop.py

Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import contextvars
5+
import heapq
6+
import logging
7+
import threading
8+
from asyncio import AbstractEventLoop, Handle, Task, TimerHandle, events
9+
from collections import deque
10+
from dataclasses import dataclass
11+
from typing import TYPE_CHECKING, cast
12+
13+
from typing_extensions import (
14+
Any,
15+
TypeVar,
16+
TypeVarTuple,
17+
Unpack,
18+
override,
19+
)
20+
21+
from duron.utils import mix_id
22+
23+
if TYPE_CHECKING:
24+
from collections.abc import Callable, Coroutine, Generator
25+
from contextvars import Context
26+
27+
T = TypeVar("T")
28+
Ts = TypeVarTuple("Ts")
29+
30+
logger = logging.getLogger(__name__)
31+
32+
33+
_task_id_ctx: contextvars.ContextVar[TaskCtx] = contextvars.ContextVar("task_id")
34+
35+
36+
@dataclass(slots=True)
37+
class SyscallFuture:
38+
id: str
39+
params: object
40+
_future: asyncio.Future[Any]
41+
42+
def __await__(self) -> Generator[Any, None, Any]:
43+
return self._future.__await__()
44+
45+
@override
46+
def __hash__(self) -> int:
47+
return hash(self.id)
48+
49+
@override
50+
def __eq__(self, other: object) -> bool:
51+
if not isinstance(other, SyscallFuture):
52+
return NotImplemented
53+
return self.id == other.id
54+
55+
56+
@dataclass(eq=True, frozen=True, slots=True)
57+
class WaitSet:
58+
syscalls: frozenset[SyscallFuture]
59+
timer: float | None
60+
61+
62+
@dataclass(slots=True)
63+
class TaskCtx:
64+
parent_id: str
65+
value: int
66+
67+
68+
class EventLoop(AbstractEventLoop):
69+
def __init__(self, event: asyncio.Event) -> None:
70+
self._ready: deque[Handle] = deque()
71+
self._timers: list[TimerHandle] = []
72+
self._cv: threading.Condition = threading.Condition()
73+
self._stopping: bool = False
74+
self._closed: bool = False
75+
self._debug: bool = False
76+
self._exc_handler: (
77+
Callable[[AbstractEventLoop, dict[str, Any]], object] | None
78+
) = None
79+
self._blocked: dict[str, SyscallFuture] = {}
80+
self._root_seq: int = 0
81+
self._now: int = 0
82+
self._event: asyncio.Event = event
83+
84+
def syscall(self, params: object) -> SyscallFuture:
85+
id = self._generate_id()
86+
fut = self.create_future()
87+
s = SyscallFuture(_future=fut, id=id, params=params)
88+
self._blocked[id] = s
89+
return s
90+
91+
def _generate_id(self) -> str:
92+
ctx = _task_id_ctx.get(None)
93+
if not ctx:
94+
parent_id = ""
95+
seq = self._root_seq
96+
self._root_seq += 1
97+
else:
98+
parent_id = ctx.parent_id
99+
seq = ctx.value
100+
ctx.value += 1
101+
102+
return mix_id(parent_id.encode(), int(seq).to_bytes(8, "big"))
103+
104+
@override
105+
def call_soon(
106+
self,
107+
callback: Callable[[*Ts], Any],
108+
*args: Unpack[Ts],
109+
context: Context | None = None,
110+
) -> Handle:
111+
h = Handle(
112+
callback,
113+
args,
114+
self,
115+
context=context,
116+
)
117+
with self._cv:
118+
self._ready.append(h)
119+
self._cv.notify()
120+
return h
121+
122+
@override
123+
def call_soon_threadsafe(
124+
self,
125+
callback: Callable[[*Ts], Any],
126+
*args: Unpack[Ts],
127+
context: Context | None = None,
128+
task_id: str | None = None,
129+
) -> Handle:
130+
self._event.set()
131+
h = Handle(
132+
callback,
133+
args,
134+
self,
135+
context=self._context_with_task_id(context, task_id=task_id),
136+
)
137+
with self._cv:
138+
self._ready.append(h)
139+
self._cv.notify()
140+
return h
141+
142+
@override
143+
def call_at(
144+
self,
145+
when: float,
146+
callback: Callable[[*Ts], Any],
147+
*args: Unpack[Ts],
148+
context: Context | None = None,
149+
) -> TimerHandle:
150+
th = TimerHandle(
151+
when,
152+
callback,
153+
args,
154+
loop=self,
155+
context=self._context_with_task_id(context),
156+
)
157+
with self._cv:
158+
heapq.heappush(self._timers, th)
159+
self._cv.notify()
160+
return th
161+
162+
@override
163+
def call_later(
164+
self,
165+
delay: float,
166+
callback: Callable[[*Ts], Any],
167+
*args: Unpack[Ts],
168+
context: Context | None = None,
169+
) -> TimerHandle:
170+
return self.call_at(self.time() + delay, callback, *args, context=context)
171+
172+
@override
173+
def time(self) -> float:
174+
return self._now / 1000.0
175+
176+
def tick(self, time: int) -> None:
177+
self._now = time
178+
179+
@override
180+
def create_future(self) -> asyncio.Future[Any]:
181+
return asyncio.Future(loop=self)
182+
183+
@override
184+
def create_task(
185+
self,
186+
coro: Generator[Any, None, T] | Coroutine[Any, Any, T],
187+
*,
188+
name: str | None = None,
189+
context: Context | None = None,
190+
**kwargs: Any,
191+
) -> Task[T]:
192+
ctx = self._context_with_task_id(context, task_id=self._generate_id())
193+
return ctx.run(
194+
cast("type[Task[T]]", Task), coro, name=name, loop=self, **kwargs
195+
)
196+
197+
def _run_once(self) -> None:
198+
now = self.time()
199+
# promote due timers
200+
with self._cv:
201+
while self._timers and self._timers[0].when() <= now:
202+
ht = heapq.heappop(self._timers)
203+
if not ht.cancelled():
204+
self._ready.append(ht)
205+
206+
# drain ready queue in batches for better performance
207+
batch_size = min(len(self._ready), 100) # Process up to 100 at a time
208+
for _ in range(batch_size):
209+
if not self._ready:
210+
break
211+
h = self._ready.popleft()
212+
if h.cancelled():
213+
continue
214+
try:
215+
h._run()
216+
except BaseException as exc:
217+
self.call_exception_handler(
218+
{"message": "exception in callback", "exception": exc, "handle": h}
219+
)
220+
221+
@override
222+
def stop(self) -> None:
223+
with self._cv:
224+
self._stopping = True
225+
self._cv.notify_all()
226+
227+
def poll_completion(self, task: Task[T]) -> None | WaitSet:
228+
old = events._get_running_loop()
229+
events._set_running_loop(self)
230+
try:
231+
while not task.done():
232+
self._run_once()
233+
234+
with self._cv:
235+
if self._ready:
236+
continue
237+
if self._timers and self._timers[0].when() <= self.time():
238+
# timer is due
239+
continue
240+
241+
break
242+
243+
if task.done():
244+
return None
245+
246+
return WaitSet(
247+
syscalls=frozenset(self._blocked.values()),
248+
timer=self._timers[0].when() if self._timers else None,
249+
)
250+
finally:
251+
events._set_running_loop(old)
252+
253+
def complete_syscall(
254+
self,
255+
syscall_id: str,
256+
*,
257+
result: Any = None,
258+
exception: BaseException | None = None,
259+
) -> None:
260+
sc = self._blocked.pop(syscall_id, None)
261+
if sc is None:
262+
return
263+
264+
fut = sc._future # pyright: ignore[reportPrivateUsage]
265+
tid = mix_id(sc.id.encode(), b"end")
266+
if exception is not None:
267+
_ = self.call_soon_threadsafe(fut.set_exception, exception, task_id=tid)
268+
else:
269+
_ = self.call_soon_threadsafe(fut.set_result, result, task_id=tid)
270+
271+
@override
272+
def is_running(self) -> bool:
273+
return not self._stopping
274+
275+
@override
276+
def is_closed(self) -> bool:
277+
return self._closed
278+
279+
@override
280+
def close(self) -> None:
281+
self._closed = True
282+
283+
@override
284+
def get_debug(self) -> bool:
285+
return self._debug
286+
287+
@override
288+
def set_debug(self, enabled: bool) -> None:
289+
self._debug = enabled
290+
291+
@override
292+
def default_exception_handler(self, context: dict[str, Any]) -> None:
293+
msg = context.get("message", "Unhandled exception")
294+
exc = context.get("exception")
295+
if exc:
296+
logger.error("%s: %r", msg, exc)
297+
else:
298+
logger.error("%s", msg)
299+
300+
@override
301+
def set_exception_handler(
302+
self, handler: Callable[[AbstractEventLoop, dict[str, Any]], object] | None
303+
) -> None:
304+
self._exc_handler = handler
305+
306+
@override
307+
def call_exception_handler(self, context: dict[str, Any]) -> None:
308+
if self._exc_handler is None:
309+
self.default_exception_handler(context)
310+
else:
311+
_ = self._exc_handler(self, context)
312+
313+
@override
314+
async def shutdown_asyncgens(self):
315+
pass
316+
317+
@override
318+
async def shutdown_default_executor(self):
319+
pass
320+
321+
def _timer_handle_cancelled(self, th: TimerHandle) -> None:
322+
try:
323+
self._timers.remove(th)
324+
heapq.heapify(self._timers)
325+
except ValueError:
326+
pass
327+
328+
def _context_with_task_id(
329+
self, context: Context | None, task_id: str | None = None
330+
) -> Context:
331+
base = context or contextvars.copy_context()
332+
if task_id is None:
333+
task_id = self._generate_id()
334+
_ = base.run(_task_id_ctx.set, TaskCtx(parent_id=task_id, value=0))
335+
return base
336+
337+
338+
def create_loop(event: asyncio.Event) -> EventLoop:
339+
return EventLoop(event) # type: ignore[abstract]

src/duron/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import base64
2+
import hashlib
3+
4+
_SEP = b"\x00"
5+
6+
7+
def mix_id(*parts: bytes) -> str:
8+
buf = _SEP.join(parts)
9+
h = hashlib.blake2b(buf, digest_size=12)
10+
return base64.b64encode(h.digest()).decode("ascii")

0 commit comments

Comments
 (0)