Skip to content

Commit 6c74f1f

Browse files
Correctly deserialize nested dataclasses in task args in case they are optional (#15)
1 parent d7d8155 commit 6c74f1f

3 files changed

Lines changed: 89 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919

2020
- `tilebox-workflows`: Registering duplicate task identifiers with a task runner now raises a `ValueError` instead of
2121
overwriting the existing task.
22+
- `tilebox-workflows`: Fixed a bug where the `deserialize_task` function would fail to deserialize nested dataclasses or
23+
protobuf messages that are wrapped in an `Optional` or `Annotated` type hint.
2224

2325
## [0.41.0] - 2025-08-01
2426

tilebox-workflows/tests/test_task.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from dataclasses import dataclass
3+
from typing import Annotated
34

45
import pytest
56

@@ -9,6 +10,7 @@
910
ExecutionContext,
1011
Task,
1112
TaskMeta,
13+
_get_deserialization_field_type,
1214
deserialize_task,
1315
serialize_task,
1416
)
@@ -348,3 +350,54 @@ def test_serialize_deserialize_task_nested_protobuf_in_nested_dict() -> None:
348350
{"a": {"b": [SampleArgs(some_string="World", some_int=123), SampleArgs(some_string="!", some_int=456)]}},
349351
)
350352
assert deserialize_task(ExampleTaskWithNestedProtobufInNestedDict, serialize_task(task)) == task
353+
354+
355+
class ExampleTaskWithOptionalNestedJson(Task):
356+
x: str
357+
optional_args: NestedJson | None = None
358+
359+
360+
def test_serialize_deserialize_task_nested_optional_json() -> None:
361+
task = ExampleTaskWithOptionalNestedJson("Hello")
362+
assert deserialize_task(ExampleTaskWithOptionalNestedJson, serialize_task(task)) == task
363+
364+
task = ExampleTaskWithOptionalNestedJson("Hello", NestedJson(nested_x="World"))
365+
assert deserialize_task(ExampleTaskWithOptionalNestedJson, serialize_task(task)) == task
366+
367+
368+
class ExampleTaskWithOptionalNestedProtobuf(Task):
369+
x: str
370+
optional_args: SampleArgs | None = None
371+
372+
373+
def test_serialize_deserialize_task_nested_optional_protobuf() -> None:
374+
task = ExampleTaskWithOptionalNestedProtobuf("Hello")
375+
assert deserialize_task(ExampleTaskWithOptionalNestedProtobuf, serialize_task(task)) == task
376+
377+
task = ExampleTaskWithOptionalNestedProtobuf("Hello", SampleArgs(some_string="World", some_int=123))
378+
assert deserialize_task(ExampleTaskWithOptionalNestedProtobuf, serialize_task(task)) == task
379+
380+
381+
class FieldTypesTest(Task):
382+
field1: str
383+
field2: str | None
384+
field3: NestedJson | None
385+
field4: NestedJson | None
386+
field5: Annotated[NestedJson, "some description"]
387+
field6: Annotated[NestedJson, "some description"] | None
388+
field7: Annotated[NestedJson | None, "some description"]
389+
field8: Annotated[NestedJson | None, "some description"]
390+
field9: Annotated[list[NestedJson] | None, "some description"]
391+
392+
393+
def test_get_deserialization_field_type() -> None:
394+
fields = FieldTypesTest.__dataclass_fields__
395+
assert _get_deserialization_field_type(fields["field1"].type) is str
396+
assert _get_deserialization_field_type(fields["field2"].type) is str
397+
assert _get_deserialization_field_type(fields["field3"].type) is NestedJson
398+
assert _get_deserialization_field_type(fields["field4"].type) is NestedJson
399+
assert _get_deserialization_field_type(fields["field5"].type) is NestedJson
400+
assert _get_deserialization_field_type(fields["field6"].type) is NestedJson
401+
assert _get_deserialization_field_type(fields["field7"].type) is NestedJson
402+
assert _get_deserialization_field_type(fields["field8"].type) is NestedJson
403+
assert _get_deserialization_field_type(fields["field9"].type) == list[NestedJson]

tilebox-workflows/tilebox/workflows/task.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import contextlib
22
import inspect
33
import json
4+
import typing
45
from abc import ABC, ABCMeta, abstractmethod
56
from base64 import b64decode, b64encode
67
from collections.abc import Sequence
78
from dataclasses import dataclass, field, fields, is_dataclass
9+
from types import NoneType, UnionType
810
from typing import Any, cast, get_args, get_origin
911

1012
# from python 3.11 onwards this is available as typing.dataclass_transform:
@@ -350,7 +352,7 @@ def deserialize_task(task_cls: type, task_input: bytes) -> Task:
350352
return task_cls() # empty task
351353
if len(task_fields) == 1:
352354
# if there is only one field, we deserialize it directly
353-
field_type = task_fields[0].type
355+
field_type = _get_deserialization_field_type(task_fields[0].type) # type: ignore[arg-type]
354356
if hasattr(field_type, "FromString"): # protobuf message
355357
value = field_type.FromString(task_input) # type: ignore[arg-type]
356358
else:
@@ -372,6 +374,10 @@ def _deserialize_dataclass(cls: type, params: dict[str, Any]) -> Task:
372374

373375

374376
def _deserialize_value(field_type: type, value: Any) -> Any: # noqa: PLR0911
377+
if value is None:
378+
return None
379+
380+
field_type = _get_deserialization_field_type(field_type)
375381
if hasattr(field_type, "FromString"):
376382
return field_type.FromString(b64decode(value))
377383
if is_dataclass(field_type) and isinstance(value, dict):
@@ -398,3 +404,30 @@ def _deserialize_value(field_type: type, value: Any) -> Any: # noqa: PLR0911
398404
return {k: _deserialize_value(type_args[1], v) for k, v in value.items()}
399405

400406
return value
407+
408+
409+
def _get_deserialization_field_type(field_type: type) -> type:
410+
"""
411+
Get the actual underlying type we want to deserialize a field type annotated as.
412+
413+
This correctly handles optional and annotated type hints.
414+
415+
For example, all of the following fields should be deserialized as MyDataclass class
416+
417+
field1: MyDataclass
418+
field2: MyDataclass | None
419+
field3: Optional[MyDataclass]
420+
field4: Annotated[MyDataclass, "some description"]
421+
field5: Annotated[Optional[MyDataclass], "some description"
422+
"""
423+
origin = typing.get_origin(field_type)
424+
if origin in (typing.Union, UnionType): # handle Optional[type] and 'type | None'
425+
args = typing.get_args(field_type)
426+
if len(args) == 2 and args[-1] == NoneType:
427+
return _get_deserialization_field_type(args[0])
428+
if origin == typing.Annotated:
429+
args = typing.get_args(field_type)
430+
if len(args) >= 1:
431+
return _get_deserialization_field_type(args[0])
432+
433+
return field_type

0 commit comments

Comments
 (0)