Skip to content

Commit 9507535

Browse files
committed
feat: add codec for log
1 parent 1e7258e commit 9507535

9 files changed

Lines changed: 292 additions & 123 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
with:
2828
enable-cache: true
2929
- name: Sync dependencies
30-
run: uv sync --group type-checking --group lint
30+
run: uv sync --group type-checking --group lint --extra full
3131
- name: Run tests
3232
run: uv run pytest
3333
- name: Run linter

noxfile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# type: ignore
2+
# pyright: basic, reportMissingImports=false
3+
14
from __future__ import annotations
25

36
import nox
@@ -9,7 +12,7 @@
912

1013
def install_deps(s: nox.Session, groups: list[str]):
1114
s.env["UV_PROJECT_ENVIRONMENT"] = s.virtualenv.location
12-
cmd = ["uv", "sync", "--frozen"]
15+
cmd = ["uv", "sync", "--frozen", "--extra", "full"]
1316
for g in groups:
1417
cmd.extend(("--group", g))
1518
_ = s.run_install(*cmd)

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ dependencies = [
88
"typing-extensions>=4.15.0",
99
]
1010

11+
[project.optional-dependencies]
12+
full = [
13+
"pydantic>=2.11.9",
14+
]
15+
1116
[build-system]
1217
requires = ["uv_build>=0.8.19,<0.9.0"]
1318
build-backend = "uv_build"
@@ -18,9 +23,10 @@ dev = [
1823
"pytest-asyncio>=0.21",
1924
]
2025
type-checking = [
26+
{include-group = "dev"},
2127
"basedmypy>=2.10.0",
2228
"basedpyright>=1.31.2",
23-
"nox>=2025.5.1",
29+
"pydantic>=2.11.9",
2430
]
2531
lint = [
2632
"ruff>=0.12.8",

src/duron/log/codec.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
import importlib
5+
import pickle
6+
from typing import TYPE_CHECKING, cast, final
7+
8+
from typing_extensions import Protocol
9+
10+
if TYPE_CHECKING:
11+
from duron.log.entry import JSONValue
12+
13+
BaseModel: type[_pydantic.BaseModel] | None = None
14+
try:
15+
import pydantic as _pydantic
16+
17+
BaseModel = _pydantic.BaseModel
18+
except ImportError:
19+
pass
20+
21+
22+
class BaseCodec(Protocol):
23+
def encode_json(self, result: object) -> JSONValue: ...
24+
def decode_json(self, encoded: JSONValue) -> object: ...
25+
26+
def encode_state(self, obj: object) -> str: ...
27+
def decode_state(self, state: str) -> object: ...
28+
29+
30+
@final
31+
class DefaultCodec:
32+
def __init__(self) -> None:
33+
self._type_cache: dict[str, type] = {}
34+
35+
def _lookup_model(self, qual: str) -> type:
36+
cls = self._type_cache.get(qual)
37+
if cls is not None:
38+
return cls
39+
40+
mod, name = qual.split(":", 1)
41+
obj: object = importlib.import_module(mod)
42+
for part in name.split("."):
43+
obj = getattr(obj, part)
44+
if isinstance(obj, type):
45+
self._type_cache[qual] = obj
46+
return obj
47+
else:
48+
raise TypeError(f"Imported object is not a type: {obj!r}")
49+
50+
def encode_json(self, result: object) -> JSONValue:
51+
if BaseModel and isinstance(result, BaseModel):
52+
cls = result.__class__
53+
model = result.model_dump(mode="json")
54+
model["_duron.pydantic"] = f"{cls.__module__}:{cls.__qualname__}"
55+
return model
56+
return cast("JSONValue", result)
57+
58+
def decode_json(self, encoded: JSONValue) -> object:
59+
if isinstance(encoded, dict) and "_duron.pydantic" in encoded:
60+
model = self._lookup_model(cast("str", encoded["_duron.pydantic"]))
61+
if not BaseModel or not issubclass(model, BaseModel):
62+
raise TypeError(f"Decoded class is not a BaseModel subclass: {model}")
63+
return model.model_validate(
64+
{k: v for k, v in encoded.items() if not k.startswith("_duron.")}
65+
)
66+
return encoded
67+
68+
def encode_state(self, obj: object) -> str:
69+
return base64.b64encode(pickle.dumps(obj, protocol=5)).decode("ascii")
70+
71+
def decode_state(self, state: str) -> object:
72+
return pickle.loads(base64.b64decode(state.encode()))

src/duron/log/entry.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing_extensions import NotRequired, TypedDict
44

5+
JSONValue = dict[str, "JSONValue"] | list["JSONValue"] | str | int | float | bool | None
6+
57

68
class BaseEntry(TypedDict):
79
id: str
@@ -13,7 +15,7 @@ class BaseEntry(TypedDict):
1315
class ErrorInfo(TypedDict):
1416
code: int
1517
message: str
16-
data: NotRequired[object]
18+
state: NotRequired[str] # opaque
1719

1820

1921
class PromiseCreateEntry(BaseEntry):
@@ -23,7 +25,7 @@ class PromiseCreateEntry(BaseEntry):
2325
class PromiseCompleteEntry(BaseEntry):
2426
type: Literal["promise/complete"]
2527
promise_id: str
26-
result: NotRequired[object]
28+
result: NotRequired[JSONValue]
2729
error: NotRequired[ErrorInfo]
2830

2931

@@ -34,8 +36,8 @@ class StreamCreateEntry(BaseEntry):
3436
class StreamEmitEntry(BaseEntry):
3537
type: Literal["stream/emit"]
3638
stream_id: str
37-
value: NotRequired[object]
38-
state: NotRequired[object]
39+
value: NotRequired[JSONValue]
40+
state: NotRequired[str] # opaque
3941

4042

4143
class StreamCloseEntry(BaseEntry):

src/duron/py.typed

Whitespace-only changes.

src/duron/task_runner.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
)
1515

