Skip to content

Commit 6312b14

Browse files
committed
feat: add pydantic codec testing
1 parent a50bc49 commit 6312b14

7 files changed

Lines changed: 275 additions & 72 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ build-backend = "uv_build"
1414

1515
[dependency-groups]
1616
dev = [
17+
"pydantic>=2.11.9",
1718
"pytest>=7.0",
1819
"pytest-asyncio>=0.21",
1920
]
@@ -46,8 +47,12 @@ preview = true
4647
future-annotations = true
4748

4849
[tool.ruff.lint.flake8-type-checking]
50+
runtime-evaluated-base-classes = ["typing_extensions.TypedDict"]
4951
runtime-evaluated-decorators = ["duron.fn"]
5052

53+
[tool.ruff.lint.flake8-tidy-imports.banned-api]
54+
"typing.TypedDict".msg = "Use typing_extensions.TypedDict instead."
55+
5156
[tool.ruff.format]
5257
preview = true
5358

src/duron/codec.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,36 @@
11
from __future__ import annotations
22

33
import inspect
4+
from abc import ABC, abstractmethod
45
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, final
6+
from typing import TYPE_CHECKING, TypeGuard, cast, final
67

7-
from typing_extensions import Protocol, override
8-
9-
from duron.log import is_json_value
8+
from typing_extensions import TypeAliasType, override
109

1110
if TYPE_CHECKING:
1211
from collections.abc import Callable
1312

14-
from duron.log import JSONValue
13+
JSONValue = (
14+
dict[str, "JSONValue"] | list["JSONValue"] | str | int | float | bool | None
15+
)
16+
else:
17+
JSONValue = TypeAliasType(
18+
"JSONValue",
19+
"dict[str, JSONValue] | list[JSONValue] | str | int | float | bool | None",
20+
)
21+
22+
23+
def is_json_value(x: object) -> TypeGuard[JSONValue]:
24+
if x is None or isinstance(x, (bool, int, float, str)):
25+
return True
26+
if isinstance(x, list):
27+
return all(is_json_value(item) for item in cast("list[object]", x))
28+
if isinstance(x, dict):
29+
return all(
30+
isinstance(k, str) and is_json_value(v)
31+
for k, v in cast("dict[object, object]", x).items()
32+
)
33+
return False
1534

1635

1736
@dataclass(slots=True)
@@ -21,31 +40,15 @@ class FunctionType:
2140
parameter_types: dict[str, type | None]
2241

2342

24-
class Codec(Protocol):
43+
class Codec(ABC):
44+
@abstractmethod
2545
def encode_json(self, result: object, /) -> JSONValue: ...
46+
47+
@abstractmethod
2648
def decode_json(
2749
self, encoded: JSONValue, expected_type: type | None, /
2850
) -> object: ...
2951

30-
def inspect_function(
31-
self,
32-
fn: Callable[..., object],
33-
) -> FunctionType: ...
34-
35-
36-
@final
37-
class DefaultCodec(Codec):
38-
@override
39-
def encode_json(self, result: object) -> JSONValue:
40-
if is_json_value(result):
41-
return result
42-
raise TypeError(f"Result is not JSON-serializable: {result!r}")
43-
44-
@override
45-
def decode_json(self, encoded: JSONValue, _expected_type: type | None) -> object:
46-
return encoded
47-
48-
@override
4952
def inspect_function(
5053
self,
5154
fn: Callable[..., object],
@@ -76,3 +79,16 @@ def inspect_function(
7679
parameters=parameter_names,
7780
parameter_types=parameter_types,
7881
)
82+
83+
84+
@final
85+
class DefaultCodec(Codec):
86+
@override
87+
def encode_json(self, result: object) -> JSONValue:
88+
if is_json_value(result):
89+
return result
90+
raise TypeError(f"Result is not JSON-serializable: {result!r}")
91+
92+
@override
93+
def decode_json(self, encoded: JSONValue, _expected_type: type | None) -> object:
94+
return encoded

src/duron/contrib/codecs.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,21 @@
44
import pickle
55
from typing import TYPE_CHECKING
66

7-
from duron.codec import DefaultCodec
7+
from typing_extensions import override
88

9-
if TYPE_CHECKING:
10-
from collections.abc import Callable
9+
from duron.codec import Codec
1110

12-
from duron.codec import FunctionType
13-
from duron.log import JSONValue
11+
if TYPE_CHECKING:
12+
from duron.codec import JSONValue
1413

1514

16-
class PickleCodec:
15+
class PickleCodec(Codec):
16+
@override
1717
def encode_json(self, result: object) -> str:
1818
return base64.b64encode(pickle.dumps(result)).decode()
1919

20+
@override
2021
def decode_json(self, encoded: JSONValue, _expected_type: type | None) -> object:
2122
if not isinstance(encoded, str):
2223
raise TypeError(f"Expected a string, got {type(encoded).__name__}")
2324
return pickle.loads(base64.b64decode(encoded.encode()))
24-
25-
def inspect_function(self, fn: Callable[..., object]) -> FunctionType:
26-
return DefaultCodec().inspect_function(fn)

src/duron/log.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Generic, Protocol, cast
3+
from abc import ABC, abstractmethod
4+
from typing import TYPE_CHECKING, Generic, Literal
45

5-
from typing_extensions import TypedDict, TypeVar
6+
from typing_extensions import NotRequired, TypedDict, TypeVar
7+
8+
from duron.codec import JSONValue
69

710
if TYPE_CHECKING:
811
from collections.abc import AsyncGenerator
9-
from typing import Literal, TypeGuard
10-
11-
from typing_extensions import NotRequired
12+
from typing import TypeGuard
1213

1314

1415
_TOffset = TypeVar("_TOffset")
@@ -78,31 +79,20 @@ def is_entry(entry: Entry | AnyEntry) -> TypeGuard[Entry]:
7879
}
7980

