Skip to content

Commit 4f11229

Browse files
authored
fix(adk): implement _get_declaration to natively support ADK schema builder (#534)
* fix(adk): Implement generating `FunctionDeclaration` from tool parameters. * test(adk): fix test_3lo_flow_simulation by properly mocking credential_service * chore: improve comments * chore: delint * chore: move `google.genai.types` import to top of file.
1 parent 437908c commit 4f11229

3 files changed

Lines changed: 128 additions & 6 deletions

File tree

packages/toolbox-adk/src/toolbox_adk/tool.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
OAuth2Auth,
2828
)
2929
from google.adk.auth.auth_tool import AuthConfig
30+
from google.genai.types import FunctionDeclaration, Type, Schema
3031
from google.adk.tools.base_tool import BaseTool
3132
from google.adk.tools.tool_context import ToolContext
3233
from toolbox_core.tool import ToolboxTool as CoreToolboxTool
@@ -79,17 +80,57 @@ def __init__(
7980
self._post_hook = post_hook
8081
self._auth_config = auth_config
8182

83+
84+
def _param_type_to_schema_type(self, param_type: str) -> Type:
85+
type_map = {
86+
"string": Type.STRING,
87+
"integer": Type.INTEGER,
88+
"number": Type.NUMBER,
89+
"boolean": Type.BOOLEAN,
90+
"array": Type.ARRAY,
91+
"object": Type.OBJECT,
92+
}
93+
return type_map.get(param_type, Type.STRING)
94+
95+
@override
96+
def _get_declaration(self) -> Optional[FunctionDeclaration]:
97+
"""Gets the function declaration for the tool."""
98+
properties = {}
99+
required = []
100+
101+
# We do not use `google.genai.types.FunctionDeclaration.from_callable`
102+
# here because it explicitly drops argument descriptions from the schema
103+
# properties, lumping them all into the root description instead.
104+
if hasattr(self._core_tool, '_params') and self._core_tool._params:
105+
for param in self._core_tool._params:
106+
properties[param.name] = Schema(
107+
type=self._param_type_to_schema_type(param.type),
108+
description=param.description or ""
109+
)
110+
if param.required:
111+
required.append(param.name)
112+
113+
parameters = Schema(
114+
type=Type.OBJECT,
115+
properties=properties,
116+
required=required
117+
) if properties else None
118+
119+
return FunctionDeclaration(
120+
name=self.name,
121+
description=self.description,
122+
parameters=parameters
123+
)
124+
82125
@override
83126
async def run_async(
84127
self,
85128
args: Dict[str, Any],
86129
tool_context: ToolContext,
87130
) -> Any:
88-
# 1. Pre-hook
89131
if self._pre_hook:
90132
await self._pre_hook(tool_context, args)
91133

92-
# 2. ADK Auth Integration (3LO)
93134
# Check if USER_IDENTITY is configured
94135
reset_token = None
95136

packages/toolbox-adk/tests/integration/test_integration.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from typing import Any, Optional
1818
from inspect import signature, Parameter
19-
from unittest.mock import MagicMock
19+
from unittest.mock import MagicMock, AsyncMock
2020

2121
import pytest
2222
from pydantic import ValidationError
@@ -60,6 +60,14 @@ async def test_load_toolset_and_run(self):
6060
tool = next((t for t in tools if t.name == "get-row-by-id"), None)
6161
assert tool is not None
6262
assert isinstance(tool, ToolboxTool)
63+
64+
# Verify the function declaration schema builds correctly end-to-end
65+
declaration = tool._get_declaration()
66+
assert declaration is not None
67+
assert declaration.name == "get-row-by-id"
68+
assert declaration.parameters is not None
69+
assert hasattr(declaration.parameters, 'properties')
70+
assert "id" in declaration.parameters.properties
6371

6472
# Run it
6573
ctx = MagicMock()
@@ -173,19 +181,27 @@ async def test_3lo_flow_simulation(self):
173181
tool = tools[0]
174182
assert isinstance(tool, ToolboxTool)
175183
assert tool.name == "get-n-rows"
184+
185+
# Verify the function declaration schema builds correctly end-to-end
186+
declaration = tool._get_declaration()
187+
assert declaration is not None
188+
assert declaration.name == "get-n-rows"
189+
assert "num_rows" in declaration.parameters.properties
176190

177191
# Create a mock context that behaves like ADK's ReadonlyContext
178192
mock_ctx_first = MagicMock()
179193
# Simulate "No Auth Response Found"
180194
mock_ctx_first.get_auth_response.return_value = None
195+
mock_cred_service_first = AsyncMock()
196+
mock_cred_service_first.load_credential.return_value = None
197+
mock_ctx_first._invocation_context = MagicMock()
198+
mock_ctx_first._invocation_context.credential_service = mock_cred_service_first
181199

182200
print("Running tool first time (expecting auth request)...")
183201
result_first = await tool.run_async({"num_rows": "1"}, mock_ctx_first)
184202

185203
# The wrapper should catch the missing creds and request them.
186-
assert (
187-
result_first is None
188-
), "Tool should return None to signal auth requirement"
204+
assert result_first is None, "Tool should return None sig for auth requirement"
189205
mock_ctx_first.request_credential.assert_called_once()
190206

191207
# Inspect the requested config

packages/toolbox-adk/tests/unit/test_tool.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from toolbox_adk.credentials import CredentialConfig, CredentialType
2020
from toolbox_adk.tool import ToolboxTool
21+
from google.genai.types import Type
2122

2223

2324
class TestToolboxTool:
@@ -267,6 +268,70 @@ async def test_3lo_exception_fallback(self):
267268
# Should catch RuntimeError, call request_credential, and return None
268269
assert result is None
269270
ctx.request_credential.assert_called_once()
271+
272+
def test_param_type_to_schema_type(self):
273+
core_tool = MagicMock()
274+
core_tool.__name__ = "mock_tool"
275+
core_tool.__doc__ = "mock doc"
276+
tool = ToolboxTool(core_tool)
277+
278+
assert tool._param_type_to_schema_type("string") == Type.STRING
279+
assert tool._param_type_to_schema_type("integer") == Type.INTEGER
280+
assert tool._param_type_to_schema_type("boolean") == Type.BOOLEAN
281+
assert tool._param_type_to_schema_type("number") == Type.NUMBER
282+
assert tool._param_type_to_schema_type("array") == Type.ARRAY
283+
assert tool._param_type_to_schema_type("object") == Type.OBJECT
284+
assert tool._param_type_to_schema_type("unknown") == Type.STRING
285+
286+
def test_get_declaration(self):
287+
# Create a mock for core tool parameters
288+
class MockParam:
289+
def __init__(self, name, param_type, description, required):
290+
self.name = name
291+
self.type = param_type
292+
self.description = description
293+
self.required = required
294+
295+
core_tool = MagicMock()
296+
core_tool.__name__ = "mock_tool"
297+
core_tool.__doc__ = "mock doc"
298+
core_tool._params = [
299+
MockParam("city", "string", "The city name", True),
300+
MockParam("count", "integer", "Number of results", False)
301+
]
302+
303+
tool = ToolboxTool(core_tool)
304+
declaration = tool._get_declaration()
305+
306+
assert declaration.name == "mock_tool"
307+
assert declaration.description == "mock doc"
308+
309+
parameters = declaration.parameters
310+
assert parameters is not None
311+
assert parameters.type == Type.OBJECT
312+
assert "city" in parameters.properties
313+
assert "count" in parameters.properties
314+
315+
assert parameters.properties["city"].type == Type.STRING
316+
assert parameters.properties["city"].description == "The city name"
317+
318+
assert parameters.properties["count"].type == Type.INTEGER
319+
assert parameters.properties["count"].description == "Number of results"
320+
321+
assert parameters.required == ["city"]
322+
323+
def test_get_declaration_no_params(self):
324+
core_tool = MagicMock()
325+
core_tool.__name__ = "mock_tool"
326+
core_tool.__doc__ = "mock doc"
327+
core_tool._params = []
328+
329+
tool = ToolboxTool(core_tool)
330+
declaration = tool._get_declaration()
331+
332+
assert declaration.name == "mock_tool"
333+
assert declaration.description == "mock doc"
334+
assert getattr(declaration, "parameters", None) is None
270335

271336
def test_init_defaults(self):
272337
# Test initialization with minimal tool metadata checks

0 commit comments

Comments
 (0)