1616
from duron.event_loop import create_loop
17+
from duron.log.codec import DefaultCodec
1718
from duron.log.entry import is_entry
1819
from duron.ops import FnCall
1920

2021
if TYPE_CHECKING:
2122
from collections.abc import AsyncGenerator, Callable
2223

2324
from duron.event_loop import WaitSet
25+
from duron.log.codec import BaseCodec
2426
from duron.log.entry import Entry, ErrorInfo, UnknownEntry
2527
from duron.log.storage import BaseLogStorage, Lease, Offset
2628
from duron.mark import DurableFn
@@ -30,27 +32,32 @@
3032

3133

3234
class TaskRunner:
33-
def __init__(self):
34-
pass
35+
def __init__(self, codec: BaseCodec | None = None) -> None:
36+
self._codec: BaseCodec = codec or DefaultCodec()
3537

3638
async def run(
3739
self,
3840
task_id: bytes,
3941
task_co: DurableFn[[], Coroutine[Any, Any, _T]],
4042
log: BaseLogStorage,
4143
) -> _T:
42-
return await _Task[_T](task_id, task_co(), log).run()
44+
return await _Task[_T](task_id, task_co(), log, self._codec).run()
4345

4446

4547
@final
4648
class _Task(Generic[_T]):
4749
def __init__(
48-
self, id: bytes, task_co: Coroutine[Any, Any, _T], log: BaseLogStorage
50+
self,
51+
id: bytes,
52+
task_co: Coroutine[Any, Any, _T],
53+
log: BaseLogStorage,
54+
codec: BaseCodec,
4955
) -> None:
5056
self._id = id
5157
self._loop = create_loop(id)
5258
self._task = self._loop.create_task(task_co)
5359
self._log = log
60+
self._codec = codec
5461
self._running: Lease | None = None
5562
self._pending_msg: list[Entry] = []
5663
self._pending_task: dict[str, Callable[[], Coroutine[Any, Any, object]]] = {}
@@ -110,7 +117,7 @@ async def run(self) -> _T:
110117
"id": _encode_id(self._id, True),
111118
"ts": _encode_timestamp(self.now()),
112119
"promise_id": _encode_id(self._id, False),
113-
"result": res,
120+
"result": self._codec.encode_json(res),
114121
},
115122
)
116123
return res
@@ -121,7 +128,7 @@ async def run(self) -> _T:
121128
"id": _encode_id(self._id, True),
122129
"ts": _encode_timestamp(self.now()),
123130
"promise_id": _encode_id(self._id, False),
124-
"error": exception_to_error_info(e),
131+
"error": _encode_error(e, self._codec),
125132
},
126133
)
127134
raise
@@ -157,13 +164,13 @@ async def handle_message(self, e: Entry) -> None:
157164
if "error" in e:
158165
self._loop.post_completion_threadsafe(
159166
id,
160-
exception=error_info_to_exception(e["error"]),
167+
exception=_decode_error(e["error"], self._codec),
161168
)
162169
self._pending_ops.discard(id)
163170
elif "result" in e:
164171
self._loop.post_completion_threadsafe(
165172
id,
166-
result=e["result"],
173+
result=self._codec.decode_json(e["result"]),
167174
)
168175
self._pending_ops.discard(id)
169176
else:
@@ -200,17 +207,17 @@ async def cb() -> None:
200207
"id": _encode_id(id, True),
201208
"type": "promise/complete",
202209
"promise_id": _encode_id(id, False),
203-
"result": result,
210+
"result": self._codec.encode_json(result),
204211
}
205212
)
206213
except BaseException as e:
207214
await self.enqueue_log(
208215
{
216+
"type": "promise/complete",
209217
"ts": _encode_timestamp(self.now()),
210218
"id": _encode_id(id, True),
211-
"type": "promise/complete",
212219
"promise_id": _encode_id(id, False),
213-
"error": exception_to_error_info(e),
220+
"error": _encode_error(e, self._codec),
214221
}
215222
)
216223

