From 7272faadd02a3971fe41bb7327f5af959b2e25dd Mon Sep 17 00:00:00 2001 From: leobin001 Date: Wed, 11 Mar 2026 11:19:16 +0800 Subject: [PATCH] fix: support Union types for stream_type parameter Fixes #607 Allow stream_type to accept Union types like MyModel1 | MyModel2. Changed type hints from Union[Type[BaseModel], Type[dict]] to Any to support Union combinations of BaseModel types. Changes: - Updated PartialType in burr/integrations/pydantic.py to typing.Any - Updated stream_type parameter in streaming_action.pydantic() to Any - Updated _validate_and_extract_signature_types_streaming() signature --- burr/core/action.py | 2 +- burr/integrations/pydantic.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/burr/core/action.py b/burr/core/action.py index b2e7c16d3..fe8e15a5c 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -1365,7 +1365,7 @@ def pydantic( writes: List[str], state_input_type: Type["BaseModel"], state_output_type: Type["BaseModel"], - stream_type: Union[Type["BaseModel"], Type[dict]], + stream_type: Any # Support Union types like MyModel1 | MyModel2, tags: Optional[List[str]] = None, ) -> Callable: """Creates a streaming action that uses pydantic models. diff --git a/burr/integrations/pydantic.py b/burr/integrations/pydantic.py index dd1f95a58..bce207c20 100644 --- a/burr/integrations/pydantic.py +++ b/burr/integrations/pydantic.py @@ -267,7 +267,9 @@ async def async_action_function(state: State, **kwargs) -> State: return decorator -PartialType = Union[Type[pydantic.BaseModel], Type[dict]] +# Support Union types like MyModel1 | MyModel2 for stream_type +# Allow any type hint that could be a BaseModel, dict, or Union of BaseModels +PartialType = typing.Any # Relaxed to support Union[Type[BaseModel], ...] combinations PydanticStreamingActionFunctionSync = Callable[ ..., Generator[Tuple[Union[pydantic.BaseModel, dict], Optional[pydantic.BaseModel]], None, None] @@ -288,11 +290,11 @@ async def async_action_function(state: State, **kwargs) -> State: def _validate_and_extract_signature_types_streaming( fn: PydanticStreamingActionFunction, - stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict]]], + stream_type: Optional[typing.Any], state_input_type: Optional[Type[pydantic.BaseModel]] = None, state_output_type: Optional[Type[pydantic.BaseModel]] = None, ) -> Tuple[ - Type[pydantic.BaseModel], Type[pydantic.BaseModel], Union[Type[dict], Type[pydantic.BaseModel]] + Type[pydantic.BaseModel], Type[pydantic.BaseModel], typing.Any ]: if stream_type is None: # TODO -- derive from the signature