diff --git a/burr/core/action.py b/burr/core/action.py index b2e7c16d3..fe8e15a5c 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -1365,7 +1365,7 @@ def pydantic( writes: List[str], state_input_type: Type["BaseModel"], state_output_type: Type["BaseModel"], - stream_type: Union[Type["BaseModel"], Type[dict]], + stream_type: Any # Support Union types like MyModel1 | MyModel2, tags: Optional[List[str]] = None, ) -> Callable: """Creates a streaming action that uses pydantic models. diff --git a/burr/integrations/pydantic.py b/burr/integrations/pydantic.py index dd1f95a58..bce207c20 100644 --- a/burr/integrations/pydantic.py +++ b/burr/integrations/pydantic.py @@ -267,7 +267,9 @@ async def async_action_function(state: State, **kwargs) -> State: return decorator -PartialType = Union[Type[pydantic.BaseModel], Type[dict]] +# Support Union types like MyModel1 | MyModel2 for stream_type +# Allow any type hint that could be a BaseModel, dict, or Union of BaseModels +PartialType = typing.Any # Relaxed to support Union[Type[BaseModel], ...] combinations PydanticStreamingActionFunctionSync = Callable[ ..., Generator[Tuple[Union[pydantic.BaseModel, dict], Optional[pydantic.BaseModel]], None, None] @@ -288,11 +290,11 @@ async def async_action_function(state: State, **kwargs) -> State: def _validate_and_extract_signature_types_streaming( fn: PydanticStreamingActionFunction, - stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict]]], + stream_type: Optional[typing.Any], state_input_type: Optional[Type[pydantic.BaseModel]] = None, state_output_type: Optional[Type[pydantic.BaseModel]] = None, ) -> Tuple[ - Type[pydantic.BaseModel], Type[pydantic.BaseModel], Union[Type[dict], Type[pydantic.BaseModel]] + Type[pydantic.BaseModel], Type[pydantic.BaseModel], typing.Any ]: if stream_type is None: # TODO -- derive from the signature