11import contextlib
22import inspect
33import json
4+ import typing
45from abc import ABC , ABCMeta , abstractmethod
56from base64 import b64decode , b64encode
67from collections .abc import Sequence
78from dataclasses import dataclass , field , fields , is_dataclass
9+ from types import NoneType , UnionType
810from typing import Any , cast , get_args , get_origin
911
1012# from python 3.11 onwards this is available as typing.dataclass_transform:
@@ -350,7 +352,7 @@ def deserialize_task(task_cls: type, task_input: bytes) -> Task:
350352 return task_cls () # empty task
351353 if len (task_fields ) == 1 :
352354 # if there is only one field, we deserialize it directly
353- field_type = task_fields [0 ].type
355+ field_type = _get_deserialization_field_type ( task_fields [0 ].type ) # type: ignore[arg-type]
354356 if hasattr (field_type , "FromString" ): # protobuf message
355357 value = field_type .FromString (task_input ) # type: ignore[arg-type]
356358 else :
@@ -372,6 +374,10 @@ def _deserialize_dataclass(cls: type, params: dict[str, Any]) -> Task:
372374
373375
374376def _deserialize_value (field_type : type , value : Any ) -> Any : # noqa: PLR0911
377+ if value is None :
378+ return None
379+
380+ field_type = _get_deserialization_field_type (field_type )
375381 if hasattr (field_type , "FromString" ):
376382 return field_type .FromString (b64decode (value ))
377383 if is_dataclass (field_type ) and isinstance (value , dict ):
@@ -398,3 +404,30 @@ def _deserialize_value(field_type: type, value: Any) -> Any: # noqa: PLR0911
398404 return {k : _deserialize_value (type_args [1 ], v ) for k , v in value .items ()}
399405
400406 return value
407+
408+
409+ def _get_deserialization_field_type (field_type : type ) -> type :
410+ """
411+ Get the actual underlying type we want to deserialize a field type annotated as.
412+
413+ This correctly handles optional and annotated type hints.
414+
415+ For example, all of the following fields should be deserialized as MyDataclass class
416+
417+ field1: MyDataclass
418+ field2: MyDataclass | None
419+ field3: Optional[MyDataclass]
420+ field4: Annotated[MyDataclass, "some description"]
421+ field5: Annotated[Optional[MyDataclass], "some description"
422+ """
423+ origin = typing .get_origin (field_type )
424+ if origin in (typing .Union , UnionType ): # handle Optional[type] and 'type | None'
425+ args = typing .get_args (field_type )
426+ if len (args ) == 2 and args [- 1 ] == NoneType :
427+ return _get_deserialization_field_type (args [0 ])
428+ if origin == typing .Annotated :
429+ args = typing .get_args (field_type )
430+ if len (args ) >= 1 :
431+ return _get_deserialization_field_type (args [0 ])
432+
433+ return field_type
0 commit comments