Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/authsome/auth/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from importlib.metadata import PackageNotFoundError, version
from typing import Literal

from pydantic import BaseModel, Field

Expand All @@ -28,10 +29,22 @@ class EncryptionConfig(BaseModel):
mode: str = "local_key"


class ProxyConfig(BaseModel):
"""Proxy configuration block."""

mode: Literal[
"connected_allow",
"connected_deny",
"configured_allow",
"configured_deny",
] = "connected_allow"


class ServerConfig(BaseModel):
"""Daemon-owned server configuration."""

spec_version: int = Field(default_factory=current_spec_version)
encryption: EncryptionConfig = Field(default_factory=EncryptionConfig)
proxy: ProxyConfig | None = Field(default_factory=ProxyConfig)

model_config = {"extra": "allow"}
88 changes: 52 additions & 36 deletions src/authsome/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,52 +153,68 @@ async def is_local_provider(self, provider: str) -> bool:
val = await self._vault.get(provider, collection="providers")
return val is not None

async def proxy_mode(self) -> str:
"""Return the configured proxy mode (e.g. "connected_allow")."""
config = await self.vault.get_config()
if config.proxy is not None:
return config.proxy.mode
return "connected_allow"

async def proxy_routes(self) -> dict[str, Any]:
"""Build the list of routes for proxy routing."""
from urllib.parse import urlparse
mode = await self.proxy_mode()
scope = mode.split("_", 1)[0]

routes = []
for provider_group in await self.list_connections():
provider_name = provider_group["name"]
selected_connections = provider_group["connections"]
if scope == "connected":
for provider_group in await self.list_connections():
provider_name = provider_group["name"]
selected_connections = provider_group["connections"]

try:
definition = await self.get_provider(provider_name)
except Exception:
continue
if not definition.host_url:
continue
try:
definition = await self.get_provider(provider_name)
except Exception:
continue

# Find the default connection
default_conn = next((c for c in selected_connections if c.get("is_default")), None)
if not default_conn:
continue
if not definition.host_url:
continue

# Find the default connection
default_conn = next((c for c in selected_connections if c.get("is_default")), None)
if not default_conn:
continue

routes.append(self._build_route_entry(definition, default_conn.get("connection_name", "default")))
else: # configured
for definition in await self.list_providers():
if not definition.host_url:
continue
routes.append(self._build_route_entry(definition, "default"))

paths: set[str] = set()
if definition.oauth:
for raw_url in [
definition.oauth.authorization_url,
definition.oauth.token_url,
definition.oauth.revocation_url,
definition.oauth.device_authorization_url,
(definition.registration.registration_endpoint if definition.registration else None),
]:
if not raw_url:
continue
parsed = urlparse(raw_url)
paths.add(parsed.path or "/")

routes.append(
{
"provider": provider_name,
"connection": default_conn.get("connection_name", "default"),
"host_url": definition.host_url,
"auth_endpoint_paths": sorted(list(paths)),
}
)
routes.sort(key=lambda r: (r["host_url"].startswith("regex:"), r["provider"]))
return {"routes": routes}

def _build_route_entry(self, definition: ProviderDefinition, connection_name: str) -> dict[str, Any]:
paths: set[str] = set()
if definition.oauth:
for raw_url in [
definition.oauth.authorization_url,
definition.oauth.token_url,
definition.oauth.revocation_url,
definition.oauth.device_authorization_url,
(definition.registration.registration_endpoint if definition.registration else None),
]:
if not raw_url:
continue
parsed = urlparse(raw_url)
paths.add(parsed.path or "/")
return {
"provider": definition.name,
"connection": connection_name,
"host_url": definition.host_url,
"auth_endpoint_paths": sorted(list(paths)),
}

async def resolve_credentials(self, **kwargs: Any) -> dict[str, Any]:
"""Resolve credentials for a provider/connection pair."""
provider = kwargs["provider"]
Expand Down
5 changes: 5 additions & 0 deletions src/authsome/cli/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ async def proxy_routes(self) -> dict[str, Any]:
"""Return proxy routes from a PoP-protected daemon endpoint."""
return await self._get("/proxy/routes")

async def proxy_mode(self) -> str:
"""Return the configured proxy mode from the daemon."""
data = await self._get("/proxy/mode")
return data["mode"]

