Skip to content

Commit 4867090

Browse files
committed
feat: each invoke is seeded differently
1 parent 60ab26b commit 4867090

4 files changed

Lines changed: 33 additions & 38 deletions

File tree

src/duron/_core/invoke.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import base64
5+
import binascii
56
import contextlib
67
import os
78
import time
@@ -97,6 +98,7 @@ def get_init() -> InitParams:
9798
"version": _CURRENT_VERSION,
9899
"args": [codec.encode_json(arg) for arg in args],
99100
"kwargs": {k: codec.encode_json(v) for k, v in kwargs.items()},
101+
"seed": binascii.b2a_base64(os.urandom(12), newline=False).decode(),
100102
}
101103

102104
codec = self._fn.codec
@@ -237,6 +239,7 @@ class InitParams(TypedDict):
237239
version: int
238240
args: list[JSONValue]
239241
kwargs: dict[str, JSONValue]
242+
seed: str
240243

241244

242245
async def _invoke_prelude(
@@ -252,6 +255,7 @@ async def _invoke_prelude(
252255
if init_params["version"] != _CURRENT_VERSION:
253256
msg = "version mismatch"
254257
raise RuntimeError(msg)
258+
loop.set_key(binascii.a2b_base64(init_params["seed"]))
255259
extra_kwargs: dict[str, object] = {}
256260
for name, type_, dtype in job_fn.inject:
257261
if type_ is Stream:
@@ -320,7 +324,7 @@ def __init__(
320324
]
321325
| None = None,
322326
) -> None:
323-
self._loop = create_loop(asyncio.get_running_loop(), b"")
327+
self._loop = create_loop(asyncio.get_running_loop())
324328
self._task = self._loop.create_task(task)
325329
self._log = log
326330
self._codec = codec

src/duron/_loop.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,24 +92,28 @@ class EventLoop(asyncio.AbstractEventLoop):
9292
"_timers",
9393
)
9494

95-
def __init__(self, host: asyncio.AbstractEventLoop, seed: bytes) -> None:
95+
def __init__(self, host: asyncio.AbstractEventLoop) -> None:
9696
self._ready: deque[asyncio.Handle] = deque()
9797
self._debug: bool = False
9898
self._host: asyncio.AbstractEventLoop = host
9999
self._exc_handler: (
100100
Callable[[asyncio.AbstractEventLoop, dict[str, object]], object] | None
101101
) = None
102102
self._ops: dict[bytes, OpFuture[object]] = {}
103-
self._ctx: _TaskCtx = _TaskCtx(parent_id=seed)
103+
self._ctx: _TaskCtx = _TaskCtx(parent_id=b"")
104+
self._key: bytes = b""
104105
self._now_us: int = 0
105106
self._closed: bool = False
106107
self._event: asyncio.Event = asyncio.Event() # loop = _host
107108
self._timers: list[asyncio.TimerHandle] = []
108109

110+
def set_key(self, key: bytes) -> None:
111+
self._key = key
112+
109113
def generate_op_id(self) -> bytes:
110114
ctx = _task_ctx.get(self._ctx)
111115
ctx.seq += 1
112-
return _mix_id(ctx.parent_id, ctx.seq - 1)
116+
return _mix_id(ctx.parent_id, self._key, ctx.seq - 1)
113117

