diff --git a/py/packages/genkit/pyproject.toml b/py/packages/genkit/pyproject.toml index e0a4a271ea..8074700839 100644 --- a/py/packages/genkit/pyproject.toml +++ b/py/packages/genkit/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "uvloop>=0.21.0; sys_platform != 'win32'", "anyio>=4.9.0", "opentelemetry-instrumentation-logging>=0.60b1", + "websockets>=15.0", ] description = "Genkit AI Framework" keywords = [ diff --git a/py/packages/genkit/src/genkit/ai/_base_async.py b/py/packages/genkit/src/genkit/ai/_base_async.py index 8f849c65b2..1f549beb14 100644 --- a/py/packages/genkit/src/genkit/ai/_base_async.py +++ b/py/packages/genkit/src/genkit/ai/_base_async.py @@ -30,7 +30,7 @@ from genkit.core.environment import is_dev_environment from genkit.core.logging import get_logger from genkit.core.plugin import Plugin -from genkit.core.reflection import create_reflection_asgi_app +from genkit.core.reflection_v1 import create_reflection_asgi_app from genkit.core.registry import Registry from genkit.web.manager._ports import find_free_port_sync diff --git a/py/packages/genkit/src/genkit/core/constants.py b/py/packages/genkit/src/genkit/core/constants.py index 77cbb4c17d..a76e21b14f 100644 --- a/py/packages/genkit/src/genkit/core/constants.py +++ b/py/packages/genkit/src/genkit/core/constants.py @@ -23,3 +23,7 @@ GENKIT_VERSION = DEFAULT_GENKIT_VERSION GENKIT_CLIENT_HEADER = f'genkit-python/{DEFAULT_GENKIT_VERSION}' + +# Reflection API specification version. +# This should match the value in JS (genkit-tools). +GENKIT_REFLECTION_API_SPEC_VERSION = 1 diff --git a/py/packages/genkit/src/genkit/core/reflection.py b/py/packages/genkit/src/genkit/core/reflection.py index b7604a6ce7..2e0839d148 100644 --- a/py/packages/genkit/src/genkit/core/reflection.py +++ b/py/packages/genkit/src/genkit/core/reflection.py @@ -14,573 +14,726 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Development API for inspecting and interacting with Genkit. - -This module provides a reflection API server for inspection and interaction -during development. It exposes endpoints for health checks, action discovery, -and action execution. - -## Caveats - -The reflection API server predates the flows server implementation and differs -in the protocol it uses to interface with the Dev UI. The streaming protocol -uses unadorned JSON per streamed chunk. This may change in the future to use -Server-Sent Events (SSE). - -## Key endpoints - - | Method | Path | Handler | - |--------|---------------------|-----------------------| - | GET | /api/__health | Health check | - | GET | /api/actions | List actions | - | POST | /api/__quitquitquit | Trigger shutdown | - | POST | /api/notify | Handle notification | - | POST | /api/runAction | Run action (streaming)| +"""Reflection API v2 client using WebSocket and JSON-RPC 2.0. + +This module implements a WebSocket-based client that connects to a Genkit +runtime manager server. Unlike v1 which starts an HTTP server, v2 acts as +a client connecting to a centralized manager. + +Key Concepts (ELI5):: + + ┌─────────────────────┬────────────────────────────────────────────────┐ + │ Concept │ ELI5 Explanation │ + ├─────────────────────┼────────────────────────────────────────────────┤ + │ Reflection API v1 │ Genkit starts a server, tools connect to it. │ + │ │ Like opening a shop and waiting for customers. │ + ├─────────────────────┼────────────────────────────────────────────────┤ + │ Reflection API v2 │ Genkit connects to a manager as a client. │ + │ │ Like calling the headquarters to report in. │ + ├─────────────────────┼────────────────────────────────────────────────┤ + │ JSON-RPC 2.0 │ A simple protocol for remote procedure calls. │ + │ │ Like a structured phone conversation. │ + ├─────────────────────┼────────────────────────────────────────────────┤ + │ WebSocket │ A persistent two-way connection. │ + │ │ Like keeping a phone line open. │ + └─────────────────────┴────────────────────────────────────────────────┘ + +Architecture Comparison:: + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ Reflection API V1 (HTTP Server) │ + └─────────────────────────────────────────────────────────────────────────┘ + + Genkit Runtime starts an HTTP server; CLI/DevUI connect to it: + + ┌─────────────────────┐ ┌─────────────────────┐ + │ Genkit CLI │ │ Dev UI │ + │ (Client) │ │ (Client) │ + └──────────┬──────────┘ └──────────┬──────────┘ + │ │ + │ HTTP Requests │ + │ (GET /api/actions, etc) │ + │ │ + ▼ ▼ + ┌───────────────────────────────────────────────────────┐ + │ Genkit Runtime │ + │ ┌────────────────────┐ │ + │ │ HTTP Server │ │ + │ │ (port 3100) │ │ + │ └────────────────────┘ │ + │ ┌────────────────────┐ │ + │ │ Registry │ │ + │ │ (Actions, Flows) │ │ + │ └────────────────────┘ │ + └───────────────────────────────────────────────────────┘ + + Discovery: Runtime writes file to ~/.genkit/{runtimeId}.runtime.json + Connection: CLI reads file, finds port, connects via HTTP + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ Reflection API V2 (WebSocket) │ + └─────────────────────────────────────────────────────────────────────────┘ + + CLI acts as WebSocket server; Genkit Runtimes connect as clients: + + ┌───────────────────────────────────────────────────────┐ + │ Runtime Manager │ + │ (CLI WebSocket Server) │ + │ ┌────────────────────┐ │ + │ │ WebSocket Server │ │ + │ │ (port 4100) │ │ + │ └────────────────────┘ │ + │ ┌────────────────────┐ │ + │ │ Dev UI │ │ + │ └────────────────────┘ │ + └───────────────────────────────────────────────────────┘ + ▲ ▲ ▲ + │ │ │ + WebSocket │ │ │ WebSocket + Connect │ │ │ Connect + │ │ │ + ┌───────────────┴───┐ ┌─────┴─────┐ ┌───┴───────────────┐ + │ Genkit Runtime │ │ Runtime │ │ Genkit Runtime │ + │ (Python app) │ │ (JS app) │ │ (Go app) │ + └───────────────────┘ └───────────┘ └───────────────────┘ + + Discovery: Runtime reads GENKIT_REFLECTION_V2_SERVER env var + Connection: Runtime connects outbound to Manager via WebSocket + +Data Flow (V2):: + + Genkit Runtime Runtime Manager Server + │ │ + │ ──── WebSocket Connect ────► │ + │ │ + │ ──── register (JSON-RPC) ────► │ + │ │ + │ ◄──── configure notification ── │ + │ │ + │ ◄──── listActions request ──── │ + │ ──── response with actions ────► │ + │ │ + │ ◄──── runAction request ──── │ + │ ──── runActionState notif ────► │ (sends traceId early) + │ ──── streamChunk notification ──► │ (if streaming) + │ ──── response with result ────► │ + │ │ + │ ◄──── cancelAction request ──── │ + │ ──── response (cancelled) ────► │ + │ │ + +Protocol Methods (V2):: + + ┌──────────────────┬─────────────────┬─────────┬─────────────────────────┐ + │ Method │ Direction │ Type │ Description │ + ├──────────────────┼─────────────────┼─────────┼─────────────────────────┤ + │ register │ Runtime→Manager │ Notif │ Register runtime info │ + │ configure │ Manager→Runtime │ Notif │ Push config (telemetry) │ + │ listActions │ Manager→Runtime │ Request │ List available actions │ + │ listValues │ Manager→Runtime │ Request │ List values by type │ + │ runAction │ Manager→Runtime │ Request │ Execute an action │ + │ runActionState │ Runtime→Manager │ Notif │ Send traceId early │ + │ streamChunk │ Runtime→Manager │ Notif │ Stream output chunk │ + │ cancelAction │ Manager→Runtime │ Request │ Cancel running action │ + └──────────────────┴─────────────────┴─────────┴─────────────────────────┘ + +Environment Variables: + GENKIT_REFLECTION_V2_SERVER: WebSocket URL to connect to (e.g., ws://localhost:4100) + GENKIT_TELEMETRY_SERVER: Optional telemetry server URL + +Example: + >>> import asyncio + >>> from genkit.core.reflection import ReflectionClientV2 + >>> async def main(): + ... client = ReflectionClientV2(registry, 'ws://localhost:4100') + ... await client.run() + +See Also: + - RFC: https://github.com/firebase/genkit/pull/4211 + - V1 HTTP server implementation: genkit.core.reflection_v1 """ from __future__ import annotations import asyncio import json -from collections.abc import AsyncGenerator, Callable -from typing import Any, cast - -from starlette.applications import Starlette -from starlette.middleware import Middleware -from starlette.middleware.cors import CORSMiddleware -from starlette.requests import Request -from starlette.responses import JSONResponse, StreamingResponse -from starlette.routing import Route - -from genkit.codec import dump_dict, dump_json -from genkit.core.action import Action +import os +import traceback +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import websockets + +from genkit.codec import dump_dict from genkit.core.action.types import ActionKind -from genkit.core.constants import DEFAULT_GENKIT_VERSION +from genkit.core.constants import ( + DEFAULT_GENKIT_VERSION, + GENKIT_REFLECTION_API_SPEC_VERSION, +) from genkit.core.error import get_reflection_json from genkit.core.logging import get_logger -from genkit.core.registry import Registry -from genkit.web.manager.signals import terminate_all_servers -from genkit.web.requests import ( - is_streaming_requested, -) -from genkit.web.typing import ( - Application, - StartupHandler, -) + +if TYPE_CHECKING: + from genkit.core.action import Action + from genkit.core.registry import Registry logger = get_logger(__name__) +# Environment variable for v2 server URL +GENKIT_REFLECTION_V2_SERVER_ENV = 'GENKIT_REFLECTION_V2_SERVER' -async def _list_registered_actions(registry: Registry) -> dict[str, Action]: - """Return all locally registered actions keyed as `//`. - Uses resolve_actions_by_kind() to trigger lazy loading for any actions with - deferred metadata (e.g., file-based prompts), ensuring schemas are available - for the Dev UI. - """ - registered: dict[str, Action] = {} - for kind in ActionKind.__members__.values(): - for name, action in (await registry.resolve_actions_by_kind(kind)).items(): - registered[f'/{kind.value}/{name}'] = action - return registered - - -def _build_actions_payload( - *, - registered_actions: dict[str, Action], - plugin_metas: list[Any], -) -> dict[str, dict[str, Any]]: - """Build payload for GET /api/actions.""" - actions: dict[str, dict[str, Any]] = {} - - # 1) Registered actions (flows/tools/etc). - for key, action in registered_actions.items(): - actions[key] = { - 'key': key, - 'name': action.name, - 'type': action.kind.value, - 'description': action.description, - 'inputSchema': action.input_schema, - 'outputSchema': action.output_schema, - 'metadata': action.metadata, - } +@dataclass +class JsonRpcRequest: + """JSON-RPC 2.0 request or notification.""" - # 2) Plugin-advertised actions (may not be registered yet). - for meta in plugin_metas or []: - try: - key = f'/{meta.kind.value}/{meta.name}' - except Exception as exc: - # Defensive: skip unexpected plugin metadata objects. - logger.warning('Skipping invalid plugin action metadata', error=str(exc)) - continue + jsonrpc: str = '2.0' + method: str = '' + params: dict[str, Any] | list[Any] | None = None + id: int | str | None = None - advertised = { - 'key': key, - 'name': meta.name, - 'type': meta.kind.value, - 'description': getattr(meta, 'description', None), - 'inputSchema': getattr(meta, 'input_json_schema', None), - 'outputSchema': getattr(meta, 'output_json_schema', None), - 'metadata': getattr(meta, 'metadata', None), - } + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result: dict[str, Any] = {'jsonrpc': self.jsonrpc, 'method': self.method} + if self.params is not None: + result['params'] = self.params + if self.id is not None: + result['id'] = self.id + return result - if key not in actions: - actions[key] = advertised - continue - # Merge into the existing (registered) action entry; prefer registered data. - existing = actions[key] +@dataclass +class JsonRpcError: + """JSON-RPC 2.0 error object.""" - if not existing.get('description') and advertised.get('description'): - existing['description'] = advertised['description'] + code: int + message: str + data: Any = None - if not existing.get('inputSchema') and advertised.get('inputSchema'): - existing['inputSchema'] = advertised['inputSchema'] + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + result: dict[str, Any] = {'code': self.code, 'message': self.message} + if self.data is not None: + result['data'] = self.data + return result - if not existing.get('outputSchema') and advertised.get('outputSchema'): - existing['outputSchema'] = advertised['outputSchema'] - existing_meta = existing.get('metadata') or {} - advertised_meta = advertised.get('metadata') or {} - if isinstance(existing_meta, dict) and isinstance(advertised_meta, dict): - # Prefer registered action metadata on key conflicts. - existing['metadata'] = {**advertised_meta, **existing_meta} +@dataclass +class JsonRpcResponse: + """JSON-RPC 2.0 response.""" - return actions + jsonrpc: str = '2.0' + result: Any = None + error: JsonRpcError | None = None + id: int | str | None = None + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + response: dict[str, Any] = {'jsonrpc': self.jsonrpc, 'id': self.id} + if self.error is not None: + response['error'] = self.error.to_dict() + else: + response['result'] = self.result + return response -def create_reflection_asgi_app( - registry: Registry, - on_app_startup: StartupHandler | None = None, - on_app_shutdown: StartupHandler | None = None, - version: str = DEFAULT_GENKIT_VERSION, - _encoding: str = 'utf-8', -) -> Application: - """Create and return a ASGI application for the Genkit reflection API. - Caveats: +@dataclass +class ActiveAction: + """Represents an in-flight action that can be cancelled.""" - The reflection API server predates the flows server implementation and - differs in the protocol it uses to interface with the Dev UI. The - streaming protocol uses unadorned JSON per streamed chunk. This may - change in the future to use Server-Sent Events (SSE). + cancel: Callable[[], None] + start_time: float + trace_id: str + task: asyncio.Task[Any] | None = None - Key endpoints: - | Method | Path | Handler | - |--------|---------------------|-----------------------| - | GET | /api/__health | Health check | - | GET | /api/actions | List actions | - | POST | /api/__quitquitquit | Trigger shutdown | - | POST | /api/notify | Handle notification | - | POST | /api/runAction | Run action (streaming)| +@dataclass +class ActiveActionsMap: + """Thread-safe map of active actions.""" - Args: - registry: The registry to use for the reflection server. - on_app_startup: Optional callback to execute when the app's - lifespan starts. Must be an async function. - on_app_shutdown: Optional callback to execute when the app's - lifespan ends. Must be an async function. - version: The version string to use when setting the value of - the X-GENKIT-VERSION HTTP header. - encoding: The text encoding to use; default 'utf-8'. + _actions: dict[str, ActiveAction] = field(default_factory=dict) + _lock: asyncio.Lock = field(default_factory=asyncio.Lock) - Returns: - An ASGI application configured with the given registry. - """ + async def set(self, trace_id: str, action: ActiveAction) -> None: + """Add an active action.""" + async with self._lock: + self._actions[trace_id] = action - async def handle_health_check(_request: Request) -> JSONResponse: - """Handle health check requests. + async def get(self, trace_id: str) -> ActiveAction | None: + """Get an active action by trace ID.""" + async with self._lock: + return self._actions.get(trace_id) - Args: - _request: The Starlette request object (unused). + async def delete(self, trace_id: str) -> None: + """Remove an active action.""" + async with self._lock: + self._actions.pop(trace_id, None) - Returns: - A JSON response with status code 200. - """ - return JSONResponse(content={'status': 'OK'}) - async def handle_terminate(_request: Request) -> JSONResponse: - """Handle the quit endpoint. +class ReflectionClientV2: + """Reflection API v2 client using WebSocket and JSON-RPC 2.0. - Args: - _request: The Starlette request object (unused). + This client connects to a Genkit runtime manager server and handles + requests for listing actions, running actions, and other reflection + operations. - Returns: - An empty JSON response with status code 200. - """ - await logger.ainfo('Shutting down servers...') - terminate_all_servers() - return JSONResponse(content={'status': 'OK'}) + Attributes: + registry: The Genkit registry containing actions and values. + url: The WebSocket URL to connect to. + active_actions: Map of currently running actions for cancellation. + """ - async def handle_list_actions(_request: Request) -> JSONResponse: - """Handle the request for listing available actions. + def __init__( + self, + registry: Registry, + url: str, + *, + version: str = DEFAULT_GENKIT_VERSION, + configured_envs: list[str] | None = None, + ) -> None: + """Initialize the Reflection v2 client. Args: - _request: The Starlette request object (unused). - - Returns: - A JSON response containing all serializable actions. + registry: The Genkit registry. + url: WebSocket URL to connect to. + version: Genkit version string. + configured_envs: List of configured environments. + """ + self._registry = registry + self._url = url + self._version = version + self._configured_envs = configured_envs or ['dev'] + self._ws: Any = None # WebSocket connection + self._active_actions = ActiveActionsMap() + self._running = False + self._reconnect_delay = 1.0 # seconds + self._max_reconnect_delay = 60.0 # maximum delay for exponential backoff + + @property + def runtime_id(self) -> str: + """Generate a unique runtime ID based on process ID.""" + return str(os.getpid()) + + async def run(self) -> None: + """Run the reflection client with automatic reconnection. + + This method will continuously try to connect to the server and + handle messages. If the connection drops, it will attempt to + reconnect after a delay. """ - registered = await _list_registered_actions(registry) - metas = await registry.list_actions() - actions = _build_actions_payload(registered_actions=registered, plugin_metas=metas) - - return JSONResponse( - content=actions, - status_code=200, - headers={'x-genkit-version': version}, + self._running = True + logger.info(f'Connecting to Reflection v2 server: {self._url}') + + while self._running: + try: + async with websockets.connect(self._url) as ws: + self._ws = ws + self._reconnect_delay = 1.0 # Reset delay on successful connection + logger.info('Connected to Reflection v2 server') + + # Register immediately upon connection + await self._register() + + # Handle messages + async for message in ws: + if isinstance(message, bytes): + message = message.decode('utf-8') + asyncio.create_task(self._handle_message(message)) + + except asyncio.CancelledError: + logger.debug('Reflection v2 client cancelled') + break + except Exception as e: + delay = self._reconnect_delay + logger.debug(f'Failed to connect to Reflection v2 server, retrying in {delay:.1f}s: {e}') + self._ws = None + await asyncio.sleep(self._reconnect_delay) + self._reconnect_delay = min(self._reconnect_delay * 2, self._max_reconnect_delay) + + self._running = False + logger.info('Disconnected from Reflection v2 server') + + async def stop(self) -> None: + """Stop the reflection client.""" + self._running = False + if self._ws: + await self._ws.close() + + async def _register(self) -> None: + """Send registration message to the server.""" + request = JsonRpcRequest( + method='register', + params={ + 'id': self.runtime_id, + 'name': self.runtime_id, + 'pid': os.getpid(), + 'genkitVersion': f'python/{self._version}', + 'reflectionApiSpecVersion': GENKIT_REFLECTION_API_SPEC_VERSION, + 'envs': self._configured_envs, + }, ) + await self._send(request.to_dict()) - async def handle_list_values(request: Request) -> JSONResponse: - """Handle the request for listing registered values. + async def _send(self, message: dict[str, Any]) -> None: + """Send a message to the server.""" + if self._ws is None: + raise RuntimeError('WebSocket not connected') - Args: - request: The Starlette request object. + data = json.dumps(message) + logger.debug(f'Sending v2 message: {data}') + await self._ws.send(data) - Returns: - A JSON response containing value names. - """ - kind = request.query_params.get('type') - if not kind: - return JSONResponse(content='Query parameter "type" is required.', status_code=400) + async def _handle_message(self, data: str) -> None: + """Handle an incoming message from the server.""" + logger.debug(f'Received v2 message: {data}') - if kind != 'defaultModel': - return JSONResponse( - content=f"'type' {kind} is not supported. Only 'defaultModel' is supported", status_code=400 - ) + try: + msg = json.loads(data) + except json.JSONDecodeError as e: + logger.error(f'Failed to parse JSON-RPC message: {e}') + return + + method = msg.get('method', '') + msg_id = msg.get('id') + + if method: + if msg_id is not None: + # Request (has ID, expects response) + await self._handle_request(msg) + else: + # Notification (no ID, no response expected) + await self._handle_notification(msg) + elif msg_id is not None: + # Response to a request we sent + logger.debug(f'Received response for id={msg_id}') - values = registry.list_values(kind) - return JSONResponse(content=values, status_code=200) + async def _handle_request(self, msg: dict[str, Any]) -> None: + """Handle an incoming JSON-RPC request.""" + method = msg.get('method', '') + params = msg.get('params', {}) + msg_id = msg.get('id') - async def handle_list_envs(_request: Request) -> JSONResponse: - """Handle the request for listing environments. + result: Any = None + error: JsonRpcError | None = None - Args: - _request: The Starlette request object (unused). + try: + if method == 'listActions': + result = await self._handle_list_actions() + elif method == 'listValues': + result = await self._handle_list_values(params) + elif method == 'runAction': + # runAction handles its own response + await self._handle_run_action(msg) + return + elif method == 'cancelAction': + result, error = await self._handle_cancel_action(params) + else: + error = JsonRpcError( + code=-32601, + message=f'Method not found: {method}', + ) + except Exception as e: + logger.exception(f'Error handling request {method}') + error = JsonRpcError( + code=-32000, + message=str(e), + data={'stack': traceback.format_exc()}, + ) - Returns: - A JSON response containing environments. - """ - return JSONResponse(content=['dev'], status_code=200) + # Send response + response = JsonRpcResponse( + id=msg_id, + result=result, + error=error, + ) + await self._send(response.to_dict()) - async def handle_notify(_request: Request) -> JSONResponse: - """Handle the notification endpoint. + async def _handle_notification(self, msg: dict[str, Any]) -> None: + """Handle an incoming JSON-RPC notification.""" + method = msg.get('method', '') + params = msg.get('params', {}) - Args: - _request: The Starlette request object (unused). + if method == 'configure': + await self._handle_configure(params) + else: + logger.debug(f'Unknown notification: {method}') + + async def _handle_list_actions(self) -> dict[str, dict[str, Any]]: + """Handle listActions request. Returns: - An empty JSON response with status code 200. + Dictionary of action descriptors keyed by action key. """ - return JSONResponse( - content={}, - status_code=200, - headers={'x-genkit-version': version}, - ) + actions: dict[str, dict[str, Any]] = {} - # Map of active actions indexed by trace ID for cancellation support. - active_actions: dict[str, asyncio.Task[Any]] = {} + # Get registered actions (using resolve to trigger lazy loading) + for kind in ActionKind.__members__.values(): + for name, action in (await self._registry.resolve_actions_by_kind(kind)).items(): + key = f'/{kind.value}/{name}' + actions[key] = self._action_to_desc(action, key) - async def handle_cancel_action(request: Request) -> JSONResponse: - """Handle the cancelAction endpoint. + # Get plugin-advertised actions + metas = await self._registry.list_actions() + for meta in metas or []: + try: + key = f'/{meta.kind.value}/{meta.name}' + if key not in actions: + actions[key] = { + 'key': key, + 'name': meta.name, + 'type': meta.kind.value, + 'description': getattr(meta, 'description', None), + 'inputSchema': getattr(meta, 'input_json_schema', None), + 'outputSchema': getattr(meta, 'output_json_schema', None), + 'metadata': getattr(meta, 'metadata', None), + } + except Exception as e: + logger.warning(f'Skipping invalid plugin action metadata: {e}') - Args: - request: The Starlette request object. + return actions - Returns: - A JSON response. - """ - try: - payload = await request.json() - trace_id = payload.get('traceId') - if not trace_id: - return JSONResponse(content={'error': 'traceId is required'}, status_code=400) - - task = active_actions.get(trace_id) - if task: - _ = task.cancel() - return JSONResponse(content={'message': 'Action cancelled'}, status_code=200) - else: - return JSONResponse(content={'message': 'Action not found or already completed'}, status_code=404) - except Exception as e: - logger.error(f'Error cancelling action: {e}', exc_info=True) - return JSONResponse( - content={'error': 'An unexpected error occurred while cancelling the action.'}, - status_code=500, - ) - - async def handle_run_action( - request: Request, - ) -> JSONResponse | StreamingResponse: - """Handle the runAction endpoint for executing registered actions. + def _action_to_desc(self, action: Action, key: str) -> dict[str, Any]: + """Convert an Action to an action descriptor dictionary.""" + return { + 'key': key, + 'name': action.name, + 'type': action.kind.value, + 'description': action.description, + 'inputSchema': action.input_schema, + 'outputSchema': action.output_schema, + 'metadata': action.metadata, + } - Flow: - 1. Reads and validates the request payload - 2. Looks up the requested action - 3. Executes the action with the provided input - 4. Returns the action result as JSON with trace ID + async def _handle_list_values(self, params: dict[str, Any]) -> list[str]: + """Handle listValues request. Args: - request: The Starlette request object. + params: Request parameters containing 'type'. Returns: - A JSON or StreamingResponse with the action result, or an error - response. + List of value names. + + Raises: + ValueError: If type parameter is missing or unsupported. """ - # Get the action using async resolve. - payload = await request.json() - action = await registry.resolve_action_by_key(payload['key']) - if action is None: - return JSONResponse( - content={'error': f'Action not found: {payload["key"]}'}, - status_code=404, - ) + value_type = params.get('type') - # Run the action. - context = payload.get('context', {}) - action_input = payload.get('input') - stream = is_streaming_requested(request) + if not value_type: + raise ValueError("The 'type' parameter is required for listValues.") - # Wrap execution to track the task for cancellation support - task = asyncio.current_task() + if value_type != 'defaultModel': + raise ValueError(f"Value type '{value_type}' is not supported. Only 'defaultModel' is currently supported.") - def on_trace_start(trace_id: str) -> None: - if task: - active_actions[trace_id] = task + return self._registry.list_values(value_type) - handler = run_streaming_action if stream else run_standard_action + async def _handle_run_action(self, msg: dict[str, Any]) -> None: + """Handle runAction request with streaming support. - try: - return await handler(action, payload, action_input, context, version, on_trace_start) - except asyncio.CancelledError: - logger.info('Action execution cancelled.') - # Can't really send response if cancelled? Starlette/uvicorn closes connection? - # Just raise. - raise - - async def run_streaming_action( - action: Action, - payload: dict[str, Any], - _action_input: object, - context: dict[str, Any], - version: str, - on_trace_start: Callable[[str], None], - ) -> StreamingResponse: - """Handle streaming action execution with early header flushing. - - Uses early header flushing to send X-Genkit-Trace-Id immediately when - the trace starts, enabling the Dev UI to subscribe to SSE for real-time - trace updates. + This method handles its own response sending since it needs to + send intermediate notifications for streaming and telemetry. Args: - action: The action to execute. - payload: Request payload with input data. - action_input: The input for the action. - context: Execution context. - version: The Genkit version header value. - on_trace_start: Callback for trace start. - - Returns: - A StreamingResponse with JSON chunks containing result or error - events. + msg: The JSON-RPC request message. """ - # Use a queue to pass chunks from the callback to the generator - chunk_queue: asyncio.Queue[str | None] = asyncio.Queue() + msg_id = msg.get('id') + params = msg.get('params', {}) + + key = params.get('key', '') + action_input = params.get('input') + context = params.get('context', {}) + stream = params.get('stream', False) - # Event to signal when trace ID is available - trace_id_event: asyncio.Event = asyncio.Event() + # Look up action + action = await self._registry.resolve_action_by_key(key) + if action is None: + await self._send_error(msg_id, -32602, f'Action not found: {key}') + return + + # Get the current task to allow for cancellation + current_task = asyncio.current_task() + + # Track trace ID for telemetry run_trace_id: str | None = None + sent_trace_ids: set[str] = set() - def wrapped_on_trace_start(tid: str) -> None: + async def on_trace_start(tid: str) -> None: nonlocal run_trace_id + if tid in sent_trace_ids: + return + sent_trace_ids.add(tid) run_trace_id = tid - on_trace_start(tid) - trace_id_event.set() # Signal that trace ID is ready - - async def run_action_task() -> None: - """Run the action and put chunks on the queue.""" - try: - def send_chunk(chunk: Any) -> None: # noqa: ANN401 - """Callback that puts chunks on the queue.""" - out = dump_json(chunk) - chunk_queue.put_nowait(f'{out}\n') + # Register active action with task cancellation + # Wrap cancel() in lambda to discard bool return value (expected: () -> None) + def cancel_fn() -> None: + if current_task: + _ = current_task.cancel() + + await self._active_actions.set( + tid, + ActiveAction( + cancel=cancel_fn, + start_time=asyncio.get_event_loop().time(), + trace_id=tid, + task=current_task, + ), + ) - output = await action.arun_raw( - raw_input=payload.get('input'), - on_chunk=send_chunk, - context=context, - on_trace_start=wrapped_on_trace_start, + # Send runActionState notification + notification = JsonRpcRequest( + method='runActionState', + params={ + 'requestId': msg_id, + 'state': {'traceId': tid}, + }, + ) + await self._send(notification.to_dict()) + + # Streaming callback + async def send_chunk(chunk: Any) -> None: # noqa: ANN401 + if stream: + notification = JsonRpcRequest( + method='streamChunk', + params={ + 'requestId': msg_id, + 'chunk': dump_dict(chunk), + }, ) - final_response = { + await self._send(notification.to_dict()) + + # Set up synchronous wrapper for on_trace_start + # (action.arun_raw expects sync callback currently) + def sync_on_trace_start(tid: str) -> None: + asyncio.create_task(on_trace_start(tid)) + + # Synchronous chunk callback wrapper + def sync_send_chunk(chunk: Any) -> None: # noqa: ANN401 + asyncio.create_task(send_chunk(chunk)) + + try: + output = await action.arun_raw( + raw_input=action_input, + on_chunk=sync_send_chunk if stream else None, + context=context, + on_trace_start=sync_on_trace_start, + ) + + # Clean up active action + if run_trace_id: + await self._active_actions.delete(run_trace_id) + + # Send success response + await self._send_response( + msg_id, + { 'result': dump_dict(output.response), 'telemetry': {'traceId': output.trace_id}, - } - chunk_queue.put_nowait(json.dumps(final_response)) + }, + ) - except Exception as e: - error_response = get_reflection_json(e).model_dump(by_alias=True) - # Log with exc_info for pretty exception output via rich/structlog - logger.exception('Error streaming action', exc_info=e) - # Error response also should not have trailing newline (final message) - chunk_queue.put_nowait(json.dumps(error_response)) - # Ensure trace_id_event is set even on error - trace_id_event.set() - - finally: - if not trace_id_event.is_set(): - trace_id_event.set() - # Signal end of stream - chunk_queue.put_nowait(None) - if run_trace_id: - _ = active_actions.pop(run_trace_id, None) - - # Start the action task immediately so trace ID becomes available ASAP - action_task = asyncio.create_task(run_action_task()) - - # Wait for trace ID before returning response - this enables early header flushing - _ = await trace_id_event.wait() - - # Now we have the trace ID, include it in headers - headers = { - 'x-genkit-version': version, - 'Transfer-Encoding': 'chunked', - } - if run_trace_id: - headers['X-Genkit-Trace-Id'] = run_trace_id # pyright: ignore[reportUnreachable] + except asyncio.CancelledError: + logger.info(f'Action {key} with traceId {run_trace_id} was cancelled.') + if run_trace_id: + await self._active_actions.delete(run_trace_id) + await self._send_error( + msg_id, + -32000, + 'Action was cancelled by request.', + data={'traceId': run_trace_id} if run_trace_id else None, + ) - async def stream_generator() -> AsyncGenerator[str, None]: - """Yield chunks from the queue as they arrive.""" - try: - while True: - chunk = await chunk_queue.get() - if chunk is None: - break - yield chunk - finally: - # Cancel task if still running (no-op if already done) - _ = action_task.cancel() - - return StreamingResponse( - stream_generator(), - # Reflection server uses text/plain for streaming (not SSE format) - # to match Go implementation - media_type='text/plain', - headers=headers, - ) + except Exception as e: + logger.exception(f'Error running action {key}') - async def run_standard_action( - action: Action, - payload: dict[str, Any], - _action_input: object, - context: dict[str, Any], - version: str, - on_trace_start: Callable[[str], None], - ) -> StreamingResponse: - """Handle standard (non-streaming) action execution with early header flushing. + # Clean up active action + if run_trace_id: + await self._active_actions.delete(run_trace_id) - Uses StreamingResponse to enable sending the X-Genkit-Trace-Id header - immediately when the trace starts, allowing the Dev UI to subscribe to - the SSE stream for real-time trace updates. + # Send error response + error_data = get_reflection_json(e).model_dump(by_alias=True) + if run_trace_id: + error_data.setdefault('details', {})['traceId'] = run_trace_id + + await self._send_error( + msg_id, + -32000, + str(e), + data=error_data, + ) + + async def _handle_cancel_action(self, params: dict[str, Any]) -> tuple[dict[str, str] | None, JsonRpcError | None]: + """Handle cancelAction request. Args: - action: The action to execute. - payload: Request payload with input data. - action_input: The input for the action. - context: Execution context. - version: The Genkit version header value. - on_trace_start: Callback for trace start. + params: Request parameters containing 'traceId'. Returns: - A StreamingResponse that flushes headers early. + Tuple of (result, error). """ - # Event to signal when trace ID is available - trace_id_event: asyncio.Event = asyncio.Event() - run_trace_id: str | None = None - action_result: dict[str, Any] | None = None - action_error: Exception | None = None + trace_id = params.get('traceId', '') - def wrapped_on_trace_start(tid: str) -> None: - nonlocal run_trace_id - run_trace_id = tid - on_trace_start(tid) - trace_id_event.set() # Signal that trace ID is ready + if not trace_id: + return None, JsonRpcError(code=-32602, message='traceId is required') - async def run_action_and_get_result() -> None: - nonlocal action_result, action_error - try: - output = await action.arun_raw( - raw_input=payload.get('input'), - context=context, - on_trace_start=wrapped_on_trace_start, - ) - action_result = { - 'result': dump_dict(output.response), - 'telemetry': {'traceId': output.trace_id}, - } - except Exception as e: - action_error = e - finally: - if not trace_id_event.is_set(): - trace_id_event.set() - - # Start the action immediately so trace ID becomes available ASAP - action_task = asyncio.create_task(run_action_and_get_result()) - - # Wait for trace ID before returning response - _ = await trace_id_event.wait() - - # Now return streaming response - headers will include trace ID - async def body_generator() -> AsyncGenerator[bytes, None]: - # Wait for action to complete - await action_task - - if action_error: - error_response = get_reflection_json(action_error).model_dump(by_alias=True) - # Log with exc_info for pretty exception output via rich/structlog - logger.exception('Error executing action', exc_info=action_error) - yield json.dumps(error_response).encode('utf-8') - else: - yield json.dumps(action_result).encode('utf-8') + action = await self._active_actions.get(trace_id) + if action is None: + return None, JsonRpcError( + code=-32004, # JSON-RPC implementation-defined server error + message='Action not found or already completed', + ) - if run_trace_id: - _ = active_actions.pop(run_trace_id, None) + # Cancel the action + action.cancel() + await self._active_actions.delete(trace_id) - headers = { - 'x-genkit-version': version, - } - if run_trace_id: - headers['X-Genkit-Trace-Id'] = run_trace_id # pyright: ignore[reportUnreachable] + return {'message': 'Action cancelled'}, None + + async def _handle_configure(self, params: dict[str, Any]) -> None: + """Handle configure notification. - return StreamingResponse( - body_generator(), - media_type='application/json', - headers=headers, + Args: + params: Configuration parameters. + """ + telemetry_url = params.get('telemetryServerUrl', '') + + if not os.environ.get('GENKIT_TELEMETRY_SERVER') and telemetry_url: + # TODO(#4401): Implement telemetry server URL configuration + logger.debug(f'Telemetry server URL configured: {telemetry_url}') + + async def _send_response(self, msg_id: int | str | None, result: object) -> None: + """Send a success response.""" + response = JsonRpcResponse(id=msg_id, result=result) + await self._send(response.to_dict()) + + async def _send_error( + self, + msg_id: int | str | None, + code: int, + message: str, + data: object = None, + ) -> None: + """Send an error response.""" + response = JsonRpcResponse( + id=msg_id, + error=JsonRpcError(code=code, message=message, data=data), ) + await self._send(response.to_dict()) - app = Starlette( - routes=[ - Route('/api/__health', handle_health_check, methods=['GET']), - Route('/api/__quitquitquit', handle_terminate, methods=['GET', 'POST']), # Support both for parity - Route('/api/actions', handle_list_actions, methods=['GET']), - Route('/api/values', handle_list_values, methods=['GET']), - Route('/api/envs', handle_list_envs, methods=['GET']), - Route('/api/notify', handle_notify, methods=['POST']), - Route('/api/runAction', handle_run_action, methods=['POST']), - Route('/api/cancelAction', handle_cancel_action, methods=['POST']), - ], - middleware=[ - Middleware( - CORSMiddleware, # type: ignore[arg-type] - allow_origins=['*'], - allow_methods=['*'], - allow_headers=['*'], - expose_headers=['X-Genkit-Trace-Id', 'X-Genkit-Span-Id', 'x-genkit-version'], - ) - ], - on_startup=[on_app_startup] if on_app_startup else [], - on_shutdown=[on_app_shutdown] if on_app_shutdown else [], - ) - app.active_actions = active_actions # type: ignore[attr-defined] - return cast(Application, app) + +def is_reflection_v2_enabled() -> bool: + """Check if Reflection API v2 is enabled. + + Returns: + True if GENKIT_REFLECTION_V2_SERVER is set, False otherwise. + """ + return bool(os.environ.get(GENKIT_REFLECTION_V2_SERVER_ENV)) + + +def get_reflection_v2_url() -> str | None: + """Get the Reflection API v2 server URL. + + Returns: + The WebSocket URL if set, None otherwise. + """ + return os.environ.get(GENKIT_REFLECTION_V2_SERVER_ENV) diff --git a/py/packages/genkit/src/genkit/core/reflection_v1.py b/py/packages/genkit/src/genkit/core/reflection_v1.py new file mode 100644 index 0000000000..c717ab0695 --- /dev/null +++ b/py/packages/genkit/src/genkit/core/reflection_v1.py @@ -0,0 +1,653 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Reflection API v1: HTTP-based server for Genkit development tools. + +This module provides an HTTP reflection server that exposes endpoints for +health checks, action discovery, and action execution. The CLI and Dev UI +connect to this server as HTTP clients. + +Key Concepts (ELI5):: + + ┌─────────────────────┬────────────────────────────────────────────────┐ + │ Concept │ ELI5 Explanation │ + ├─────────────────────┼────────────────────────────────────────────────┤ + │ Reflection Server │ Genkit starts a mini web server that lets │ + │ │ tools peek inside and run your AI flows. │ + ├─────────────────────┼────────────────────────────────────────────────┤ + │ Action │ Any registered flow, tool, or model that can │ + │ │ be discovered and executed via this API. │ + ├─────────────────────┼────────────────────────────────────────────────┤ + │ Streaming Response │ Results sent piece by piece, like watching │ + │ │ a video stream instead of downloading first. │ + └─────────────────────┴────────────────────────────────────────────────┘ + +Architecture:: + + ┌─────────────────────┐ ┌─────────────────────┐ + │ Genkit CLI │ │ Dev UI │ + │ (HTTP Client) │ │ (HTTP Client) │ + └──────────┬──────────┘ └──────────┬──────────┘ + │ │ + │ HTTP Requests │ + │ (GET /api/actions, etc) │ + │ │ + ▼ ▼ + ┌───────────────────────────────────────────────────────┐ + │ Genkit Runtime │ + │ ┌────────────────────┐ │ + │ │ HTTP Server │ │ + │ │ (port 3100) │ │ + │ └────────────────────┘ │ + │ ┌────────────────────┐ │ + │ │ Registry │ │ + │ │ (Actions, Flows) │ │ + │ └────────────────────┘ │ + └───────────────────────────────────────────────────────┘ + + Discovery: Runtime writes file to ~/.genkit/{runtimeId}.runtime.json + Connection: CLI reads file, finds port, connects via HTTP + +Data Flow:: + + CLI / Dev UI Reflection Server + │ │ + │ ──── GET /api/__health ────► │ + │ ◄──── {"status": "OK"} ──── │ + │ │ + │ ──── GET /api/actions ────► │ + │ ◄──── {actions dict} ──── │ + │ │ + │ ──── POST /api/runAction ────► │ (key, input, context) + │ ◄──── X-Genkit-Trace-Id header ── │ (early flush) + │ ◄──── stream chunks ──── │ (if streaming) + │ ◄──── final result ──── │ + │ │ + │ ──── POST /api/cancelAction ──► │ (traceId) + │ ◄──── {"message": "..."} ──── │ + │ │ + +Key Endpoints:: + + ┌────────┬─────────────────────┬───────────────────────────────────────┐ + │ Method │ Path │ Description │ + ├────────┼─────────────────────┼───────────────────────────────────────┤ + │ GET │ /api/__health │ Health check │ + │ GET │ /api/actions │ List all registered actions │ + │ GET │ /api/values │ List values by type │ + │ GET │ /api/envs │ List configured environments │ + │ POST │ /api/runAction │ Execute action (supports streaming) │ + │ POST │ /api/cancelAction │ Cancel a running action by traceId │ + │ POST │ /api/notify │ Handle notifications │ + │ POST │ /api/__quitquitquit │ Trigger graceful shutdown │ + └────────┴─────────────────────┴───────────────────────────────────────┘ + +Caveats: + The reflection API server predates the flows server implementation and + differs in the protocol it uses to interface with the Dev UI. The + streaming protocol uses unadorned JSON per streamed chunk. This may + change in the future to use Server-Sent Events (SSE). + +See Also: + - V2 WebSocket implementation: genkit.core.reflection +""" + +from __future__ import annotations + +import asyncio +import json +from collections.abc import AsyncGenerator, Callable +from typing import Any, cast + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, StreamingResponse +from starlette.routing import Route + +from genkit.codec import dump_dict, dump_json +from genkit.core.action import Action +from genkit.core.action.types import ActionKind +from genkit.core.constants import DEFAULT_GENKIT_VERSION +from genkit.core.error import get_reflection_json +from genkit.core.logging import get_logger +from genkit.core.registry import Registry +from genkit.web.manager.signals import terminate_all_servers +from genkit.web.requests import ( + is_streaming_requested, +) +from genkit.web.typing import ( + Application, + StartupHandler, +) + +logger = get_logger(__name__) + + +async def _list_registered_actions(registry: Registry) -> dict[str, Action]: + """Return all locally registered actions keyed as `//`. + + Uses resolve_actions_by_kind() to trigger lazy loading for any actions with + deferred metadata (e.g., file-based prompts), ensuring schemas are available + for the Dev UI. + """ + registered: dict[str, Action] = {} + for kind in ActionKind.__members__.values(): + for name, action in (await registry.resolve_actions_by_kind(kind)).items(): + registered[f'/{kind.value}/{name}'] = action + return registered + + +def _build_actions_payload( + *, + registered_actions: dict[str, Action], + plugin_metas: list[Any], +) -> dict[str, dict[str, Any]]: + """Build payload for GET /api/actions.""" + actions: dict[str, dict[str, Any]] = {} + + # 1) Registered actions (flows/tools/etc). + for key, action in registered_actions.items(): + actions[key] = { + 'key': key, + 'name': action.name, + 'type': action.kind.value, + 'description': action.description, + 'inputSchema': action.input_schema, + 'outputSchema': action.output_schema, + 'metadata': action.metadata, + } + + # 2) Plugin-advertised actions (may not be registered yet). + for meta in plugin_metas or []: + try: + key = f'/{meta.kind.value}/{meta.name}' + except Exception as exc: + # Defensive: skip unexpected plugin metadata objects. + logger.warning('Skipping invalid plugin action metadata', error=str(exc)) + continue + + advertised = { + 'key': key, + 'name': meta.name, + 'type': meta.kind.value, + 'description': getattr(meta, 'description', None), + 'inputSchema': getattr(meta, 'input_json_schema', None), + 'outputSchema': getattr(meta, 'output_json_schema', None), + 'metadata': getattr(meta, 'metadata', None), + } + + if key not in actions: + actions[key] = advertised + continue + + # Merge into the existing (registered) action entry; prefer registered data. + existing = actions[key] + + if not existing.get('description') and advertised.get('description'): + existing['description'] = advertised['description'] + + if not existing.get('inputSchema') and advertised.get('inputSchema'): + existing['inputSchema'] = advertised['inputSchema'] + + if not existing.get('outputSchema') and advertised.get('outputSchema'): + existing['outputSchema'] = advertised['outputSchema'] + + existing_meta = existing.get('metadata') or {} + advertised_meta = advertised.get('metadata') or {} + if isinstance(existing_meta, dict) and isinstance(advertised_meta, dict): + # Prefer registered action metadata on key conflicts. + existing['metadata'] = {**advertised_meta, **existing_meta} + + return actions + + +def create_reflection_asgi_app( + registry: Registry, + on_app_startup: StartupHandler | None = None, + on_app_shutdown: StartupHandler | None = None, + version: str = DEFAULT_GENKIT_VERSION, + _encoding: str = 'utf-8', +) -> Application: + """Create and return a ASGI application for the Genkit reflection API. + + Caveats: + + The reflection API server predates the flows server implementation and + differs in the protocol it uses to interface with the Dev UI. The + streaming protocol uses unadorned JSON per streamed chunk. This may + change in the future to use Server-Sent Events (SSE). + + Key endpoints: + + | Method | Path | Handler | + |--------|---------------------|-----------------------| + | GET | /api/__health | Health check | + | GET | /api/actions | List actions | + | POST | /api/__quitquitquit | Trigger shutdown | + | POST | /api/notify | Handle notification | + | POST | /api/runAction | Run action (streaming)| + + Args: + registry: The registry to use for the reflection server. + on_app_startup: Optional callback to execute when the app's + lifespan starts. Must be an async function. + on_app_shutdown: Optional callback to execute when the app's + lifespan ends. Must be an async function. + version: The version string to use when setting the value of + the X-GENKIT-VERSION HTTP header. + encoding: The text encoding to use; default 'utf-8'. + + Returns: + An ASGI application configured with the given registry. + """ + + async def handle_health_check(_request: Request) -> JSONResponse: + """Handle health check requests. + + Args: + _request: The Starlette request object (unused). + + Returns: + A JSON response with status code 200. + """ + return JSONResponse(content={'status': 'OK'}) + + async def handle_terminate(_request: Request) -> JSONResponse: + """Handle the quit endpoint. + + Args: + _request: The Starlette request object (unused). + + Returns: + An empty JSON response with status code 200. + """ + await logger.ainfo('Shutting down servers...') + terminate_all_servers() + return JSONResponse(content={'status': 'OK'}) + + async def handle_list_actions(_request: Request) -> JSONResponse: + """Handle the request for listing available actions. + + Args: + _request: The Starlette request object (unused). + + Returns: + A JSON response containing all serializable actions. + """ + registered = await _list_registered_actions(registry) + metas = await registry.list_actions() + actions = _build_actions_payload(registered_actions=registered, plugin_metas=metas) + + return JSONResponse( + content=actions, + status_code=200, + headers={'x-genkit-version': version}, + ) + + async def handle_list_values(request: Request) -> JSONResponse: + """Handle the request for listing registered values. + + Args: + request: The Starlette request object. + + Returns: + A JSON response containing value names. + """ + kind = request.query_params.get('type') + if not kind: + return JSONResponse(content='Query parameter "type" is required.', status_code=400) + + if kind != 'defaultModel': + return JSONResponse( + content=f"'type' {kind} is not supported. Only 'defaultModel' is supported", status_code=400 + ) + + values = registry.list_values(kind) + return JSONResponse(content=values, status_code=200) + + async def handle_list_envs(_request: Request) -> JSONResponse: + """Handle the request for listing environments. + + Args: + _request: The Starlette request object (unused). + + Returns: + A JSON response containing environments. + """ + return JSONResponse(content=['dev'], status_code=200) + + async def handle_notify(_request: Request) -> JSONResponse: + """Handle the notification endpoint. + + Args: + _request: The Starlette request object (unused). + + Returns: + An empty JSON response with status code 200. + """ + return JSONResponse( + content={}, + status_code=200, + headers={'x-genkit-version': version}, + ) + + # Map of active actions indexed by trace ID for cancellation support. + active_actions: dict[str, asyncio.Task[Any]] = {} + + async def handle_cancel_action(request: Request) -> JSONResponse: + """Handle the cancelAction endpoint. + + Args: + request: The Starlette request object. + + Returns: + A JSON response. + """ + try: + payload = await request.json() + trace_id = payload.get('traceId') + if not trace_id: + return JSONResponse(content={'error': 'traceId is required'}, status_code=400) + + task = active_actions.get(trace_id) + if task: + _ = task.cancel() + return JSONResponse(content={'message': 'Action cancelled'}, status_code=200) + else: + return JSONResponse(content={'message': 'Action not found or already completed'}, status_code=404) + except Exception as e: + logger.error(f'Error cancelling action: {e}', exc_info=True) + return JSONResponse( + content={'error': 'An unexpected error occurred while cancelling the action.'}, + status_code=500, + ) + + async def handle_run_action( + request: Request, + ) -> JSONResponse | StreamingResponse: + """Handle the runAction endpoint for executing registered actions. + + Flow: + 1. Reads and validates the request payload + 2. Looks up the requested action + 3. Executes the action with the provided input + 4. Returns the action result as JSON with trace ID + + Args: + request: The Starlette request object. + + Returns: + A JSON or StreamingResponse with the action result, or an error + response. + """ + # Get the action using async resolve. + payload = await request.json() + action = await registry.resolve_action_by_key(payload['key']) + if action is None: + return JSONResponse( + content={'error': f'Action not found: {payload["key"]}'}, + status_code=404, + ) + + # Run the action. + context = payload.get('context', {}) + action_input = payload.get('input') + stream = is_streaming_requested(request) + + # Wrap execution to track the task for cancellation support + task = asyncio.current_task() + + def on_trace_start(trace_id: str) -> None: + if task: + active_actions[trace_id] = task + + handler = run_streaming_action if stream else run_standard_action + + try: + return await handler(action, payload, action_input, context, version, on_trace_start) + except asyncio.CancelledError: + logger.info('Action execution cancelled.') + # Can't really send response if cancelled? Starlette/uvicorn closes connection? + # Just raise. + raise + + async def run_streaming_action( + action: Action, + payload: dict[str, Any], + _action_input: object, + context: dict[str, Any], + version: str, + on_trace_start: Callable[[str], None], + ) -> StreamingResponse: + """Handle streaming action execution with early header flushing. + + Uses early header flushing to send X-Genkit-Trace-Id immediately when + the trace starts, enabling the Dev UI to subscribe to SSE for real-time + trace updates. + + Args: + action: The action to execute. + payload: Request payload with input data. + action_input: The input for the action. + context: Execution context. + version: The Genkit version header value. + on_trace_start: Callback for trace start. + + Returns: + A StreamingResponse with JSON chunks containing result or error + events. + """ + # Use a queue to pass chunks from the callback to the generator + chunk_queue: asyncio.Queue[str | None] = asyncio.Queue() + + # Event to signal when trace ID is available + trace_id_event: asyncio.Event = asyncio.Event() + run_trace_id: str | None = None + + def wrapped_on_trace_start(tid: str) -> None: + nonlocal run_trace_id + run_trace_id = tid + on_trace_start(tid) + trace_id_event.set() # Signal that trace ID is ready + + async def run_action_task() -> None: + """Run the action and put chunks on the queue.""" + try: + + def send_chunk(chunk: Any) -> None: # noqa: ANN401 + """Callback that puts chunks on the queue.""" + out = dump_json(chunk) + chunk_queue.put_nowait(f'{out}\n') + + output = await action.arun_raw( + raw_input=payload.get('input'), + on_chunk=send_chunk, + context=context, + on_trace_start=wrapped_on_trace_start, + ) + final_response = { + 'result': dump_dict(output.response), + 'telemetry': {'traceId': output.trace_id}, + } + chunk_queue.put_nowait(json.dumps(final_response)) + + except Exception as e: + error_response = get_reflection_json(e).model_dump(by_alias=True) + # Log with exc_info for pretty exception output via rich/structlog + logger.exception('Error streaming action', exc_info=e) + # Error response also should not have trailing newline (final message) + chunk_queue.put_nowait(json.dumps(error_response)) + # Ensure trace_id_event is set even on error + trace_id_event.set() + + finally: + if not trace_id_event.is_set(): + trace_id_event.set() + # Signal end of stream + chunk_queue.put_nowait(None) + if run_trace_id: + _ = active_actions.pop(run_trace_id, None) + + # Start the action task immediately so trace ID becomes available ASAP + action_task = asyncio.create_task(run_action_task()) + + # Wait for trace ID before returning response - this enables early header flushing + _ = await trace_id_event.wait() + + # Now we have the trace ID, include it in headers + headers = { + 'x-genkit-version': version, + 'Transfer-Encoding': 'chunked', + } + if run_trace_id: + headers['X-Genkit-Trace-Id'] = run_trace_id # pyright: ignore[reportUnreachable] + + async def stream_generator() -> AsyncGenerator[str, None]: + """Yield chunks from the queue as they arrive.""" + try: + while True: + chunk = await chunk_queue.get() + if chunk is None: + break + yield chunk + finally: + # Cancel task if still running (no-op if already done) + _ = action_task.cancel() + + return StreamingResponse( + stream_generator(), + # Reflection server uses text/plain for streaming (not SSE format) + # to match Go implementation + media_type='text/plain', + headers=headers, + ) + + async def run_standard_action( + action: Action, + payload: dict[str, Any], + _action_input: object, + context: dict[str, Any], + version: str, + on_trace_start: Callable[[str], None], + ) -> StreamingResponse: + """Handle standard (non-streaming) action execution with early header flushing. + + Uses StreamingResponse to enable sending the X-Genkit-Trace-Id header + immediately when the trace starts, allowing the Dev UI to subscribe to + the SSE stream for real-time trace updates. + + Args: + action: The action to execute. + payload: Request payload with input data. + action_input: The input for the action. + context: Execution context. + version: The Genkit version header value. + on_trace_start: Callback for trace start. + + Returns: + A StreamingResponse that flushes headers early. + """ + # Event to signal when trace ID is available + trace_id_event: asyncio.Event = asyncio.Event() + run_trace_id: str | None = None + action_result: dict[str, Any] | None = None + action_error: Exception | None = None + + def wrapped_on_trace_start(tid: str) -> None: + nonlocal run_trace_id + run_trace_id = tid + on_trace_start(tid) + trace_id_event.set() # Signal that trace ID is ready + + async def run_action_and_get_result() -> None: + nonlocal action_result, action_error + try: + output = await action.arun_raw( + raw_input=payload.get('input'), + context=context, + on_trace_start=wrapped_on_trace_start, + ) + action_result = { + 'result': dump_dict(output.response), + 'telemetry': {'traceId': output.trace_id}, + } + except Exception as e: + action_error = e + finally: + if not trace_id_event.is_set(): + trace_id_event.set() + + # Start the action immediately so trace ID becomes available ASAP + action_task = asyncio.create_task(run_action_and_get_result()) + + # Wait for trace ID before returning response + _ = await trace_id_event.wait() + + # Now return streaming response - headers will include trace ID + async def body_generator() -> AsyncGenerator[bytes, None]: + # Wait for action to complete + await action_task + + if action_error: + error_response = get_reflection_json(action_error).model_dump(by_alias=True) + # Log with exc_info for pretty exception output via rich/structlog + logger.exception('Error executing action', exc_info=action_error) + yield json.dumps(error_response).encode('utf-8') + else: + yield json.dumps(action_result).encode('utf-8') + + if run_trace_id: + _ = active_actions.pop(run_trace_id, None) + + headers = { + 'x-genkit-version': version, + } + if run_trace_id: + headers['X-Genkit-Trace-Id'] = run_trace_id # pyright: ignore[reportUnreachable] + + return StreamingResponse( + body_generator(), + media_type='application/json', + headers=headers, + ) + + app = Starlette( + routes=[ + Route('/api/__health', handle_health_check, methods=['GET']), + Route('/api/__quitquitquit', handle_terminate, methods=['GET', 'POST']), # Support both for parity + Route('/api/actions', handle_list_actions, methods=['GET']), + Route('/api/values', handle_list_values, methods=['GET']), + Route('/api/envs', handle_list_envs, methods=['GET']), + Route('/api/notify', handle_notify, methods=['POST']), + Route('/api/runAction', handle_run_action, methods=['POST']), + Route('/api/cancelAction', handle_cancel_action, methods=['POST']), + ], + middleware=[ + Middleware( + CORSMiddleware, # type: ignore[arg-type] + allow_origins=['*'], + allow_methods=['*'], + allow_headers=['*'], + expose_headers=['X-Genkit-Trace-Id', 'X-Genkit-Span-Id', 'x-genkit-version'], + ) + ], + on_startup=[on_app_startup] if on_app_startup else [], + on_shutdown=[on_app_shutdown] if on_app_shutdown else [], + ) + app.active_actions = active_actions # type: ignore[attr-defined] + return cast(Application, app) diff --git a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py index 02e7539f01..af5e6e0f0f 100644 --- a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py +++ b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py @@ -46,7 +46,7 @@ from genkit.core.action import ActionMetadata from genkit.core.action.types import ActionKind -from genkit.core.reflection import create_reflection_asgi_app +from genkit.core.reflection_v1 import create_reflection_asgi_app from genkit.core.registry import Registry @@ -188,7 +188,7 @@ async def mock_resolve_action_by_key(key: str) -> AsyncMock: @pytest.mark.asyncio -@patch('genkit.core.reflection.is_streaming_requested') +@patch('genkit.core.reflection_v1.is_streaming_requested') async def test_run_action_streaming( mock_is_streaming: MagicMock, asgi_client: AsyncClient, diff --git a/py/packages/genkit/tests/genkit/core/endpoints/reflection_v1_test.py b/py/packages/genkit/tests/genkit/core/endpoints/reflection_v1_test.py new file mode 100644 index 0000000000..22e2612947 --- /dev/null +++ b/py/packages/genkit/tests/genkit/core/endpoints/reflection_v1_test.py @@ -0,0 +1,225 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the reflection API server. + +This module contains unit tests for the ASGI-based reflection API server +which provides endpoints for inspecting and interacting with Genkit during +development. + +Test coverage includes: +- Health check endpoint (/api/__health) +- Listing registered actions (/api/actions) +- Notification endpoint (/api/notify) +- Action execution with various scenarios (/api/runAction): + - Standard action execution + - Streaming action execution + - Error handling when action not found + - Context passing to actions + +The tests use an ASGI client with mocked Registry to isolate and verify +each endpoint's behavior. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import Any, cast +from unittest.mock import ANY, AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient + +from genkit.core.reflection_v1 import create_reflection_asgi_app +from genkit.core.registry import Registry + + +@pytest.fixture +def mock_registry() -> MagicMock: + """Create a mock Registry for testing.""" + return MagicMock(spec=Registry) + + +@pytest_asyncio.fixture +async def asgi_client(mock_registry: MagicMock) -> AsyncIterator[AsyncClient]: + """Create an ASGI test client with a mock registry. + + Args: + mock_registry: A mock Registry object. + + Returns: + An AsyncClient configured to make requests to the test ASGI app. + """ + app = create_reflection_asgi_app(mock_registry) + transport = ASGITransport(app=app) + client = AsyncClient(transport=transport, base_url='http://test') + try: + yield client + finally: + await client.aclose() + + +@pytest.mark.asyncio +async def test_health_check(asgi_client: AsyncClient) -> None: + """Test that the health check endpoint returns 200 OK.""" + response = await asgi_client.get('/api/__health') + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_list_actions(asgi_client: AsyncClient, mock_registry: MagicMock) -> None: + """Test that the actions list endpoint returns registered actions.""" + from genkit.core.action import ActionMetadata + from genkit.core.action.types import ActionKind + + # Mock the async list_actions method to return a list of ActionMetadata + async def mock_list_actions_async(allowed_kinds: list[ActionKind] | None = None) -> list[ActionMetadata]: + return [ + ActionMetadata( + kind=ActionKind.CUSTOM, + name='action1', + ) + ] + + # Mock resolve_actions_by_kind to return empty dict (no registered actions in this test) + async def mock_resolve_actions_by_kind(kind: ActionKind) -> dict: + return {} + + mock_registry.list_actions = mock_list_actions_async + mock_registry.resolve_actions_by_kind = mock_resolve_actions_by_kind + response = await asgi_client.get('/api/actions') + assert response.status_code == 200 + result = response.json() + assert '/custom/action1' in result + assert result['/custom/action1']['name'] == 'action1' + assert result['/custom/action1']['type'] == 'custom' + + +@pytest.mark.asyncio +async def test_notify_endpoint(asgi_client: AsyncClient) -> None: + """Test that the notify endpoint returns 200 OK.""" + response = await asgi_client.post('/api/notify') + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_run_action_not_found(asgi_client: AsyncClient, mock_registry: MagicMock) -> None: + """Test that requesting a non-existent action returns a 404 error.""" + + async def mock_resolve_action_by_key(key: str) -> None: + return None + + mock_registry.resolve_action_by_key = mock_resolve_action_by_key + response = await asgi_client.post( + '/api/runAction', + json={'key': 'non_existent_action', 'input': {'data': 'test'}}, + ) + assert response.status_code == 404 + assert 'error' in response.json() + + +@pytest.mark.asyncio +async def test_run_action_standard(asgi_client: AsyncClient, mock_registry: MagicMock) -> None: + """Test that a standard (non-streaming) action works correctly.""" + mock_action = AsyncMock() + mock_output = MagicMock() + mock_output.response = {'result': 'success'} + mock_output.trace_id = 'test_trace_id' + mock_action.arun_raw.return_value = mock_output + + async def mock_resolve_action_by_key(key: str) -> AsyncMock: + return mock_action + + mock_registry.resolve_action_by_key = mock_resolve_action_by_key + + response = await asgi_client.post('/api/runAction', json={'key': 'test_action', 'input': {'data': 'test'}}) + + assert response.status_code == 200 + response_data = response.json() + assert 'result' in response_data + assert 'telemetry' in response_data + assert response_data['telemetry']['traceId'] == 'test_trace_id' + mock_action.arun_raw.assert_called_once_with(raw_input={'data': 'test'}, context={}, on_trace_start=ANY) + + +@pytest.mark.asyncio +async def test_run_action_with_context(asgi_client: AsyncClient, mock_registry: MagicMock) -> None: + """Test that an action with context works correctly.""" + mock_action = AsyncMock() + mock_output = MagicMock() + mock_output.response = {'result': 'success'} + mock_output.trace_id = 'test_trace_id' + mock_action.arun_raw.return_value = mock_output + + async def mock_resolve_action_by_key(key: str) -> AsyncMock: + return mock_action + + mock_registry.resolve_action_by_key = mock_resolve_action_by_key + + response = await asgi_client.post( + '/api/runAction', + json={ + 'key': 'test_action', + 'input': {'data': 'test'}, + 'context': {'user': 'test_user'}, + }, + ) + + assert response.status_code == 200 + mock_action.arun_raw.assert_called_once_with( + raw_input={'data': 'test'}, + context={'user': 'test_user'}, + on_trace_start=ANY, + ) + + +@pytest.mark.asyncio +@patch('genkit.core.reflection_v1.is_streaming_requested') +async def test_run_action_streaming( + mock_is_streaming: MagicMock, + asgi_client: AsyncClient, + mock_registry: MagicMock, +) -> None: + """Test that streaming actions work correctly.""" + mock_is_streaming.return_value = True + mock_action = AsyncMock() + + async def mock_streaming( + raw_input: object, + on_chunk: object | None = None, + context: object | None = None, + **kwargs: Any, # noqa: ANN401 + ) -> MagicMock: + if on_chunk: + on_chunk_fn = cast(Callable[[object], Awaitable[None]], on_chunk) + await on_chunk_fn({'chunk': 1}) + await on_chunk_fn({'chunk': 2}) + mock_output = MagicMock() + mock_output.response = {'final': 'result'} + mock_output.trace_id = 'stream_trace_id' + return mock_output + + mock_action.arun_raw.side_effect = mock_streaming + mock_registry.resolve_action_by_key.return_value = mock_action + + response = await asgi_client.post( + '/api/runAction?stream=true', + json={'key': 'test_action', 'input': {'data': 'test'}}, + ) + + assert response.status_code == 200 + assert mock_is_streaming.called diff --git a/py/pyproject.toml b/py/pyproject.toml index 6d718aef8d..cbd6260b0d 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -75,7 +75,7 @@ dev = [ lint = [ "bandit>=1.7.0", "deptry>=0.22.0", - "litestar>=2.0.0", # For web/typing.py type resolution + "litestar>=2.0.0", # For web/typing.py type resolution "mypy>=1.14.0", "pip-audit>=2.7.0", "pypdf>=6.6.2", diff --git a/py/samples/multi-server/src/main.py b/py/samples/multi-server/src/main.py index 89736eb032..4117536ea8 100755 --- a/py/samples/multi-server/src/main.py +++ b/py/samples/multi-server/src/main.py @@ -88,7 +88,7 @@ from genkit.aio.loop import run_loop from genkit.core.environment import is_dev_environment from genkit.core.logging import get_logger -from genkit.core.reflection import create_reflection_asgi_app +from genkit.core.reflection_v1 import create_reflection_asgi_app from genkit.core.registry import Registry from genkit.web.manager import ( AbstractBaseServer, diff --git a/py/uv.lock b/py/uv.lock index 0d84d2ae2b..59c68181a0 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -1985,6 +1985,7 @@ dependencies = [ { name = "typing-extensions" }, { name = "uvicorn" }, { name = "uvloop", marker = "sys_platform != 'win32'" }, + { name = "websockets" }, ] [package.optional-dependencies] @@ -2040,6 +2041,7 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.0" }, { name = "uvicorn", specifier = ">=0.34.0" }, { name = "uvloop", marker = "sys_platform != 'win32'", specifier = ">=0.21.0" }, + { name = "websockets", specifier = ">=15.0" }, ] provides-extras = ["dev-local-vectorstore", "flask", "google-cloud", "google-genai", "ollama", "openai", "vertex-ai"]