Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/google/adk/tools/_function_parameter_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _parse_schema_from_parameter(
_raise_if_schema_unsupported(variant, schema)
return schema
if (
get_origin(param.annotation) is Union
get_origin(param.annotation) in (Union, typing_types.UnionType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This check will raise an AttributeError on Python versions below 3.10, as types.UnionType was introduced in Python 3.10. To ensure backward compatibility, it's safer to access UnionType using getattr with a default value that won't be matched.

Suggested change
get_origin(param.annotation) in (Union, typing_types.UnionType)
get_origin(param.annotation) in (Union, getattr(typing_types, 'UnionType', object()))

# only parse simple UnionType, example int | str | float | bool
# complex types.UnionType will be invoked in raise branch
and all(
Expand Down Expand Up @@ -330,7 +330,7 @@ def _parse_schema_from_parameter(
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin is Union:
if origin in (Union, typing_types.UnionType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This check will raise an AttributeError on Python versions below 3.10, as types.UnionType was introduced in Python 3.10. To ensure backward compatibility, it's safer to access UnionType using getattr with a default value that won't be matched.

Suggested change
if origin in (Union, typing_types.UnionType):
if origin in (Union, getattr(typing_types, 'UnionType', object())):

schema.any_of = []
schema.type = types.Type.OBJECT
unique_types = set()
Expand Down
3 changes: 2 additions & 1 deletion src/google/adk/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from abc import ABC
import inspect
import logging
import types as typing_types
from typing import Any
from typing import Callable
from typing import get_args
Expand Down Expand Up @@ -168,7 +169,7 @@ def from_config(
value = config_dict[param_name]

# Get the actual type T of the parameter if it's Optional[T]
if get_origin(param_type) is Union:
if get_origin(param_type) in (Union, typing_types.UnionType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This check will raise an AttributeError on Python versions below 3.10, as types.UnionType was introduced in Python 3.10. To ensure backward compatibility, it's safer to access UnionType using getattr with a default value that won't be matched.

Suggested change
if get_origin(param_type) in (Union, typing_types.UnionType):
if get_origin(param_type) in (Union, getattr(typing_types, 'UnionType', object())):

# This is Optional[T] which is Union[T, None]
args = get_args(param_type)
if len(args) == 2 and type(None) in args:
Expand Down
3 changes: 2 additions & 1 deletion src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import inspect
import logging
import types as typing_types
from typing import Any
from typing import Callable
from typing import get_args
Expand Down Expand Up @@ -122,7 +123,7 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
target_type = param.annotation

# Handle Optional[PydanticModel] types
if get_origin(param.annotation) is Union:
if get_origin(param.annotation) in (Union, typing_types.UnionType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This check will raise an AttributeError on Python versions below 3.10, as types.UnionType was introduced in Python 3.10. To ensure backward compatibility, it's safer to access UnionType using getattr with a default value that won't be matched.

Suggested change
if get_origin(param.annotation) in (Union, typing_types.UnionType):
if get_origin(param.annotation) in (Union, getattr(typing_types, 'UnionType', object())):

union_args = get_args(param.annotation)
# Find the non-None type in Optional[T] (which is Union[T, None])
non_none_types = [arg for arg in union_args if arg is not type(None)]
Expand Down