Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/google/adk/tools/_function_parameter_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,9 @@ def _parse_schema_from_parameter(
_raise_if_schema_unsupported(variant, schema)
return schema
if (
get_origin(param.annotation) is Union
get_origin(param.annotation) in (Union, typing_types.UnionType)
# only parse simple UnionType, example int | str | float | bool
# complex types.UnionType will be invoked in raise branch
# complex UnionType will be handled in GenericAlias block below
and all(
(_is_builtin_primitive_or_compound(arg) or arg is type(None))
for arg in get_args(param.annotation)
Expand Down Expand Up @@ -287,8 +287,10 @@ def _parse_schema_from_parameter(
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if isinstance(param.annotation, _GenericAlias) or isinstance(
param.annotation, typing_types.GenericAlias
if (
isinstance(param.annotation, _GenericAlias)
or isinstance(param.annotation, typing_types.GenericAlias)
or isinstance(param.annotation, typing_types.UnionType)
):
origin = get_origin(param.annotation)
args = get_args(param.annotation)
Expand Down Expand Up @@ -330,7 +332,7 @@ def _parse_schema_from_parameter(
schema.default = param.default
_raise_if_schema_unsupported(variant, schema)
return schema
if origin is Union:
if origin in (Union, typing_types.UnionType):
schema.any_of = []
schema.type = types.Type.OBJECT
unique_types = set()
Expand Down
117 changes: 117 additions & 0 deletions tests/unittests/tools/test_build_function_declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,120 @@ def greet(name: str = 'World') -> str:
schema = decl.parameters_json_schema
assert schema['properties']['name']['default'] == 'World'
assert 'name' not in schema.get('required', [])


# ── Pipe-union (X | Y) tests ──────────────────────────────────────────


def test_pipe_union_optional_list():
"""list[str] | None should parse as ARRAY with nullable=True."""

def func(a: list[str] | None):
pass

decl = _automatic_function_calling_util.build_function_declaration(
func=func, variant=GoogleLLMVariant.VERTEX_AI
)
prop = decl.parameters.properties['a']
assert prop.type == types.Type.ARRAY
assert prop.nullable is True


def test_pipe_union_optional_dict():
"""dict[str, int] | None should parse as OBJECT with nullable=True."""

def func(a: dict[str, int] | None):
pass

decl = _automatic_function_calling_util.build_function_declaration(
func=func, variant=GoogleLLMVariant.VERTEX_AI
)
prop = decl.parameters.properties['a']
assert prop.type == types.Type.OBJECT
assert prop.nullable is True


def test_pipe_union_optional_list_with_default():
"""list[str] | None = None should parse as ARRAY, nullable, no default."""

def func(a: list[str] | None = None):
pass

decl = _automatic_function_calling_util.build_function_declaration(
func=func, variant=GoogleLLMVariant.VERTEX_AI
)
prop = decl.parameters.properties['a']
assert prop.type == types.Type.ARRAY
assert prop.nullable is True


def test_pipe_union_simple_primitives():
"""int | str should produce any_of with two types."""

def func(a: int | str):
pass

decl = _automatic_function_calling_util.build_function_declaration(
func=func, variant=GoogleLLMVariant.VERTEX_AI
)
prop = decl.parameters.properties['a']
assert prop.any_of is not None
assert len(prop.any_of) == 2


def test_pipe_union_simple_primitives_with_none():
"""int | str | None should produce any_of + nullable."""

def func(a: int | str | None):
pass

decl = _automatic_function_calling_util.build_function_declaration(
func=func, variant=GoogleLLMVariant.VERTEX_AI
)
prop = decl.parameters.properties['a']
assert prop.any_of is not None
assert len(prop.any_of) == 2
assert prop.nullable is True


def test_pipe_union_complex_multi_type():
"""list[str] | dict[str, int] should produce any_of (VERTEX_AI)."""

def func(a: list[str] | dict[str, int]):
pass

decl = _automatic_function_calling_util.build_function_declaration(
func=func, variant=GoogleLLMVariant.VERTEX_AI
)
prop = decl.parameters.properties['a']
assert prop.any_of is not None
assert len(prop.any_of) == 2


def test_pipe_union_complex_falls_back_for_gemini_api():
"""Complex pipe union for GEMINI_API falls back to pydantic schema."""

def func(a: list[str] | dict[str, int]):
pass

decl = _automatic_function_calling_util.build_function_declaration(
func=func, variant=GoogleLLMVariant.GEMINI_API
)
# GEMINI_API does not support any_of, so the parser falls back to
# pydantic-based json schema generation.
assert decl.name == 'func'


def test_typing_union_optional_list_still_works():
"""Regression: typing.Union[list[str], None] must still work."""
import typing

def func(a: typing.Union[list[str], None]):
pass

decl = _automatic_function_calling_util.build_function_declaration(
func=func, variant=GoogleLLMVariant.VERTEX_AI
)
prop = decl.parameters.properties['a']
assert prop.type == types.Type.ARRAY
assert prop.nullable is True