Skip to content

Commit 53b6513

Browse files
committed
feat: remove pydantic and allow custom type codec
1 parent f3248a1 commit 53b6513

9 files changed

Lines changed: 46 additions & 244 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
with:
2828
enable-cache: true
2929
- name: Sync dependencies
30-
run: uv sync --group type-checking --group lint --extra full
30+
run: uv sync --group type-checking --group lint
3131
- name: Run tests
3232
run: uv run pytest
3333
- name: Run linter

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
def install_deps(s: nox.Session, groups: list[str]):
1414
s.env["UV_PROJECT_ENVIRONMENT"] = s.virtualenv.location
15-
cmd = ["uv", "sync", "--frozen", "--extra", "full"]
15+
cmd = ["uv", "sync", "--frozen"]
1616
for g in groups:
1717
cmd.extend(("--group", g))
1818
_ = s.run_install(*cmd)

pyproject.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@ dependencies = [
88
"typing-extensions>=4.15.0",
99
]
1010

11-
[project.optional-dependencies]
12-
full = [
13-
"pydantic>=2.11.9",
14-
]
15-
1611
[build-system]
1712
requires = ["uv_build>=0.8.19,<0.9.0"]
1813
build-backend = "uv_build"
@@ -25,7 +20,6 @@ dev = [
2520
type-checking = [
2621
"basedmypy>=2.10.0",
2722
"basedpyright>=1.31.2",
28-
"pydantic>=2.11.9",
2923
]
3024
lint = [
3125
"ruff>=0.12.8",

src/duron/codec.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,30 @@
11
from __future__ import annotations
22

3-
import base64
4-
import importlib
5-
import pickle
6-
from typing import TYPE_CHECKING, cast, final
3+
from typing import TYPE_CHECKING, final
74

85
from typing_extensions import Protocol, override
96

107
from duron.log import is_json_value
118

12-
BaseModel: type[_pydantic.BaseModel] | None = None
13-
try:
14-
import pydantic as _pydantic
15-
16-
BaseModel = _pydantic.BaseModel
17-
except ImportError:
18-
pass
19-
209
if TYPE_CHECKING:
2110
from duron.log import JSONValue
2211

23-
__all__ = ["Codec", "DEFAULT_CODEC"]
12+
__all__ = ["Codec", "DefaultCodec"]
2413

2514

2615
class Codec(Protocol):
2716
def encode_json(self, result: object) -> JSONValue: ...
2817
def decode_json(self, encoded: JSONValue) -> object: ...
2918

30-
def encode_state(self, obj: object) -> str: ...
31-
def decode_state(self, state: str) -> object: ...
32-
3319

3420
@final
35-
class _DefaultCodec(Codec):
36-
def __init__(self) -> None:
37-
self._type_cache: dict[str, type] = {}
38-
39-
def _lookup_model(self, qual: str) -> type:
40-
cls = self._type_cache.get(qual)
41-
if cls is not None:
42-
return cls
43-
44-
mod, name = qual.split(":", 1)
45-
obj: object = importlib.import_module(mod)
46-
for part in name.split("."):
47-
obj = getattr(obj, part)
48-
if isinstance(obj, type):
49-
self._type_cache[qual] = obj
50-
return obj
51-
else:
52-
raise TypeError(f"Imported object is not a type: {obj!r}")
53-
21+
class DefaultCodec(Codec):
5422
@override
5523
def encode_json(self, result: object) -> JSONValue:
56-
if BaseModel and isinstance(result, BaseModel):
57-
cls = result.__class__
58-
model = result.model_dump(mode="json")
59-
model["_duron.pydantic"] = f"{cls.__module__}:{cls.__qualname__}"
60-
return model
6124
if is_json_value(result):
6225
return result
6326
raise TypeError(f"Result is not JSON-serializable: {result!r}")
6427

6528
@override
6629
def decode_json(self, encoded: JSONValue) -> object:
67-
if isinstance(encoded, dict) and "_duron.pydantic" in encoded:
68-
model = self._lookup_model(cast("str", encoded["_duron.pydantic"]))
69-
if not BaseModel or not issubclass(model, BaseModel):
70-
raise TypeError(f"Decoded class is not a BaseModel subclass: {model}")
71-
return model.model_validate({
72-
k: v for k, v in encoded.items() if not k.startswith("_duron.")
73-
})
7430
return encoded
75-
76-
@override
77-
def encode_state(self, obj: object) -> str:
78-
return base64.b64encode(pickle.dumps(obj, protocol=5)).decode("ascii")
79-
80-
@override
81-
def decode_state(self, state: str) -> object:
82-
return pickle.loads(base64.b64decode(state.encode()))
83-
84-
85-
DEFAULT_CODEC: Codec = _DefaultCodec()

src/duron/fn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
TypeVar,
99
)
1010

11-
from duron.codec import DEFAULT_CODEC
11+
from duron.codec import DefaultCodec
1212

1313
if TYPE_CHECKING:
1414
from collections.abc import Callable
@@ -30,14 +30,14 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T_co:
3030

3131

3232
def durable(
33-
*, codec: Codec = DEFAULT_CODEC
33+
*, codec: Codec | None = None
3434
) -> Callable[[Callable[_P, _T_co]], DurableFn[_P, _T_co]]:
3535
"""
3636
Mark a function as durable, meaning its execution can be recorded and
3737
replayed.
3838
"""
3939

4040
def decorate(fn: Callable[_P, _T_co]) -> DurableFn[_P, _T_co]:
41-
return DurableFn(codec=codec, fn=fn)
41+
return DurableFn(codec=codec or DefaultCodec(), fn=fn)
4242

4343
return decorate

src/duron/log/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class _BaseEntry(TypedDict):
2727
class ErrorInfo(TypedDict):
2828
code: int
2929
message: str
30-
state: NotRequired[str] # opaque
3130

3231

3332
class PromiseCreateEntry(_BaseEntry):
@@ -49,7 +48,7 @@ class StreamEmitEntry(_BaseEntry):
4948
type: Literal["stream/emit"]
5049
stream_id: str
5150
value: NotRequired[JSONValue]
52-
state: NotRequired[str] # opaque
51+
state: NotRequired[JSONValue]
5352

5453

5554
class StreamCompleteEntry(_BaseEntry):

src/duron/task.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ async def handle_message(self, e: Entry) -> None:
242242
if "error" in e:
243243
self._loop.post_completion_threadsafe(
244244
id,
245-
exception=_decode_error(e["error"], self._codec),
245+
exception=_decode_error(e["error"]),
246246
)
247247
self._pending_ops.discard(id)
248248
elif "result" in e:
@@ -287,7 +287,7 @@ async def cb() -> None:
287287
result = await cast("Awaitable[object]", result)
288288
entry["result"] = self._codec.encode_json(result)
289289
except BaseException as e:
290-
entry["error"] = _encode_error(e, self._codec)
290+
entry["error"] = _encode_error(e)
291291
await self.enqueue_log(entry)
292292

293293
if self._running:
@@ -313,7 +313,7 @@ async def cb() -> None:
313313
result = await op.task
314314
entry["result"] = self._codec.encode_json(result)
315315
except BaseException as e:
316-
entry["error"] = _encode_error(e, self._codec)
316+
entry["error"] = _encode_error(e)
317317

318318
await self.enqueue_log(entry, True)
319319

@@ -343,20 +343,12 @@ def _decode_timestamp(ts: int) -> int:
343343
return ts * 1_000
344344

345345

346-
def _encode_error(error: BaseException, codec: Codec) -> ErrorInfo:
347-
"""Convert exception to ErrorInfo dict."""
346+
def _encode_error(error: BaseException) -> ErrorInfo:
348347
return {
349348
"code": -1,
350349
"message": str(error),
351-
"state": codec.encode_state(error),
352350
}
353351

354352

355-
def _decode_error(error_info: ErrorInfo, codec: Codec) -> BaseException:
356-
"""Convert ErrorInfo dict to exception."""
357-
try:
358-
if "state" not in error_info:
359-
return Exception(f"[{error_info['code']}] {error_info['message']}")
360-
return cast("BaseException", codec.decode_state(error_info["state"]))
361-
except Exception:
362-
return Exception(f"[{error_info['code']}] {error_info['message']}")
353+
def _decode_error(error_info: ErrorInfo) -> BaseException:
354+
return Exception(f"[{error_info['code']}] {error_info['message']}")

tests/test_task.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from __future__ import annotations
2+
13
import asyncio
4+
import base64
5+
import pickle
26
import random
37
import uuid
4-
from typing import ClassVar
8+
from dataclasses import dataclass
59

610
import pytest
711

@@ -104,31 +108,33 @@ async def activity(s: str) -> str:
104108
assert x == "hello"
105109

106110

107-
try:
108-
import pydantic
111+
@dataclass
112+
class CustomPoint:
113+
x: int
114+
y: int
109115

110-
class Point(pydantic.BaseModel):
111-
model_config: ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(
112-
extra="forbid"
113-
)
114116

115-
x: int
116-
y: int
117+
class PickleCodec:
118+
def encode_json(self, result: object) -> str:
119+
return base64.b64encode(pickle.dumps(result)).decode()
117120

118-
@pytest.mark.asyncio
119-
async def test_pydantic():
120-
@durable()
121-
async def activity() -> Point:
122-
ctx = get_context()
123-
pt = await ctx.run(lambda: Point(x=1, y=2))
124-
return Point(x=pt.x + 5, y=pt.y + 10)
121+
def decode_json(self, encoded: object) -> object:
122+
if not isinstance(encoded, str):
123+
raise TypeError(f"Expected a string, got {type(encoded).__name__}")
124+
return pickle.loads(base64.b64decode(encoded.encode()))
125125

126-
log = MemoryLogStorage()
127-
async with task(activity, log) as t:
128-
await t.start()
129-
a = await t.wait()
130-
assert type(a) is Point
131-
assert a.x == 6 and a.y == 12
132126

133-
except ImportError:
134-
pass
127+
@pytest.mark.asyncio
128+
async def test_serialize():
129+
@durable(codec=PickleCodec())
130+
async def activity() -> CustomPoint:
131+
ctx = get_context()
132+
pt = await ctx.run(lambda: CustomPoint(x=1, y=2))
133+
return CustomPoint(x=pt.x + 5, y=pt.y + 10)
134+
135+
log = MemoryLogStorage()
136+
async with task(activity, log) as t:
137+
await t.start()
138+
a = await t.wait()
139+
assert type(a) is CustomPoint
140+
assert a.x == 6 and a.y == 12

0 commit comments

Comments
 (0)