8081

81-
JSONValue = dict[str, "JSONValue"] | list["JSONValue"] | str | int | float | bool | None
82-
83-
84-
def is_json_value(x: object) -> TypeGuard[JSONValue]:
85-
if x is None or isinstance(x, (bool, int, float, str)):
86-
return True
87-
if isinstance(x, list):
88-
return all(is_json_value(item) for item in cast("list[object]", x))
89-
if isinstance(x, dict):
90-
return all(
91-
isinstance(k, str) and is_json_value(v)
92-
for k, v in cast("dict[object, object]", x).items()
93-
)
94-
return False
95-
96-
97-
class LogStorage(Protocol, Generic[_TOffset, _TLease]):
82+
class LogStorage(ABC, Generic[_TOffset, _TLease]):
83+
@abstractmethod
9884
def stream(
9985
self, start: _TOffset | None, live: bool, /
10086
) -> AsyncGenerator[tuple[_TOffset, AnyEntry], None]: ...
10187

88+
@abstractmethod
10289
async def acquire_lease(self) -> _TLease: ...
10390

91+
@abstractmethod
10492
async def release_lease(self, token: _TLease, /): ...
10593

94+
@abstractmethod
10695
async def append(self, token: _TLease, entry: Entry, /): ...
10796

97+
@abstractmethod
10898
async def flush(self, token: _TLease, /): ...

src/duron/task.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
Any,
1212
Generic,
1313
ParamSpec,
14-
TypedDict,
1514
TypeVar,
1615
cast,
1716
final,
1817
)
1918

