Skip to content

Commit 4e1b731

Browse files
committed
feat: add time and random for context
1 parent b81eb06 commit 4e1b731

6 files changed

Lines changed: 99 additions & 41 deletions

File tree

src/duron/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from duron.context import Context as Context
2-
from duron.context import get_context as get_context
32
from duron.fn import durable as durable

src/duron/context.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,68 @@
11
from __future__ import annotations
22

3-
import asyncio
4-
from typing import TYPE_CHECKING, TypeVar, cast
3+
from contextvars import ContextVar
4+
from random import Random
5+
from typing import TYPE_CHECKING, TypeVar, cast, final
56

67
from typing_extensions import overload
78

8-
from duron.event_loop import EventLoop
99
from duron.ops import FnCall
1010

1111
if TYPE_CHECKING:
1212
from collections.abc import Awaitable, Callable
13+
from contextvars import Token
14+
from inspect import Traceback
15+
16+
from duron.event_loop import EventLoop
1317

1418
_T = TypeVar("_T")
1519

20+
_context: ContextVar[Context | None] = ContextVar("duron_context", default=None)
21+
1622

23+
@final
1724
class Context:
18-
def __init__(self):
19-
loop = asyncio.get_event_loop()
20-
assert isinstance(loop, EventLoop)
25+
def __init__(self, loop: EventLoop) -> None:
2126
self._loop: EventLoop = loop
22-
pass
27+
self._token: Token[Context | None] | None = None
28+
29+
def __enter__(self) -> Context:
30+
token = _context.set(self)
31+
self._token = token
32+
return self
33+
34+
def __exit__(
35+
self,
36+
exc_type: type[BaseException] | None,
37+
exc_val: BaseException | None,
38+
exc_tb: Traceback | None,
39+
):
40+
if self._token:
41+
_context.reset(self._token)
42+
43+
@classmethod
44+
def current(cls) -> Context:
45+
ctx = _context.get()
46+
if ctx is None:
47+
raise RuntimeError("No duron context is active")
48+
return ctx
2349

2450
@overload
25-
async def run(self, fn: Callable[[], Awaitable[_T]]) -> _T: ...
51+
async def run(self, fn: Callable[[], Awaitable[_T]], /) -> _T: ...
2652
@overload
27-
async def run(self, fn: Callable[[], _T]) -> _T: ...
28-
async def run(self, fn: Callable[[], Awaitable[_T] | _T]) -> _T:
53+
async def run(self, fn: Callable[[], _T], /) -> _T: ...
54+
async def run(
55+
self,
56+
fn: Callable[[], Awaitable[_T] | _T],
57+
/,
58+
) -> _T:
2959
return cast("_T", await self._loop.create_op(FnCall(callable=fn)))
3060

61+
def time(self) -> float:
62+
return self._loop.time()
63+
64+
def time_ns(self) -> int:
65+
return self._loop.time_ns()
3166

32-
def get_context() -> Context:
33-
"""
34-
Get the current duron execution context.
35-
"""
36-
return Context()
67+
def random(self) -> Random:
68+
return Random(self._loop.generate_op_id())

src/duron/event_loop.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(self, seed: bytes) -> None:
8787

8888
self._timers: list[TimerHandle] = []
8989

90-
def _generate_id(self) -> bytes:
90+
def generate_op_id(self) -> bytes:
9191
ctx = _task_ctx.get(self._ctx)
9292
ctx.seq += 1
9393
return _mix_id(ctx.parent_id, ctx.seq - 1)
@@ -150,6 +150,9 @@ def call_later(
150150
def time(self) -> float:
151151
return self._now_ns / 1e9
152152

153+
def time_ns(self) -> int:
154+
return self._now_ns
155+
153156
def tick(self, time: int) -> None:
154157
self._now_ns = time
155158

@@ -166,7 +169,7 @@ def create_task(
166169
context: Context | None = None,
167170
**kwargs: Any,
168171
) -> Task[_T]:
169-
ctx = self._context_new_task(context, self._generate_id())
172+
ctx = self._context_new_task(context, self.generate_op_id())
170173
return ctx.run(
171174
cast("type[Task[_T]]", Task), coro, name=name, loop=self, **kwargs
172175
)
@@ -217,7 +220,7 @@ def poll_completion(self, task: Future[_T]) -> WaitSet | None:
217220
events._set_running_loop(old)
218221

219222
def create_op(self, params: object) -> OpFuture:
220-
id = self._generate_id()
223+
id = self.generate_op_id()
221224
s = OpFuture(id, params, self)
222225
self._ops[id] = s
223226
return s

src/duron/fn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ class DurableFn(Generic[_P, _T_co]):
3434
fn: Callable[Concatenate[Context, _P], Coroutine[Any, Any, _T_co]]
3535

3636
def __call__(
37-
self,
38-
log: LogStorage[_TOffset, _TLease],
39-
) -> TaskGuard[_P, _T_co]:
37+
self, ctx: Context, *args: _P.args, **kwargs: _P.kwargs
38+
) -> Coroutine[Any, Any, _T_co]:
39+
return self.fn(ctx, *args, **kwargs)
40+
41+
def create_task(self, log: LogStorage[_TOffset, _TLease]) -> TaskGuard[_P, _T_co]:
4042
return TaskGuard(Task(self, cast("LogStorage[object, object]", log)))
4143

