From 15f91b4249581c5632d6a664f07eb5811b7dab50 Mon Sep 17 00:00:00 2001 From: Jon Sharkey Date: Tue, 13 Jan 2026 20:13:58 -0500 Subject: [PATCH] Support Rizzcharts example from internal Google infrastructure --- .../a2ui_extension/src/a2ui/a2ui_extension.py | 6 +- samples/agent/adk/rizzcharts/__main__.py | 48 ++++++++- samples/agent/adk/rizzcharts/agent.py | 87 +++++++++------ .../agent/adk/rizzcharts/agent_executor.py | 102 ++++++++---------- .../rizzcharts/component_catalog_builder.py | 49 ++++----- samples/agent/adk/rizzcharts/tools.py | 18 +++- 6 files changed, 188 insertions(+), 122 deletions(-) diff --git a/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_extension.py b/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_extension.py index 838b5dd2..1ade0231 100644 --- a/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_extension.py +++ b/a2a_agents/python/a2ui_extension/src/a2ui/a2ui_extension.py @@ -123,7 +123,11 @@ def try_activate_a2ui_extension(context: RequestContext) -> bool: Returns: True if activated, False otherwise. """ - if A2UI_EXTENSION_URI in context.requested_extensions: + if A2UI_EXTENSION_URI in context.requested_extensions or ( + context.message + and context.message.extensions + and A2UI_EXTENSION_URI in context.message.extensions + ): context.add_activated_extension(A2UI_EXTENSION_URI) return True return False diff --git a/samples/agent/adk/rizzcharts/__main__.py b/samples/agent/adk/rizzcharts/__main__.py index bf85f625..ab459ee0 100644 --- a/samples/agent/adk/rizzcharts/__main__.py +++ b/samples/agent/adk/rizzcharts/__main__.py @@ -14,13 +14,20 @@ import logging import os +import pathlib import traceback import click from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore -from agent_executor import RizzchartsAgentExecutor +from agent_executor import RizzchartsAgentExecutor, get_a2ui_enabled, get_a2ui_schema +from agent import RizzchartsAgent +from google.adk.artifacts import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.models.lite_llm import LiteLlm +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService from dotenv import load_dotenv from starlette.middleware.cors import CORSMiddleware @@ -46,8 +53,45 @@ def main(host, port): "GEMINI_API_KEY environment variable not set and GOOGLE_GENAI_USE_VERTEXAI is not TRUE." ) + lite_llm_model = os.getenv("LITELLM_MODEL", "gemini/gemini-2.5-flash") + agent = RizzchartsAgent( + model=LiteLlm(model=lite_llm_model), + a2ui_enabled_provider=get_a2ui_enabled, + a2ui_schema_provider=get_a2ui_schema, + ) + runner = Runner( + app_name=agent.name, + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + ) + + current_dir = pathlib.Path(__file__).resolve().parent + spec_root = current_dir / "../../../../specification/v0_8/json" + + try: + a2ui_schema_content = (spec_root / "server_to_client.json").read_text() + standard_catalog_content = ( + spec_root / "standard_catalog_definition.json" + ).read_text() + rizzcharts_catalog_content = ( + current_dir / "rizzcharts_catalog_definition.json" + ).read_text() + except FileNotFoundError as e: + logger.error(f"Failed to load required JSON files: {e}") + exit(1) + + logger.info(f"Loaded schema from {spec_root}") + base_url = f"http://{host}:{port}" - agent_executor = RizzchartsAgentExecutor(base_url=base_url) + agent_executor = RizzchartsAgentExecutor( + base_url=base_url, + runner=runner, + a2ui_schema_content=a2ui_schema_content, + standard_catalog_content=standard_catalog_content, + rizzcharts_catalog_content=rizzcharts_catalog_content, + ) request_handler = DefaultRequestHandler( agent_executor=agent_executor, diff --git a/samples/agent/adk/rizzcharts/agent.py b/samples/agent/adk/rizzcharts/agent.py index 892cbbda..00b81d6d 100644 --- a/samples/agent/adk/rizzcharts/agent.py +++ b/samples/agent/adk/rizzcharts/agent.py @@ -14,39 +14,67 @@ import json import logging -import os from pathlib import Path -from typing import Any - -import jsonschema +import pkgutil +from typing import Any, ClassVar from a2ui.a2ui_extension import STANDARD_CATALOG_ID from a2ui.a2ui_schema_utils import wrap_as_json_array from a2ui.send_a2ui_to_client_toolset import SendA2uiToClientToolset, A2uiEnabledProvider, A2uiSchemaProvider from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.readonly_context import ReadonlyContext -from google.adk.models.lite_llm import LiteLlm from google.adk.planners.built_in_planner import BuiltInPlanner from google.genai import types -from tools import get_store_sales, get_sales_data +import jsonschema +from pydantic import PrivateAttr + +try: + from .tools import get_sales_data, get_store_sales +except ImportError: + from tools import get_sales_data, get_store_sales logger = logging.getLogger(__name__) RIZZCHARTS_CATALOG_URI = "https://github.com/google/A2UI/blob/main/samples/agent/adk/rizzcharts/rizzcharts_catalog_definition.json" A2UI_CATALOG_URI_STATE_KEY = "user:a2ui_catalog_uri" -class RizzchartsAgent: +class RizzchartsAgent(LlmAgent): """An agent that runs an ecommerce dashboard""" - SUPPORTED_CONTENT_TYPES = ["text", "text/plain"] - - def __init__(self, a2ui_enabled_provider: A2uiEnabledProvider, a2ui_schema_provider: A2uiSchemaProvider): + SUPPORTED_CONTENT_TYPES: ClassVar[list[str]] = ["text", "text/plain"] + _a2ui_enabled_provider: A2uiEnabledProvider = PrivateAttr() + _a2ui_schema_provider: A2uiSchemaProvider = PrivateAttr() + + def __init__( + self, + model: Any, + a2ui_enabled_provider: A2uiEnabledProvider, + a2ui_schema_provider: A2uiSchemaProvider + ): """Initializes the RizzchartsAgent. Args: + model: The LLM model to use. a2ui_enabled_provider: A provider to check if A2UI is enabled. a2ui_schema_provider: A provider to retrieve the A2UI schema. """ + super().__init__( + model=model, + name="rizzcharts_agent", + description="An agent that lets sales managers request sales data.", + instruction=self.get_instructions, + tools=[get_store_sales, get_sales_data, SendA2uiToClientToolset( + a2ui_schema=a2ui_schema_provider, + a2ui_enabled=a2ui_enabled_provider, + )], + planner=BuiltInPlanner( + thinking_config=types.ThinkingConfig( + include_thoughts=True, + ) + ), + disallow_transfer_to_peers=True, + ) + self._a2ui_enabled_provider = a2ui_enabled_provider self._a2ui_schema_provider = a2ui_schema_provider @@ -72,8 +100,21 @@ def load_example(self, path: str, a2ui_schema: dict[str, Any]) -> dict[str, Any] Returns: The loaded and validated JSON data. """ - full_path = Path(__file__).parent / path - example_str = full_path.read_text() + data = None + try: + # Try pkgutil first (for Google3) + package_name = __package__ or "" + data = pkgutil.get_data(package_name, path) + except ImportError: + logger.info("pkgutil failed to get data, falling back to file system.") + + if data: + example_str = data.decode("utf-8") + else: + # Fallback to direct Path relative to this file (for local dev) + full_path = Path(__file__).parent / path + example_str = full_path.read_text() + example_json = json.loads(example_str) jsonschema.validate( instance=example_json, schema=a2ui_schema @@ -162,25 +203,3 @@ def get_instructions(self, readonly_context: ReadonlyContext) -> str: logger.info(f"Generated system instructions for A2UI {'ENABLED' if use_ui else 'DISABLED'} and catalog {catalog_uri}") return final_prompt - - def build_agent(self) -> LlmAgent: - """Builds the LLM agent for the rizzchartsAgent agent.""" - LITELLM_MODEL = os.getenv("LITELLM_MODEL", "gemini/gemini-2.5-flash") - - return LlmAgent( - model=LiteLlm(model=LITELLM_MODEL), - name="rizzcharts_agent", - description="An agent that lets sales managers request sales data.", - instruction=self.get_instructions, - tools=[get_store_sales, get_sales_data, SendA2uiToClientToolset( - a2ui_schema=self._a2ui_schema_provider, - a2ui_enabled=self._a2ui_enabled_provider, - )], - planner=BuiltInPlanner( - thinking_config=types.ThinkingConfig( - include_thoughts=True, - ) - ), - disallow_transfer_to_peers=True, - - ) diff --git a/samples/agent/adk/rizzcharts/agent_executor.py b/samples/agent/adk/rizzcharts/agent_executor.py index 81ef27e6..0c5e5ce5 100644 --- a/samples/agent/adk/rizzcharts/agent_executor.py +++ b/samples/agent/adk/rizzcharts/agent_executor.py @@ -24,61 +24,73 @@ from a2ui.a2ui_extension import get_a2ui_agent_extension from a2ui.a2ui_extension import try_activate_a2ui_extension from a2ui.send_a2ui_to_client_toolset import convert_send_a2ui_to_client_genai_part_to_a2a_part -from agent import A2UI_CATALOG_URI_STATE_KEY -from agent import RIZZCHARTS_CATALOG_URI -from agent import RizzchartsAgent -from component_catalog_builder import ComponentCatalogBuilder +try: + from .agent import A2UI_CATALOG_URI_STATE_KEY # pylint: disable=import-error + from .agent import RIZZCHARTS_CATALOG_URI # pylint: disable=import-error + from .agent import RizzchartsAgent # pylint: disable=import-error + from .component_catalog_builder import ComponentCatalogBuilder # pylint: disable=import-error +except ImportError: + from agent import A2UI_CATALOG_URI_STATE_KEY + from agent import RIZZCHARTS_CATALOG_URI + from agent import RizzchartsAgent + from component_catalog_builder import ComponentCatalogBuilder from google.adk.a2a.converters.request_converter import AgentRunRequest from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig from google.adk.agents.invocation_context import new_invocation_context_id from google.adk.agents.readonly_context import ReadonlyContext -from google.adk.artifacts import InMemoryArtifactService from google.adk.events.event import Event from google.adk.events.event_actions import EventActions -from google.adk.memory.in_memory_memory_service import InMemoryMemoryService -from google.adk.models import LlmRequest from google.adk.runners import Runner -from google.adk.sessions import InMemorySessionService logger = logging.getLogger(__name__) _A2UI_ENABLED_KEY = "system:a2ui_enabled" _A2UI_SCHEMA_KEY = "system:a2ui_schema" +def get_a2ui_schema(ctx: ReadonlyContext): + """Retrieves the A2UI schema from the session state. -class RizzchartsAgentExecutor(A2aAgentExecutor): - """Contact AgentExecutor Example.""" + Args: + ctx: The ReadonlyContext for resolving the schema. - def __init__(self, base_url: str): - """Initializes the RizzchartsAgentExecutor. + Returns: + The A2UI schema or None if not found. + """ + return ctx.state.get(_A2UI_SCHEMA_KEY) - Args: - base_url: The base URL for the agent. - """ - self._base_url = base_url +def get_a2ui_enabled(ctx: ReadonlyContext): + """Checks if A2UI is enabled in the current session. + + Args: + ctx: The ReadonlyContext for resolving enablement. + + Returns: + True if A2UI is enabled, False otherwise. + """ + return ctx.state.get(_A2UI_ENABLED_KEY, False) + +class RizzchartsAgentExecutor(A2aAgentExecutor): + """Executor for the Rizzcharts agent that handles A2UI session setup.""" - spec_root = Path(__file__).parent / "../../../../specification/v0_8/json" - + def __init__( + self, + base_url: str, + runner: Runner, + a2ui_schema_content: str, + standard_catalog_content: str, + rizzcharts_catalog_content: str, + ): + self._base_url = base_url self._component_catalog_builder = ComponentCatalogBuilder( - a2ui_schema_path=str(spec_root.joinpath("server_to_client.json")), - uri_to_local_catalog_path={ - STANDARD_CATALOG_ID: str(spec_root.joinpath("standard_catalog_definition.json")), - RIZZCHARTS_CATALOG_URI: "rizzcharts_catalog_definition.json", + a2ui_schema_content=a2ui_schema_content, + uri_to_local_catalog_content={ + STANDARD_CATALOG_ID: standard_catalog_content, + RIZZCHARTS_CATALOG_URI: rizzcharts_catalog_content, }, - default_catalog_uri=STANDARD_CATALOG_ID - ) - agent = RizzchartsAgent( - a2ui_schema_provider=self.get_a2ui_schema, - a2ui_enabled_provider=self.get_a2ui_enabled, - ).build_agent() - runner = Runner( - app_name=agent.name, - agent=agent, - artifact_service=InMemoryArtifactService(), - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), + default_catalog_uri=STANDARD_CATALOG_ID, ) + config = A2aAgentExecutorConfig( gen_ai_part_converter=convert_send_a2ui_to_client_genai_part_to_a2a_part ) @@ -126,28 +138,6 @@ def get_agent_card(self) -> AgentCard: ], ) - def get_a2ui_schema(self, ctx: ReadonlyContext): - """Retrieves the A2UI schema from the session state. - - Args: - ctx: The ReadonlyContext for resolving the schema. - - Returns: - The A2UI schema or None if not found. - """ - return ctx.state.get(_A2UI_SCHEMA_KEY) - - def get_a2ui_enabled(self, ctx: ReadonlyContext): - """Checks if A2UI is enabled in the current session. - - Args: - ctx: The ReadonlyContext for resolving enablement. - - Returns: - True if A2UI is enabled, False otherwise. - """ - return ctx.state.get(_A2UI_ENABLED_KEY, False) - @override async def _prepare_session( self, diff --git a/samples/agent/adk/rizzcharts/component_catalog_builder.py b/samples/agent/adk/rizzcharts/component_catalog_builder.py index af0d95ee..212a8f4a 100644 --- a/samples/agent/adk/rizzcharts/component_catalog_builder.py +++ b/samples/agent/adk/rizzcharts/component_catalog_builder.py @@ -12,26 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import cache -from typing import Any, List, Optional -from pathlib import Path import json import logging -from agent import RIZZCHARTS_CATALOG_URI -from a2ui.a2ui_extension import STANDARD_CATALOG_ID, SUPPORTED_CATALOG_IDS_KEY, INLINE_CATALOGS_KEY +from typing import Any, List, Optional +from a2ui.a2ui_extension import INLINE_CATALOGS_KEY, SUPPORTED_CATALOG_IDS_KEY +try: + from .agent import RIZZCHARTS_CATALOG_URI, STANDARD_CATALOG_ID +except ImportError: + from agent import RIZZCHARTS_CATALOG_URI, STANDARD_CATALOG_ID + logger = logging.getLogger(__name__) class ComponentCatalogBuilder: - def __init__(self, a2ui_schema_path: str, uri_to_local_catalog_path: dict[str, str], default_catalog_uri: Optional[str]): - self._a2ui_schema_path = a2ui_schema_path - self._uri_to_local_catalog_path = uri_to_local_catalog_path + def __init__(self, + a2ui_schema_content: str, + uri_to_local_catalog_content: dict[str, str], + default_catalog_uri: Optional[str], + ): + self._a2ui_schema_content = a2ui_schema_content + self._uri_to_local_catalog_content = uri_to_local_catalog_content self._default_catalog_uri = default_catalog_uri - pass - - @cache - def get_file_content(self, path: str) -> str: - return Path(path).read_text() def load_a2ui_schema(self, client_ui_capabilities: Optional[dict[str, Any]]) -> tuple[dict[str, Any], Optional[str]]: """ @@ -50,20 +51,19 @@ def load_a2ui_schema(self, client_ui_capabilities: Optional[dict[str, Any]]) -> else: catalog_uri = None - inline_catalog_str = client_ui_capabilities.get(INLINE_CATALOGS_KEY) + inline_catalog_str = client_ui_capabilities.get(INLINE_CATALOGS_KEY) elif self._default_catalog_uri: logger.info(f"Using default catalog {self._default_catalog_uri} since client UI capabilities not found") catalog_uri = self._default_catalog_uri inline_catalog_str = None else: - raise ValueError("Client UI capabilities not provided") + raise ValueError("Client UI capabilities not provided") if catalog_uri and inline_catalog_str: raise ValueError(f"Cannot set both {SUPPORTED_CATALOG_IDS_KEY} and {INLINE_CATALOGS_KEY} in ClientUiCapabilities: {client_ui_capabilities}") elif catalog_uri: - if local_path := self._uri_to_local_catalog_path.get(catalog_uri): - logger.info(f"Loading local component catalog with uri {catalog_uri} and local path {local_path}") - catalog_str = self.get_file_content(local_path) + if catalog_str := self._uri_to_local_catalog_content.get(catalog_uri): + logger.info(f"Loading local component catalog with uri {catalog_uri}") catalog_json = json.loads(catalog_str) else: raise ValueError(f"Local component catalog with URI {catalog_uri} not found") @@ -73,16 +73,13 @@ def load_a2ui_schema(self, client_ui_capabilities: Optional[dict[str, Any]]) -> else: raise ValueError("No supported catalogs found in client UI capabilities") - logger.info(f"Loading A2UI schema at {self._a2ui_schema_path}") - a2ui_schema = self.get_file_content(self._a2ui_schema_path) - a2ui_schema_json = json.loads(a2ui_schema) + logger.info("Loading A2UI schema") + a2ui_schema_json = json.loads(self._a2ui_schema_content) a2ui_schema_json["properties"]["surfaceUpdate"]["properties"]["components"]["items"]["properties"]["component"]["properties"] = catalog_json - + return a2ui_schema_json, catalog_uri - + except Exception as e: logger.error(f"Failed to a2ui schema with client ui capabilities {client_ui_capabilities}: {e}") - raise e - - \ No newline at end of file + raise e diff --git a/samples/agent/adk/rizzcharts/tools.py b/samples/agent/adk/rizzcharts/tools.py index 9c31b62a..61424872 100644 --- a/samples/agent/adk/rizzcharts/tools.py +++ b/samples/agent/adk/rizzcharts/tools.py @@ -12,19 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any import logging +from typing import Any logger = logging.getLogger(__name__) -def get_store_sales() -> dict[str, Any]: +def get_store_sales(region: str = "all", **kwargs: Any) -> dict[str, Any]: """ Gets individual store sales + Args: + region: The region to get store sales for. + **kwargs: Additional arguments. + Returns: A dict containing the stores with locations and their sales, and with outlier stores highlighted """ + logger.info("get_store_sales called with region=%s, kwargs=%s", region, kwargs) return { "center": {"lat": 34, "lng": -118.2437}, @@ -49,13 +54,20 @@ def get_store_sales() -> dict[str, Any]: } -def get_sales_data() -> dict[str, Any]: +def get_sales_data(time_period: str = "year", **kwargs: Any) -> dict[str, Any]: """ Gets the sales data. + Args: + time_period: The time period to get sales data for (e.g. 'Q1', 'year'). Defaults to 'year'. + **kwargs: Additional arguments. + Returns: A dict containing the sales breakdown by product category. """ + logger.info( + "get_sales_data called with time_period=%s, kwargs=%s", time_period, kwargs + ) return { "sales_data": [