diff --git a/docs/docs/cli.md b/docs/docs/cli.md index 8b4271e5..e383fe5c 100644 --- a/docs/docs/cli.md +++ b/docs/docs/cli.md @@ -10,7 +10,7 @@ VeADK 提供如下命令便捷您的操作: | :-- | :-- | :-- | | `veadk init` | 生成可在 VeFaaS 中部署的项目脚手架 | 生成完整的项目目录结构,包括智能体定义文件、部署配置、依赖文件等,支持标准智能体和Web应用两种模板。 | | `veadk create` | 在当前目录中创建一个新的智能体 | 创建智能体项目结构,生成.env环境配置文件、agent.py智能体定义文件和__init__.py包初始化文件。 | -| `veadk web` | 支持长短期记忆、知识库的前端调试界面 | 启动本地Web服务器,生成调试界面访问地址,配置记忆服务和知识库集成环境。 | +| `veadk web` | 支持长短期记忆、知识库的前端调试界面 | 启动本地Web服务器,生成调试界面访问地址,配置记忆服务和知识库集成环境,并可通过 Agent Identity 提供 SSO 单点登录。 | | `veadk kb` | 知识库相关操作 | 创建知识库索引文件,支持多种后端存储,生成向量化文档和检索配置文件。 | | `veadk deploy` | 将某个项目部署到 VeFaaS 中 | 生成部署包文件,创建云资源配置,部署智能体到火山引擎函数计算服务并生成API访问端点。 | | `veadk eval` | 支持不同后端的评测 | 生成评测报告文件,包含性能指标数据和测试结果,支持本地和远程智能体评估。 | @@ -137,6 +137,8 @@ location-agent/ 该命令会启动一个支持 VeADK 智能体短期和长期记忆功能的 Web 服务器。它会自动检测并加载当前目录下的智能体,并配置相应的记忆服务。 +同时,`veadk web` 支持使用 Agent Identity(User Pool)提供 OAuth2/OIDC 的 SSO 单点登录能力。启用后,Web 调试界面会接入 Agent Identity 的用户池与客户端配置,自动完成登录重定向与回调处理。 + ### 参数说明 `veadk web` 命令全面兼容 Google ADK 中的 `adk web` 命令,支持相同的参数接口和行为模式。 @@ -146,6 +148,9 @@ location-agent/ | `--port` | INTEGER | 指定 Web 服务器监听的端口号,默认值为 8000 | | `--host` | TEXT | 指定 Web 服务器绑定的主机地址,默认值为 127.0.0.1 | | `--log-level` | [debug\|info\|warning\|error] | 设置日志输出级别,默认值为 info | +| `--oauth2-user-pool` | TEXT | 启用 Agent Identity 的 User Pool 名称 | +| `--oauth2-user-pool-client` | TEXT | 启用 Agent Identity 的 User Pool Client 名称 | +| `--oauth2-redirect-uri` | TEXT | OAuth2 回调地址,默认使用 `http://{host}:{port}/oauth2/callback` | | `--help` | | 显示此帮助信息并退出 | **长短期记忆机制**: @@ -181,6 +186,12 @@ veadk web veadk web --port 8080 ``` +启用基于Agent Identity的单点登录: + +```bash +veadk web --oauth2-user-pool my-user-pool --oauth2-user-pool-client my-web-client +``` + 该命令能够自动读取执行命令目录中的 `agent.py` 文件,并加载其中的 `root_agent` 全局变量。服务启动后,通常可以在 `http://127.0.0.1:8000` 访问。 ### 使用示例 diff --git a/veadk/auth/middleware/oauth2_auth.py b/veadk/auth/middleware/oauth2_auth.py index 6decb5be..66aeef6b 100644 --- a/veadk/auth/middleware/oauth2_auth.py +++ b/veadk/auth/middleware/oauth2_auth.py @@ -65,6 +65,7 @@ from __future__ import annotations +import asyncio import base64 import hashlib import hmac @@ -90,6 +91,17 @@ if TYPE_CHECKING: from veadk.integrations.ve_identity import IdentityClient +try: + from authlib.jose import JsonWebKey, jwt + from authlib.jose.errors import JoseError + + _AUTHLIB_AVAILABLE = True +except ImportError: # pragma: no cover - optional dependency + JsonWebKey = None + jwt = None + JoseError = Exception + _AUTHLIB_AVAILABLE = False + # Maximum cookie size before warning (browsers typically limit to 4KB). _MAX_COOKIE_SIZE_WARNING = 3800 @@ -295,6 +307,20 @@ class OAuth2Config(BaseModel): token_refresh_threshold_seconds: int = 300 # Refresh when < 5 min remaining auto_refresh_token: bool = True + # Access token validation (Authorization header + session) + issuer: Optional[str] = None + jwks_uri: Optional[str] = None + audience: Optional[str | list[str]] = None + allowed_algorithms: list[str] = Field(default_factory=lambda: ["RS256"]) + jwks_cache_ttl_seconds: int = 300 + jwks_kid_miss_cooldown_seconds: int = 30 + use_introspection: bool = False + introspection_url: Optional[str] = None + introspection_client_id: Optional[str] = None + introspection_client_secret: Optional[str] = None + introspection_cache_ttl_seconds: int = 300 + introspection_cache_max_entries: int = 1000 + # API vs browser behavior api_path_prefixes: list[str] = Field(default_factory=lambda: ["/api/"]) @@ -454,6 +480,9 @@ def from_veidentity( token_url=oidc_config.token_endpoint, userinfo_url=oidc_config.userinfo_endpoint, end_session_url=oidc_config.end_session_endpoint, + issuer=oidc_config.issuer, + jwks_uri=oidc_config.jwks_uri, + introspection_url=oidc_config.introspection_endpoint, client_id=resolved_client_id, client_secret=client_secret, redirect_uri=redirect_uri, @@ -630,6 +659,11 @@ def __init__( timeout=config.http_timeout_seconds, limits=limits, ) + self._jwks_cache: Optional[dict[str, Any]] = None + self._jwks_cache_time = 0.0 + self._jwks_last_kid_miss_refresh = 0.0 + self._jwks_lock = asyncio.Lock() + self._introspection_cache: dict[str, tuple[dict[str, Any], float]] = {} async def close(self) -> None: """Close the HTTP client.""" @@ -866,6 +900,279 @@ async def _fetch_user_info(self, access_token: str) -> dict[str, Any]: logger.error("User info fetch error: %s", e) raise Exception(f"User info fetch error: {e}") + def _normalized_audience(self) -> Optional[list[str]]: + audience = self.config.audience + if not audience: + return None + if isinstance(audience, str): + audience_list = [audience] + else: + audience_list = list(audience) + audience_list = [str(item) for item in audience_list if item is not None] + return audience_list or None + + def _decode_jwt_header(self, token: str) -> dict[str, Any]: + parts = token.split(".") + if len(parts) != 3: + raise HTTPException(status_code=401, detail="Invalid JWT format") + try: + header_bytes = self._base64url_decode(parts[0]) + header = json.loads(header_bytes.decode("utf-8")) + except Exception as e: + raise HTTPException(status_code=401, detail="Invalid JWT header") from e + if not isinstance(header, dict): + raise HTTPException(status_code=401, detail="Invalid JWT header") + return header + + def _ensure_allowed_algorithm(self, alg: Optional[str]) -> None: + if not alg or str(alg).lower() == "none": + raise HTTPException(status_code=401, detail="Unsupported token algorithm") + allowed = self.config.allowed_algorithms + if allowed and alg not in allowed: + raise HTTPException(status_code=401, detail="Unsupported token algorithm") + + @staticmethod + def _jwks_has_kid(jwks: dict[str, Any], kid: str) -> bool: + keys = jwks.get("keys", []) + if not isinstance(keys, list): + return False + return any(isinstance(key, dict) and key.get("kid") == kid for key in keys) + + async def _fetch_jwks(self) -> dict[str, Any]: + if not self.config.jwks_uri: + raise HTTPException(status_code=503, detail="JWKS URI not configured") + try: + response = await self._http_client.get(self.config.jwks_uri) + response.raise_for_status() + jwks = response.json() + except httpx.HTTPStatusError as e: + logger.error("JWKS fetch failed: %s", e.response.text) + raise HTTPException(status_code=503, detail="Failed to fetch JWKS") from e + except httpx.RequestError as e: + logger.error("JWKS request error: %s", e) + raise HTTPException(status_code=503, detail="Failed to fetch JWKS") from e + except Exception as e: + logger.error("JWKS decode error: %s", e) + raise HTTPException(status_code=503, detail="Invalid JWKS response") from e + + if not isinstance(jwks, dict) or "keys" not in jwks: + raise HTTPException(status_code=503, detail="Invalid JWKS response") + return jwks + + async def _get_jwks(self, force_refresh: bool = False) -> dict[str, Any]: + now = time.time() + if ( + not force_refresh + and self._jwks_cache + and now - self._jwks_cache_time < self.config.jwks_cache_ttl_seconds + ): + return self._jwks_cache + + async with self._jwks_lock: + now = time.time() + if ( + not force_refresh + and self._jwks_cache + and now - self._jwks_cache_time < self.config.jwks_cache_ttl_seconds + ): + return self._jwks_cache + + try: + jwks = await self._fetch_jwks() + except HTTPException as exc: + if self._jwks_cache: + logger.warning( + "JWKS fetch failed, using cached keys: %s", exc.detail + ) + self._jwks_cache_time = now + return self._jwks_cache + raise + except Exception as exc: + if self._jwks_cache: + logger.warning("JWKS fetch failed, using cached keys: %s", exc) + self._jwks_cache_time = now + return self._jwks_cache + raise + + self._jwks_cache = jwks + self._jwks_cache_time = now + return jwks + + async def _get_jwks_for_kid(self, kid: Optional[str]) -> dict[str, Any]: + jwks = await self._get_jwks() + if not kid: + return jwks + if self._jwks_has_kid(jwks, kid): + return jwks + now = time.time() + if ( + now - self._jwks_last_kid_miss_refresh + < self.config.jwks_kid_miss_cooldown_seconds + ): + return jwks + jwks = await self._get_jwks(force_refresh=True) + self._jwks_last_kid_miss_refresh = now + return jwks + + def _prune_introspection_cache(self) -> None: + if not self._introspection_cache: + return + now = time.time() + expired = [ + token + for token, (_, expires_at) in self._introspection_cache.items() + if expires_at <= now + ] + for token in expired: + self._introspection_cache.pop(token, None) + while ( + len(self._introspection_cache) > self.config.introspection_cache_max_entries + ): + self._introspection_cache.pop(next(iter(self._introspection_cache))) + + def _validate_audience(self, claims: dict[str, Any]) -> None: + audience = self._normalized_audience() + if not audience: + return + aud_claim = claims.get("aud") + if aud_claim is not None: + if isinstance(aud_claim, list): + aud_list = [str(item) for item in aud_claim if item is not None] + else: + aud_list = [str(aud_claim)] + if any(aud in audience for aud in aud_list): + return + raise HTTPException(status_code=403, detail="Token audience not allowed") + + client_id = claims.get("client_id") or claims.get("azp") + if client_id and str(client_id) in audience: + return + raise HTTPException(status_code=403, detail="Token audience not allowed") + + async def _validate_with_jwks(self, token: str) -> dict[str, Any]: + if not _AUTHLIB_AVAILABLE: + raise HTTPException( + status_code=503, detail="authlib is required for JWT validation" + ) + header = self._decode_jwt_header(token) + alg = header.get("alg") + self._ensure_allowed_algorithm(alg) + kid = header.get("kid") + + jwks = await self._get_jwks_for_kid(kid) + if kid and not self._jwks_has_kid(jwks, kid): + raise HTTPException(status_code=401, detail="Unknown token key") + + try: + key_set = JsonWebKey.import_key_set(jwks) + claims_options: dict[str, Any] = {"exp": {"essential": True}} + if self.config.issuer: + claims_options["iss"] = { + "essential": True, + "value": self.config.issuer, + } + claims = jwt.decode(token, key_set, claims_options=claims_options) + claims.validate() + claims_dict = dict(claims) + except JoseError as e: + logger.warning("JWT validation failed: %s", e) + raise HTTPException(status_code=401, detail="Invalid access token") from e + except Exception as e: + logger.error("JWT validation error: %s", e) + raise HTTPException(status_code=500, detail="Token validation error") from e + + self._validate_audience(claims_dict) + return claims_dict + + async def _validate_with_introspection(self, token: str) -> dict[str, Any]: + if not self.config.introspection_url: + raise HTTPException( + status_code=503, detail="Introspection endpoint not configured" + ) + + now = time.time() + cached = self._introspection_cache.get(token) + if cached and cached[1] > now: + return cached[0] + + data = {"token": token} + headers = {"Content-Type": "application/x-www-form-urlencoded"} + auth = None + if ( + self.config.introspection_client_id + and self.config.introspection_client_secret + ): + auth = ( + self.config.introspection_client_id, + self.config.introspection_client_secret, + ) + elif self.config.client_id and self.config.client_secret: + auth = (self.config.client_id, self.config.client_secret) + + try: + response = await self._http_client.post( + self.config.introspection_url, + data=data, + headers=headers, + auth=auth, + ) + response.raise_for_status() + result = response.json() + except httpx.HTTPStatusError as e: + logger.error("Introspection failed: %s", e.response.text) + raise HTTPException( + status_code=503, detail="Token introspection failed" + ) from e + except httpx.RequestError as e: + logger.error("Introspection request error: %s", e) + raise HTTPException( + status_code=503, detail="Token introspection failed" + ) from e + except Exception as e: + logger.error("Introspection decode error: %s", e) + raise HTTPException( + status_code=503, detail="Token introspection failed" + ) from e + + if not isinstance(result, dict): + raise HTTPException(status_code=503, detail="Token introspection failed") + + if not result.get("active", False): + raise HTTPException(status_code=401, detail="Inactive access token") + + exp = result.get("exp") + if exp is not None: + try: + exp = int(exp) + except (TypeError, ValueError): + exp = None + if exp is not None and exp <= now: + raise HTTPException(status_code=401, detail="Access token expired") + + if self.config.issuer: + issuer = result.get("iss") + if issuer and issuer != self.config.issuer: + raise HTTPException(status_code=403, detail="Token issuer not allowed") + + self._validate_audience(result) + + cache_until = now + self.config.introspection_cache_ttl_seconds + if exp is not None: + cache_until = min(cache_until, exp) + self._introspection_cache[token] = (result, cache_until) + if len(self._introspection_cache) > self.config.introspection_cache_max_entries: + self._prune_introspection_cache() + + return result + + async def validate_access_token(self, token: str) -> dict[str, Any]: + """Validate access token via introspection or JWKS.""" + if not token: + raise HTTPException(status_code=401, detail="Missing access token") + if self.config.use_introspection: + return await self._validate_with_introspection(token) + return await self._validate_with_jwks(token) + def encode_session(self, session: OAuth2Session) -> str: """Encode OAuth2 session data for cookie storage. @@ -1146,6 +1453,12 @@ async def get_current_user_info(request: Request) -> JSONResponse: if not session or session.is_expired(): raise HTTPException(status_code=401, detail="Not authenticated") + if not _is_access_token_already_validated(request, session.access_token): + try: + await oauth2_handler.validate_access_token(session.access_token) + except HTTPException as exc: + raise HTTPException(status_code=exc.status_code, detail=exc.detail) + if not session.user_info: if oauth2_handler.config.userinfo_url: try: @@ -1244,6 +1557,7 @@ def setup_oauth2( route_paths = register_oauth2_routes(app, oauth2_handler, routes=routes) merged_exempt_paths = set(route_paths.all_paths()) + merged_exempt_paths.discard(route_paths.userinfo) if exempt_paths: merged_exempt_paths.update(exempt_paths) @@ -1284,21 +1598,74 @@ def _is_api_request(request: Request, api_prefixes: list[str]) -> bool: return False +def _extract_bearer_token(authorization: str) -> Optional[str]: + """Extract bearer token from Authorization header.""" + if not authorization: + return None + parts = authorization.strip().split() + if len(parts) != 2: + return None + if parts[0].lower() != "bearer": + return None + return parts[1] or None + + +def _set_scope_user_from_claims( + request: Request, claims: dict[str, Any], user_id_field: str +) -> None: + if not claims: + return + user_id = claims.get(user_id_field) + if not user_id: + for key in ("sub", "user_id", "uid", "username", "email"): + if key == user_id_field: + continue + user_id = claims.get(key) + if user_id: + break + if not user_id: + return + try: + from starlette.authentication import SimpleUser + + request.scope["user"] = SimpleUser(str(user_id)) + except Exception: + request.scope["user"] = str(user_id) + + +def _mark_access_token_validated(request: Request, token: str) -> None: + try: + request.state.oauth2_access_token_validated = True + request.state.oauth2_access_token = token + except Exception: + pass + + +def _is_access_token_already_validated(request: Request, token: str) -> bool: + try: + return ( + getattr(request.state, "oauth2_access_token_validated", False) + and getattr(request.state, "oauth2_access_token", None) == token + ) + except Exception: + return False + + def create_oauth2_middleware( oauth2_handler: OAuth2Handler, *, exempt_paths: Optional[Iterable[str]] = None, exempt_prefixes: Optional[Iterable[str]] = None, - allow_existing_authorization: bool = True, ): """Create OAuth2 authentication middleware for Starlette/FastAPI. The middleware: - Skips authentication for exempt paths/prefixes and OPTIONS requests. - - Passes through requests that already have an Authorization header. + - Validates bearer tokens from Authorization headers (JWKS/introspection). - Injects the OAuth2 access token as an Authorization header for valid sessions. - Auto-refreshes tokens when they are close to expiry (if configured). - - Returns 401 JSON for API requests, redirects browsers to login for others. + - Returns 401 JSON for API requests or invalid Authorization tokens. + - Redirects browsers to login when no valid auth is present. """ exempt_paths_set = set(exempt_paths or []) exempt_prefixes_tuple = tuple(exempt_prefixes or []) @@ -1316,8 +1683,30 @@ async def oauth2_middleware(request: Request, call_next): ): return await call_next(request) - # Pass through if there's already an Authorization header. - if allow_existing_authorization and "authorization" in request.headers: + authorization = request.headers.get("authorization") + if authorization: + token = _extract_bearer_token(authorization) + if not token: + return JSONResponse( + status_code=401, + content={"detail": "Bearer token required"}, + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + claims = await oauth2_handler.validate_access_token(token) + except HTTPException as exc: + headers = exc.headers or {} + if exc.status_code == 401 and "WWW-Authenticate" not in headers: + headers = {**headers, "WWW-Authenticate": "Bearer"} + return JSONResponse( + status_code=exc.status_code, + content={"detail": exc.detail}, + headers=headers, + ) + _set_scope_user_from_claims( + request, claims, oauth2_handler.config.user_id_field + ) + _mark_access_token_validated(request, token) return await call_next(request) session = oauth2_handler.get_session_from_request(request) @@ -1346,30 +1735,45 @@ async def oauth2_middleware(request: Request, call_next): response_cookies.append(user_id_cookie) if session and not session.is_expired(): - auth_header_value = session.to_authorization_header().encode() - - # Update the scope headers so downstream dependencies can read it. - headers = list(request.scope.get("headers", [])) - headers = [ - (name, value) - for name, value in headers - if name.lower() != b"authorization" - ] - headers.append((b"authorization", auth_header_value)) - request.scope["headers"] = headers - - logger.debug( - "Added Authorization header to request: %s...", - session.to_authorization_header()[:20], - ) + try: + claims = await oauth2_handler.validate_access_token( + session.access_token + ) + except HTTPException as exc: + if exc.status_code >= 500: + return JSONResponse( + status_code=exc.status_code, content={"detail": exc.detail} + ) + session = None + else: + auth_header_value = session.to_authorization_header().encode() + + # Update the scope headers so downstream dependencies can read it. + headers = list(request.scope.get("headers", [])) + headers = [ + (name, value) + for name, value in headers + if name.lower() != b"authorization" + ] + headers.append((b"authorization", auth_header_value)) + request.scope["headers"] = headers + _set_scope_user_from_claims( + request, claims, oauth2_handler.config.user_id_field + ) + _mark_access_token_validated(request, session.access_token) + + logger.debug( + "Added Authorization header to request: %s...", + session.to_authorization_header()[:20], + ) - response = await call_next(request) + response = await call_next(request) - # Set any refreshed session cookies on the response. - for cookie_params in response_cookies: - response.set_cookie(**cookie_params) + # Set any refreshed session cookies on the response. + for cookie_params in response_cookies: + response.set_cookie(**cookie_params) - return response + return response # No valid session - handle API vs browser requests differently. if _is_api_request(request, config.api_path_prefixes): diff --git a/veadk/cli/cli_web.py b/veadk/cli/cli_web.py index 64a74055..7719e624 100644 --- a/veadk/cli/cli_web.py +++ b/veadk/cli/cli_web.py @@ -22,6 +22,50 @@ logger = get_logger(__name__) +def _patch_adkwebserver_oauth2( + user_pool_name: str, + user_pool_client: str, + redirect_uri: str, +) -> None: + """ + Monkey patch AdkWebServer to enable OAuth2 authentication. + + This function patches the AdkWebServer.get_fast_api_app method to add + OAuth2 authentication middleware using VeIdentity User Pool. + + Args: + user_pool_name: VeIdentity User Pool name. + user_pool_client: VeIdentity User Pool client name. + redirect_uri: OAuth2 redirect URI (e.g., http://127.0.0.1:8000/oauth2/callback). + """ + import google.adk.cli.adk_web_server + + from veadk.auth.middleware.oauth2_auth import OAuth2Config, setup_oauth2 + + original_get_fast_api = google.adk.cli.adk_web_server.AdkWebServer.get_fast_api_app + + def wrapped_get_fast_api(self, *args, **kwargs): + app = original_get_fast_api(self, *args, **kwargs) + + # Setup OAuth2 with VeIdentity User Pool + oauth2_config = OAuth2Config.from_veidentity( + user_pool_name=user_pool_name, + client_name=user_pool_client, + redirect_uri=redirect_uri, + ) + oauth2_config.cookie_secure = False + + setup_oauth2( + app, + oauth2_config, + ) + logger.info("OAuth2 middleware installed") + + return app + + google.adk.cli.adk_web_server.AdkWebServer.get_fast_api_app = wrapped_get_fast_api + + def patch_adkwebserver_disable_openapi(): """ Monkey patch AdkWebServer to disable OpenAPI documentation endpoints. @@ -58,8 +102,33 @@ def wrapped_get_fast_api(self, *args, **kwargs): @click.command( context_settings=dict(ignore_unknown_options=True, allow_extra_args=True) ) +@click.option( + "--oauth2-user-pool", + type=str, + default=None, + help="VeIdentity User Pool name for OAuth2 authentication.", +) +@click.option( + "--oauth2-user-pool-client", + type=str, + default=None, + help="VeIdentity User Pool client name for OAuth2 authentication.", +) +@click.option( + "--oauth2-redirect-uri", + type=str, + default=None, + help="OAuth2 redirect URI. Defaults to http://{host}:{port}/oauth2/callback.", +) @click.pass_context -def web(ctx, *args, **kwargs) -> None: +def web( + ctx, + oauth2_user_pool: str | None, + oauth2_user_pool_client: str | None, + oauth2_redirect_uri: str | None, + *args, + **kwargs, +) -> None: """ Launch a web server with VeADK agent support and memory integration. @@ -134,6 +203,30 @@ async def wrapper(*args, **kwargs) -> ADKRunner: from google.adk.cli.cli_tools_click import cli_web extra_args: list = ctx.args + + # Setup OAuth2 if configured + if oauth2_user_pool and oauth2_user_pool_client: + # Build redirect_uri from host/port if not provided + redirect_uri = oauth2_redirect_uri + if not redirect_uri: + # Parse host and port from extra_args + host = "127.0.0.1" + port = "8000" + if "--host" in extra_args: + host = extra_args[extra_args.index("--host") + 1] + if "--port" in extra_args: + port = extra_args[extra_args.index("--port") + 1] + redirect_uri = f"http://{host}:{port}/oauth2/callback" + + _patch_adkwebserver_oauth2( + user_pool_name=oauth2_user_pool, + user_pool_client=oauth2_user_pool_client, + redirect_uri=redirect_uri, + ) + logger.info( + f"OAuth2 enabled: user_pool={oauth2_user_pool}, " + f"client={oauth2_user_pool_client}, redirect_uri={redirect_uri}" + ) logger.debug(f"User args: {extra_args}") # set a default log level to avoid unnecessary outputs