diff --git a/src/authsome/auth/models/config.py b/src/authsome/auth/models/config.py index dd00397..ca2d63f 100644 --- a/src/authsome/auth/models/config.py +++ b/src/authsome/auth/models/config.py @@ -3,6 +3,7 @@ from __future__ import annotations from importlib.metadata import PackageNotFoundError, version +from typing import Literal from pydantic import BaseModel, Field @@ -28,10 +29,22 @@ class EncryptionConfig(BaseModel): mode: str = "local_key" +class ProxyConfig(BaseModel): + """Proxy configuration block.""" + + mode: Literal[ + "connected_allow", + "connected_deny", + "configured_allow", + "configured_deny", + ] = "connected_allow" + + class ServerConfig(BaseModel): """Daemon-owned server configuration.""" spec_version: int = Field(default_factory=current_spec_version) encryption: EncryptionConfig = Field(default_factory=EncryptionConfig) + proxy: ProxyConfig | None = Field(default_factory=ProxyConfig) model_config = {"extra": "allow"} diff --git a/src/authsome/auth/service.py b/src/authsome/auth/service.py index 9d75015..ac52f93 100644 --- a/src/authsome/auth/service.py +++ b/src/authsome/auth/service.py @@ -44,7 +44,7 @@ TokenExpiredError, UnsupportedFlowError, ) -from authsome.server.dependencies import list_registered_identity_handles +from authsome.server.dependencies import list_registered_identity_handles, load_server_config from authsome.utils import build_store_key, format_duration, is_filesystem_safe, parse_store_key, utc_now from authsome.vault import Vault @@ -153,52 +153,68 @@ async def is_local_provider(self, provider: str) -> bool: val = await self._vault.get(provider, collection="providers") return val is not None + async def proxy_mode(self) -> str: + """Return the configured proxy mode (e.g. "connected_allow").""" + config = load_server_config(self._vault.home) + if config.proxy is not None: + return config.proxy.mode + return "connected_allow" + async def proxy_routes(self) -> dict[str, Any]: """Build the list of routes for proxy routing.""" - from urllib.parse import urlparse + mode = await self.proxy_mode() + scope = mode.split("_", 1)[0] routes = [] - for provider_group in await self.list_connections(): - provider_name = provider_group["name"] - selected_connections = provider_group["connections"] + if scope == "connected": + for provider_group in await self.list_connections(): + provider_name = provider_group["name"] + selected_connections = provider_group["connections"] - try: - definition = await self.get_provider(provider_name) - except Exception: - continue - if not definition.host_url: - continue + try: + definition = await self.get_provider(provider_name) + except Exception: + continue - # Find the default connection - default_conn = next((c for c in selected_connections if c.get("is_default")), None) - if not default_conn: - continue + if not definition.host_url: + continue + + # Find the default connection + default_conn = next((c for c in selected_connections if c.get("is_default")), None) + if not default_conn: + continue + + routes.append(self._build_route_entry(definition, default_conn.get("connection_name", "default"))) + else: # configured + for definition in await self.list_providers(): + if not definition.host_url: + continue + routes.append(self._build_route_entry(definition, "default")) - paths: set[str] = set() - if definition.oauth: - for raw_url in [ - definition.oauth.authorization_url, - definition.oauth.token_url, - definition.oauth.revocation_url, - definition.oauth.device_authorization_url, - (definition.registration.registration_endpoint if definition.registration else None), - ]: - if not raw_url: - continue - parsed = urlparse(raw_url) - paths.add(parsed.path or "/") - - routes.append( - { - "provider": provider_name, - "connection": default_conn.get("connection_name", "default"), - "host_url": definition.host_url, - "auth_endpoint_paths": sorted(list(paths)), - } - ) routes.sort(key=lambda r: (r["host_url"].startswith("regex:"), r["provider"])) return {"routes": routes} + def _build_route_entry(self, definition: ProviderDefinition, connection_name: str) -> dict[str, Any]: + paths: set[str] = set() + if definition.oauth: + for raw_url in [ + definition.oauth.authorization_url, + definition.oauth.token_url, + definition.oauth.revocation_url, + definition.oauth.device_authorization_url, + (definition.registration.registration_endpoint if definition.registration else None), + ]: + if not raw_url: + continue + parsed = urlparse(raw_url) + paths.add(parsed.path or "/") + return { + "provider": definition.name, + "connection": connection_name, + "host_url": definition.host_url, + "auth_endpoint_paths": sorted(list(paths)), + } + async def resolve_credentials(self, **kwargs: Any) -> dict[str, Any]: """Resolve credentials for a provider/connection pair.""" provider = kwargs["provider"] diff --git a/src/authsome/cli/client.py b/src/authsome/cli/client.py index 8c2c48d..9213815 100644 --- a/src/authsome/cli/client.py +++ b/src/authsome/cli/client.py @@ -202,6 +202,11 @@ async def proxy_routes(self) -> dict[str, Any]: """Return proxy routes from a PoP-protected daemon endpoint.""" return await self._get("/proxy/routes") + async def proxy_mode(self) -> str: + """Return the configured proxy mode from the daemon.""" + data = await self._get("/proxy/mode") + return data["mode"] + async def resolve_credentials(self, **kwargs: Any) -> dict[str, Any]: """Resolve proxy credentials from a PoP-protected daemon endpoint.""" return await self._post("/credentials/resolve", kwargs) diff --git a/src/authsome/cli/main.py b/src/authsome/cli/main.py index 6a5c438..8621545 100644 --- a/src/authsome/cli/main.py +++ b/src/authsome/cli/main.py @@ -5,7 +5,7 @@ import pathlib import sys from pathlib import Path -from typing import Any +from typing import Any, Literal, cast import click import requests @@ -962,6 +962,61 @@ async def ui(ctx_obj: ContextObj, no_browser: bool) -> None: webbrowser.open(url) +PROXY_MODE_CHOICES = ( + "connected_allow", + "connected_deny", + "configured_allow", + "configured_deny", +) + + +@cli.group(name="proxy") +def proxy() -> None: + """Manage proxy behavior settings.""" + + +@proxy.command(name="mode") +@click.argument("value", required=False, type=click.Choice(PROXY_MODE_CHOICES)) +@auth_command +async def proxy_mode(ctx_obj: ContextObj, value: str | None) -> None: + """Show or set the persisted proxy mode. + + Without arguments, prints the current mode. With VALUE, updates the + persisted ServerConfig.proxy.mode. Changes take effect on the next + `authsome run` invocation (the proxy reads the mode at startup). + """ + from authsome.auth.models.config import ProxyConfig + from authsome.server.dependencies import load_server_config, save_server_config + + home = Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome"))) + + if value is None: + cfg = load_server_config(home) + current = cfg.proxy.mode if cfg.proxy is not None else "connected_allow" + if ctx_obj.json_output: + ctx_obj.print_json({"mode": current}) + else: + ctx_obj.echo(current) + return + + mode_value = cast( + Literal["connected_allow", "connected_deny", "configured_allow", "configured_deny"], + value, + ) + cfg = load_server_config(home) + if cfg.proxy is None: + cfg.proxy = ProxyConfig(mode=mode_value) + else: + cfg.proxy.mode = mode_value + save_server_config(cfg, home) + logger.info("proxy_mode_set mode={}", mode_value) + + if ctx_obj.json_output: + ctx_obj.print_json({"status": "updated", "mode": value}) + else: + ctx_obj.echo(f"proxy.mode = {value}", color="green") + + @cli.group() def daemon() -> None: """Manage the local Authsome daemon.""" diff --git a/src/authsome/proxy/runner.py b/src/authsome/proxy/runner.py index c6821bb..1a1fc60 100644 --- a/src/authsome/proxy/runner.py +++ b/src/authsome/proxy/runner.py @@ -25,6 +25,8 @@ async def proxy_routes(self) -> Any: ... async def list_providers_by_source(self) -> Any: ... + async def proxy_mode(self) -> str: ... + class ProxyRunner: """Launch a subprocess behind the Authsome local auth proxy.""" diff --git a/src/authsome/proxy/server.py b/src/authsome/proxy/server.py index eb25b11..f3f9ec0 100644 --- a/src/authsome/proxy/server.py +++ b/src/authsome/proxy/server.py @@ -17,6 +17,7 @@ from mitmproxy.options import Options from mitmproxy.tools.dump import DumpMaster +from authsome import audit from authsome.proxy.router import RouteMatch, RouteResolution from authsome.utils import utc_now @@ -36,6 +37,8 @@ async def resolve_credentials(self, **kwargs: Any) -> Any: ... async def proxy_routes(self) -> Any: ... + async def proxy_mode(self) -> str: ... + @dataclass(frozen=True) class _RouteTarget: @@ -334,19 +337,27 @@ class AuthProxyAddon: def __init__(self, client: ProxyClient) -> None: self._client = client self._router: ProxyRouter | None = None + self._mode: str | None = None self._header_cache: dict[tuple[str, str], _HeaderCacheEntry] = {} self._header_locks: dict[tuple[str, str], asyncio.Lock] = {} - async def _get_router(self) -> ProxyRouter: + async def _ensure_initialized(self) -> tuple[ProxyRouter, str]: + """Build the router and read the proxy mode once at router-build time.""" if self._router is None: self._router = await ProxyRouter.create(self._client) - return self._router + if self._mode is None: + self._mode = await self._client.proxy_mode() + return self._router, self._mode async def request(self, flow: http.HTTPFlow) -> None: - router = await self._get_router() + router, mode = await self._ensure_initialized() + policy = mode.split("_", 1)[1] + resolution = router.resolve(flow.request.scheme, flow.request.host, flow.request.port, flow.request.path) if resolution.match is None: - if resolution.miss_reason is not None: + if resolution.miss_reason == "no_match" and policy == "deny": + self._deny_request(flow, "no_match") + elif resolution.miss_reason is not None: normalized_host = _normalize_host(flow.request.host) logger.info( "client_event event=proxy_miss host={} reason={} method={} path={}", @@ -355,7 +366,7 @@ async def request(self, flow: http.HTTPFlow) -> None: flow.request.method, flow.request.path, ) - logger.error( + logger.debug( "Proxy miss: host={} reason={} {} {}", normalized_host, resolution.miss_reason, @@ -367,12 +378,23 @@ async def request(self, flow: http.HTTPFlow) -> None: match = resolution.match try: headers = await self._get_auth_headers(match) - except Exception: + except Exception as exc: + normalized_host = _normalize_host(flow.request.host) + audit.log( + "proxy_no_credentials", + host=normalized_host, + provider=match.provider, + connection=match.connection, + ) logger.warning( - "Failed to retrieve auth headers for provider={} connection={}. Forwarding unchanged.", + "No credentials for provider={} connection={} host={}: {}", match.provider, match.connection, + normalized_host, + exc, ) + if policy == "deny": + self._deny_request(flow, "no_credentials", match=match) return for key, value in headers.items(): @@ -387,6 +409,21 @@ async def request(self, flow: http.HTTPFlow) -> None: flow.request.path, ) + def _deny_request( + self, + flow: http.HTTPFlow, + reason: str, + *, + match: RouteMatch | None = None, + ) -> None: + host = _normalize_host(flow.request.host) + audit.log("proxy_deny", host=host, reason=reason) + logger.warning("Proxy deny: host={} reason={}", host, reason) + if flow.request.method.upper() == "CONNECT": + flow.kill() + return + flow.response = http.Response.make(403, _deny_body(reason, match).encode("utf-8")) + async def _get_auth_headers(self, match: RouteMatch) -> dict[str, str]: cache_key = (match.provider, match.connection or "") now = utc_now() @@ -425,6 +462,27 @@ async def _get_auth_headers(self, match: RouteMatch) -> dict[str, str]: return headers +def _deny_body(reason: str, match: RouteMatch | None) -> str: + """Build a human-readable 403 body for a denied proxy request. + + For `no_credentials` we surface the provider name plus a CLI command + and a dashboard URL so the agent (or human) can recover; other + reasons fall back to a generic message. + + The dashboard URL assumes the default local daemon on + `127.0.0.1:7998`. It still requires an active dashboard session + (`authsome ui`) to land on the connect screen directly. + """ + if reason == "no_credentials" and match is not None: + provider = match.provider + return ( + f"Forbidden: provider '{provider}' is configured but has no " + f"active connection. Run `authsome login {provider}` to connect, " + f"or visit http://127.0.0.1:7998/ui/apps/{provider}." + ) + return "Forbidden by Authsome proxy policy" + + def _header_cache_valid(entry: _HeaderCacheEntry, now: datetime) -> bool: if entry.expires_at is None: return True diff --git a/src/authsome/server/routes/proxy.py b/src/authsome/server/routes/proxy.py index caace1a..252f5f4 100644 --- a/src/authsome/server/routes/proxy.py +++ b/src/authsome/server/routes/proxy.py @@ -9,6 +9,7 @@ from authsome.server.schemas import ( CredentialResolutionRequest, CredentialResolutionResponse, + ProxyModeResponse, ProxyRoutesResponse, ) @@ -21,6 +22,11 @@ async def proxy_routes(auth: AuthService = Depends(get_protected_auth_service)) return ProxyRoutesResponse.model_validate(data) +@router.get("/proxy/mode", response_model=ProxyModeResponse) +async def proxy_mode(auth: AuthService = Depends(get_protected_auth_service)) -> ProxyModeResponse: + return ProxyModeResponse.model_validate({"mode": await auth.proxy_mode()}) + + @router.post("/credentials/resolve", response_model=CredentialResolutionResponse) async def resolve_credentials( body: CredentialResolutionRequest, diff --git a/src/authsome/server/schemas.py b/src/authsome/server/schemas.py index 06041ab..c7adb05 100644 --- a/src/authsome/server/schemas.py +++ b/src/authsome/server/schemas.py @@ -87,3 +87,12 @@ class ProviderRoute(BaseModel): class ProxyRoutesResponse(BaseModel): routes: list[ProviderRoute] + + +class ProxyModeResponse(BaseModel): + mode: Literal[ + "connected_allow", + "connected_deny", + "configured_allow", + "configured_deny", + ] diff --git a/src/authsome/server/ui/pages.py b/src/authsome/server/ui/pages.py index 1e309e5..99eef22 100644 --- a/src/authsome/server/ui/pages.py +++ b/src/authsome/server/ui/pages.py @@ -39,7 +39,7 @@ def input_page( optional_rows = [] for field in fields: row = _field_row(field) - if field.get("default") is None: + if field.get("default") is None or field.get("name") == "client_secret": required_rows.append(row) else: optional_rows.append(row) diff --git a/tests/cli/test_proxy.py b/tests/cli/test_proxy.py new file mode 100644 index 0000000..8db071b --- /dev/null +++ b/tests/cli/test_proxy.py @@ -0,0 +1,55 @@ +"""Tests for ``authsome proxy mode``.""" + +from __future__ import annotations + +import asyncio +import json +from pathlib import Path + +from click.testing import CliRunner + +from authsome.cli.main import cli +from authsome.server.dependencies import load_server_config +from authsome.store.local import LocalAppStore + + +def test_proxy_mode_defaults_to_connected_allow_when_unset( + runner: CliRunner, + tmp_path: Path, +) -> None: + asyncio.run(LocalAppStore(tmp_path).ensure_initialized()) + + result = runner.invoke(cli, ["--log-file", "", "proxy", "mode", "--json"]) + + assert result.exit_code == 0, result.output + data = json.loads(result.output) + assert data["mode"] == "connected_allow" + + +def test_proxy_mode_sets_and_persists_value( + runner: CliRunner, + tmp_path: Path, +) -> None: + asyncio.run(LocalAppStore(tmp_path).ensure_initialized()) + + set_result = runner.invoke(cli, ["--log-file", "", "proxy", "mode", "configured_deny", "--json"]) + assert set_result.exit_code == 0, set_result.output + set_data = json.loads(set_result.output) + assert set_data["status"] == "updated" + assert set_data["mode"] == "configured_deny" + + persisted = load_server_config(tmp_path) + assert persisted.proxy is not None + assert persisted.proxy.mode == "configured_deny" + + show_result = runner.invoke(cli, ["--log-file", "", "proxy", "mode", "--json"]) + assert show_result.exit_code == 0, show_result.output + assert json.loads(show_result.output)["mode"] == "configured_deny" + + +def test_proxy_mode_rejects_unknown_value(runner: CliRunner, tmp_path: Path) -> None: + asyncio.run(LocalAppStore(tmp_path).ensure_initialized()) + + result = runner.invoke(cli, ["--log-file", "", "proxy", "mode", "bogus"]) + assert result.exit_code != 0 + assert "Invalid value" in result.output or "invalid choice" in result.output.lower() diff --git a/tests/proxy/test_proxy.py b/tests/proxy/test_proxy.py index 16baa10..bb5b453 100644 --- a/tests/proxy/test_proxy.py +++ b/tests/proxy/test_proxy.py @@ -461,7 +461,7 @@ def _make_flow(self, scheme="https", host="api.openai.com", port=443, path="/v1/ flow.request.headers = headers if headers is not None else {} return flow - def _make_addon(self, auth, match, *, miss_reason=None): + def _make_addon(self, auth, match, *, miss_reason=None, mode="connected_allow"): router = mock.Mock() router.resolve.return_value = RouteResolution(match=match, miss_reason=miss_reason) @@ -471,6 +471,7 @@ def _make_addon(self, auth, match, *, miss_reason=None): patcher = patch("authsome.proxy.server.ProxyRouter.create", mock_create) patcher.start() + auth.proxy_mode.return_value = mode addon = AuthProxyAddon(client=auth) return addon, router, patcher @@ -515,6 +516,67 @@ async def test_addon_skips_unmatched_request(self) -> None: auth.resolve_credentials.assert_not_called() + @pytest.mark.asyncio + async def test_addon_denies_no_match_with_generic_body_in_connected_deny(self) -> None: + auth = mock.AsyncMock() + flow = self._make_flow(host="example.com", path="/") + + with patch("authsome.proxy.server.audit.log") as log_mock: + addon, _router, patcher = self._make_addon(auth, None, miss_reason="no_match", mode="connected_deny") + try: + await addon.request(flow) + finally: + patcher.stop() + + assert flow.response.status_code == 403 + assert flow.response.content == b"Forbidden by Authsome proxy policy" + log_mock.assert_called_once_with("proxy_deny", host="example.com", reason="no_match") + auth.resolve_credentials.assert_not_called() + + @pytest.mark.asyncio + async def test_addon_denies_no_credentials_with_provider_hint_in_configured_deny(self) -> None: + auth = mock.AsyncMock() + auth.resolve_credentials.side_effect = RuntimeError("no connection for openai") + flow = self._make_flow() + + with patch("authsome.proxy.server.audit.log") as log_mock: + addon, _router, patcher = self._make_addon( + auth, + RouteMatch(provider="openai", connection="default"), + mode="configured_deny", + ) + try: + await addon.request(flow) + finally: + patcher.stop() + + assert flow.response.status_code == 403 + body = flow.response.content.decode("utf-8") + assert "openai" in body + assert "authsome login openai" in body + assert "http://127.0.0.1:7998/ui/apps/openai" in body + log_mock.assert_any_call( + "proxy_no_credentials", + host="api.openai.com", + provider="openai", + connection="default", + ) + log_mock.assert_any_call("proxy_deny", host="api.openai.com", reason="no_credentials") + + @pytest.mark.asyncio + async def test_addon_kills_connect_tunnel_on_deny(self) -> None: + auth = mock.AsyncMock() + flow = self._make_flow(host="example.com", path="/") + flow.request.method = "CONNECT" + + addon, _router, patcher = self._make_addon(auth, None, miss_reason="no_match", mode="connected_deny") + try: + await addon.request(flow) + finally: + patcher.stop() + + flow.kill.assert_called_once() + @pytest.mark.asyncio async def test_addon_continues_on_header_retrieval_failure(self) -> None: auth = mock.AsyncMock()