Skip to content

Commit 2f7deed

Browse files
committed
feat: move id encoding to log
1 parent 01e3109 commit 2f7deed

6 files changed

Lines changed: 92 additions & 66 deletions

File tree

src/duron/_core/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import binascii
54
from contextlib import contextmanager
65
from contextvars import ContextVar
76
from random import Random
@@ -19,6 +18,7 @@
1918
from duron._core.signal import create_signal
2019
from duron._core.stream import create_stream, run_stream
2120
from duron._decorator.op import CheckpointOp, Op
21+
from duron.log import encode_id
2222
from duron.typing import inspect_function
2323

2424
if TYPE_CHECKING:
@@ -185,7 +185,7 @@ async def create_promise(
185185
ExternalPromiseCreate(metadata=self._get_metadata(None), return_type=dtype),
186186
)
187187
return (
188-
binascii.b2a_base64(fut.id, newline=False).decode(),
188+
encode_id(fut.id),
189189
cast("asyncio.Future[_T]", fut),
190190
)
191191

src/duron/_core/invoke.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import binascii
54
import contextlib
6-
import os
75
import time
8-
from hashlib import blake2b
96
from typing import TYPE_CHECKING, Generic, Literal, cast
107
from typing_extensions import (
118
Any,
@@ -31,7 +28,7 @@
3128
from duron._core.stream import ObserverStream, Stream, StreamWriter
3229
from duron._loop import EventLoop, create_loop
3330
from duron.codec import Codec, JSONValue
34-
from duron.log import is_entry
31+
from duron.log import decode_id, derive_id, encode_id, is_entry, random_id
3532
from duron.typing import Unspecified, inspect_function
3633

3734
if TYPE_CHECKING:
@@ -257,7 +254,7 @@ async def _invoke_prelude(
257254
if init_params["version"] != _CURRENT_VERSION:
258255
msg = "version mismatch"
259256
raise RuntimeError(msg)
260-
loop.set_key(_decode_id(init_params["seed"]))
257+
loop.set_key(decode_id(init_params["seed"]))
261258
extra_kwargs: dict[str, object] = {}
262259
for name, type_, dtype in job_fn.inject:
263260
with ctx.metadata({"param.name": name}):
@@ -423,7 +420,7 @@ async def handle_message(
423420
pending_info = self._pending_task.pop(e["promise_id"], None)
424421
task_info = self._tasks.get(e["promise_id"], None)
425422

426-
id_ = _decode_id(e["promise_id"])
423+
id_ = decode_id(e["promise_id"])
427424

428425
return_type: TypeHint[Any] = Unspecified
429426
if pending_info is not None:
@@ -454,7 +451,7 @@ async def handle_message(
454451
msg = f"Invalid promise/complete entry: {e!r}"
455452
raise ValueError(msg)
456453
elif e["type"] == "stream/create":
457-
id_ = _decode_id(e["id"])
454+
id_ = decode_id(e["id"])
458455
if e["id"] not in self._streams:
459456
self._loop.post_completion(
460457
id_, exception=ValueError("Stream not found")
@@ -463,7 +460,7 @@ async def handle_message(
463460
self._loop.post_completion(id_, result=e["id"])
464461
self._pending_ops.discard(id_)
465462
elif e["type"] == "stream/emit":
466-
id_ = _decode_id(e["id"])
463+
id_ = decode_id(e["id"])
467464
if e["stream_id"] not in self._streams:
468465
self._loop.post_completion(
469466
id_, exception=ValueError("Stream not found")
@@ -478,7 +475,7 @@ async def handle_message(
478475
self._loop.post_completion(id_, result=None)
479476
self._pending_ops.discard(id_)
480477
elif e["type"] == "stream/complete":
481-
id_ = _decode_id(e["id"])
478+
id_ = decode_id(e["id"])
482479
if e["stream_id"] not in self._streams:
483480
self._loop.post_completion(
484481
id_, exception=ValueError("Stream not found")
@@ -495,7 +492,7 @@ async def handle_message(
495492
_ = self._streams.pop(e["stream_id"], None)
496493
self._pending_ops.discard(id_)
497494
elif e["type"] == "barrier":
498-
id_ = _decode_id(e["id"])
495+
id_ = decode_id(e["id"])
499496
self._loop.post_completion(id_, result=offset)
500497
self._pending_ops.discard(id_)
501498

@@ -524,7 +521,7 @@ async def enqueue_op(self, id_: bytes, fut: OpFuture[object]) -> None:
524521
case FnCall():
525522
promise_create_entry: PromiseCreateEntry = {
526523
"ts": self.now(),
527-
"id": _encode_id(id_),
524+
"id": encode_id(id_),
528525
"type": "promise/create",
529526
}
530527
if op.metadata:
@@ -538,11 +535,12 @@ async def enqueue_op(self, id_: bytes, fut: OpFuture[object]) -> None:
538535
await self.enqueue_log(promise_create_entry)
539536

540537
async def cb() -> None:
538+
now_us = self.now()
541539
entry: PromiseCompleteEntry = {
542-
"ts": self.now(),
543-
"id": _encode_id(id_, ack=True),
540+
"ts": now_us,
541+
"id": encode_id(derive_id(id_)),
544542
"type": "promise/complete",
545-
"promise_id": _encode_id(id_),
543+
"promise_id": encode_id(id_),
546544
}
547545
try:
548546
result = op.callable(*op.args, **op.kwargs)
@@ -558,7 +556,7 @@ async def cb() -> None:
558556

559557
def done(f: OpFuture[object]) -> None:
560558
if f.cancelled():
561-
sid = _encode_id(f.id)
559+
sid = encode_id(f.id)
562560
if self._pending_task.get(sid, None):
563561
# pending task cancelled
564562
pass
@@ -568,14 +566,14 @@ def done(f: OpFuture[object]) -> None:
568566
_ = task.get_loop().call_soon(task.cancel)
569567

570568
fut.add_done_callback(done)
571-
sid = _encode_id(id_)
569+
sid = encode_id(id_)
572570
if self._running:
573571
self._tasks[sid] = (asyncio.create_task(cb()), op.return_type)
574572
else:
575573
self._pending_task[sid] = (cb, op.return_type)
576574

577575
case StreamCreate():
578-
stream_id = _encode_id(id_)
576+
stream_id = encode_id(id_)
579577

580578
# Determine which observer to use
581579
ob = [op.observer] if op.observer else []
@@ -599,7 +597,7 @@ def done(f: OpFuture[object]) -> None:
599597
case StreamEmit():
600598
stream_emit_entry: StreamEmitEntry = {
601599
"ts": self.now(),
602-
"id": _encode_id(id_),
600+
"id": encode_id(id_),
603601
"stream_id": op.stream_id,
604602
"type": "stream/emit",
605603
"value": self._codec.encode_json(op.value),
@@ -609,7 +607,7 @@ def done(f: OpFuture[object]) -> None:
609607
if op.exception:
610608
stream_close_entry_err: StreamCompleteEntry = {
611609
"ts": self.now(),
612-
"id": _encode_id(id_),
610+
"id": encode_id(id_),
613611
"stream_id": op.stream_id,
614612
"type": "stream/complete",
615613
"error": _encode_error(op.exception),
@@ -618,27 +616,27 @@ def done(f: OpFuture[object]) -> None:
618616
else:
619617
stream_close_entry: StreamCompleteEntry = {
620618
"ts": self.now(),
621-
"id": _encode_id(id_),
619+
"id": encode_id(id_),
622620
"stream_id": op.stream_id,
623621
"type": "stream/complete",
624622
}
625623
await self.enqueue_log(stream_close_entry)
626624
case Barrier():
627625
barrier_entry: BarrierEntry = {
628626
"ts": self.now(),
629-
"id": _encode_id(id_),
627+
"id": encode_id(id_),
630628
"type": "barrier",
631629
}
632630
await self.enqueue_log(barrier_entry, flush=True)
633631
case ExternalPromiseCreate():
634632
promise_create_entry = {
635633
"ts": self.now(),
636-
"id": _encode_id(id_),
634+
"id": encode_id(id_),
637635
"type": "promise/create",
638636
}
639637
if op.metadata:
640638
promise_create_entry["metadata"] = op.metadata
641-
self._tasks[_encode_id(id_)] = (asyncio.Future(), op.return_type)
639+
self._tasks[encode_id(id_)] = (asyncio.Future(), op.return_type)
642640
await self.enqueue_log(promise_create_entry)
643641
case _:
644642
assert_never(op)
@@ -653,9 +651,10 @@ async def complete_external_promise(
653651
if id_ not in self._tasks:
654652
msg = "Promise not found"
655653
raise ValueError(msg)
654+
now_us = self.now()
656655
entry: PromiseCompleteEntry = {
657-
"ts": self.now(),
658-
"id": _encode_id(_decode_id(id_), ack=True),
656+
"ts": now_us,
657+
"id": encode_id(derive_id(decode_id(id_))),
659658
"type": "promise/complete",
660659
"promise_id": id_,
661660
}
@@ -723,21 +722,8 @@ async def close_stream(
723722
return cnt
724723

725724

726-
def _encode_id(id_: bytes, *, ack: bool = False) -> str:
727-
if ack:
728-
id_ = blake2b(
729-
id_,
730-
digest_size=12,
731-
).digest()
732-
return binascii.b2a_base64(id_, newline=False).decode()
733-
734-
735-
def _decode_id(encoded: str) -> bytes:
736-
return binascii.a2b_base64(encoded)
737-
738-
739725
def _generate_id() -> str:
740-
return _encode_id(os.urandom(12))
726+
return encode_id(random_id())
741727

742728

743729
def _encode_error(error: BaseException) -> ErrorInfo:

src/duron/_loop.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
import contextlib
55
import contextvars
66
import logging
7-
import os
87
from asyncio import events, tasks
98
from collections import deque
109
from dataclasses import dataclass
11-
from hashlib import blake2b
1210
from heapq import heappop, heappush
1311
from typing import TYPE_CHECKING, Generic
1412
from typing_extensions import (
@@ -20,6 +18,8 @@
2018
override,
2119
)
2220

21+
from duron.log import derive_id, random_id
22+
2323
_T = TypeVar("_T")
2424

2525
if TYPE_CHECKING:
@@ -108,15 +108,14 @@ def __init__(self, host: asyncio.AbstractEventLoop) -> None:
108108
self._timers: list[asyncio.TimerHandle] = []
109109

110110
def set_key(self, key: bytes) -> None:
111-
self._key = key
111+
self._key = derive_id(key, key=self._key)
112112

113113
def generate_op_id(self) -> bytes:
114114
ctx = _task_ctx.get(self._ctx)
115115
ctx.seq += 1
116-
return _mix_id(ctx.parent_id, self._key, ctx.seq - 1)
117-
118-
def host_loop(self) -> asyncio.AbstractEventLoop:
119-
return self._host
116+
return derive_id(
117+
ctx.parent_id, context=(ctx.seq - 1).to_bytes(4, "big"), key=self._key
118+
)
120119

121120
@override
122121
def call_soon(
@@ -256,7 +255,7 @@ def poll_completion(self, task: Future[_T]) -> WaitSet | None:
256255

257256
def create_op(self, params: object, *, external: bool = False) -> OpFuture[object]:
258257
if external:
259-
id_ = os.urandom(12)
258+
id_ = random_id()
260259
self._event.set()
261260
else:
262261
id_ = self.generate_op_id()
@@ -288,7 +287,7 @@ def post_completion(
288287
if op := self._ops.pop(id_, None):
289288
if op.done():
290289
return
291-
tid = _mix_id(op.id, self._key, -1)
290+
tid = derive_id(op.id, key=self._key)
292291
token = _task_ctx.set(_TaskCtx(parent_id=tid))
293292
if exception is None:
294293
_ = self.call_soon(op.set_result, result)
@@ -356,12 +355,6 @@ def _timer_handle_cancelled(self, _th: asyncio.TimerHandle) -> None:
356355
pass
357356

358357

359-
def _mix_id(a: bytes, key: bytes, b: int) -> bytes:
360-
if b == -1:
361-
return blake2b(a, key=key, digest_size=12).digest()
362-
return blake2b(b.to_bytes(4, "little") + a, key=key, digest_size=12).digest()
363-
364-
365358
def create_loop(
366359
parent_loop: asyncio.AbstractEventLoop,
367360
) -> EventLoop:

src/duron/log.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from __future__ import annotations
22

3+
import binascii
4+
import os
35
from abc import ABC, abstractmethod
6+
from hashlib import blake2b
47
from typing import TYPE_CHECKING, Literal
58
from typing_extensions import NotRequired, TypedDict
69

@@ -102,3 +105,19 @@ async def append(self, token: bytes, entry: Entry, /) -> int: ...
102105

103106
@abstractmethod
104107
async def flush(self, token: bytes, /) -> None: ...
108+
109+
110+
def encode_id(raw: bytes) -> str:
111+
return binascii.b2a_base64(raw, newline=False).decode()
112+
113+
114+
def decode_id(encoded: str) -> bytes:
115+
return binascii.a2b_base64(encoded)
116+
117+
118+
def random_id() -> bytes:
119+
return os.urandom(12)
120+
121+
122+
def derive_id(base: bytes, *, context: bytes = b"", key: bytes = b"") -> bytes:
123+
return blake2b(base, salt=context, key=key, digest_size=12).digest()

tests/test_id.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
3+
from duron.log import (
4+
decode_id,
5+
derive_id,
6+
encode_id,
7+
random_id,
8+
)
9+
10+
11+
def test_generates_unique_ids() -> None:
12+
ids = {random_id() for _ in range(1000)}
13+
assert len(ids) == 1000
14+
15+
16+
def test_derive_id_deterministic() -> None:
17+
base = b"test base"
18+
key = b"test key"
19+
20+
id1 = derive_id(base, key=key)
21+
id2 = derive_id(base, key=key)
22+
assert id1 == id2
23+
assert decode_id(encode_id(id1)) == id2
24+
25+
26+
@pytest.mark.benchmark
27+
def test_bench_derive_id() -> None:
28+
_ = decode_id(encode_id(derive_id(b"hello", context=b"key")))

0 commit comments

Comments
 (0)