20-
from duron.codec import Codec
19+
from typing_extensions import TypedDict
20+
21+
from duron.codec import Codec, JSONValue
2122
from duron.context import Context
2223
from duron.event_loop import EventLoop, create_loop
2324
from duron.log import is_entry
@@ -33,7 +34,6 @@
3334
from duron.log import (
3435
Entry,
3536
ErrorInfo,
36-
JSONValue,
3737
LogStorage,
3838
PromiseCompleteEntry,
3939
)
@@ -73,14 +73,16 @@ def __init__(
7373
self._run: _TaskRun | None = None
7474

7575
async def start(self, *args: _P.args, **kwargs: _P.kwargs) -> None:
76+
def get_init() -> TaskInitParams:
77+
return {
78+
"version": _CURRENT_VERSION,
79+
"args": [codec.encode_json(arg) for arg in args],
80+
"kwargs": {k: codec.encode_json(v) for k, v in kwargs.items()},
81+
}
82+
7683
codec = self._task_fn.codec
77-
init: TaskInitParams = {
78-
"version": _CURRENT_VERSION,
79-
"args": [codec.encode_json(arg) for arg in args],
80-
"kwargs": {k: codec.encode_json(v) for k, v in kwargs.items()},
81-
}
8284
type_info = codec.inspect_function(self._task_fn.fn)
83-
task_prelude = _task_prelude(self._task_fn, type_info, lambda: init)
85+
task_prelude = _task_prelude(self._task_fn, type_info, get_init)
8486
self._run = _TaskRun(
8587
TaskRun(task=task_prelude, return_type=type_info.return_type),
8688
self._log,
@@ -162,7 +164,7 @@ def __init__(
162164
self._pending_ops: set[bytes] = set()
163165
self._now = 0
164166
self._offset: object | None = None
165-
self._tasks: dict[str, tuple[asyncio.Task[None], type | None]] = {}
167+
self._tasks: dict[str, tuple[asyncio.Future[None], type | None]] = {}
166168

167169
def now(self) -> int:
168170
if self._running:
@@ -256,25 +258,34 @@ async def handle_message(self, e: Entry) -> None:
256258
pending_info = self._pending_task.pop(e["promise_id"], None)
257259
task_info = self._tasks.pop(e["promise_id"], None)
258260

259-
# Get return type from either pending or running task
261+
id = _decode_id(e["promise_id"])
262+
if id not in self._pending_ops:
263+
return
264+
260265
return_type = None
261266
if pending_info is not None:
262267
_, return_type = pending_info
263268
elif task_info is not None:
264269
_, return_type = task_info
270+
else:
271+
print(e)
272+
raise AssertionError("unreachable")
265273

266-
id = _decode_id(e["promise_id"])
267274
if "error" in e:
268275
self._loop.post_completion(
269276
id,
270277
exception=_decode_error(e["error"]),
271278
)
272279
self._pending_ops.discard(id)
273280
elif "result" in e:
274-
self._loop.post_completion(
275-
id,
276-
result=self._codec.decode_json(e["result"], return_type),
277-
)
281+
try:
282+
result = self._codec.decode_json(e["result"], return_type)
283+
self._loop.post_completion(id, result=result)
284+
except BaseException as exc:
285+
self._loop.post_completion(
286+
id,
287+
exception=exc,
288+
)
278289
self._pending_ops.discard(id)
279290
else:
280291
raise ValueError(f"Invalid promise/complete entry: {e!r}")
@@ -339,6 +350,8 @@ def done(f: OpFuture) -> None:
339350
"type": "promise/create",
340351
})
341352

353+
fut_host: asyncio.Future[None] = asyncio.Future()
354+
342355
async def cb() -> None:
343356
entry: PromiseCompleteEntry = {
344357
"ts": _encode_timestamp(self.now()),
@@ -353,9 +366,15 @@ async def cb() -> None:
353366
entry["error"] = _encode_error(e)
354367
finally:
355368
await self.enqueue_log(entry, True)
369+
fut_host.set_result(None)
356370

357371
_ = self._loop.create_task(cb())
358372

373+
self._tasks[_encode_id(id)] = (
374+
fut_host,
375+
op.return_type,
376+
)
377+
359378
case _:
360379
raise NotImplementedError(f"Unsupported op: {op!r}")
361380

tests/test_pydantic.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, cast
4+
5+
import pytest
6+
from pydantic import BaseModel, TypeAdapter
7+
from typing_extensions import override
8+
9+
from duron import fn
10+
from duron.codec import Codec
11+
from duron.context import Context
12+
from duron.contrib.storage import MemoryLogStorage
13+
14+
if TYPE_CHECKING:
15+
from duron.codec import JSONValue
16+
17+
18+
class PydanticPoint(BaseModel):
19+
x: int
20+
y: int
21+
22+
23+
class PydanticCodec(Codec):
24+
@override
25+
def encode_json(self, result: object) -> JSONValue:
26+
return cast(
27+
"JSONValue", TypeAdapter(type(result)).dump_python(result, mode="json")
28+
)
29+
30+
@override
31+
def decode_json(self, encoded: JSONValue, expected_type: type | None) -> object:
32+
return cast("object", TypeAdapter(expected_type).validate_python(encoded))
33+
34+
35+
@pytest.mark.asyncio
36+
async def test_pydantic_serialize():
37+
@fn(codec=PydanticCodec())
38+
async def activity(ctx: Context) -> PydanticPoint:
39+
def new_pt() -> PydanticPoint:
40+
return PydanticPoint(x=1, y=2)
41+
42+
pt = await ctx.run(new_pt)
43+
return PydanticPoint(x=pt.x + 5, y=pt.y + 10)
44+
45+
log = MemoryLogStorage()
46+
async with activity.create_task(log) as t:
47+
await t.start()
48+
a = await t.wait()
49+
assert type(a) is PydanticPoint
50+
assert a.x == 6 and a.y == 12

0 commit comments

Comments
 (0)