async def resolve_credentials(self, **kwargs: Any) -> dict[str, Any]:
"""Resolve proxy credentials from a PoP-protected daemon endpoint."""
return await self._post("/credentials/resolve", kwargs)
Expand Down
59 changes: 58 additions & 1 deletion src/authsome/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import sys
from pathlib import Path
from typing import Any
from typing import Any, Literal, cast

import click
import requests
Expand Down Expand Up @@ -962,6 +962,63 @@ async def ui(ctx_obj: ContextObj, no_browser: bool) -> None:
webbrowser.open(url)


PROXY_MODE_CHOICES = (
"connected_allow",
"connected_deny",
"configured_allow",
"configured_deny",
)


@cli.group(name="proxy")
def proxy() -> None:
"""Manage proxy behavior settings."""


@proxy.command(name="mode")
@click.argument("value", required=False, type=click.Choice(PROXY_MODE_CHOICES))
@auth_command
async def proxy_mode(ctx_obj: ContextObj, value: str | None) -> None:
"""Show or set the persisted proxy mode.

Without arguments, prints the current mode. With VALUE, updates the
persisted GlobalConfig.proxy.mode. Changes take effect on the next
`authsome run` invocation (the proxy reads the mode at startup).
"""
from authsome.auth.models.config import ProxyConfig
from authsome.store.local import LocalAppStore

home = Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome")))
store = LocalAppStore(home)
await store.ensure_initialized()

if value is None:
cfg = await store.get_config()
current = cfg.proxy.mode if cfg.proxy is not None else "connected_allow"
if ctx_obj.json_output:
ctx_obj.print_json({"mode": current})
else:
ctx_obj.echo(current)
return

mode_value = cast(
Literal["connected_allow", "connected_deny", "configured_allow", "configured_deny"],
value,
)
cfg = await store.get_config()
if cfg.proxy is None:
cfg.proxy = ProxyConfig(mode=mode_value)
else:
cfg.proxy.mode = mode_value
await store.save_config(cfg)
logger.info("proxy_mode_set mode={}", mode_value)

if ctx_obj.json_output:
ctx_obj.print_json({"status": "updated", "mode": value})
else:
ctx_obj.echo(f"proxy.mode = {value}", color="green")


@cli.group()
def daemon() -> None:
"""Manage the local Authsome daemon."""
Expand Down
2 changes: 2 additions & 0 deletions src/authsome/proxy/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ async def proxy_routes(self) -> Any: ...

async def list_providers_by_source(self) -> Any: ...

async def proxy_mode(self) -> str: ...


class ProxyRunner:
"""Launch a subprocess behind the Authsome local auth proxy."""
Expand Down
71 changes: 64 additions & 7 deletions src/authsome/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ async def resolve_credentials(self, **kwargs: Any) -> Any: ...

async def proxy_routes(self) -> Any: ...

async def proxy_mode(self) -> str: ...


@dataclass(frozen=True)
class _RouteTarget:
Expand Down Expand Up @@ -334,19 +336,27 @@ class AuthProxyAddon:
def __init__(self, client: ProxyClient) -> None:
self._client = client
self._router: ProxyRouter | None = None
self._mode: str | None = None
self._header_cache: dict[tuple[str, str], _HeaderCacheEntry] = {}
self._header_locks: dict[tuple[str, str], asyncio.Lock] = {}

async def _get_router(self) -> ProxyRouter:
async def _ensure_initialized(self) -> tuple[ProxyRouter, str]:
"""Build the router and read the proxy mode once at router-build time."""
if self._router is None:
self._router = await ProxyRouter.create(self._client)
return self._router
if self._mode is None:
self._mode = await self._client.proxy_mode()
return self._router, self._mode

async def request(self, flow: http.HTTPFlow) -> None:
router = await self._get_router()
router, mode = await self._ensure_initialized()
policy = mode.split("_", 1)[1]

