From 23246f462981e9a45e7cfca194fbcd2bc3bfcb81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=8E=89=E6=B6=B5?= Date: Tue, 27 Jan 2026 22:42:26 +0800 Subject: [PATCH 1/2] Add OAuth2 middleware and enhance inbound auth docs Introduced veadk.auth.middleware.oauth2_auth.py, providing OAuth2 3LO middleware for Starlette/FastAPI with VeIdentity integration, including session management, PKCE, and extensible state storage. Updated inbound authentication documentation to detail both API Gateway and middleware OAuth2 SSO integration, with code samples, configuration options, and deployment guidance. Create __init__.py Refactor exception handling and formatting in oauth2_auth.py Simplified exception raising in _fetch_oidc_discovery by consolidating RuntimeError messages into single lines. Improved code readability by reformatting function calls to follow consistent indentation. --- docs/docs/auth/inbound.md | 242 ++- veadk/auth/middleware/__init__.py | 13 + veadk/auth/middleware/oauth2_auth.py | 1386 +++++++++++++++++ .../ve_identity/identity_client.py | 139 +- 4 files changed, 1730 insertions(+), 50 deletions(-) create mode 100644 veadk/auth/middleware/__init__.py create mode 100644 veadk/auth/middleware/oauth2_auth.py diff --git a/docs/docs/auth/inbound.md b/docs/docs/auth/inbound.md index 691e8f95..1c2142cd 100644 --- a/docs/docs/auth/inbound.md +++ b/docs/docs/auth/inbound.md @@ -1,41 +1,48 @@ --- -titie: 入站认证 +title: 入站认证 description: 通过入站认证访问 Agent navigation: icon: i-lucide-lock --- -VeADK 支持 API key 和 OAuth2 方式的入站认证。 +VeADK 支持 API Key 和 OAuth2 方式的入站认证。 ## API Key 认证 -API key 认证是通过唯一字符串密钥验证请求方身份、授权访问 API 资源的常见认证方式。VeADK 约定将 API key 通过 URL 的 `token` 参数传递。 +API Key 认证是通过唯一字符串密钥验证请求方身份、授权访问 API 资源的常见认证方式。VeADK 约定将 API Key 通过 URL 的 `token` 参数传递。 !!! tip - API key 仅适用于 A2A/MCP Server 部署模式,不建议在 VeADK Web 部署模式中使用,更推荐采用 OAuth2 认证。 + API Key 仅适用于 A2A/MCP Server 部署模式,不建议在 VeADK Web 部署模式中使用,更推荐采用 OAuth2 认证。 ### 使用方式 -您可以通过脚手架创建 Agent 时指定 API key 认证方式,或者部署已有项目时添加 `--auth-method=api-key` 参数启用该认证。 +您可以通过脚手架创建 Agent 时指定 API Key 认证方式,或者部署已有项目时添加 `--auth-method=api-key` 参数启用该认证。 -当用户访问应用时,API 网关将验证用户 `token` URL 参数中携带的 API key。 +当用户访问应用时,API 网关将验证用户 `token` URL 参数中携带的 API Key。 ## OAuth2 单点登录 OAuth2 是一种开放标准的授权框架,通过令牌而非直接暴露账号密码,安全实现第三方应用对资源的有限访问。 -OAuth2 单点登录是基于 OAuth2 授权框架实现的身份认证方案,用户一次登录后可免重复验证访问多个关联应用。VeADK Web 支持 OAuth2 单点登录的认证方式。 +OAuth2 单点登录是基于 OAuth2 授权框架实现的身份认证方案,用户一次登录后可免重复验证访问多个关联应用。VeADK 提供两种 OAuth2 单点登录的接入方式: +| 方式 | 适用场景 | 说明 | +|------|----------|------| +| API 网关模式 | VeFaaS 云端部署 | 通过脚手架部署,由 API 网关处理认证 | +| Starlette/FastAPI 中间件 | 本地开发 / 自托管部署 | 在应用内集成 OAuth2 中间件 | -!!! tip - 使用 OAuth2 单点登录需要版本为 4.0.0 及以上的 API 网关。 +### 方式一:API 网关模式(VeFaaS 部署) +适用于通过 VeFaaS 部署的 VeADK Web 应用,由 API 网关处理 OAuth2 认证流程。 -### 使用方式 +!!! tip + 使用 API 网关模式需要版本为 4.0.0 及以上的 API 网关。 + +#### 使用方式 您可以通过脚手架创建 Agent 时指定 OAuth2 认证方式,或者部署已有项目时添加 `--auth-method=oauth2` 参数启用该认证,VeADK 将自动为您创建 Identity 用户池和客户端。如果您需要使用已有的用户池或客户端,您可以在部署时添加 `--user-pool-name` 和 `--client-name` 参数指定用户池和客户端。 -在部署 VeADK Web 应用后,您可以在 Identity 中创建用户。 +在部署 VeADK Web 应用后,您可以在 Identity 中创建用户: 1. 登录火山引擎控制台,导航到 Agent Identity 服务 2. 在左侧导航树中,选择 身份认证 > 用户池管理,选择用户池 @@ -43,6 +50,213 @@ OAuth2 单点登录是基于 OAuth2 授权框架实现的身份认证方案, 当用户访问 VeADK Web 应用时,API 网关将引导用户至登录页完成登录。您可以在 `Authorization` 请求头中获得用户的 JWT 令牌。 +### 方式二:Starlette/FastAPI 中间件(本地/自托管) + +适用于本地开发或自托管部署场景,通过 VeADK 提供的中间件在应用内处理 OAuth2 认证。支持所有基于 Starlette 的框架,包括 FastAPI。 + +#### 快速开始(FastAPI) + +推荐使用 `OAuth2Config.from_veidentity()` 方法,自动配置 VeIdentity User Pool: + +```python +from fastapi import FastAPI +from veadk.auth.middleware.oauth2_auth import OAuth2Config, setup_oauth2 + +app = FastAPI() + +setup_oauth2( + app, + OAuth2Config.from_veidentity( + user_pool_name="my-app", + client_name="my-app-web", + redirect_uri="https://myapp.com/oauth2/callback", + ), +) +``` + +#### 快速开始(Starlette) + +```python +from starlette.applications import Starlette +from veadk.auth.middleware.oauth2_auth import OAuth2Config, setup_oauth2 + +app = Starlette() + +setup_oauth2( + app, + OAuth2Config.from_veidentity( + user_pool_name="my-app", + client_name="my-app-web", + redirect_uri="https://myapp.com/oauth2/callback", + ), +) +``` + +该方法会自动: + +- 创建用户池(如不存在) +- 创建用户池客户端(如不存在) +- 注册回调 URL +- 配置 OAuth2 端点 + +#### 使用已有资源 + +如果您已有用户池和客户端,可以禁用自动创建: + +```python +setup_oauth2( + app, + OAuth2Config.from_veidentity( + user_pool_name="existing-pool", + client_name="existing-client", + redirect_uri="https://myapp.com/oauth2/callback", + auto_create=False, # 资源不存在时报错 + auto_register_callback=False, # 不修改回调 URL + ), +) +``` + +#### 本地开发配置 + +本地开发时需要禁用 HTTPS cookie: + +```python +setup_oauth2( + app, + OAuth2Config.from_veidentity( + user_pool_name="my-app", + client_name="my-app-web", + redirect_uri="http://localhost:8000/oauth2/callback", + cookie_secure=False, # 本地 HTTP 开发 + ), +) +``` + +#### 自定义 OAuth2 Provider + +如需接入非 VeIdentity 的 OAuth2 提供商,可直接配置 `OAuth2Config`: + +```python +setup_oauth2( + app, + OAuth2Config( + authorize_url="https://provider.com/oauth2/authorize", + token_url="https://provider.com/oauth2/token", + userinfo_url="https://provider.com/oauth2/userinfo", + client_id="your-client-id", + client_secret="your-client-secret", + redirect_uri="https://myapp.com/oauth2/callback", + ), +) +``` + +#### 路由说明 + +中间件会自动注册以下路由: + +| 路由 | 说明 | +|------|------| +| `/oauth2/login` | 发起 OAuth2 登录流程 | +| `/oauth2/callback` | OAuth2 回调处理 | +| `/oauth2/logout` | 登出并清除会话 | +| `/oauth2/userinfo` | 获取当前用户信息 | + +#### 免认证路径 + +可以配置跳过认证的路径: + +```python +setup_oauth2( + app, + config, + exempt_paths=["/health", "/metrics"], # 精确匹配 + exempt_prefixes=["/public/", "/static/"], # 前缀匹配 +) +``` + +#### API 请求处理 + +中间件会根据请求类型自动选择响应方式: + +- **浏览器请求**:重定向到登录页面 +- **API 请求**:返回 `401 Unauthorized` JSON 响应 + +API 请求通过以下方式识别: + +- `Accept: application/json` 请求头 +- 路径前缀匹配(默认 `/api/`) +- `X-Requested-With: XMLHttpRequest` 请求头 + +可通过 `api_path_prefixes` 参数自定义: + +```python +OAuth2Config.from_veidentity( + # ... + api_path_prefixes=["/api/", "/graphql"], +) +``` + +#### 配置参考 + +##### OAuth2Config.from_veidentity() 参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `user_pool_name` | (必填) | VeIdentity 用户池名称 | +| `client_name` | (必填) | 用户池客户端名称 | +| `redirect_uri` | (必填) | OAuth2 回调 URL | +| `auto_create` | `True` | 资源不存在时自动创建 | +| `auto_register_callback` | `True` | 自动注册回调 URL | +| `client_type` | `WEB_APPLICATION` | 客户端类型 | +| `scope` | `"openid profile email"` | OAuth2 作用域 | +| `**extra_config` | - | 其他 OAuth2Config 参数 | + +##### OAuth2Config 参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `session_timeout_seconds` | `3600` | 会话超时时间(秒) | +| `cookie_secure` | `True` | 是否启用安全 cookie | +| `auto_refresh_token` | `True` | 自动刷新令牌 | +| `token_refresh_threshold_seconds` | `300` | 令牌刷新阈值(秒) | +| `api_path_prefixes` | `["/api/"]` | API 路径前缀 | + +#### 分布式部署 + +默认的 `InMemoryStateStore` 仅适用于单进程部署。分布式场景需要使用 Redis 等外部存储: + +```python +class RedisStateStore: + def __init__(self, redis_client, ttl: int = 300): + self._redis = redis_client + self._ttl = ttl + + def create_state(self, redirect_after_auth: str = "/", code_verifier=None) -> str: + import secrets, json + state = secrets.token_urlsafe(32) + self._redis.setex( + f"oauth2:{state}", + self._ttl, + json.dumps({ + "redirect_after_auth": redirect_after_auth, + "code_verifier": code_verifier, + }), + ) + return state + + def validate_and_consume_state(self, state: str): + import json + key = f"oauth2:{state}" + data = self._redis.get(key) + if not data: + return None + self._redis.delete(key) + return json.loads(data) + +# 使用 +setup_oauth2(app, config, state_store=RedisStateStore(redis_client)) +``` + ## OAuth2 JWT 认证 OAuth2 JWT 认证是将 OAuth2 授权框架与 JWT 结合,用 JWT 格式承载授权令牌的认证方式。A2A/MCP Server 支持 OAuth2 JWT 的认证方式。 @@ -51,13 +265,13 @@ OAuth2 JWT 认证是将 OAuth2 授权框架与 JWT 结合,用 JWT 格式承载 您可以通过脚手架创建 Agent 时指定 OAuth2 认证方式,或者部署已有项目时添加 `--auth-method=oauth2` 参数启用该认证,VeADK 将自动为您创建 Identity 用户池。如果您需要使用已有的用户池,您可以在部署时添加 `--user-pool-name` 参数指定用户池。 -在部署 A2A/MCP Server 应用后,您可以在 Identity 中管理客户端。 +在部署 A2A/MCP Server 应用后,您可以在 Identity 中管理客户端: 1. 登录火山引擎控制台,导航到 Agent Identity 服务 2. 在左侧导航树中,选择 身份认证 > 用户池管理,选择用户池 3. 在客户端的用户标签中,点击 新建客户端,填写 客户端名称,选择 客户端类型 并点击确定 -您可以创建 M2M 类型的客户端用于验证。您可以使用以下 curl 命令生成 JWT 令牌。 +您可以创建 M2M 类型的客户端用于验证。您可以使用以下 curl 命令生成 JWT 令牌: ```bash REGION="cn-beijing" @@ -71,4 +285,4 @@ curl --location "https://userpool-${USER_POOL_ID}.userpool.auth.id.${REGION}.vol --data-urlencode "grant_type=client_credentials" ``` -当用户访问 A2A/MCP Server 应用时,API 网关将验证用户携带的 JWT 令牌。您可以在 `Authorization` 请求头中获得用户的 JWT 令牌。 \ No newline at end of file +当用户访问 A2A/MCP Server 应用时,API 网关将验证用户携带的 JWT 令牌。您可以在 `Authorization` 请求头中获得用户的 JWT 令牌。 diff --git a/veadk/auth/middleware/__init__.py b/veadk/auth/middleware/__init__.py new file mode 100644 index 00000000..7f463206 --- /dev/null +++ b/veadk/auth/middleware/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/veadk/auth/middleware/oauth2_auth.py b/veadk/auth/middleware/oauth2_auth.py new file mode 100644 index 00000000..6decb5be --- /dev/null +++ b/veadk/auth/middleware/oauth2_auth.py @@ -0,0 +1,1386 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth2 3LO middleware for Starlette/FastAPI with VeIdentity User Pool integration. + +This middleware works with any Starlette-based framework, including FastAPI, Starlette, +and other ASGI frameworks built on Starlette. + +Quick start with FastAPI (recommended - using VeIdentity User Pool): + + from fastapi import FastAPI + from veadk.auth.middleware.oauth2_auth import OAuth2Config, setup_oauth2 + + app = FastAPI() + + setup_oauth2( + app, + OAuth2Config.from_veidentity( + user_pool_name="my-app", + client_name="my-app-web", + redirect_uri="https://myapp.com/oauth2/callback", + ), + ) + +Quick start with Starlette: + + from starlette.applications import Starlette + from veadk.auth.middleware.oauth2_auth import OAuth2Config, setup_oauth2 + + app = Starlette() + + setup_oauth2( + app, + OAuth2Config.from_veidentity( + user_pool_name="my-app", + client_name="my-app-web", + redirect_uri="https://myapp.com/oauth2/callback", + ), + ) + +For custom OAuth2 providers: + + setup_oauth2( + app, + OAuth2Config( + authorize_url="https://provider.com/oauth2/authorize", + token_url="https://provider.com/oauth2/token", + client_id="your-client-id", + client_secret="your-client-secret", + redirect_uri="https://myapp.com/oauth2/callback", + ), + ) +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +import logging +import random +import secrets +import time +import urllib.parse +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, runtime_checkable + +import httpx +from pydantic import BaseModel, Field +from starlette.applications import Starlette +from starlette.exceptions import HTTPException +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse +from starlette.routing import Route + +if TYPE_CHECKING: + from veadk.integrations.ve_identity import IdentityClient + +# Maximum cookie size before warning (browsers typically limit to 4KB). +_MAX_COOKIE_SIZE_WARNING = 3800 + +logger = logging.getLogger(__name__) + + +def _get_origin_from_url(url: str) -> str: + """Extract origin (scheme + host + port) from a URL.""" + parsed = urllib.parse.urlparse(url) + return f"{parsed.scheme}://{parsed.netloc}" + + +def _get_identity_client() -> "IdentityClient": + """Get IdentityClient from global config or create new instance.""" + try: + # Prefer global config for connection pool reuse + from veadk.config import settings + + return settings.veidentity.get_identity_client() + except Exception: + pass + + # Fallback to creating new instance + try: + from veadk.integrations.ve_identity import IdentityClient + + return IdentityClient() + except ImportError as e: + raise RuntimeError( + "VeIdentity integration requires veadk.integrations.ve_identity. " + "Ensure the veadk package is properly installed." + ) from e + + +__all__ = [ + # Configuration classes + "OAuth2Config", + "UserPoolClientType", + "OIDCDiscoveryConfig", + "OAuth2Session", + "OAuth2RoutePaths", + # Core handler + "OAuth2Handler", + # State store interface and implementations + "StateStore", + "InMemoryStateStore", + # Setup function + "setup_oauth2", + # Lower-level functions for manual integration + "register_oauth2_routes", + "create_oauth2_middleware", +] + + +@dataclass(frozen=True) +class OAuth2RoutePaths: + """Route paths used by the OAuth2 integration.""" + + login: str = "/oauth2/login" + callback: str = "/oauth2/callback" + logout: str = "/oauth2/logout" + userinfo: str = "/oauth2/userinfo" + + def all_paths(self) -> set[str]: + """Return all configured paths as a set.""" + return {self.login, self.callback, self.logout, self.userinfo} + + +class UserPoolClientType(str, Enum): + """VeIdentity User Pool client types.""" + + WEB_APPLICATION = "WEB_APPLICATION" + MOBILE_APPLICATION = "MOBILE_APPLICATION" + SINGLE_PAGE_APPLICATION = "SINGLE_PAGE_APPLICATION" + + +class OIDCDiscoveryConfig(BaseModel): + """OIDC Discovery configuration from .well-known/openid-configuration.""" + + issuer: str + authorization_endpoint: str + token_endpoint: str + userinfo_endpoint: Optional[str] = None + end_session_endpoint: Optional[str] = None + jwks_uri: Optional[str] = None + introspection_endpoint: Optional[str] = None + revocation_endpoint: Optional[str] = None + scopes_supported: list[str] = Field(default_factory=list) + response_types_supported: list[str] = Field(default_factory=list) + grant_types_supported: list[str] = Field(default_factory=list) + code_challenge_methods_supported: list[str] = Field(default_factory=list) + + model_config = {"extra": "ignore"} + + +def _fetch_oidc_discovery( + base_url: str, + timeout: float = 10.0, +) -> OIDCDiscoveryConfig: + """Fetch OIDC discovery configuration from .well-known endpoint. + + Args: + base_url: Base URL of the OIDC provider (e.g., https://domain.com). + timeout: HTTP request timeout in seconds. + + Returns: + OIDCDiscoveryConfig with discovered endpoints. + + Raises: + RuntimeError: If discovery fails. + """ + import httpx + + discovery_url = f"{base_url}/.well-known/openid-configuration" + logger.debug("Fetching OIDC discovery from: %s", discovery_url) + + try: + with httpx.Client(timeout=timeout) as client: + response = client.get(discovery_url) + response.raise_for_status() + data = response.json() + config = OIDCDiscoveryConfig.model_validate(data) + logger.info( + "OIDC discovery successful: issuer=%s", + config.issuer, + ) + return config + except httpx.HTTPStatusError as e: + raise RuntimeError( + f"OIDC discovery failed: HTTP {e.response.status_code} from {discovery_url}" + ) from e + except httpx.RequestError as e: + raise RuntimeError(f"OIDC discovery failed: {e} for {discovery_url}") from e + except Exception as e: + raise RuntimeError(f"OIDC discovery failed: {e}") from e + + +class OAuth2Config(BaseModel): + """OAuth2 configuration for 3LO authentication. + + Can be created manually or via `OAuth2Config.from_veidentity()` for VeIdentity + User Pool integration. + + Example - Manual configuration: + config = OAuth2Config( + authorize_url="https://provider.com/oauth2/authorize", + token_url="https://provider.com/oauth2/token", + client_id="your-client-id", + client_secret="your-client-secret", + redirect_uri="https://myapp.com/oauth2/callback", + ) + + Example - VeIdentity User Pool (recommended): + config = OAuth2Config.from_veidentity( + user_pool_name="my-app", + client_name="my-app-web", + redirect_uri="https://myapp.com/oauth2/callback", + ) + """ + + # OAuth2 provider endpoints + authorize_url: str + token_url: str + + # Client credentials + client_id: str + client_secret: Optional[str] = None + + # OAuth2 parameters + scope: str = "openid profile" + response_type: str = "code" + redirect_uri: str + extra_authorize_params: dict[str, str] = Field(default_factory=dict) + extra_token_params: dict[str, str] = Field(default_factory=dict) + use_pkce: bool = False + + # Session + cookie configuration + session_cookie_name: str = "veadk_session" + session_timeout_seconds: int = 3600 # 1 hour + cookie_secure: bool = True + cookie_samesite: str = "lax" + cookie_domain: Optional[str] = None + cookie_path: str = "/" + cookie_signing_secret: Optional[str] = None + + # User info and logout behavior + userinfo_url: Optional[str] = None + end_session_url: Optional[str] = None + user_id_cookie_name: str = "veadk_user_id" + user_id_field: str = "sub" + logout_redirect_url: str = "/" + + # State store behavior + state_ttl_seconds: int = 300 + state_max_entries: int = 10000 + + # HTTP client behavior + http_timeout_seconds: float = 10.0 + http_max_connections: int = 100 + http_max_keepalive_connections: int = 20 + + # Token refresh behavior + token_refresh_threshold_seconds: int = 300 # Refresh when < 5 min remaining + auto_refresh_token: bool = True + + # API vs browser behavior + api_path_prefixes: list[str] = Field(default_factory=lambda: ["/api/"]) + + @classmethod + def from_veidentity( + cls, + *, + user_pool_name: Optional[str] = None, + user_pool_uid: Optional[str] = None, + client_name: Optional[str] = None, + client_uid: Optional[str] = None, + redirect_uri: str, + auto_create: bool = True, + auto_register_callback: bool = True, + client_type: UserPoolClientType = UserPoolClientType.WEB_APPLICATION, + web_origin: Optional[str] = None, + scope: str = "openid profile email", + identity_client: Optional["IdentityClient"] = None, + **extra_config: Any, + ) -> "OAuth2Config": + """Create OAuth2Config from VeIdentity User Pool (recommended). + + This method automatically: + - Gets or creates the user pool + - Gets or creates the user pool client + - Registers the callback URL + - Builds the OAuth2Config with correct endpoints + + Args: + user_pool_name: Name of the VeIdentity user pool (used if user_pool_uid not set). + user_pool_uid: UID of the VeIdentity user pool (takes precedence over name). + client_name: Name of the user pool client (used if client_uid not set). + client_uid: UID of the user pool client (takes precedence over name). + redirect_uri: OAuth2 callback URL (e.g., https://myapp.com/oauth2/callback). + auto_create: Create user pool and client if not found (default: True). + auto_register_callback: Register callback URL with client (default: True). + client_type: Client type for new clients (default: WEB_APPLICATION). + web_origin: Web origin for CORS. Auto-detected from redirect_uri if not set. + scope: OAuth2 scopes to request (default: "openid profile email"). + identity_client: Custom IdentityClient instance. Uses global config if not provided. + **extra_config: Additional OAuth2Config options (e.g., cookie_secure=False). + + Returns: + Configured OAuth2Config instance. + + Raises: + ValueError: If user pool or client not found and auto_create=False. + RuntimeError: If VeIdentity module is not available. + + Example: + # Auto-create resources by name + config = OAuth2Config.from_veidentity( + user_pool_name="my-app", + client_name="my-app-web", + redirect_uri="https://myapp.com/oauth2/callback", + ) + + # Use existing resources by UID + config = OAuth2Config.from_veidentity( + user_pool_uid="pool-xxxx", + client_uid="client-xxxx", + redirect_uri="https://myapp.com/oauth2/callback", + auto_create=False, + ) + + # With extra config options + config = OAuth2Config.from_veidentity( + user_pool_name="my-app", + client_name="my-app-web", + redirect_uri="http://localhost:8000/oauth2/callback", + cookie_secure=False, # For local development + api_path_prefixes=["/api/", "/graphql"], + ) + """ + # Validate inputs + if not user_pool_name and not user_pool_uid: + raise ValueError("Either user_pool_name or user_pool_uid must be provided") + if not client_name and not client_uid: + raise ValueError("Either client_name or client_uid must be provided") + + # Get or create IdentityClient + if identity_client is None: + identity_client = _get_identity_client() + + # Step 1: Get or create user pool + user_pool = identity_client.get_user_pool( + name=user_pool_name, uid=user_pool_uid + ) + if user_pool: + user_pool_id, user_pool_domain = user_pool + logger.info( + "Using existing user pool: %s", + user_pool_uid or user_pool_name, + ) + elif auto_create and user_pool_name: + user_pool_id, user_pool_domain = identity_client.create_user_pool( + name=user_pool_name + ) + logger.info( + "Created user pool: %s (domain: %s)", user_pool_name, user_pool_domain + ) + else: + identifier = user_pool_uid or user_pool_name + raise ValueError( + f"User pool '{identifier}' not found (auto_create=False or only UID provided)" + ) + + # Step 2: Get or create client + client = identity_client.get_user_pool_client( + user_pool_uid=user_pool_id, + name=client_name, + client_uid=client_uid, + ) + if client: + resolved_client_id, client_secret = client + logger.info("Using existing client: %s", client_uid or client_name) + elif auto_create and client_name: + client_type_value = ( + client_type.value + if isinstance(client_type, UserPoolClientType) + else client_type + ) + resolved_client_id, client_secret = identity_client.create_user_pool_client( + user_pool_uid=user_pool_id, + name=client_name, + client_type=client_type_value, + ) + logger.info("Created client: %s", client_name) + else: + identifier = client_uid or client_name + raise ValueError( + f"Client '{identifier}' not found (auto_create=False or only UID provided)" + ) + + # Step 3: Register callback URL + if auto_register_callback: + detected_origin = _get_origin_from_url(redirect_uri) + callback_origin = web_origin or detected_origin + try: + identity_client.register_callback_for_user_pool_client( + user_pool_uid=user_pool_id, + client_uid=resolved_client_id, + callback_url=redirect_uri, + web_origin=callback_origin, + ) + logger.info("Registered callback: %s", redirect_uri) + except Exception as e: + logger.warning("Callback registration skipped (may exist): %s", e) + + # Step 4: Fetch OIDC discovery configuration + base_url = f"https://{user_pool_domain}" + oidc_config = _fetch_oidc_discovery(base_url) + + # Step 5: Build OAuth2Config from discovered endpoints + return cls( + authorize_url=oidc_config.authorization_endpoint, + token_url=oidc_config.token_endpoint, + userinfo_url=oidc_config.userinfo_endpoint, + end_session_url=oidc_config.end_session_endpoint, + client_id=resolved_client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + scope=scope, + cookie_signing_secret=extra_config.pop( + "cookie_signing_secret", client_secret + ), + **extra_config, + ) + + +class OAuth2Session(BaseModel): + """OAuth2 session data stored in cookies.""" + + access_token: str + token_type: str = "Bearer" + expires_at: float + refresh_token: Optional[str] = None + user_info: Optional[dict[str, Any]] = None + + def is_expired(self) -> bool: + """Check if the access token is expired.""" + return time.time() >= self.expires_at + + def is_refresh_needed(self, threshold_seconds: int = 300) -> bool: + """Check if token refresh is needed (expires within threshold).""" + return time.time() >= (self.expires_at - threshold_seconds) + + def can_refresh(self) -> bool: + """Check if this session has a refresh token available.""" + return bool(self.refresh_token) + + def time_until_expiry(self) -> float: + """Return seconds until token expires (negative if expired).""" + return self.expires_at - time.time() + + def to_authorization_header(self) -> str: + """Convert to Authorization header value.""" + return f"{self.token_type} {self.access_token}" + + +@runtime_checkable +class StateStore(Protocol): + """Protocol for OAuth2 state storage backends. + + Implement this protocol to provide custom state storage (e.g., Redis, database). + + Example Redis implementation: + class RedisStateStore: + def __init__(self, redis_client, ttl_seconds: int = 300): + self._redis = redis_client + self._ttl = ttl_seconds + + def create_state( + self, redirect_after_auth: str = "/", code_verifier: Optional[str] = None + ) -> str: + state = secrets.token_urlsafe(32) + data = json.dumps({ + "redirect_after_auth": redirect_after_auth, + "code_verifier": code_verifier, + }) + self._redis.setex(f"oauth2_state:{state}", self._ttl, data) + return state + + def validate_and_consume_state(self, state: str) -> Optional[dict[str, Any]]: + key = f"oauth2_state:{state}" + data = self._redis.get(key) + if not data: + return None + self._redis.delete(key) + return json.loads(data) + """ + + def create_state( + self, redirect_after_auth: str = "/", code_verifier: Optional[str] = None + ) -> str: + """Create and store a new OAuth2 state parameter.""" + ... + + def validate_and_consume_state(self, state: str) -> Optional[dict[str, Any]]: + """Validate, consume and return state data. Returns None if invalid/expired.""" + ... + + +class InMemoryStateStore: + """In-memory store for OAuth2 state parameters. + + Suitable for single-process deployments. For multi-process or distributed + deployments, implement the StateStore protocol with Redis or a database. + """ + + def __init__( + self, + ttl_seconds: int = 300, + max_entries: int = 10000, + prune_probability: float = 0.01, + ) -> None: + self._states: dict[str, dict[str, Any]] = {} + self._ttl_seconds = ttl_seconds + self._max_entries = max_entries + self._prune_probability = prune_probability + + def create_state( + self, redirect_after_auth: str = "/", code_verifier: Optional[str] = None + ) -> str: + """Create a new OAuth2 state parameter.""" + # Probabilistic pruning to avoid performance hit on every call. + if random.random() < self._prune_probability: + self._prune_expired() + + if len(self._states) >= self._max_entries: + self._prune_oldest() + + state = secrets.token_urlsafe(32) + self._states[state] = { + "created_at": time.time(), + "redirect_after_auth": redirect_after_auth, + "code_verifier": code_verifier, + } + return state + + def validate_and_consume_state(self, state: str) -> Optional[dict[str, Any]]: + """Validate and consume an OAuth2 state parameter.""" + state_data = self._states.pop(state, None) + if not state_data: + return None + + if time.time() - state_data["created_at"] > self._ttl_seconds: + return None + + return state_data + + def _prune_expired(self) -> None: + """Remove expired states to keep memory bounded.""" + now = time.time() + expired_keys = [ + key + for key, value in self._states.items() + if now - value["created_at"] > self._ttl_seconds + ] + for key in expired_keys: + self._states.pop(key, None) + + def _prune_oldest(self) -> None: + """Remove the oldest states when size limit is reached.""" + if len(self._states) <= self._max_entries: + return + + items = sorted(self._states.items(), key=lambda item: item[1]["created_at"]) + to_remove = len(self._states) - self._max_entries + for key, _ in items[:to_remove]: + self._states.pop(key, None) + + +class OAuth2Handler: + """Handles OAuth2 authentication flow for Starlette/FastAPI apps.""" + + def __init__( + self, + config: OAuth2Config, + state_store: Optional[StateStore] = None, + ): + self.config = config + self.state_store: StateStore = state_store or InMemoryStateStore( + ttl_seconds=config.state_ttl_seconds, + max_entries=config.state_max_entries, + ) + # Configure HTTP client with connection pool limits. + limits = httpx.Limits( + max_keepalive_connections=config.http_max_keepalive_connections, + max_connections=config.http_max_connections, + ) + self._http_client = httpx.AsyncClient( + timeout=config.http_timeout_seconds, + limits=limits, + ) + + async def close(self) -> None: + """Close the HTTP client.""" + await self._http_client.aclose() + + def build_authorization_request(self, redirect_after_auth: str) -> tuple[str, str]: + """Create state and authorization URL for a redirect.""" + code_verifier = self._generate_code_verifier() if self.config.use_pkce else None + state = self.state_store.create_state( + redirect_after_auth=redirect_after_auth, + code_verifier=code_verifier, + ) + auth_url = self.get_authorization_url(state, code_verifier=code_verifier) + return state, auth_url + + def get_authorization_url( + self, state: str, code_verifier: Optional[str] = None + ) -> str: + """Generate the OAuth2 authorization URL.""" + params = { + "response_type": self.config.response_type, + "client_id": self.config.client_id, + "scope": self.config.scope, + "redirect_uri": self.config.redirect_uri, + "state": state, + } + + if self.config.use_pkce: + if not code_verifier: + raise HTTPException( + status_code=400, detail="Missing PKCE code verifier" + ) + params["code_challenge"] = self._build_code_challenge(code_verifier) + params["code_challenge_method"] = "S256" + + params.update(self.config.extra_authorize_params) + return f"{self.config.authorize_url}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_token( + self, code: str, code_verifier: Optional[str] = None + ) -> OAuth2Session: + """Exchange authorization code for access token.""" + token_data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.config.redirect_uri, + } + + if self.config.use_pkce: + if not code_verifier: + raise HTTPException( + status_code=400, detail="Missing PKCE code verifier" + ) + token_data["code_verifier"] = code_verifier + + token_data.update(self.config.extra_token_params) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + # Use Basic Auth for client credentials when possible. + if self.config.client_secret: + credentials = f"{self.config.client_id}:{self.config.client_secret}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + headers["Authorization"] = f"Basic {encoded_credentials}" + else: + # Fallback to client_id in form data for public clients. + token_data["client_id"] = self.config.client_id + + try: + response = await self._http_client.post( + self.config.token_url, + data=token_data, + headers=headers, + ) + response.raise_for_status() + + try: + token_response = response.json() + except ValueError as exc: + raise HTTPException( + status_code=400, + detail="Token response is not valid JSON", + ) from exc + + if "access_token" not in token_response: + raise HTTPException( + status_code=400, + detail="Token response missing access_token", + ) + + expires_in = token_response.get("expires_in", 3600) + try: + expires_in = int(expires_in) + except (TypeError, ValueError): + expires_in = 3600 + + expires_at = time.time() + max(0, expires_in) + + session = OAuth2Session( + access_token=token_response["access_token"], + token_type=token_response.get("token_type", "Bearer"), + expires_at=expires_at, + refresh_token=token_response.get("refresh_token"), + ) + + if self.config.userinfo_url: + try: + user_info = await self._fetch_user_info(session.access_token) + session.user_info = user_info + logger.info( + "Successfully fetched user info for user: %s", + user_info.get("sub") + or user_info.get("email") + or user_info.get("id") + or "unknown", + ) + except Exception as e: + logger.warning("Failed to fetch user info: %s", e) + # Continue without user info. + + return session + + except httpx.HTTPStatusError as e: + logger.error("Token exchange failed: %s", e.response.text) + raise HTTPException( + status_code=400, + detail=f"Token exchange failed: {e.response.text}", + ) + except HTTPException: + raise + except Exception as e: + logger.error("Token exchange error: %s", e) + raise HTTPException(status_code=500, detail="Authentication failed") + + async def refresh_access_token( + self, session: OAuth2Session + ) -> Optional[OAuth2Session]: + """Refresh the access token using the refresh token. + + Returns a new OAuth2Session with updated tokens, or None if refresh fails. + The original session's user_info is preserved in the new session. + """ + if not session.refresh_token: + logger.debug("Cannot refresh: no refresh_token available") + return None + + token_data = { + "grant_type": "refresh_token", + "refresh_token": session.refresh_token, + } + token_data.update(self.config.extra_token_params) + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + + # Use Basic Auth for client credentials when possible. + if self.config.client_secret: + credentials = f"{self.config.client_id}:{self.config.client_secret}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + headers["Authorization"] = f"Basic {encoded_credentials}" + else: + token_data["client_id"] = self.config.client_id + + try: + response = await self._http_client.post( + self.config.token_url, + data=token_data, + headers=headers, + ) + response.raise_for_status() + + token_response = response.json() + + if "access_token" not in token_response: + logger.warning("Token refresh response missing access_token") + return None + + expires_in = token_response.get("expires_in", 3600) + try: + expires_in = int(expires_in) + except (TypeError, ValueError): + expires_in = 3600 + + expires_at = time.time() + max(0, expires_in) + + # Create new session with refreshed tokens, preserving user_info. + new_session = OAuth2Session( + access_token=token_response["access_token"], + token_type=token_response.get("token_type", "Bearer"), + expires_at=expires_at, + # Use new refresh_token if provided, otherwise keep the old one. + refresh_token=token_response.get( + "refresh_token", session.refresh_token + ), + user_info=session.user_info, + ) + + logger.info( + "Successfully refreshed access token, new expiry in %d seconds", + expires_in, + ) + return new_session + + except httpx.HTTPStatusError as e: + logger.warning("Token refresh failed: %s", e.response.text) + return None + except Exception as e: + logger.warning("Token refresh error: %s", e) + return None + + async def _fetch_user_info(self, access_token: str) -> dict[str, Any]: + """Fetch user information from the userinfo endpoint.""" + if not self.config.userinfo_url: + raise ValueError("userinfo_url not configured") + + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + } + + try: + response = await self._http_client.get( + self.config.userinfo_url, + headers=headers, + ) + response.raise_for_status() + + user_info = response.json() + logger.debug("Fetched user info: %s", user_info) + return user_info + + except httpx.HTTPStatusError as e: + logger.error("User info fetch failed: %s", e.response.text) + raise Exception(f"User info fetch failed: {e.response.text}") + except Exception as e: + logger.error("User info fetch error: %s", e) + raise Exception(f"User info fetch error: {e}") + + def encode_session(self, session: OAuth2Session) -> str: + """Encode OAuth2 session data for cookie storage. + + Warns if the encoded session exceeds the recommended cookie size limit. + """ + session_json = session.model_dump_json() + payload = self._base64url_encode(session_json.encode("utf-8")) + signing_key = self._get_cookie_signing_key() + + if not signing_key: + encoded = payload + else: + signature = hmac.new( + signing_key, payload.encode("ascii"), hashlib.sha256 + ).digest() + signature_text = self._base64url_encode(signature) + encoded = f"{payload}.{signature_text}" + + # Warn if cookie size approaches browser limits. + if len(encoded) > _MAX_COOKIE_SIZE_WARNING: + logger.warning( + "Session cookie size (%d bytes) approaching 4KB browser limit. " + "Consider reducing user_info stored in session or using server-side sessions.", + len(encoded), + ) + + return encoded + + def decode_session(self, encoded_session: str) -> Optional[OAuth2Session]: + """Decode OAuth2 session data from cookie.""" + try: + signing_key = self._get_cookie_signing_key() + + if "." in encoded_session: + payload, signature = encoded_session.split(".", 1) + if not signing_key: + logger.warning("Signed session cookie rejected without signing key") + return None + expected = hmac.new( + signing_key, payload.encode("ascii"), hashlib.sha256 + ).digest() + expected_signature = self._base64url_encode(expected) + if not hmac.compare_digest(signature, expected_signature): + logger.warning("Session signature mismatch") + return None + else: + payload = encoded_session + if signing_key: + logger.warning("Unsigned session cookie rejected") + return None + + session_bytes = self._base64url_decode(payload) + session_json = session_bytes.decode("utf-8") + session_data = json.loads(session_json) + return OAuth2Session.model_validate(session_data) + except Exception as e: + logger.warning("Failed to decode session: %s", e) + return None + + def get_session_from_request(self, request: Request) -> Optional[OAuth2Session]: + """Extract OAuth2 session from request cookies.""" + session_cookie = request.cookies.get(self.config.session_cookie_name) + if not session_cookie: + return None + + session = self.decode_session(session_cookie) + if not session or session.is_expired(): + return None + + return session + + def create_session_cookie(self, session: OAuth2Session) -> dict[str, Any]: + """Create session cookie parameters.""" + encoded_session = self.encode_session(session) + max_age = self._session_cookie_max_age(session) + + return { + "key": self.config.session_cookie_name, + "value": encoded_session, + "max_age": max_age, + "httponly": True, + "secure": self.config.cookie_secure, + "samesite": self.config.cookie_samesite, + "domain": self.config.cookie_domain, + "path": self.config.cookie_path, + } + + def create_user_id_cookie(self, session: OAuth2Session) -> Optional[dict[str, Any]]: + """Create user ID cookie for frontend access (non-HTTP-only).""" + if not session.user_info: + return None + + user_id = session.user_info.get(self.config.user_id_field) + if not user_id: + user_id = session.user_info.get("sub") or session.user_info.get("email") + if not user_id: + return None + + return { + "key": self.config.user_id_cookie_name, + "value": str(user_id), + "max_age": self._session_cookie_max_age(session), + "httponly": False, + "secure": self.config.cookie_secure, + "samesite": self.config.cookie_samesite, + "domain": self.config.cookie_domain, + "path": self.config.cookie_path, + } + + def _session_cookie_max_age(self, session: OAuth2Session) -> int: + token_ttl = max(0, int(session.expires_at - time.time())) + if self.config.session_timeout_seconds <= 0: + return token_ttl + return min(self.config.session_timeout_seconds, token_ttl) + + def _get_cookie_signing_key(self) -> Optional[bytes]: + # Fall back to client_secret when a dedicated signing secret is not provided. + secret = self.config.cookie_signing_secret or self.config.client_secret + if not secret: + return None + return secret.encode("utf-8") + + @staticmethod + def _base64url_encode(data: bytes) -> str: + return base64.urlsafe_b64encode(data).decode("ascii").rstrip("=") + + @staticmethod + def _base64url_decode(data: str) -> bytes: + padding = "=" * (-len(data) % 4) + return base64.urlsafe_b64decode(data + padding) + + @staticmethod + def _generate_code_verifier() -> str: + return OAuth2Handler._base64url_encode(secrets.token_bytes(32)) + + @staticmethod + def _build_code_challenge(code_verifier: str) -> str: + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + return OAuth2Handler._base64url_encode(digest) + + +def _resolve_redirect_after_auth(request: Request, redirect: Optional[str]) -> str: + """Resolve a safe redirect URL after login.""" + if not redirect: + return "/" + + redirect = redirect.strip() + if redirect.startswith("/"): + return redirect + + parsed = urllib.parse.urlparse(redirect) + if not parsed.scheme and not parsed.netloc: + return f"/{redirect.lstrip('/')}" + + current = urllib.parse.urlparse(str(request.url)) + if parsed.scheme == current.scheme and parsed.netloc == current.netloc: + return redirect + + logger.warning("Unsafe redirect ignored: %s", redirect) + return "/" + + +def register_oauth2_routes( + app: Starlette, + oauth2_handler: OAuth2Handler, + *, + routes: Optional[OAuth2RoutePaths] = None, +) -> OAuth2RoutePaths: + """Register OAuth2 callback/login/logout/userinfo routes. + + Works with both Starlette and FastAPI applications. + + Args: + app: The Starlette or FastAPI application instance. + oauth2_handler: The OAuth2Handler instance. + routes: Custom route paths (defaults to /oauth2/*). + + Returns: + The OAuth2RoutePaths used for registration. + """ + routes = routes or OAuth2RoutePaths() + + async def oauth2_login(request: Request) -> RedirectResponse: + """Start the OAuth2 authorization flow.""" + redirect = request.query_params.get("redirect") + redirect_after_auth = _resolve_redirect_after_auth(request, redirect) + _, auth_url = oauth2_handler.build_authorization_request(redirect_after_auth) + return RedirectResponse(url=auth_url, status_code=302) + + async def oauth2_callback(request: Request) -> RedirectResponse: + """Handle OAuth2 authorization callback.""" + params = request.query_params + error = params.get("error") + error_description = params.get("error_description") + code = params.get("code") + state = params.get("state") + + if error: + detail = error_description or error + raise HTTPException( + status_code=400, detail=f"OAuth2 authorization failed: {detail}" + ) + + if not code or not state: + raise HTTPException( + status_code=400, detail="Missing authorization code or state" + ) + + # Validate and consume the state to prevent replay attacks. + state_data = oauth2_handler.state_store.validate_and_consume_state(state) + if not state_data: + raise HTTPException( + status_code=400, detail="Invalid or expired state parameter" + ) + + try: + session = await oauth2_handler.exchange_code_for_token( + code, + code_verifier=state_data.get("code_verifier"), + ) + + # Create session cookie for subsequent requests. + session_cookie_params = oauth2_handler.create_session_cookie(session) + + redirect_url = state_data.get("redirect_after_auth") or "/" + response = RedirectResponse(url=redirect_url, status_code=302) + response.set_cookie(**session_cookie_params) + + # Set user ID cookie for frontend access (if user info available). + user_id_cookie_params = oauth2_handler.create_user_id_cookie(session) + if user_id_cookie_params: + response.set_cookie(**user_id_cookie_params) + logger.info( + "Set user ID cookie for user: %s", + user_id_cookie_params["value"], + ) + + return response + + except HTTPException: + raise + except Exception as e: + logger.error("OAuth2 callback error: %s", e) + raise HTTPException(status_code=500, detail="Authentication failed") + + async def oauth2_logout(request: Request) -> RedirectResponse: + """Logout and clear session cookies.""" + # Determine logout redirect URL + config = oauth2_handler.config + if config.end_session_url: + # Use OAuth2 provider's logout endpoint with post_logout redirect + logout_url = ( + f"{config.end_session_url}?" + f"post_logout_redirect_uri={urllib.parse.quote(config.logout_redirect_url)}&" + f"client_id={config.client_id}" + ) + else: + logout_url = config.logout_redirect_url + + response = RedirectResponse(url=logout_url, status_code=302) + response.delete_cookie( + config.session_cookie_name, + domain=config.cookie_domain, + path=config.cookie_path, + ) + response.delete_cookie( + config.user_id_cookie_name, + domain=config.cookie_domain, + path=config.cookie_path, + ) + logger.info("User logged out, cleared session and user ID cookies") + return response + + async def get_current_user_info(request: Request) -> JSONResponse: + """Get current user information from OAuth2 session.""" + session = oauth2_handler.get_session_from_request(request) + + if not session or session.is_expired(): + raise HTTPException(status_code=401, detail="Not authenticated") + + if not session.user_info: + if oauth2_handler.config.userinfo_url: + try: + user_info = await oauth2_handler._fetch_user_info( + session.access_token + ) + session.user_info = user_info + + session_cookie_params = oauth2_handler.create_session_cookie( + session + ) + response = JSONResponse(content=user_info) + response.set_cookie(**session_cookie_params) + + user_id_cookie_params = oauth2_handler.create_user_id_cookie( + session + ) + if user_id_cookie_params: + response.set_cookie(**user_id_cookie_params) + + return response + except Exception as e: + logger.error("Failed to fetch user info: %s", e) + raise HTTPException( + status_code=500, detail="Failed to fetch user info" + ) + + return JSONResponse( + content={ + "message": "User info not available", + "reason": "userinfo_url not configured", + } + ) + + return JSONResponse(content=session.user_info) + + # Register routes using Starlette's Route objects (works with both Starlette and FastAPI) + oauth2_routes = [ + Route(routes.login, oauth2_login, methods=["GET"]), + Route(routes.callback, oauth2_callback, methods=["GET"]), + Route(routes.logout, oauth2_logout, methods=["GET"]), + Route(routes.userinfo, get_current_user_info, methods=["GET"]), + ] + app.routes.extend(oauth2_routes) + + return routes + + +def setup_oauth2( + app: Starlette, + config: OAuth2Config, + *, + routes: Optional[OAuth2RoutePaths] = None, + exempt_paths: Optional[Iterable[str]] = None, + exempt_prefixes: Optional[Iterable[str]] = None, + state_store: Optional[StateStore] = None, +) -> OAuth2Handler: + """Install OAuth2 routes, middleware, and shutdown hook. + + Works with both Starlette and FastAPI applications. + + Example with VeIdentity User Pool (recommended): + setup_oauth2( + app, + OAuth2Config.from_veidentity( + user_pool_name="my-app", + client_name="my-app-web", + redirect_uri="https://myapp.com/oauth2/callback", + ), + ) + + Example with custom OAuth2 provider: + setup_oauth2( + app, + OAuth2Config( + authorize_url="https://provider.com/oauth2/authorize", + token_url="https://provider.com/oauth2/token", + client_id="...", + client_secret="...", + redirect_uri="https://myapp.com/oauth2/callback", + ), + ) + + Args: + app: The Starlette or FastAPI application instance. + config: OAuth2 configuration. Use OAuth2Config.from_veidentity() for VeIdentity. + routes: Custom route paths (defaults to /oauth2/*). + exempt_paths: Paths that skip authentication (exact match). + exempt_prefixes: Path prefixes that skip authentication. + state_store: Custom state store for distributed deployments. + + Returns: + The OAuth2Handler instance, also available at app.state.oauth2_handler. + """ + oauth2_handler = OAuth2Handler(config, state_store=state_store) + route_paths = register_oauth2_routes(app, oauth2_handler, routes=routes) + + merged_exempt_paths = set(route_paths.all_paths()) + if exempt_paths: + merged_exempt_paths.update(exempt_paths) + + app.add_middleware( + BaseHTTPMiddleware, + dispatch=create_oauth2_middleware( + oauth2_handler, + exempt_paths=merged_exempt_paths, + exempt_prefixes=exempt_prefixes, + ), + ) + + if hasattr(app, "add_event_handler"): + app.add_event_handler("shutdown", oauth2_handler.close) + if hasattr(app, "state"): + app.state.oauth2_handler = oauth2_handler + + return oauth2_handler + + +def _is_api_request(request: Request, api_prefixes: list[str]) -> bool: + """Determine if a request is an API request (should get 401, not redirect).""" + # Check Accept header for JSON preference. + accept = request.headers.get("accept", "") + if "application/json" in accept and "text/html" not in accept: + return True + + # Check configured API path prefixes. + path = request.url.path + for prefix in api_prefixes: + if path.startswith(prefix): + return True + + # Check X-Requested-With header (common for AJAX requests). + if request.headers.get("x-requested-with", "").lower() == "xmlhttprequest": + return True + + return False + + +def create_oauth2_middleware( + oauth2_handler: OAuth2Handler, + *, + exempt_paths: Optional[Iterable[str]] = None, + exempt_prefixes: Optional[Iterable[str]] = None, + allow_existing_authorization: bool = True, +): + """Create OAuth2 authentication middleware for Starlette/FastAPI. + + The middleware: + - Skips authentication for exempt paths/prefixes and OPTIONS requests. + - Passes through requests that already have an Authorization header. + - Injects the OAuth2 access token as an Authorization header for valid sessions. + - Auto-refreshes tokens when they are close to expiry (if configured). + - Returns 401 JSON for API requests, redirects browsers to login for others. + """ + exempt_paths_set = set(exempt_paths or []) + exempt_prefixes_tuple = tuple(exempt_prefixes or []) + config = oauth2_handler.config + + async def oauth2_middleware(request: Request, call_next): + """OAuth2 authentication middleware.""" + path = request.url.path + + # Allow preflight requests and explicitly exempted paths. + if request.method == "OPTIONS": + return await call_next(request) + if path in exempt_paths_set or any( + path.startswith(prefix) for prefix in exempt_prefixes_tuple + ): + return await call_next(request) + + # Pass through if there's already an Authorization header. + if allow_existing_authorization and "authorization" in request.headers: + return await call_next(request) + + session = oauth2_handler.get_session_from_request(request) + response_cookies: list[dict[str, Any]] = [] + + # Attempt token refresh if session is close to expiry. + if session and not session.is_expired(): + if ( + config.auto_refresh_token + and session.can_refresh() + and session.is_refresh_needed(config.token_refresh_threshold_seconds) + ): + logger.debug( + "Token expires in %.0f seconds, attempting refresh", + session.time_until_expiry(), + ) + refreshed = await oauth2_handler.refresh_access_token(session) + if refreshed: + session = refreshed + # Queue cookies to be set on the response. + response_cookies.append( + oauth2_handler.create_session_cookie(session) + ) + user_id_cookie = oauth2_handler.create_user_id_cookie(session) + if user_id_cookie: + response_cookies.append(user_id_cookie) + + if session and not session.is_expired(): + auth_header_value = session.to_authorization_header().encode() + + # Update the scope headers so downstream dependencies can read it. + headers = list(request.scope.get("headers", [])) + headers = [ + (name, value) + for name, value in headers + if name.lower() != b"authorization" + ] + headers.append((b"authorization", auth_header_value)) + request.scope["headers"] = headers + + logger.debug( + "Added Authorization header to request: %s...", + session.to_authorization_header()[:20], + ) + + response = await call_next(request) + + # Set any refreshed session cookies on the response. + for cookie_params in response_cookies: + response.set_cookie(**cookie_params) + + return response + + # No valid session - handle API vs browser requests differently. + if _is_api_request(request, config.api_path_prefixes): + return JSONResponse( + status_code=401, + content={"detail": "Not authenticated"}, + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Browser request: redirect to OAuth2 authorization. + _, auth_url = oauth2_handler.build_authorization_request(str(request.url)) + return RedirectResponse(url=auth_url, status_code=302) + + return oauth2_middleware diff --git a/veadk/integrations/ve_identity/identity_client.py b/veadk/integrations/ve_identity/identity_client.py index 91022703..076618c7 100644 --- a/veadk/integrations/ve_identity/identity_client.py +++ b/veadk/integrations/ve_identity/identity_client.py @@ -743,27 +743,59 @@ def create_user_pool(self, name: str) -> tuple[str, str]: return response.uid, response.domain - def get_user_pool(self, name: str) -> tuple[str, str] | None: + def get_user_pool( + self, + name: Optional[str] = None, + uid: Optional[str] = None, + ) -> tuple[str, str] | None: + """Get user pool by name or UID. + + Args: + name: User pool name (used for list query). + uid: User pool UID (used for direct get query). + + Returns: + Tuple of (uid, domain) if found, None otherwise. + + Raises: + ValueError: If neither name nor uid is provided. + """ from volcenginesdkid import ( ListUserPoolsRequest, ListUserPoolsResponse, + GetUserPoolRequest, + GetUserPoolResponse, FilterForListUserPoolsInput, DataForListUserPoolsOutput, ) - request = ListUserPoolsRequest( - page_number=1, - page_size=1, - filter=FilterForListUserPoolsInput( - name=name, - ), - ) - response: ListUserPoolsResponse = self._api_client.list_user_pools(request) - if response.total_count == 0: - return None + if uid: + # Direct get by UID + request = GetUserPoolRequest(uid=uid) + try: + response: GetUserPoolResponse = self._api_client.get_user_pool(request) + return response.uid, response.domain + except Exception as e: + logger.warning(f"Failed to get user pool by UID {uid}: {e}") + return None + + if name: + # List query by name + request = ListUserPoolsRequest( + page_number=1, + page_size=1, + filter=FilterForListUserPoolsInput( + name=name, + ), + ) + response: ListUserPoolsResponse = self._api_client.list_user_pools(request) + if response.total_count == 0: + return None + + user_pool: DataForListUserPoolsOutput = response.data[0] + return user_pool.uid, user_pool.domain - user_pool: DataForListUserPoolsOutput = response.data[0] - return user_pool.uid, user_pool.domain + raise ValueError("Either name or uid must be provided") def create_user_pool_client( self, user_pool_uid: str, name: str, client_type: str @@ -828,8 +860,24 @@ def register_callback_for_user_pool_client( self._api_client.update_user_pool_client(request2) def get_user_pool_client( - self, user_pool_uid: str, name: str + self, + user_pool_uid: str, + name: Optional[str] = None, + client_uid: Optional[str] = None, ) -> tuple[str, str] | None: + """Get user pool client by name or client UID. + + Args: + user_pool_uid: User pool UID (required). + name: Client name (used for list query). + client_uid: Client UID (used for direct get query). + + Returns: + Tuple of (client_uid, client_secret) if found, None otherwise. + + Raises: + ValueError: If neither name nor client_uid is provided. + """ from volcenginesdkid import ( ListUserPoolClientsRequest, ListUserPoolClientsResponse, @@ -839,26 +887,45 @@ def get_user_pool_client( GetUserPoolClientResponse, ) - request = ListUserPoolClientsRequest( - user_pool_uid=user_pool_uid, - page_number=1, - page_size=1, - filter=FilterForListUserPoolClientsInput( - name=name, - ), - ) - response: ListUserPoolClientsResponse = self._api_client.list_user_pool_clients( - request - ) - if response.total_count == 0: - return None + if client_uid: + # Direct get by client UID + request = GetUserPoolClientRequest( + user_pool_uid=user_pool_uid, + client_uid=client_uid, + ) + try: + response: GetUserPoolClientResponse = ( + self._api_client.get_user_pool_client(request) + ) + return response.uid, response.client_secret + except Exception as e: + logger.warning(f"Failed to get client by UID {client_uid}: {e}") + return None + + if name: + # List query by name + request = ListUserPoolClientsRequest( + user_pool_uid=user_pool_uid, + page_number=1, + page_size=1, + filter=FilterForListUserPoolClientsInput( + name=name, + ), + ) + response: ListUserPoolClientsResponse = ( + self._api_client.list_user_pool_clients(request) + ) + if response.total_count == 0: + return None - client: DataForListUserPoolClientsOutput = response.data[0] - request2 = GetUserPoolClientRequest( - user_pool_uid=user_pool_uid, - client_uid=client.uid, - ) - response2: GetUserPoolClientResponse = self._api_client.get_user_pool_client( - request2 - ) - return response2.uid, response2.client_secret + client: DataForListUserPoolClientsOutput = response.data[0] + request2 = GetUserPoolClientRequest( + user_pool_uid=user_pool_uid, + client_uid=client.uid, + ) + response2: GetUserPoolClientResponse = ( + self._api_client.get_user_pool_client(request2) + ) + return response2.uid, response2.client_secret + + raise ValueError("Either name or client_uid must be provided") From 4b862433f2cecb9d6b727eac05b2c5b04c85032a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E7=8E=89=E6=B6=B5?= Date: Tue, 27 Jan 2026 23:41:31 +0800 Subject: [PATCH 2/2] Set user pool sign-up and recovery options Explicitly disables self sign-up and enables self account recovery when creating a user pool in IdentityClient. --- veadk/integrations/ve_identity/identity_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/veadk/integrations/ve_identity/identity_client.py b/veadk/integrations/ve_identity/identity_client.py index 076618c7..317ef621 100644 --- a/veadk/integrations/ve_identity/identity_client.py +++ b/veadk/integrations/ve_identity/identity_client.py @@ -738,6 +738,8 @@ def create_user_pool(self, name: str) -> tuple[str, str]: request = CreateUserPoolRequest( name=name, + self_sign_up_enabled=False, + self_account_recovery_enabled=True, ) response: CreateUserPoolResponse = self._api_client.create_user_pool(request)