114118
def host_loop(self) -> asyncio.AbstractEventLoop:
115119
return self._host
@@ -284,7 +288,7 @@ def post_completion(
284288
if op := self._ops.pop(id_, None):
285289
if op.done():
286290
return
287-
tid = _mix_id(op.id, -1)
291+
tid = _mix_id(op.id, self._key, -1)
288292
token = _task_ctx.set(_TaskCtx(parent_id=tid))
289293
if exception is None:
290294
_ = self.call_soon(op.set_result, result)
@@ -352,17 +356,16 @@ def _timer_handle_cancelled(self, _th: asyncio.TimerHandle) -> None:
352356
pass
353357

354358

355-
def _mix_id(a: bytes, b: int) -> bytes:
359+
def _mix_id(a: bytes, key: bytes, b: int) -> bytes:
356360
if b == -1:
357-
return blake2b(a, digest_size=12).digest()
358-
return blake2b(b.to_bytes(4, "little") + a, digest_size=12).digest()
361+
return blake2b(a, key=key, digest_size=12).digest()
362+
return blake2b(b.to_bytes(4, "little") + a, key=key, digest_size=12).digest()
359363

360364

361365
def create_loop(
362366
parent_loop: asyncio.AbstractEventLoop,
363-
seed: bytes,
364367
) -> EventLoop:
365-
return EventLoop(parent_loop, seed) # type: ignore[abstract]
368+
return EventLoop(parent_loop) # type: ignore[abstract]
366369

367370

368371
def _copy_future_state(source: asyncio.Future[_T], dest: asyncio.Future[_T]) -> None:

tests/test_invoke.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,10 @@ async def activity(ctx: Context, i: str) -> str:
2929
_ = await ctx.run(lambda: asyncio.sleep(0.1))
3030
return i + ":".join(x)
3131

32-
ids = {
33-
"+qPYuDgKBdMkb8ME",
34-
"20QZakraLA2aFVh0",
35-
"9nLMU+itD7QHcCsf",
36-
"BCLA1azFK4LrrEHg",
37-
"D7qSBNZIThKa2P+H",
38-
"d0eMKz6qFm4C+TYX",
39-
"tlmy+UdKzdxIIykQ",
40-
"vP2AH9GHFTbZRJQ7",
41-
}
42-
4332
log = MemoryLogStorage()
4433
async with activity.invoke(log) as t:
4534
await t.start("test")
4635
a = await t.wait()
47-
assert {e["id"] for e in await log.entries()} == ids
4836

4937
async with activity.invoke(log) as t:
5038
await t.start("test")
@@ -56,7 +44,6 @@ async def activity(ctx: Context, i: str) -> str:
5644
await t.start("test")
5745
c = await t.wait()
5846
assert a == c
59-
assert {e["id"] for e in await log2.entries()} == ids
6047

6148

6249
@pytest.mark.asyncio
@@ -162,7 +149,6 @@ async def activity(ctx: Context) -> int:
162149
await t.start()
163150
a = await t.wait()
164151

165-
log = MemoryLogStorage()
166152
async with activity.invoke(log) as t:
167153
await t.start()
168154
b = await t.wait()
@@ -172,10 +158,12 @@ async def activity(ctx: Context) -> int:
172158

173159
@pytest.mark.asyncio
174160
async def test_external_promise() -> None:
161+
v: dict[str, str] = {}
162+
175163
@fn
176164
async def activity(ctx: Context) -> int:
177165
a, b = await ctx.create_promise(int)
178-
assert a == "9mcIBsvU2ej9uDsV"
166+
v["data"] = a
179167
return await b
180168

181169
log = MemoryLogStorage()
@@ -184,11 +172,11 @@ async def activity(ctx: Context) -> int:
184172

185173
async def do() -> None:
186174
while True:
187-
try:
188-
await t.complete_promise("9mcIBsvU2ej9uDsV", result=9)
189-
break
190-
except ValueError:
191-
await asyncio.sleep(0.1)
175+
if v.get("data") is None:
176+
await asyncio.sleep(0.01)
177+
continue
178+
await t.complete_promise(v["data"], result=9)
179+
break
192180

193181
bg = asyncio.create_task(do())
194182
assert await t.wait() == 9

tests/test_loop.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async def timer() -> int:
2020
)
2121
return 0
2222

23-
loop = create_loop(asyncio.get_event_loop(), b"")
23+
loop = create_loop(asyncio.get_event_loop())
2424
loop.tick(time.time_ns())
2525
tsk = loop.create_task(timer())
2626
while (waitset := loop.poll_completion(tsk)) is not None:
@@ -45,7 +45,7 @@ async def op() -> None:
4545
_ = await loop.create_op(6)
4646

4747
ids: set[bytes] = set()
48-
loop = create_loop(asyncio.get_event_loop(), b"tsk")
48+
loop = create_loop(asyncio.get_event_loop())
4949
loop.tick(time.time_ns())
5050
tsk = loop.create_task(op())
5151

@@ -77,12 +77,12 @@ def tick(n: int, expect: set[int] | None) -> list[bytes] | None:
7777
@pytest.mark.asyncio
7878
async def test_op() -> None: # noqa: RUF029
7979
baseline = {
80-
b"\x89U\x82\xd9\xe9\xa1\x01\x0fb\xab}\xba",
81-
b"\xa2VE\xcb\xf3\x81\x82\xe75@\xe9\xdf",
82-
b"\xb2\x1d\xb3\xd5c\xe3J\no\xa6U\x18",
83-
b"\xe7y\tDQ-\xe8\xfb\x9dBX\xde",
84-
b"\t\xf9?\xf8kJ\xd4\ry8\xf2\xae",
85-
b"\xc6\xbbEu\xdf\xd1uc\xf5M\x11'",
80+
b'\x04"\xc0\xd5\xac\xc5+\x82\xeb\xacA\xe0',
81+
b"\x07\x04%\x7f\xbf\xc3x*\x89}Jb",
82+
b"\x8f\x1f\x1dMR\x19\xe7\xbf\xa2D\xbe\x7f",
83+
b'\xb4"\xd9\x1a\x89\x896\xcb.\x9b#P',
84+
b"\xf6g\x08\x06\xcb\xd4\xd9\xe8\xfd\xb8;\x15",
85+
b"\xfa\xa3\xd8\xb88\n\x05\xd3$o\xc3\x04",
8686
}
8787
for _ in range(4):
8888
assert op_single() == baseline

0 commit comments

Comments
 (0)