resolution = router.resolve(flow.request.scheme, flow.request.host, flow.request.port, flow.request.path)
if resolution.match is None:
if resolution.miss_reason is not None:
if resolution.miss_reason == "no_match" and policy == "deny":
self._deny_request(flow, "no_match")
elif resolution.miss_reason is not None:
normalized_host = _normalize_host(flow.request.host)
logger.info(
"client_event event=proxy_miss host={} reason={} method={} path={}",
Expand All @@ -355,7 +365,7 @@ async def request(self, flow: http.HTTPFlow) -> None:
flow.request.method,
flow.request.path,
)
logger.error(
logger.debug(
"Proxy miss: host={} reason={} {} {}",
normalized_host,
resolution.miss_reason,
Expand All @@ -367,12 +377,23 @@ async def request(self, flow: http.HTTPFlow) -> None:
match = resolution.match
try:
headers = await self._get_auth_headers(match)
except Exception:
except Exception as exc:
normalized_host = _normalize_host(flow.request.host)
logger.info(
"proxy_no_credentials host={} provider={} connection={}",
normalized_host,
match.provider,
match.connection,
)
logger.warning(
"Failed to retrieve auth headers for provider={} connection={}. Forwarding unchanged.",
"No credentials for provider={} connection={} host={}: {}",
match.provider,
match.connection,
normalized_host,
exc,
)
if policy == "deny":
self._deny_request(flow, "no_credentials", match=match)
return

for key, value in headers.items():
Expand All @@ -387,6 +408,21 @@ async def request(self, flow: http.HTTPFlow) -> None:
flow.request.path,
)

def _deny_request(
self,
flow: http.HTTPFlow,
reason: str,
*,
match: RouteMatch | None = None,
) -> None:
host = _normalize_host(flow.request.host)
logger.info("proxy_deny host={} reason={}", host, reason)
logger.warning("Proxy deny: host={} reason={}", host, reason)
if flow.request.method.upper() == "CONNECT":
flow.kill()
return
flow.response = http.Response.make(403, _deny_body(reason, match).encode("utf-8"))

async def _get_auth_headers(self, match: RouteMatch) -> dict[str, str]:
cache_key = (match.provider, match.connection or "")
now = utc_now()
Expand Down Expand Up @@ -425,6 +461,27 @@ async def _get_auth_headers(self, match: RouteMatch) -> dict[str, str]:
return headers


def _deny_body(reason: str, match: RouteMatch | None) -> str:
"""Build a human-readable 403 body for a denied proxy request.

For `no_credentials` we surface the provider name plus a CLI command
and a dashboard URL so the agent (or human) can recover; other
reasons fall back to a generic message.

The dashboard URL assumes the default local daemon on
`127.0.0.1:7998`. It still requires an active dashboard session
(`authsome ui`) to land on the connect screen directly.
"""
if reason == "no_credentials" and match is not None:
provider = match.provider
return (
f"Forbidden: provider '{provider}' is configured but has no "
f"active connection. Run `authsome login {provider}` to connect, "
f"or visit http://127.0.0.1:7998/ui/apps/{provider}."
)
return "Forbidden by Authsome proxy policy"


def _header_cache_valid(entry: _HeaderCacheEntry, now: datetime) -> bool:
if entry.expires_at is None:
return True
Expand Down
6 changes: 6 additions & 0 deletions src/authsome/server/routes/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from authsome.server.schemas import (
CredentialResolutionRequest,
CredentialResolutionResponse,
ProxyModeResponse,
ProxyRoutesResponse,
)

Expand All @@ -21,6 +22,11 @@ async def proxy_routes(auth: AuthService = Depends(get_protected_auth_service))
return ProxyRoutesResponse.model_validate(data)


@router.get("/proxy/mode", response_model=ProxyModeResponse)
async def proxy_mode(auth: AuthService = Depends(get_protected_auth_service)) -> ProxyModeResponse:
return ProxyModeResponse.model_validate({"mode": await auth.proxy_mode()})


@router.post("/credentials/resolve", response_model=CredentialResolutionResponse)
async def resolve_credentials(
body: CredentialResolutionRequest,
Expand Down
9 changes: 9 additions & 0 deletions src/authsome/server/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,12 @@ class ProviderRoute(BaseModel):

class ProxyRoutesResponse(BaseModel):
routes: list[ProviderRoute]


class ProxyModeResponse(BaseModel):
mode: Literal[
"connected_allow",
"connected_deny",
"configured_allow",
"configured_deny",
]
Loading
Loading