From 5402e072f5fb9adbc6029ae60a0335cfe19cf4bb Mon Sep 17 00:00:00 2001 From: Jon Sharkey Date: Tue, 6 Jan 2026 13:55:48 -0500 Subject: [PATCH] Move example A2UI tools to SDK --- a2a_agents/python/a2ui_extension/README.md | 5 +- .../python/a2ui_extension/pyproject.toml | 16 + .../a2ui_extension/src/a2ui/a2ui_extension.py | 144 +++---- .../src/a2ui/a2ui_schema_utils.py | 34 ++ .../src/a2ui/send_a2ui_to_client_toolset.py | 331 ++++++++++++++++ .../tests/test_a2ui_extension.py | 111 ++++++ .../tests/test_a2ui_schema_utils.py | 14 +- .../a2ui_extension/tests/test_extension.py | 102 ----- .../tests/test_send_a2ui_to_client_toolset.py | 369 ++++++++++++++++++ samples/agent/adk/rizzcharts/a2ui_toolset.py | 136 ------- samples/agent/adk/rizzcharts/agent.py | 92 +++-- .../agent/adk/rizzcharts/agent_executor.py | 93 +++-- .../agent/adk/rizzcharts/part_converter.py | 81 ---- 13 files changed, 1078 insertions(+), 450 deletions(-) create mode 100644 a2a_agents/python/a2ui_extension/src/a2ui/a2ui_schema_utils.py create mode 100644 a2a_agents/python/a2ui_extension/src/a2ui/send_a2ui_to_client_toolset.py create mode 100644 a2a_agents/python/a2ui_extension/tests/test_a2ui_extension.py rename samples/agent/adk/rizzcharts/a2ui_session_util.py => a2a_agents/python/a2ui_extension/tests/test_a2ui_schema_utils.py (66%) delete mode 100644 a2a_agents/python/a2ui_extension/tests/test_extension.py create mode 100644 a2a_agents/python/a2ui_extension/tests/test_send_a2ui_to_client_toolset.py delete mode 100644 samples/agent/adk/rizzcharts/a2ui_toolset.py delete mode 100644 samples/agent/adk/rizzcharts/part_converter.py diff --git a/a2a_agents/python/a2ui_extension/README.md b/a2a_agents/python/a2ui_extension/README.md index 57679150..139ff1fe 100644 --- a/a2a_agents/python/a2ui_extension/README.md +++ b/a2a_agents/python/a2ui_extension/README.md @@ -1,6 +1,7 @@ # A2UI Extension Implementation -This is the Python implementation of the a2ui extension. +a2ui_extension.py is the Python implementation of the a2ui extension. +send_a2ui_to_client_toolset.py is an example Python implementation of using ADK toolcalls to implement A2UI. ## Running Tests @@ -13,7 +14,7 @@ This is the Python implementation of the a2ui extension. 2. Run the tests ```bash - uv run --with pytest pytest tests/test_extension.py + uv run --with pytest pytest tests/*.py ``` ## Disclaimer diff --git a/a2a_agents/python/a2ui_extension/pyproject.toml b/a2a_agents/python/a2ui_extension/pyproject.toml index 24e4e9ad..f26350a6 100644 --- a/a2a_agents/python/a2ui_extension/pyproject.toml +++ b/a2a_agents/python/a2ui_extension/pyproject.toml @@ -13,3 +13,19 @@ build-backend = "hatchling.build" [[tool.uv.index]] url = "https://pypi.org/simple" default = true + +[tool.pyink] +unstable = true +target-version = [] +pyink-indentation = 2 +pyink-use-majority-quotes = true +pyink-annotation-pragmas = [ + "noqa", + "pylint:", + "type: ignore", + "pytype:", + "mypy:", + "pyright:", + "pyre-", +] + diff --git a/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_extension.py b/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_extension.py index 53f1ef99..8cb581d9 100644 --- a/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_extension.py +++ b/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_extension.py @@ -31,95 +31,99 @@ STANDARD_CATALOG_ID = "https://github.com/google/A2UI/blob/main/specification/0.8/json/standard_catalog_definition.json" + def create_a2ui_part(a2ui_data: dict[str, Any]) -> Part: - """Creates an A2A Part containing A2UI data. - - Args: - a2ui_data: The A2UI data dictionary. - - Returns: - An A2A Part with a DataPart containing the A2UI data. - """ - return Part( - root=DataPart( - data=a2ui_data, - metadata={ - MIME_TYPE_KEY: A2UI_MIME_TYPE, - }, - ) - ) + """Creates an A2A Part containing A2UI data. + + Args: + a2ui_data: The A2UI data dictionary. + + Returns: + An A2A Part with a DataPart containing the A2UI data. + """ + return Part( + root=DataPart( + data=a2ui_data, + metadata={ + MIME_TYPE_KEY: A2UI_MIME_TYPE, + }, + ) + ) def is_a2ui_part(part: Part) -> bool: - """Checks if an A2A Part contains A2UI data. - - Args: - part: The A2A Part to check. - - Returns: - True if the part contains A2UI data, False otherwise. - """ - return ( - isinstance(part.root, DataPart) - and part.root.metadata - and part.root.metadata.get(MIME_TYPE_KEY) == A2UI_MIME_TYPE - ) + """Checks if an A2A Part contains A2UI data. + + Args: + part: The A2A Part to check. + + Returns: + True if the part contains A2UI data, False otherwise. + """ + return ( + isinstance(part.root, DataPart) + and part.root.metadata + and part.root.metadata.get(MIME_TYPE_KEY) == A2UI_MIME_TYPE + ) def get_a2ui_datapart(part: Part) -> Optional[DataPart]: - """Extracts the DataPart containing A2UI data from an A2A Part, if present. + """Extracts the DataPart containing A2UI data from an A2A Part, if present. - Args: - part: The A2A Part to extract A2UI data from. + Args: + part: The A2A Part to extract A2UI data from. - Returns: - The DataPart containing A2UI data if present, None otherwise. - """ - if is_a2ui_part(part): - return part.root - return None + Returns: + The DataPart containing A2UI data if present, None otherwise. + """ + if is_a2ui_part(part): + return part.root + return None AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY = "supportedCatalogIds" AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY = "acceptsInlineCatalogs" + def get_a2ui_agent_extension( accepts_inline_catalogs: bool = False, supported_catalog_ids: List[str] = [], ) -> AgentExtension: - """Creates the A2UI AgentExtension configuration. - - Args: - accepts_inline_catalogs: Whether the agent accepts inline custom catalogs. - supported_catalog_ids: All pre-defined catalogs the agent is known to support. - - Returns: - The configured A2UI AgentExtension. - """ - params = {} - if accepts_inline_catalogs: - params[AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY] = True # Only set if not default of False - - if supported_catalog_ids: - params[AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY] = supported_catalog_ids - - return AgentExtension( - uri=A2UI_EXTENSION_URI, - description="Provides agent driven UI using the A2UI JSON format.", - params=params if params else None, + """Creates the A2UI AgentExtension configuration. + + Args: + accepts_inline_catalogs: Whether the agent accepts inline custom catalogs. + supported_catalog_ids: All pre-defined catalogs the agent is known to support. + + Returns: + The configured A2UI AgentExtension. + """ + params = {} + if accepts_inline_catalogs: + params[AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY] = ( + True # Only set if not default of False ) + if supported_catalog_ids: + params[AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY] = supported_catalog_ids + + return AgentExtension( + uri=A2UI_EXTENSION_URI, + description="Provides agent driven UI using the A2UI JSON format.", + params=params if params else None, + ) + def try_activate_a2ui_extension(context: RequestContext) -> bool: - """Activates the A2UI extension if requested. - - Args: - context: The request context to check. - - Returns: - True if activated, False otherwise. - """ - if A2UI_EXTENSION_URI in context.requested_extensions: - context.add_activated_extension(A2UI_EXTENSION_URI) - return True - return False + """Activates the A2UI extension if requested. + + Args: + context: The request context to check. + + Returns: + True if activated, False otherwise. + """ + if A2UI_EXTENSION_URI in context.requested_extensions: + context.add_activated_extension(A2UI_EXTENSION_URI) + return True + return False diff --git a/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_schema_utils.py b/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_schema_utils.py new file mode 100644 index 00000000..4315dfad --- /dev/null +++ b/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_schema_utils.py @@ -0,0 +1,34 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for A2UI Schema manipulation.""" + +from typing import Any + + +def wrap_as_json_array(a2ui_schema: dict[str, Any]) -> dict[str, Any]: + """Wraps the A2UI schema in an array object to support multiple parts. + + Args: + a2ui_schema: The A2UI schema to wrap. + + Returns: + The wrapped A2UI schema object. + + Raises: + ValueError: If the A2UI schema is empty. + """ + if not a2ui_schema: + raise ValueError("A2UI schema is empty") + return {"type": "array", "items": a2ui_schema} diff --git a/a2a_agents/python/a2ui_extension/src/a2ui/send_a2ui_to_client_toolset.py b/a2a_agents/python/a2ui_extension/src/a2ui/send_a2ui_to_client_toolset.py new file mode 100644 index 00000000..7a503d54 --- /dev/null +++ b/a2a_agents/python/a2ui_extension/src/a2ui/send_a2ui_to_client_toolset.py @@ -0,0 +1,331 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for the SendA2uiToClientToolset and Part Converter. + +This module provides the necessary components to enable an agent to send A2UI (Agent-to-User Interface) +JSON payloads to a client. It includes the `SendA2uiToClientToolset` for managing A2UI tools, a specific tool for the LLM +to send JSON, and a part converter to translate the LLM's tool calls into A2A (Agent-to-Agent) parts. + +This is just one approach for capturing A2UI JSON payloads from an LLM. + +Key Components: + * `SendA2uiToClientToolset`: The main entry point. It accepts providers for determining + if A2UI is enabled and for fetching the A2UI schema. It manages the lifecycle of the + `_SendA2uiJsonToClientTool`. + * `_SendA2uiJsonToClientTool`: A tool exposed to the LLM. It allows the LLM to "call" a function + that effectively sends a JSON payload to the client. This tool validates the JSON against + the provided schema. It automatically wraps the provided schema in an array structure, + instructing the LLM that it can send a list of UI items. + * `convert_send_a2ui_to_client_genai_part_to_a2a_part`: A utility function that intercepts the `send_a2ui_json_to_client` + tool calls from the LLM and converts them into `a2a_types.Part` objects, which are then + returned by the A2A Agent Executor. + +Usage Examples: + + 1. Defining Providers: + You can use simple values or callables (sync or async) for enablement and schema. + + ```python + # Simple boolean and dict + toolset = SendA2uiToClientToolset(a2ui_enabled=True, a2ui_schema=MY_SCHEMA) + + # Async providers + async def check_enabled(ctx: ReadonlyContext) -> bool: + return await some_condition(ctx) + + async def get_schema(ctx: ReadonlyContext) -> dict[str, Any]: + return await fetch_schema(ctx) + + toolset = SendA2uiToClientToolset(a2ui_enabled=check_enabled, a2ui_schema=get_schema) + ``` + + 2. Integration with Agent: + Typically used when initializing an agent's toolset. + + ```python + # In your agent initialization + LlmAgent( + tools=[ + SendA2uiToClientToolset( + a2ui_enabled=True, + a2ui_schema=MY_SCHEMA + ) + ] + ) + ``` + + 3. Integration with Executor: + Configure the executor to use the A2UI part converter. + + ```python + config = A2aAgentExecutorConfig( + genai_part_converter=convert_send_a2ui_to_client_genai_part_to_a2a_part + ) + executor = A2aAgentExecutor(config) + ``` +""" + +import inspect +import json +import logging +from typing import Any, Awaitable, Callable, Optional, TypeAlias, Union + +import jsonschema + +from a2a import types as a2a_types +from a2ui.a2ui_extension import create_a2ui_part +from a2ui.a2ui_schema_utils import wrap_as_json_array +from google.adk.a2a.converters import part_converter +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.models import LlmRequest +from google.adk.tools import base_toolset +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.adk.utils.feature_decorator import experimental +from google.genai import types as genai_types + +logger = logging.getLogger(__name__) + +A2uiEnabledProvider: TypeAlias = Callable[ + [ReadonlyContext], Union[bool, Awaitable[bool]] +] +A2uiSchemaProvider: TypeAlias = Callable[ + [ReadonlyContext], Union[dict[str, Any], Awaitable[dict[str, Any]]] +] + + +@experimental +class SendA2uiToClientToolset(base_toolset.BaseToolset): + """A toolset that provides A2UI Tools and can be enabled/disabled.""" + + def __init__( + self, + a2ui_enabled: Union[bool, A2uiEnabledProvider], + a2ui_schema: Union[dict[str, Any], A2uiSchemaProvider], + ): + super().__init__() + self._a2ui_enabled = a2ui_enabled + self._ui_tools = [self._SendA2uiJsonToClientTool(a2ui_schema)] + + async def _resolve_a2ui_enabled(self, ctx: ReadonlyContext) -> bool: + """The resolved self.a2ui_enabled field to construct instruction for this agent. + + Args: + ctx: The ReadonlyContext to resolve the provider with. + + Returns: + If A2UI is enabled, return True. Otherwise, return False. + """ + if isinstance(self._a2ui_enabled, bool): + return self._a2ui_enabled + else: + a2ui_enabled = self._a2ui_enabled(ctx) + if inspect.isawaitable(a2ui_enabled): + a2ui_enabled = await a2ui_enabled + return a2ui_enabled + + async def get_tools( + self, + readonly_context: Optional[ReadonlyContext] = None, + ) -> list[BaseTool]: + """Returns the list of tools provided by this toolset. + + Args: + readonly_context: The ReadonlyContext for resolving tool enablement. + + Returns: + A list of tools. + """ + use_ui = False + if readonly_context is not None: + use_ui = await self._resolve_a2ui_enabled(readonly_context) + if use_ui: + logger.info("A2UI is ENABLED, adding ui tools") + return self._ui_tools + else: + logger.info("A2UI is DISABLED, not adding ui tools") + return [] + + class _SendA2uiJsonToClientTool(BaseTool): + TOOL_NAME = "send_a2ui_json_to_client" + VALIDATED_A2UI_JSON_KEY = "validated_a2ui_json" + A2UI_JSON_ARG_NAME = "a2ui_json" + TOOL_ERROR_KEY = "error" + + def __init__(self, a2ui_schema: Union[dict[str, Any], A2uiSchemaProvider]): + self._a2ui_schema = a2ui_schema + super().__init__( + name=self.TOOL_NAME, + description=( + "Sends A2UI JSON to the client to render rich UI for the user." + " This tool can be called multiple times in the same call to" + " render multiple UI surfaces.Args: " + f" {self.A2UI_JSON_ARG_NAME}: Valid A2UI JSON Schema to send to" + " the client. The A2UI JSON Schema definition is between" + " ---BEGIN A2UI JSON SCHEMA--- and ---END A2UI JSON SCHEMA--- in" + " the system instructions." + ), + ) + + def _get_declaration(self) -> genai_types.FunctionDeclaration | None: + return genai_types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=genai_types.Schema( + type=genai_types.Type.OBJECT, + properties={ + self.A2UI_JSON_ARG_NAME: genai_types.Schema( + type=genai_types.Type.STRING, + description="valid A2UI JSON Schema to send to the client.", + ), + }, + required=[self.A2UI_JSON_ARG_NAME], + ), + ) + + async def _resolve_a2ui_schema(self, ctx: ReadonlyContext) -> dict[str, Any]: + """The resolved self.a2ui_schema field to construct instruction for this agent. + + Args: + ctx: The ReadonlyContext to resolve the provider with. + + Returns: + The A2UI schema to send to the client. + """ + if isinstance(self._a2ui_schema, dict): + return self._a2ui_schema + else: + a2ui_schema = self._a2ui_schema(ctx) + if inspect.isawaitable(a2ui_schema): + a2ui_schema = await a2ui_schema + return a2ui_schema + + async def get_a2ui_schema(self, ctx: ReadonlyContext) -> dict[str, Any]: + """Retrieves and wraps the A2UI schema. + + Args: + ctx: The ReadonlyContext for resolving the schema. + + Returns: + The wrapped A2UI schema. + """ + a2ui_schema = await self._resolve_a2ui_schema(ctx) + return wrap_as_json_array(a2ui_schema) + + async def process_llm_request( + self, *, tool_context: ToolContext, llm_request: LlmRequest + ) -> None: + await super().process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + a2ui_schema = await self.get_a2ui_schema(tool_context) + + llm_request.append_instructions([f""" +---BEGIN A2UI JSON SCHEMA--- +{json.dumps(a2ui_schema)} +---END A2UI JSON SCHEMA--- +"""]) + + logger.info("Added a2ui_schema to system instructions") + + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + try: + a2ui_json = args.get(self.A2UI_JSON_ARG_NAME) + if not a2ui_json: + raise ValueError( + f"Failed to call tool {self.TOOL_NAME} because missing required" + f" arg {self.A2UI_JSON_ARG_NAME} " + ) + + a2ui_json_payload = json.loads(a2ui_json) + + # Auto-wrap single object in list + if not isinstance(a2ui_json_payload, list): + logger.info( + "Received a single JSON object, wrapping in a list for validation." + ) + a2ui_json_payload = [a2ui_json_payload] + + a2ui_schema = await self.get_a2ui_schema(tool_context) + jsonschema.validate(instance=a2ui_json_payload, schema=a2ui_schema) + + logger.info( + f"Validated call to tool {self.TOOL_NAME} with {self.A2UI_JSON_ARG_NAME}" + ) + + # Don't do a second LLM inference call for the JSON response + tool_context.actions.skip_summarization = True + + # Return the validated JSON so the converter can use it. + # We return it in a dict under "result" key for consistent JSON structure. + return {self.VALIDATED_A2UI_JSON_KEY: a2ui_json_payload} + + except Exception as e: + err = f"Failed to call A2UI tool {self.TOOL_NAME}: {e}" + logger.error(err) + + return {self.TOOL_ERROR_KEY: err} + + +@experimental +def convert_send_a2ui_to_client_genai_part_to_a2a_part( + part: genai_types.Part, +) -> list[a2a_types.Part]: + if ( + (function_response := part.function_response) + and function_response.name + == SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_NAME + ): + if ( + SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_ERROR_KEY + in function_response.response + ): + logger.warning( + "A2UI tool call failed:" + f" {function_response.response[SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_ERROR_KEY]}" + ) + return [] + + # The tool returns the list of messages directly on success + json_data = function_response.response.get( + SendA2uiToClientToolset._SendA2uiJsonToClientTool.VALIDATED_A2UI_JSON_KEY + ) + if not json_data: + logger.info("No result in A2UI tool response") + return [] + + final_parts = [] + for message in json_data: + logger.info(f"Found {len(json_data)} messages. Creating individual DataParts.") + final_parts.append(create_a2ui_part(message)) + + return final_parts + + # Don't send a2ui tool call to client + elif ( + (function_call := part.function_call) + and function_call.name + == SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_NAME + ): + return [] + + # Use default part converter for other types (images, etc) + converted_part = part_converter.convert_genai_part_to_a2a_part(part) + + logger.info(f"Returning converted part: {converted_part}") + return [converted_part] if converted_part else [] diff --git a/a2a_agents/python/a2ui_extension/tests/test_a2ui_extension.py b/a2a_agents/python/a2ui_extension/tests/test_a2ui_extension.py new file mode 100644 index 00000000..a410bdb8 --- /dev/null +++ b/a2a_agents/python/a2ui_extension/tests/test_a2ui_extension.py @@ -0,0 +1,111 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from a2a.server.agent_execution import RequestContext +from a2a.types import DataPart, TextPart, Part +from a2ui import a2ui_extension +from a2ui.a2ui_extension import AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY, AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY +from unittest.mock import MagicMock + + +def test_a2ui_part_serialization(): + a2ui_data = { + "beginRendering": {"surfaceId": "test-surface", "root": "root-column"} + } + + part = a2ui_extension.create_a2ui_part(a2ui_data) + + assert a2ui_extension.is_a2ui_part(part), "Should be identified as A2UI part" + + data_part = a2ui_extension.get_a2ui_datapart(part) + assert data_part is not None, "Should contain DataPart" + assert a2ui_data == data_part.data, "Deserialized data should match original" + + +def test_non_a2ui_data_part(): + part = Part( + root=DataPart( + data={"foo": "bar"}, + metadata={"mimeType": "application/json"}, # Not A2UI + ) + ) + assert not a2ui_extension.is_a2ui_part( + part + ), "Should not be identified as A2UI part" + assert ( + a2ui_extension.get_a2ui_datapart(part) is None + ), "Should not return A2UI DataPart" + + +def test_non_a2ui_part(): + text_part = TextPart(text="this is some text") + part = Part(root=text_part) + + assert not a2ui_extension.is_a2ui_part( + part + ), "Should not be identified as A2UI part" + assert ( + a2ui_extension.get_a2ui_datapart(part) is None + ), "Should not return A2UI DataPart" + + +def test_get_a2ui_agent_extension(): + agent_extension = a2ui_extension.get_a2ui_agent_extension() + assert agent_extension.uri == a2ui_extension.A2UI_EXTENSION_URI + assert agent_extension.params is None + + +def test_get_a2ui_agent_extension_with_accepts_inline_catalogs(): + accepts_inline_catalogs = True + agent_extension = a2ui_extension.get_a2ui_agent_extension( + accepts_inline_catalogs=accepts_inline_catalogs + ) + assert agent_extension.uri == a2ui_extension.A2UI_EXTENSION_URI + assert agent_extension.params is not None + assert ( + agent_extension.params.get(AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY) + == accepts_inline_catalogs + ) + + +def test_get_a2ui_agent_extension_with_supported_catalog_ids(): + supported_catalog_ids = ["a", "b", "c"] + agent_extension = a2ui_extension.get_a2ui_agent_extension( + supported_catalog_ids=supported_catalog_ids + ) + assert agent_extension.uri == a2ui_extension.A2UI_EXTENSION_URI + assert agent_extension.params is not None + assert ( + agent_extension.params.get(AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY) + == supported_catalog_ids + ) + + +def test_try_activate_a2ui_extension(): + context = MagicMock(spec=RequestContext) + context.requested_extensions = [a2ui_extension.A2UI_EXTENSION_URI] + + assert a2ui_extension.try_activate_a2ui_extension(context) + context.add_activated_extension.assert_called_once_with( + a2ui_extension.A2UI_EXTENSION_URI + ) + + +def test_try_activate_a2ui_extension_not_requested(): + context = MagicMock(spec=RequestContext) + context.requested_extensions = [] + + assert not a2ui_extension.try_activate_a2ui_extension(context) + context.add_activated_extension.assert_not_called() diff --git a/samples/agent/adk/rizzcharts/a2ui_session_util.py b/a2a_agents/python/a2ui_extension/tests/test_a2ui_schema_utils.py similarity index 66% rename from samples/agent/adk/rizzcharts/a2ui_session_util.py rename to a2a_agents/python/a2ui_extension/tests/test_a2ui_schema_utils.py index d4012e51..652a3d95 100644 --- a/samples/agent/adk/rizzcharts/a2ui_session_util.py +++ b/a2a_agents/python/a2ui_extension/tests/test_a2ui_schema_utils.py @@ -12,6 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -A2UI_ENABLED_STATE_KEY = "user:a2ui_enabled" -A2UI_CATALOG_URI_STATE_KEY = "user:a2ui_catalog_uri" -A2UI_SCHEMA_STATE_KEY = "user:a2ui_schema" \ No newline at end of file +import pytest +from a2ui.a2ui_schema_utils import wrap_as_json_array + + +def test_wrap_as_json_array(): + schema = {"type": "object"} + wrapped = wrap_as_json_array(schema) + assert wrapped == {"type": "array", "items": schema} + + with pytest.raises(ValueError): + wrap_as_json_array({}) diff --git a/a2a_agents/python/a2ui_extension/tests/test_extension.py b/a2a_agents/python/a2ui_extension/tests/test_extension.py deleted file mode 100644 index 6ecb6ab5..00000000 --- a/a2a_agents/python/a2ui_extension/tests/test_extension.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from a2a.server.agent_execution import RequestContext -from a2a.types import DataPart, TextPart, Part -from a2ui import a2ui_extension -from a2ui.a2ui_extension import AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY, AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY -from unittest.mock import MagicMock - - -def test_a2ui_part_serialization(): - a2ui_data = {"beginRendering": {"surfaceId": "test-surface", "root": "root-column"}} - - part = a2ui_extension.create_a2ui_part(a2ui_data) - - assert a2ui_extension.is_a2ui_part(part), "Should be identified as A2UI part" - - data_part = a2ui_extension.get_a2ui_datapart(part) - assert data_part is not None, "Should contain DataPart" - assert a2ui_data == data_part.data, "Deserialized data should match original" - - -def test_non_a2ui_data_part(): - part = Part( - root=DataPart( - data={"foo": "bar"}, metadata={"mimeType": "application/json"} # Not A2UI - ) - ) - assert not a2ui_extension.is_a2ui_part( - part - ), "Should not be identified as A2UI part" - assert ( - a2ui_extension.get_a2ui_datapart(part) is None - ), "Should not return A2UI DataPart" - - -def test_non_a2ui_part(): - text_part = TextPart(text="this is some text") - part = Part(root=text_part) - - assert not a2ui_extension.is_a2ui_part( - part - ), "Should not be identified as A2UI part" - assert ( - a2ui_extension.get_a2ui_datapart(part) is None - ), "Should not return A2UI DataPart" - - -def test_get_a2ui_agent_extension(): - agent_extension = a2ui_extension.get_a2ui_agent_extension() - assert agent_extension.uri == a2ui_extension.A2UI_EXTENSION_URI - assert agent_extension.params is None - - -def test_get_a2ui_agent_extension_with_accepts_inline_catalogs(): - accepts_inline_catalogs = True - agent_extension = a2ui_extension.get_a2ui_agent_extension( - accepts_inline_catalogs=accepts_inline_catalogs - ) - assert agent_extension.uri == a2ui_extension.A2UI_EXTENSION_URI - assert agent_extension.params is not None - assert agent_extension.params.get(AGENT_EXTENSION_ACCEPTS_INLINE_CATALOGS_KEY) == accepts_inline_catalogs - - -def test_get_a2ui_agent_extension_with_supported_catalog_ids(): - supported_catalog_ids = ["a", "b", "c"] - agent_extension = a2ui_extension.get_a2ui_agent_extension( - supported_catalog_ids=supported_catalog_ids - ) - assert agent_extension.uri == a2ui_extension.A2UI_EXTENSION_URI - assert agent_extension.params is not None - assert agent_extension.params.get(AGENT_EXTENSION_SUPPORTED_CATALOG_IDS_KEY) == supported_catalog_ids - - -def test_try_activate_a2ui_extension(): - context = MagicMock(spec=RequestContext) - context.requested_extensions = [a2ui_extension.A2UI_EXTENSION_URI] - - assert a2ui_extension.try_activate_a2ui_extension(context) - context.add_activated_extension.assert_called_once_with( - a2ui_extension.A2UI_EXTENSION_URI - ) - - -def test_try_activate_a2ui_extension_not_requested(): - context = MagicMock(spec=RequestContext) - context.requested_extensions = [] - - assert not a2ui_extension.try_activate_a2ui_extension(context) - context.add_activated_extension.assert_not_called() diff --git a/a2a_agents/python/a2ui_extension/tests/test_send_a2ui_to_client_toolset.py b/a2a_agents/python/a2ui_extension/tests/test_send_a2ui_to_client_toolset.py new file mode 100644 index 00000000..1ed380bf --- /dev/null +++ b/a2a_agents/python/a2ui_extension/tests/test_send_a2ui_to_client_toolset.py @@ -0,0 +1,369 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from a2a import types as a2a_types +from a2ui.a2ui_extension import create_a2ui_part + +from a2ui.send_a2ui_to_client_toolset import convert_send_a2ui_to_client_genai_part_to_a2a_part +from a2ui.send_a2ui_to_client_toolset import SendA2uiToClientToolset +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.tools.tool_context import ToolContext +from google.genai import types as genai_types + +# Basic A2UI Schema for testing +TEST_A2UI_SCHEMA = { + "type": "object", + "properties": {"type": {"const": "Text"}, "text": {"type": "string"}}, + "required": ["type", "text"], +} + +# region SendA2uiToClientToolset Tests +"""Tests for the SendA2uiToClientToolset class.""" + + +@pytest.mark.asyncio +async def test_toolset_init_bool(): + toolset = SendA2uiToClientToolset( + a2ui_enabled=True, a2ui_schema=TEST_A2UI_SCHEMA + ) + ctx = MagicMock(spec=ReadonlyContext) + assert await toolset._resolve_a2ui_enabled(ctx) == True + + # Access the tool to check schema resolution + tool = toolset._ui_tools[0] + assert await tool._resolve_a2ui_schema(ctx) == TEST_A2UI_SCHEMA + + +@pytest.mark.asyncio +async def test_toolset_init_callable(): + enabled_mock = MagicMock(return_value=True) + schema_mock = MagicMock(return_value=TEST_A2UI_SCHEMA) + toolset = SendA2uiToClientToolset( + a2ui_enabled=enabled_mock, a2ui_schema=schema_mock + ) + ctx = MagicMock(spec=ReadonlyContext) + assert await toolset._resolve_a2ui_enabled(ctx) == True + + # Access the tool to check schema resolution + tool = toolset._ui_tools[0] + assert await tool._resolve_a2ui_schema(ctx) == TEST_A2UI_SCHEMA + enabled_mock.assert_called_once_with(ctx) + schema_mock.assert_called_once_with(ctx) + + +@pytest.mark.asyncio +async def test_toolset_init_async_callable(): + async def async_enabled(_ctx): + return True + + async def async_schema(_ctx): + return TEST_A2UI_SCHEMA + + toolset = SendA2uiToClientToolset( + a2ui_enabled=async_enabled, a2ui_schema=async_schema + ) + ctx = MagicMock(spec=ReadonlyContext) + assert await toolset._resolve_a2ui_enabled(ctx) == True + + # Access the tool to check schema resolution + tool = toolset._ui_tools[0] + assert await tool._resolve_a2ui_schema(ctx) == TEST_A2UI_SCHEMA + + +@pytest.mark.asyncio +async def test_toolset_get_tools_enabled(): + toolset = SendA2uiToClientToolset( + a2ui_enabled=True, a2ui_schema=TEST_A2UI_SCHEMA + ) + tools = await toolset.get_tools(MagicMock(spec=ReadonlyContext)) + assert len(tools) == 1 + assert isinstance(tools[0], SendA2uiToClientToolset._SendA2uiJsonToClientTool) + + +@pytest.mark.asyncio +async def test_toolset_get_tools_disabled(): + toolset = SendA2uiToClientToolset( + a2ui_enabled=False, a2ui_schema=TEST_A2UI_SCHEMA + ) + tools = await toolset.get_tools(MagicMock(spec=ReadonlyContext)) + assert len(tools) == 0 + + +# endregion + +# region SendA2uiJsonToClientTool Tests +"""Tests for the _SendA2uiJsonToClientTool class.""" + + +def test_send_tool_init(): + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(TEST_A2UI_SCHEMA) + assert ( + tool.name == SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_NAME + ) + assert tool._a2ui_schema == TEST_A2UI_SCHEMA + + +def test_send_tool_get_declaration(): + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(TEST_A2UI_SCHEMA) + declaration = tool._get_declaration() + assert declaration is not None + assert ( + declaration.name + == SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_NAME + ) + assert ( + SendA2uiToClientToolset._SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME + in declaration.parameters.properties + ) + assert ( + SendA2uiToClientToolset._SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME + in declaration.parameters.required + ) + + +@pytest.mark.asyncio +async def test_send_tool_get_a2ui_schema(): + schema_mock = MagicMock(return_value=TEST_A2UI_SCHEMA) + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(schema_mock) + schema = await tool.get_a2ui_schema(MagicMock(spec=ReadonlyContext)) + assert schema == {"type": "array", "items": TEST_A2UI_SCHEMA} + + +@pytest.mark.asyncio +async def test_send_tool_get_a2ui_schema_empty(): + schema_mock = MagicMock(return_value=None) + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(schema_mock) + with pytest.raises(ValueError): + await tool.get_a2ui_schema(MagicMock(spec=ReadonlyContext)) + + +@pytest.mark.asyncio +async def test_send_tool_process_llm_request(): + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(TEST_A2UI_SCHEMA) + tool_context_mock = MagicMock(spec=ToolContext) + tool_context_mock.state = {} + llm_request_mock = MagicMock() + llm_request_mock.append_instructions = MagicMock() + + await tool.process_llm_request( + tool_context=tool_context_mock, llm_request=llm_request_mock + ) + + llm_request_mock.append_instructions.assert_called_once() + args, _ = llm_request_mock.append_instructions.call_args + instruction = args[0][0] + assert "---BEGIN A2UI JSON SCHEMA---" in instruction + assert json.dumps({"type": "array", "items": TEST_A2UI_SCHEMA}) in instruction + assert "---END A2UI JSON SCHEMA---" in instruction + + +@pytest.mark.asyncio +async def test_send_tool_run_async_valid(): + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(TEST_A2UI_SCHEMA) + tool_context_mock = MagicMock(spec=ToolContext) + tool_context_mock.state = {} + tool_context_mock.actions = MagicMock(skip_summarization=False) + + valid_a2ui = [{"type": "Text", "text": "Hello"}] + args = { + SendA2uiToClientToolset._SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME: ( + json.dumps(valid_a2ui) + ) + } + + result = await tool.run_async(args=args, tool_context=tool_context_mock) + assert result == { + SendA2uiToClientToolset._SendA2uiJsonToClientTool.VALIDATED_A2UI_JSON_KEY: ( + valid_a2ui + ) + } + assert tool_context_mock.actions.skip_summarization == True + + +@pytest.mark.asyncio +async def test_send_tool_run_async_valid_list(): + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(TEST_A2UI_SCHEMA) + tool_context_mock = MagicMock(spec=ToolContext) + tool_context_mock.state = {} + tool_context_mock.actions = MagicMock(skip_summarization=False) + + valid_a2ui = [{"type": "Text", "text": "Hello"}] + args = { + SendA2uiToClientToolset._SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME: ( + json.dumps(valid_a2ui) + ) + } + + result = await tool.run_async(args=args, tool_context=tool_context_mock) + assert result == { + SendA2uiToClientToolset._SendA2uiJsonToClientTool.VALIDATED_A2UI_JSON_KEY: ( + valid_a2ui + ) + } + assert tool_context_mock.actions.skip_summarization == True + + +@pytest.mark.asyncio +async def test_send_tool_run_async_missing_arg(): + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(TEST_A2UI_SCHEMA) + result = await tool.run_async(args={}, tool_context=MagicMock()) + assert "error" in result + assert ( + SendA2uiToClientToolset._SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME + in result["error"] + ) + + +@pytest.mark.asyncio +async def test_send_tool_run_async_invalid_json(): + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(TEST_A2UI_SCHEMA) + args = { + SendA2uiToClientToolset._SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME: ( + "{invalid" + ) + } + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert "error" in result + assert "Failed to call A2UI tool" in result["error"] + + +@pytest.mark.asyncio +async def test_send_tool_run_async_schema_validation_fail(): + tool = SendA2uiToClientToolset._SendA2uiJsonToClientTool(TEST_A2UI_SCHEMA) + invalid_a2ui = [{"type": "Text"}] # Missing 'text' + args = { + SendA2uiToClientToolset._SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME: ( + json.dumps(invalid_a2ui) + ) + } + result = await tool.run_async(args=args, tool_context=MagicMock()) + assert "error" in result + assert "Failed to call A2UI tool" in result["error"] + assert "'text' is a required property" in result["error"] + + +# endregion + +# region send_a2ui_to_client_part_converter Tests +"""Tests for the send_a2ui_to_client_part_converter function.""" + + +def test_converter_convert_valid_response_single(): + valid_a2ui = {"type": "Text", "text": "Hello"} + function_response = genai_types.FunctionResponse( + name=SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_NAME, + response={ + SendA2uiToClientToolset._SendA2uiJsonToClientTool.VALIDATED_A2UI_JSON_KEY: [ + valid_a2ui + ] + }, + ) + part = genai_types.Part(function_response=function_response) + + a2a_parts = convert_send_a2ui_to_client_genai_part_to_a2a_part(part) + assert len(a2a_parts) == 1 + assert a2a_parts[0] == create_a2ui_part(valid_a2ui) + + +def test_converter_convert_valid_response_list(): + valid_a2ui = [ + {"type": "Text", "text": "Hello"}, + {"type": "Text", "text": "World"}, + ] + function_response = genai_types.FunctionResponse( + name=SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_NAME, + response={ + SendA2uiToClientToolset._SendA2uiJsonToClientTool.VALIDATED_A2UI_JSON_KEY: ( + valid_a2ui + ) + }, + ) + part = genai_types.Part(function_response=function_response) + + a2a_parts = convert_send_a2ui_to_client_genai_part_to_a2a_part(part) + assert len(a2a_parts) == 2 + assert a2a_parts[0] == create_a2ui_part(valid_a2ui[0]) + assert a2a_parts[1] == create_a2ui_part(valid_a2ui[1]) + + +def test_converter_convert_function_call_returns_empty(): + # Converter should ignore the function call itself + function_call = genai_types.FunctionCall( + name=SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_NAME, + args={ + SendA2uiToClientToolset._SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME: ( + "..." + ) + }, + ) + part = genai_types.Part(function_call=function_call) + a2a_parts = convert_send_a2ui_to_client_genai_part_to_a2a_part(part) + assert len(a2a_parts) == 0 + + +def test_converter_convert_error_response(): + function_response = genai_types.FunctionResponse( + name=SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_NAME, + response={"error": "Something went wrong"}, + ) + part = genai_types.Part(function_response=function_response) + a2a_parts = convert_send_a2ui_to_client_genai_part_to_a2a_part(part) + assert len(a2a_parts) == 0 + + +def test_converter_convert_empty_result_response(): + function_response = genai_types.FunctionResponse( + name=SendA2uiToClientToolset._SendA2uiJsonToClientTool.TOOL_NAME, + response={}, # Missing result + ) + part = genai_types.Part(function_response=function_response) + a2a_parts = convert_send_a2ui_to_client_genai_part_to_a2a_part(part) + assert len(a2a_parts) == 0 + + +@patch( + "google.adk.a2a.converters.part_converter.convert_genai_part_to_a2a_part" +) +def test_converter_convert_non_a2ui_function_call(mock_convert): + function_call = genai_types.FunctionCall(name="other_tool", args={}) + part = genai_types.Part(function_call=function_call) + mock_a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="test")) + mock_convert.return_value = mock_a2a_part + + a2a_parts = convert_send_a2ui_to_client_genai_part_to_a2a_part(part) + assert len(a2a_parts) == 1 + assert a2a_parts[0] is mock_a2a_part + mock_convert.assert_called_once_with(part) + + +@patch( + "google.adk.a2a.converters.part_converter.convert_genai_part_to_a2a_part" +) +def test_converter_convert_other_part(mock_convert): + part = genai_types.Part(text="Hello") + mock_a2a_part = a2a_types.Part(root=a2a_types.TextPart(text="Hello")) + mock_convert.return_value = mock_a2a_part + + a2a_parts = convert_send_a2ui_to_client_genai_part_to_a2a_part(part) + assert len(a2a_parts) == 1 + assert a2a_parts[0] is mock_a2a_part + mock_convert.assert_called_once_with(part) + + +# endregion diff --git a/samples/agent/adk/rizzcharts/a2ui_toolset.py b/samples/agent/adk/rizzcharts/a2ui_toolset.py deleted file mode 100644 index ffb7cbfb..00000000 --- a/samples/agent/adk/rizzcharts/a2ui_toolset.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import jsonschema -import logging -from typing import Any, List, Optional - -from google.genai import types as genai_types - -from google.adk.models import LlmRequest -from google.adk.tools.base_tool import BaseTool -from google.adk.tools import base_toolset -from google.adk.tools.tool_context import ToolContext -from google.adk.agents.readonly_context import ReadonlyContext -from a2ui_session_util import A2UI_ENABLED_STATE_KEY, A2UI_SCHEMA_STATE_KEY - -logger = logging.getLogger(__name__) - - -class A2uiToolset(base_toolset.BaseToolset): - """A toolset that provides A2UI Tools and can be enabled/disabled.""" - - def __init__(self): - super().__init__() - self._ui_tools = [SendA2uiJsonToClientTool()] - - async def get_tools( - self, - readonly_context: Optional[ReadonlyContext] = None, - ) -> List[BaseTool]: - use_ui = readonly_context and readonly_context.state.get(A2UI_ENABLED_STATE_KEY) - if use_ui: - logger.info("A2UI is ENABLED, adding ui tools") - return self._ui_tools - else: - logger.info("A2UI is DISABLED, not adding ui tools") - return [] - - -class SendA2uiJsonToClientTool(BaseTool): - TOOL_NAME = "send_a2ui_json_to_client" - A2UI_JSON_ARG_NAME = "a2ui_json" - - def __init__(self): - super().__init__( - name=self.TOOL_NAME, - description="Sends A2UI JSON to the client to render rich UI for the user. This tool can be called multiple times in the same call to render multiple UI surfaces." - "Args:" - f" {self.A2UI_JSON_ARG_NAME}: Valid A2UI JSON Schema to send to the client. The A2UI JSON Schema definition is between ---BEGIN A2UI JSON SCHEMA--- and ---END A2UI JSON SCHEMA--- in the system instructions.", - ) - - def _get_declaration(self) -> genai_types.FunctionDeclaration | None: - return genai_types.FunctionDeclaration( - name=self.name, - description=self.description, - parameters=genai_types.Schema( - type=genai_types.Type.OBJECT, - properties={ - self.A2UI_JSON_ARG_NAME: genai_types.Schema( - type=genai_types.Type.STRING, - description="valid A2UI JSON Schema to send to the client.", - ), - }, - required=[self.A2UI_JSON_ARG_NAME], - ), - ) - - def get_a2ui_schema(self, tool_context: ToolContext) -> dict[str, Any]: - a2ui_schema = tool_context.state.get(A2UI_SCHEMA_STATE_KEY) - if not a2ui_schema: - raise ValueError("A2UI schema is empty") - a2ui_schema_object = {"type": "array", "items": a2ui_schema} # Make a list since we support multiple parts in this tool call - return a2ui_schema_object - - async def process_llm_request( - self, *, tool_context: ToolContext, llm_request: LlmRequest - ) -> None: - await super().process_llm_request( - tool_context=tool_context, llm_request=llm_request - ) - - a2ui_schema = self.get_a2ui_schema(tool_context) - - llm_request.append_instructions( - [ - f""" ----BEGIN A2UI JSON SCHEMA--- -{json.dumps(a2ui_schema)} ----END A2UI JSON SCHEMA--- -""" - ] - ) - - logger.info("Added a2ui_schema to system instructions") - - async def run_async( - self, *, args: dict[str, Any], tool_context: ToolContext - ) -> Any: - try: - a2ui_json = args.get(self.A2UI_JSON_ARG_NAME) - if not a2ui_json: - raise ValueError( - f"Failed to call tool {self.TOOL_NAME} because missing required arg {self.A2UI_JSON_ARG_NAME} " - ) - - a2ui_json_payload = json.loads(a2ui_json) - a2ui_schema = self.get_a2ui_schema(tool_context) - jsonschema.validate( - instance=a2ui_json_payload, schema=a2ui_schema - ) - - logger.info( - f"Validated call to tool {self.TOOL_NAME} with {self.A2UI_JSON_ARG_NAME}" - ) - - # Don't do a second LLM inference call for the None response - tool_context.actions.skip_summarization = True - - return None - except Exception as e: - err = f"Failed to call A2UI tool {self.TOOL_NAME}: {e}" - logger.error(err) - - return {"error": err} diff --git a/samples/agent/adk/rizzcharts/agent.py b/samples/agent/adk/rizzcharts/agent.py index 137c1773..892cbbda 100644 --- a/samples/agent/adk/rizzcharts/agent.py +++ b/samples/agent/adk/rizzcharts/agent.py @@ -17,58 +17,90 @@ import os from pathlib import Path from typing import Any + import jsonschema -from google.adk.models.lite_llm import LiteLlm +from a2ui.a2ui_extension import STANDARD_CATALOG_ID +from a2ui.a2ui_schema_utils import wrap_as_json_array +from a2ui.send_a2ui_to_client_toolset import SendA2uiToClientToolset, A2uiEnabledProvider, A2uiSchemaProvider from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.models.lite_llm import LiteLlm from google.adk.planners.built_in_planner import BuiltInPlanner from google.genai import types -from google.adk.agents.readonly_context import ReadonlyContext from tools import get_store_sales, get_sales_data -from a2ui_toolset import A2uiToolset -from a2ui_session_util import A2UI_ENABLED_STATE_KEY, A2UI_CATALOG_URI_STATE_KEY, A2UI_SCHEMA_STATE_KEY -from a2ui.a2ui_extension import STANDARD_CATALOG_ID logger = logging.getLogger(__name__) RIZZCHARTS_CATALOG_URI = "https://github.com/google/A2UI/blob/main/samples/agent/adk/rizzcharts/rizzcharts_catalog_definition.json" +A2UI_CATALOG_URI_STATE_KEY = "user:a2ui_catalog_uri" -class rizzchartsAgent: +class RizzchartsAgent: """An agent that runs an ecommerce dashboard""" SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] - @classmethod - def get_a2ui_schema(cls, readonly_context: ReadonlyContext) -> dict[str, Any]: - a2ui_schema = readonly_context.state.get(A2UI_SCHEMA_STATE_KEY) - if not a2ui_schema: - raise ValueError("A2UI schema is empty") - a2ui_schema_object = {"type": "array", "items": a2ui_schema} # Make a list since we support multiple parts in this tool call - return a2ui_schema_object - - @classmethod - def load_example(cls, path: str, a2ui_schema: dict[str, Any]) -> dict[str, Any]: - example_str = Path(path).read_text() + def __init__(self, a2ui_enabled_provider: A2uiEnabledProvider, a2ui_schema_provider: A2uiSchemaProvider): + """Initializes the RizzchartsAgent. + + Args: + a2ui_enabled_provider: A provider to check if A2UI is enabled. + a2ui_schema_provider: A provider to retrieve the A2UI schema. + """ + self._a2ui_enabled_provider = a2ui_enabled_provider + self._a2ui_schema_provider = a2ui_schema_provider + + def get_a2ui_schema(self, ctx: ReadonlyContext) -> dict[str, Any]: + """Retrieves and wraps the A2UI schema from the session state. + + Args: + ctx: The ReadonlyContext for resolving the schema. + + Returns: + The wrapped A2UI schema. + """ + a2ui_schema = self._a2ui_schema_provider(ctx) + return wrap_as_json_array(a2ui_schema) + + def load_example(self, path: str, a2ui_schema: dict[str, Any]) -> dict[str, Any]: + """Loads an example JSON file and validates it against the A2UI schema. + + Args: + path: Relative path to the example JSON file. + a2ui_schema: The A2UI schema to validate against. + + Returns: + The loaded and validated JSON data. + """ + full_path = Path(__file__).parent / path + example_str = full_path.read_text() example_json = json.loads(example_str) jsonschema.validate( instance=example_json, schema=a2ui_schema ) return example_json - @classmethod - def get_instructions(cls, readonly_context: ReadonlyContext) -> str: - use_ui = readonly_context.state.get(A2UI_ENABLED_STATE_KEY) + def get_instructions(self, readonly_context: ReadonlyContext) -> str: + """Generates the system instructions for the agent. + + Args: + readonly_context: The ReadonlyContext for resolving instructions. + + Returns: + The generated system instructions. + """ + use_ui = self._a2ui_enabled_provider(readonly_context) if not use_ui: raise ValueError("A2UI must be enabled to run rizzcharts agent") - a2ui_schema = cls.get_a2ui_schema(readonly_context) + a2ui_schema = self.get_a2ui_schema(readonly_context) catalog_uri = readonly_context.state.get(A2UI_CATALOG_URI_STATE_KEY) if catalog_uri == RIZZCHARTS_CATALOG_URI: - map_example = cls.load_example("examples/rizzcharts_catalog/map.json", a2ui_schema) - chart_example = cls.load_example("examples/rizzcharts_catalog/chart.json", a2ui_schema) + map_example = self.load_example("examples/rizzcharts_catalog/map.json", a2ui_schema) + chart_example = self.load_example("examples/rizzcharts_catalog/chart.json", a2ui_schema) elif catalog_uri == STANDARD_CATALOG_ID: - map_example = cls.load_example("examples/standard_catalog/map.json", a2ui_schema) - chart_example = cls.load_example("examples/standard_catalog/chart.json", a2ui_schema) + map_example = self.load_example("examples/standard_catalog/map.json", a2ui_schema) + chart_example = self.load_example("examples/standard_catalog/chart.json", a2ui_schema) else: raise ValueError(f"Unsupported catalog uri: {catalog_uri if catalog_uri else 'None'}") @@ -131,8 +163,7 @@ def get_instructions(cls, readonly_context: ReadonlyContext) -> str: return final_prompt - @classmethod - def build_agent(cls) -> LlmAgent: + def build_agent(self) -> LlmAgent: """Builds the LLM agent for the rizzchartsAgent agent.""" LITELLM_MODEL = os.getenv("LITELLM_MODEL", "gemini/gemini-2.5-flash") @@ -140,8 +171,11 @@ def build_agent(cls) -> LlmAgent: model=LiteLlm(model=LITELLM_MODEL), name="rizzcharts_agent", description="An agent that lets sales managers request sales data.", - instruction=cls.get_instructions, - tools=[get_store_sales, get_sales_data, A2uiToolset()], + instruction=self.get_instructions, + tools=[get_store_sales, get_sales_data, SendA2uiToClientToolset( + a2ui_schema=self._a2ui_schema_provider, + a2ui_enabled=self._a2ui_enabled_provider, + )], planner=BuiltInPlanner( thinking_config=types.ThinkingConfig( include_thoughts=True, diff --git a/samples/agent/adk/rizzcharts/agent_executor.py b/samples/agent/adk/rizzcharts/agent_executor.py index de516cc4..3cc94ba3 100644 --- a/samples/agent/adk/rizzcharts/agent_executor.py +++ b/samples/agent/adk/rizzcharts/agent_executor.py @@ -13,41 +13,49 @@ # limitations under the License. import logging +from pathlib import Path from typing import override from a2a.server.agent_execution import RequestContext - +from a2a.types import AgentCapabilities, AgentCard, AgentExtension, AgentSkill +from a2ui.a2ui_extension import A2UI_CLIENT_CAPABILITIES_KEY +from a2ui.a2ui_extension import A2UI_EXTENSION_URI +from a2ui.a2ui_extension import STANDARD_CATALOG_ID +from a2ui.a2ui_extension import get_a2ui_agent_extension +from a2ui.a2ui_extension import try_activate_a2ui_extension +from a2ui.send_a2ui_to_client_toolset import convert_send_a2ui_to_client_genai_part_to_a2a_part +from agent import A2UI_CATALOG_URI_STATE_KEY +from agent import RIZZCHARTS_CATALOG_URI +from agent import RizzchartsAgent +from component_catalog_builder import ComponentCatalogBuilder +from google.adk.a2a.converters.request_converter import AgentRunRequest +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig from google.adk.agents.invocation_context import new_invocation_context_id +from google.adk.agents.readonly_context import ReadonlyContext from google.adk.artifacts import InMemoryArtifactService from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.models import LlmRequest from google.adk.runners import Runner from google.adk.sessions import InMemorySessionService -from google.adk.a2a.converters.request_converter import AgentRunRequest -from google.adk.a2a.executor.a2a_agent_executor import ( - A2aAgentExecutorConfig, - A2aAgentExecutor, -) -from a2ui.a2ui_extension import A2UI_EXTENSION_URI, get_a2ui_agent_extension, try_activate_a2ui_extension, A2UI_CLIENT_CAPABILITIES_KEY -from component_catalog_builder import ComponentCatalogBuilder -from a2a.types import AgentCapabilities, AgentCard, AgentSkill -from a2a.types import AgentExtension -from a2ui_session_util import A2UI_ENABLED_STATE_KEY, A2UI_CATALOG_URI_STATE_KEY, A2UI_SCHEMA_STATE_KEY -from agent import RIZZCHARTS_CATALOG_URI -from a2ui.a2ui_extension import STANDARD_CATALOG_ID - -from agent import rizzchartsAgent -import part_converter -from pathlib import Path logger = logging.getLogger(__name__) +_A2UI_ENABLED_KEY = "system:a2ui_enabled" +_A2UI_SCHEMA_KEY = "system:a2ui_schema" + class RizzchartsAgentExecutor(A2aAgentExecutor): """Contact AgentExecutor Example.""" def __init__(self, base_url: str): + """Initializes the RizzchartsAgentExecutor. + + Args: + base_url: The base URL for the agent. + """ self._base_url = base_url spec_root = Path(__file__).parent / "../../../../specification/0.8/json" @@ -60,7 +68,10 @@ def __init__(self, base_url: str): }, default_catalog_uri=STANDARD_CATALOG_ID ) - agent = rizzchartsAgent.build_agent() + agent = RizzchartsAgent( + a2ui_schema_provider=self.get_a2ui_schema, + a2ui_enabled_provider=self.get_a2ui_enabled, + ).build_agent() runner = Runner( app_name=agent.name, agent=agent, @@ -68,20 +79,24 @@ def __init__(self, base_url: str): session_service=InMemorySessionService(), memory_service=InMemoryMemoryService(), ) - self._part_converter = part_converter.A2uiPartConverter() config = A2aAgentExecutorConfig( - gen_ai_part_converter=self._part_converter.convert_genai_part_to_a2a_part + gen_ai_part_converter=convert_send_a2ui_to_client_genai_part_to_a2a_part ) super().__init__(runner=runner, config=config) def get_agent_card(self) -> AgentCard: + """Returns the AgentCard defining this agent's metadata and skills. + + Returns: + An AgentCard object. + """ return AgentCard( name="Ecommerce Dashboard Agent", description="This agent visualizes ecommerce data, showing sales breakdowns, YOY revenue performance, and regional sales outliers.", url=self._base_url, version="1.0.0", - default_input_modes=rizzchartsAgent.SUPPORTED_CONTENT_TYPES, - default_output_modes=rizzchartsAgent.SUPPORTED_CONTENT_TYPES, + default_input_modes=RizzchartsAgent.SUPPORTED_CONTENT_TYPES, + default_output_modes=RizzchartsAgent.SUPPORTED_CONTENT_TYPES, capabilities=AgentCapabilities( streaming=True, extensions=[get_a2ui_agent_extension( @@ -111,6 +126,28 @@ def get_agent_card(self) -> AgentCard: ], ) + def get_a2ui_schema(self, ctx: ReadonlyContext): + """Retrieves the A2UI schema from the session state. + + Args: + ctx: The ReadonlyContext for resolving the schema. + + Returns: + The A2UI schema or None if not found. + """ + return ctx.state.get(_A2UI_SCHEMA_KEY) + + def get_a2ui_enabled(self, ctx: ReadonlyContext): + """Checks if A2UI is enabled in the current session. + + Args: + ctx: The ReadonlyContext for resolving enablement. + + Returns: + True if A2UI is enabled, False otherwise. + """ + return ctx.state.get(_A2UI_ENABLED_KEY, False) + @override async def _prepare_session( self, @@ -127,10 +164,12 @@ async def _prepare_session( use_ui = try_activate_a2ui_extension(context) if use_ui: - a2ui_schema, catalog_uri = self._component_catalog_builder.load_a2ui_schema(client_ui_capabilities=context.message.metadata.get(A2UI_CLIENT_CAPABILITIES_KEY) if context.message and context.message.metadata else None) + a2ui_schema, catalog_uri = self._component_catalog_builder.load_a2ui_schema( + client_ui_capabilities=context.message.metadata.get(A2UI_CLIENT_CAPABILITIES_KEY) + if context.message and context.message.metadata + else None + ) - self._part_converter.set_a2ui_schema(a2ui_schema) - await runner.session_service.append_event( session, Event( @@ -138,8 +177,8 @@ async def _prepare_session( author="system", actions=EventActions( state_delta={ - A2UI_ENABLED_STATE_KEY: use_ui, - A2UI_SCHEMA_STATE_KEY: a2ui_schema, + _A2UI_ENABLED_KEY: True, + _A2UI_SCHEMA_KEY: a2ui_schema, A2UI_CATALOG_URI_STATE_KEY: catalog_uri, } ), diff --git a/samples/agent/adk/rizzcharts/part_converter.py b/samples/agent/adk/rizzcharts/part_converter.py deleted file mode 100644 index 448c5552..00000000 --- a/samples/agent/adk/rizzcharts/part_converter.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import jsonschema -import logging -from typing import Any, List - -from a2a import types as a2a_types -from google.genai import types as genai_types - -from google.adk.a2a.converters import part_converter -from a2ui.a2ui_extension import create_a2ui_part -from a2ui_toolset import SendA2uiJsonToClientTool - -logger = logging.getLogger(__name__) - -class A2uiPartConverter: - - def __init__(self): - self._a2ui_schema = None - - def set_a2ui_schema(self, a2ui_schema: dict[str, Any]): - self._a2ui_schema = a2ui_schema - - def convert_genai_part_to_a2a_part(self, part: genai_types.Part) -> List[a2a_types.Part]: - if (function_call := part.function_call) and function_call.name == SendA2uiJsonToClientTool.TOOL_NAME: - if self._a2ui_schema is None: - raise Exception("A2UI schema is not set in part converter") - - try: - a2ui_json = function_call.args.get(SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME) - if a2ui_json is None: - raise ValueError(f"Failed to convert A2UI function call because required arg {SendA2uiJsonToClientTool.A2UI_JSON_ARG_NAME} not found in {str(part)}") - if not a2ui_json.strip(): - logger.info("Empty a2ui_json, skipping") - return [] - - logger.info(f"Converting a2ui json: {a2ui_json}") - - json_data = json.loads(a2ui_json) - a2ui_schema_object = {"type": "array", "items": self._a2ui_schema} # Make a list since we support multiple parts in this tool call - jsonschema.validate( - instance=json_data, schema=a2ui_schema_object - ) - - final_parts = [] - if isinstance(json_data, list): - logger.info( f"Found {len(json_data)} messages. Creating individual DataParts." ) - for message in json_data: - final_parts.append(create_a2ui_part(message)) - else: - # Handle the case where a single JSON object is returned - logger.info("Received a single JSON object. Creating a DataPart." ) - final_parts.append(create_a2ui_part(json_data)) - - return final_parts - except Exception as e: - logger.error(f"Error converting A2UI function call to A2A parts: {str(e)}") - return [] - - # Don't send a2ui tool responses - elif (function_response := part.function_response) and function_response.name == SendA2uiJsonToClientTool.TOOL_NAME: - return [] - - # Use default part converter for other types (images, etc) - converted_part = part_converter.convert_genai_part_to_a2a_part(part) - - logger.info(f"Returning converted part: {converted_part}" ) - return [converted_part] if converted_part else []