4244

src/duron/task.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
final,
1818
)
1919

20-
from duron.context import get_context
21-
from duron.event_loop import create_loop
20+
from duron.context import Context
21+
from duron.event_loop import EventLoop, create_loop
2222
from duron.log import is_entry
2323
from duron.ops import FnCall, TaskRun
2424

@@ -113,14 +113,16 @@ async def _task_prelude(
113113
task_fn: DurableFn[..., object],
114114
init: Callable[[], TaskInitParams],
115115
) -> object:
116-
ctx = get_context()
117-
init_params = await ctx.run(init)
118-
if init_params["version"] != _CURRENT_VERSION:
119-
raise Exception("version mismatch")
120-
codec = task_fn.codec
121-
args = (codec.decode_json(arg) for arg in init_params["args"])
122-
kwargs = {k: codec.decode_json(v) for k, v in init_params["kwargs"].items()}
123-
return await task_fn.fn(get_context(), *args, **kwargs)
116+
loop = asyncio.get_event_loop()
117+
assert isinstance(loop, EventLoop)
118+
with Context(loop) as ctx:
119+
init_params = await ctx.run(init)
120+
if init_params["version"] != _CURRENT_VERSION:
121+
raise Exception("version mismatch")
122+
codec = task_fn.codec
123+
args = (codec.decode_json(arg) for arg in init_params["args"])
124+
kwargs = {k: codec.decode_json(v) for k, v in init_params["kwargs"].items()}
125+
return await task_fn.fn(ctx, *args, **kwargs)
124126

125127

126128
@final

tests/test_task.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,18 @@ async def activity(ctx: Context, i: str) -> str:
4747
}
4848

4949
log = MemoryLogStorage()
50-
async with activity(log) as t:
50+
async with activity.create_task(log) as t:
5151
await t.start("test")
5252
a = await t.wait()
5353
assert set(e["id"] for e in await log.entries()) == IDS
5454

55-
async with activity(log) as t:
55+
async with activity.create_task(log) as t:
5656
await t.start("test")
5757
b = await t.wait()
5858
assert a == b
5959

6060
log2 = MemoryLogStorage((await log.entries())[:-2])
61-
async with activity(log2) as t:
61+
async with activity.create_task(log2) as t:
6262
await t.start("test")
6363
c = await t.wait()
6464
assert a == c
@@ -78,11 +78,11 @@ async def error():
7878

7979
log = MemoryLogStorage()
8080
with pytest.raises(check=lambda v: "test error" in str(v)):
81-
async with activity(log) as t:
81+
async with activity.create_task(log) as t:
8282
await t.start()
8383
await t.wait()
8484
with pytest.raises(check=lambda v: "test error" in str(v)):
85-
async with activity(log) as t:
85+
async with activity.create_task(log) as t:
8686
await t.start()
8787
await t.wait()
8888

@@ -97,12 +97,12 @@ async def activity(ctx: Context, s: str) -> str:
9797
return s
9898

9999
log = MemoryLogStorage()
100-
async with activity(log) as t:
100+
async with activity.create_task(log) as t:
101101
await t.start("hello")
102102
with pytest.raises(asyncio.TimeoutError):
103103
_ = await asyncio.wait_for(t.wait(), 0.1)
104104

105-
async with activity(log) as t:
105+
async with activity.create_task(log) as t:
106106
sleep = 0
107107
await t.resume()
108108
x = await t.wait()
@@ -119,7 +119,7 @@ async def activity(ctx: Context, s: str) -> str:
119119
return s
120120

121121
log = MemoryLogStorage()
122-
async with activity(log) as t:
122+
async with activity.create_task(log) as t:
123123
await t.start("hello")
124124
try:
125125
_ = await t.wait()
@@ -143,8 +143,28 @@ async def activity(ctx: Context) -> CustomPoint:
143143
return CustomPoint(x=pt.x + 5, y=pt.y + 10)
144144

145145
log = MemoryLogStorage()
146-
async with activity(log) as t:
146+
async with activity.create_task(log) as t:
147147
await t.start()
148148
a = await t.wait()
149149
assert type(a) is CustomPoint
150150
assert a.x == 6 and a.y == 12
151+
152+
153+
@pytest.mark.asyncio
154+
async def test_random():
155+
@durable()
156+
async def activity(ctx: Context) -> int:
157+
assert ctx.time_ns() == ctx.time_ns()
158+
return ctx.random().randint(1, 100)
159+
160+
log = MemoryLogStorage()
161+
async with activity.create_task(log) as t:
162+
await t.start()
163+
a = await t.wait()
164+
165+
log = MemoryLogStorage()
166+
async with activity.create_task(log) as t:
167+
await t.start()
168+
b = await t.wait()
169+
170+
assert a == b

0 commit comments

Comments
 (0)