@@ -223,20 +230,20 @@ async def cb() -> None:
223230
raise NotImplementedError(f"Unsupported op: {op!r}")
224231

225232

226-
def _encode_id(id: bytes, end: bool) -> str:
227-
if end:
233+
def _encode_id(id: bytes, is_end: bool) -> str:
234+
if is_end:
228235
return base64.b64encode(id).decode() + "-"
229236
else:
230237
return base64.b64encode(id).decode() + "+"
231238

232239

233-
def _decode_id(s: str) -> tuple[bytes, bool]:
234-
if s.endswith("-"):
235-
return base64.b64decode(s[:-1]), True
236-
elif s.endswith("+"):
237-
return base64.b64decode(s[:-1]), False
240+
def _decode_id(encoded: str) -> tuple[bytes, bool]:
241+
if encoded.endswith("-"):
242+
return base64.b64decode(encoded[:-1]), True
243+
elif encoded.endswith("+"):
244+
return base64.b64decode(encoded[:-1]), False
238245
else:
239-
raise ValueError(f"Invalid encoded id: {s!r}")
246+
raise ValueError(f"Invalid encoded id: {encoded!r}")
240247

241248

242249
def _encode_timestamp(ts_ns: int) -> int:
@@ -247,12 +254,20 @@ def _decode_timestamp(ts: int) -> int:
247254
return ts * 1_000
248255

249256

250-
def exception_to_error_info(e: BaseException) -> ErrorInfo:
257+
def _encode_error(error: BaseException, codec: BaseCodec) -> ErrorInfo:
258+
"""Convert exception to ErrorInfo dict."""
251259
return {
252-
"code": 1,
253-
"message": str(e),
260+
"code": -1,
261+
"message": str(error),
262+
"state": codec.encode_state(error),
254263
}
255264

256265

257-
def error_info_to_exception(e: ErrorInfo) -> Exception:
258-
return Exception(f"[{e['code']}] {e['message']}")
266+
def _decode_error(error_info: ErrorInfo, codec: BaseCodec) -> BaseException:
267+
"""Convert ErrorInfo dict to exception."""
268+
try:
269+
if "state" not in error_info:
270+
return Exception(f"[{error_info['code']}] {error_info['message']}")
271+
return cast("BaseException", codec.decode_state(error_info["state"]))
272+
except Exception:
273+
return Exception(f"[{error_info['code']}] {error_info['message']}")

tests/test_task.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import uuid
3+
from typing import ClassVar
34

45
import pytest
56

@@ -10,7 +11,7 @@
1011

1112

1213
@pytest.mark.asyncio
13-
async def test_task_runner():
14+
async def test_task():
1415
@durable()
1516
async def activity() -> str:
1617
ctx = duron.context.Context()
@@ -50,3 +51,31 @@ async def error():
5051
_ = await tr.run(b"1", activity, log)
5152
with pytest.raises(check=lambda v: "test error" in str(v)):
5253
_ = await tr.run(b"1", activity, log)
54+
55+
56+
try:
57+
import pydantic
58+
59+
class Point(pydantic.BaseModel):
60+
model_config: ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(
61+
extra="forbid"
62+
)
63+
64+
x: int
65+
y: int
66+
67+
@pytest.mark.asyncio
68+
async def test_pydantic():
69+
@durable()
70+
async def activity() -> Point:
71+
ctx = duron.context.Context()
72+
pt = await ctx.run(lambda: Point(x=1, y=2))
73+
return Point(x=pt.x + 5, y=pt.y + 10)
74+
75+
tr = duron.task_runner.TaskRunner()
76+
log = MemoryLogStorage()
77+
a = await tr.run(b"1", activity, log)
78+
assert type(a) is Point
79+
assert a.x == 6 and a.y == 12
80+
except ImportError:
81+
pass

0 commit comments

Comments
 (0)