diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index 1b9559b29c..3c848ef281 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -247,7 +247,7 @@ def _parse_schema_from_parameter( _raise_if_schema_unsupported(variant, schema) return schema if ( - get_origin(param.annotation) is Union + get_origin(param.annotation) in (Union, typing_types.UnionType) # only parse simple UnionType, example int | str | float | bool # complex types.UnionType will be invoked in raise branch and all( @@ -330,7 +330,7 @@ def _parse_schema_from_parameter( schema.default = param.default _raise_if_schema_unsupported(variant, schema) return schema - if origin is Union: + if origin in (Union, typing_types.UnionType): schema.any_of = [] schema.type = types.Type.OBJECT unique_types = set() diff --git a/src/google/adk/tools/base_tool.py b/src/google/adk/tools/base_tool.py index 8dd112a6c8..da6423ae0d 100644 --- a/src/google/adk/tools/base_tool.py +++ b/src/google/adk/tools/base_tool.py @@ -17,6 +17,7 @@ from abc import ABC import inspect import logging +import types as typing_types from typing import Any from typing import Callable from typing import get_args @@ -168,7 +169,7 @@ def from_config( value = config_dict[param_name] # Get the actual type T of the parameter if it's Optional[T] - if get_origin(param_type) is Union: + if get_origin(param_type) in (Union, typing_types.UnionType): # This is Optional[T] which is Union[T, None] args = get_args(param_type) if len(args) == 2 and type(None) in args: diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index de59755365..caf2491e2f 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -16,6 +16,7 @@ import inspect import logging +import types as typing_types from typing import Any from typing import Callable from typing import get_args @@ -122,7 +123,7 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]: target_type = param.annotation # Handle Optional[PydanticModel] types - if get_origin(param.annotation) is Union: + if get_origin(param.annotation) in (Union, typing_types.UnionType): union_args = get_args(param.annotation) # Find the non-None type in Optional[T] (which is Union[T, None]) non_none_types = [arg for arg in union_args if arg is not type(None)]