diff --git a/src/authsome/audit/__init__.py b/src/authsome/audit/__init__.py index 9437d01..5be53ca 100644 --- a/src/authsome/audit/__init__.py +++ b/src/authsome/audit/__init__.py @@ -1,57 +1,83 @@ -"""Audit logging for Authsome operations.""" +"""Structured server-side event logging helpers.""" + +from __future__ import annotations import json +import threading +import uuid +from datetime import datetime from pathlib import Path from typing import Any -from loguru import logger +from pydantic import BaseModel, Field from authsome.utils import utc_now -class AuditLogger: - """Append-only structured audit logger.""" +class AuditEvent(BaseModel): + """Structured server-side event record.""" + + event_id: str = Field(default_factory=lambda: f"audit_{uuid.uuid4().hex}") + timestamp: datetime = Field(default_factory=utc_now) + event: str + provider: str | None = None + connection: str | None = None + identity: str | None = None + status: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) - def __init__(self, filepath: Path) -> None: - self.filepath = filepath - def log(self, event_type: str, **kwargs: Any) -> None: - """Write an event to the audit log.""" +_log_path: Path | None = None +_lock = threading.Lock() - # Ensure directory exists - if not self.filepath.parent.exists(): - try: - self.filepath.parent.mkdir(parents=True, exist_ok=True) - except Exception as e: - logger.error("Failed to create audit log directory {}: {}", self.filepath.parent, e) - return - # Filter out None values to keep the log clean - filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} +def _build_event(event_type: str, **kwargs: Any) -> AuditEvent: + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + return AuditEvent( + event=event_type, + provider=filtered_kwargs.pop("provider", None), + connection=filtered_kwargs.pop("connection", None), + identity=filtered_kwargs.pop("identity", None), + status=filtered_kwargs.pop("status", None), + metadata=filtered_kwargs, + ) - entry = { - "timestamp": utc_now().isoformat(), - "event": event_type, - **filtered_kwargs, - } - try: - with open(self.filepath, "a", encoding="utf-8") as f: - f.write(json.dumps(entry) + "\n") - except Exception as e: - logger.error("Failed to write to audit log at {}: {}", self.filepath, e) +def setup(path: Path) -> None: + """Configure the server-side structured log path.""" + global _log_path + path.parent.mkdir(parents=True, exist_ok=True) + if not path.exists(): + path.touch() + _log_path = path -_logger_instance: AuditLogger | None = None +def clear() -> None: + """Clear configured server-side log state.""" + global _log_path + _log_path = None -def setup(filepath: Path) -> None: - """Initialize the global audit logger singleton.""" - global _logger_instance - _logger_instance = AuditLogger(filepath) +def _serialize_event(event: AuditEvent) -> str: + payload = event.model_dump(mode="json") + metadata = payload.pop("metadata", {}) + if isinstance(metadata, dict): + payload.update(metadata) + return json.dumps(payload, separators=(",", ":")) def log(event_type: str, **kwargs: Any) -> None: - """Write an event to the global audit log.""" - if _logger_instance is not None: - _logger_instance.log(event_type, **kwargs) + """Append a structured server event to the configured log file.""" + if _log_path is None: + return + line = _serialize_event(_build_event(event_type, **kwargs)) + with _lock: + _log_path.parent.mkdir(parents=True, exist_ok=True) + with _log_path.open("a", encoding="utf-8") as handle: + handle.write(line) + handle.write("\n") + + +async def alog(event_type: str, **kwargs: Any) -> None: + """Async wrapper around structured server event logging.""" + log(event_type, **kwargs) diff --git a/src/authsome/auth/models/__init__.py b/src/authsome/auth/models/__init__.py index 4dd7976..c85b06e 100644 --- a/src/authsome/auth/models/__init__.py +++ b/src/authsome/auth/models/__init__.py @@ -1,6 +1,6 @@ """auth.models — re-exports all model types used by the auth layer.""" -from authsome.auth.models.config import EncryptionConfig, GlobalConfig +from authsome.auth.models.config import EncryptionConfig, ServerConfig from authsome.auth.models.connection import ( AccountInfo, ConnectionRecord, @@ -32,11 +32,11 @@ "ExportConfig", "ExportFormat", "FlowType", - "GlobalConfig", "OAuthConfig", "ProviderClientRecord", "ProviderDefinition", "ProviderMetadataRecord", "ProviderStateRecord", + "ServerConfig", "Sensitive", ] diff --git a/src/authsome/auth/models/config.py b/src/authsome/auth/models/config.py index 9368974..dd00397 100644 --- a/src/authsome/auth/models/config.py +++ b/src/authsome/auth/models/config.py @@ -1,9 +1,8 @@ -"""Global configuration models.""" +"""Authsome configuration models and helpers.""" from __future__ import annotations from importlib.metadata import PackageNotFoundError, version -from typing import Any from pydantic import BaseModel, Field @@ -24,23 +23,15 @@ def current_spec_version() -> int: class EncryptionConfig(BaseModel): - """ - Encryption configuration block. - - Modes: - - "local_key": master key stored at ~/.authsome/server/master.key - - "keyring": master key stored in the OS keyring - """ + """Vault encryption backend settings for the daemon.""" mode: str = "local_key" -class GlobalConfig(BaseModel): - """Daemon configuration for the local Authsome install.""" +class ServerConfig(BaseModel): + """Daemon-owned server configuration.""" spec_version: int = Field(default_factory=current_spec_version) - encryption: EncryptionConfig | None = Field(default_factory=EncryptionConfig) - - extra_fields: dict[str, Any] = Field(default_factory=dict, exclude=True) + encryption: EncryptionConfig = Field(default_factory=EncryptionConfig) model_config = {"extra": "allow"} diff --git a/src/authsome/auth/service.py b/src/authsome/auth/service.py index 4a41076..9d75015 100644 --- a/src/authsome/auth/service.py +++ b/src/authsome/auth/service.py @@ -231,8 +231,8 @@ async def remove_provider(self, name: str) -> bool: """Remove a custom provider. Returns True if removed.""" return await self._vault.delete(name, collection="providers") - def _iter_registered_identity_handles(self) -> list[str]: - handles = list_registered_identity_handles(self._vault.home) + async def _iter_registered_identity_handles(self) -> list[str]: + handles = await list_registered_identity_handles(self._vault.home) return handles or [self._identity] def _ensure_local_provider_admin_operation_allowed(self, operation: str, provider: str) -> None: @@ -700,7 +700,7 @@ async def logout(self, provider: str, connection: str = "default") -> None: async def revoke(self, provider: str) -> None: self._ensure_local_provider_admin_operation_allowed("revoke", provider) await self.get_provider(provider) - for identity in self._iter_registered_identity_handles(): + for identity in await self._iter_registered_identity_handles(): identity_service = AuthService( vault=self._vault, identity=identity, @@ -894,7 +894,7 @@ async def _get_oauth_token(self, record: ConnectionRecord, provider: str, connec return refreshed.access_token except RefreshFailedError as exc: fallback_available = record.expires_at and now < record.expires_at - audit.log( + await audit.alog( "refresh_failed", provider=provider, connection=connection, diff --git a/src/authsome/auth/sessions.py b/src/authsome/auth/sessions.py index 065ce84..2e1765a 100644 --- a/src/authsome/auth/sessions.py +++ b/src/authsome/auth/sessions.py @@ -50,13 +50,13 @@ def is_expired(self) -> bool: class AuthSessionStore: - """In-memory auth session store for the daemon process.""" + """In-memory auth session state for the daemon process.""" def __init__(self) -> None: self._sessions: dict[str, AuthSession] = {} self._state_index: dict[str, str] = {} - def create( + async def create( self, *, provider: str, @@ -77,37 +77,52 @@ def create( self._sessions[session.session_id] = session return session - def get(self, session_id: str) -> AuthSession: + async def get(self, session_id: str) -> AuthSession: self.cleanup_expired() session = self._sessions.get(session_id) if session is None: raise KeyError(f"Session not found: {session_id}") if session.is_expired: - self.delete(session_id) session.state = AuthSessionStatus.EXPIRED + await self.delete(session_id) raise KeyError(f"Session expired: {session_id}") return session - def delete(self, session_id: str) -> None: + async def save(self, session: AuthSession) -> None: + session.updated_at = utc_now() + self._sessions[session.session_id] = session + oauth_state = session.payload.get("internal_state") + if oauth_state: + self._state_index[str(oauth_state)] = session.session_id + + async def delete(self, session_id: str) -> None: session = self._sessions.pop(session_id, None) if session: oauth_state = session.payload.get("internal_state") if oauth_state: self._state_index.pop(str(oauth_state), None) - def index_oauth_state(self, session: AuthSession) -> None: + async def index_oauth_state(self, session: AuthSession) -> None: oauth_state = session.payload.get("internal_state") if oauth_state: self._state_index[str(oauth_state)] = session.session_id + await self.save(session) - def get_by_oauth_state(self, state: str) -> AuthSession: + async def get_by_oauth_state(self, state: str) -> AuthSession: self.cleanup_expired() session_id = self._state_index.get(state) if session_id is None: raise KeyError(f"Session not found for OAuth state: {state}") - return self.get(session_id) + return await self.get(session_id) def cleanup_expired(self) -> None: expired = [session_id for session_id, session in self._sessions.items() if session.is_expired] for session_id in expired: - self.delete(session_id) + session = self._sessions.get(session_id) + if session is not None: + session.state = AuthSessionStatus.EXPIRED + self._sessions.pop(session_id, None) + if session is not None: + oauth_state = session.payload.get("internal_state") + if oauth_state: + self._state_index.pop(str(oauth_state), None) diff --git a/src/authsome/identity/client_config.py b/src/authsome/cli/client_config.py similarity index 67% rename from src/authsome/identity/client_config.py rename to src/authsome/cli/client_config.py index 1b6c952..9df39f0 100644 --- a/src/authsome/identity/client_config.py +++ b/src/authsome/cli/client_config.py @@ -1,4 +1,4 @@ -"""Client-side identity selection config.""" +"""Caller-local CLI config helpers.""" from __future__ import annotations @@ -6,16 +6,21 @@ from pydantic import BaseModel +from authsome import __version__ +from authsome.paths import get_client_home + class ClientConfig(BaseModel): """Caller-local config that should not live in daemon-owned storage.""" + version: str = __version__ active_identity: str | None = None + proxy_ca_installed: bool = False def client_config_path(home: Path) -> Path: """Return the caller-local config file path.""" - return home / "config.json" + return get_client_home(home) / "config.json" def load_client_config(home: Path) -> ClientConfig: @@ -31,5 +36,6 @@ def load_client_config(home: Path) -> ClientConfig: def save_client_config(home: Path, config: ClientConfig) -> None: """Persist caller-local config.""" - home.mkdir(parents=True, exist_ok=True) - client_config_path(home).write_text(config.model_dump_json(indent=2), encoding="utf-8") + path = client_config_path(home) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(config.model_dump_json(indent=2), encoding="utf-8") diff --git a/src/authsome/cli/context.py b/src/authsome/cli/context.py index bbb9167..9c8aea4 100644 --- a/src/authsome/cli/context.py +++ b/src/authsome/cli/context.py @@ -8,7 +8,6 @@ import click -from authsome import audit from authsome.cli.client import AuthsomeApiClient from authsome.cli.daemon_control import resolve_runtime_client from authsome.proxy.runner import ProxyRunner @@ -25,7 +24,7 @@ async def doctor(self) -> dict[str, Any]: return await self.runtime_client.doctor() def require_local_proxy(self) -> ProxyRunner: - return ProxyRunner(client=self.runtime_client) + return ProxyRunner(client=self.runtime_client, home=self.home) class ContextObj: @@ -40,7 +39,6 @@ def __init__(self, json_output: bool, quiet: bool, no_color: bool): async def initialize(self) -> CliRuntime: if self._ctx is None: self._ctx = CliRuntime(await resolve_runtime_client()) - audit.setup(self._ctx.home / "audit.log") return self._ctx def print_json(self, data: Any) -> None: diff --git a/src/authsome/cli/daemon_control.py b/src/authsome/cli/daemon_control.py index cef0c7c..a6fe6af 100644 --- a/src/authsome/cli/daemon_control.py +++ b/src/authsome/cli/daemon_control.py @@ -9,7 +9,6 @@ import subprocess import sys import time -from pathlib import Path from typing import Any from authsome.cli.client import ( @@ -18,10 +17,11 @@ is_managed_local_daemon_url, resolve_daemon_url, ) +from authsome.paths import get_authsome_home, get_server_home from authsome.server.daemon import DEFAULT_HOST, DEFAULT_PORT -AUTHSOME_HOME = Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome"))) -DAEMON_DIR = AUTHSOME_HOME / "server" / "daemon" +AUTHSOME_HOME = get_authsome_home() +DAEMON_DIR = get_server_home(AUTHSOME_HOME) / "daemon" PID_FILE = DAEMON_DIR / "daemon.pid" LOG_FILE = DAEMON_DIR / "daemon.log" STATE_FILE = DAEMON_DIR / "daemon.json" diff --git a/src/authsome/cli/main.py b/src/authsome/cli/main.py index 754955b..6a5c438 100644 --- a/src/authsome/cli/main.py +++ b/src/authsome/cli/main.py @@ -9,8 +9,9 @@ import click import requests +from loguru import logger -from authsome import AuthenticationFailedError, FlowType, __version__, audit +from authsome import AuthenticationFailedError, FlowType, __version__ from authsome.auth.models.enums import AuthType, ExportFormat from authsome.auth.models.provider import ProviderDefinition from authsome.cli.context import ContextObj, common_options @@ -27,6 +28,7 @@ auth_command, setup_logging, ) +from authsome.paths import get_client_log_path from authsome.utils import connection_is_active, format_error_code, format_expires_at, redact @@ -36,7 +38,7 @@ @click.option( "--log-file", "log_file", - default=str(Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome"))) / "logs" / "authsome.log"), + default=str(get_client_log_path(Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome"))))), show_default=True, metavar="PATH", help="Path for the rotating log file. Pass empty string to disable.", @@ -215,33 +217,29 @@ def render_row(row: dict[str, Any], is_header: bool = False, is_divider: bool = @click.option("-n", "--lines", default=50, metavar="COUNT", help="Number of lines to show.") @auth_command async def log_cmd(ctx_obj: ContextObj, lines: int) -> None: - """View the authsome audit log.""" - actx = await ctx_obj.initialize() - audit_file = actx.home / "audit.log" - if not audit_file.exists(): + """View the local authsome client log.""" + home = Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome"))) + log_path = get_client_log_path(home) + try: + entries = log_path.read_text(encoding="utf-8", errors="replace").splitlines()[-lines:] if ctx_obj.json_output: - ctx_obj.print_json([]) + ctx_obj.print_json({"log_file": str(log_path), "entries": entries}) else: - ctx_obj.echo("No audit log found.", err=True, color="yellow") - sys.exit(0) - - try: - with open(audit_file, encoding="utf-8") as f: - log_lines = f.readlines() - - target_lines = [line.strip() for line in log_lines[-lines:] if line.strip()] - + if not entries: + ctx_obj.echo("No log entries found.", err=True, color="yellow") + sys.exit(0) + for entry in entries: + ctx_obj.echo(entry) + except FileNotFoundError: if ctx_obj.json_output: - parsed_lines = [json_lib.loads(line) for line in target_lines] - ctx_obj.print_json({"lines": parsed_lines}) + ctx_obj.print_json({"log_file": str(log_path), "entries": []}) else: - for line in target_lines: - ctx_obj.echo(line) + ctx_obj.echo("No log entries found.", err=True, color="yellow") except Exception as e: if ctx_obj.json_output: ctx_obj.print_json({"error": str(e)}) else: - ctx_obj.echo(f"Error reading audit log: {e}", err=True, color="red") + ctx_obj.echo(f"Error reading log file: {e}", err=True, color="red") sys.exit(1) @@ -315,11 +313,15 @@ async def login( ) ctx_obj.echo(f"Session ID: {session_id}") - audit.log( - "login", provider=provider, connection=connection, flow=flow or "unknown", status=login_result["status"] + logger.info( + "client_event event=login provider={} connection={} flow={} status={}", + provider, + connection, + flow or "unknown", + login_result["status"], ) except Exception: - audit.log("login", provider=provider, connection=connection, status="failure") + logger.warning("client_event event=login provider={} connection={} status=failure", provider, connection) raise if ctx_obj.json_output: @@ -457,13 +459,12 @@ async def scan(ctx_obj: ContextObj, connection: str, auto_import: bool) -> None: imported += 1 results.append({"provider": provider_name, "status": "imported", "env_var": item["env_var"]}) - audit.log( - "scan", - provider=provider_name, - connection=connection, - source=item["source"], - source_env=item["env_var"], - status="success", + logger.info( + "client_event event=scan provider={} connection={} source={} source_env={} status=success", + provider_name, + connection, + item["source"], + item["env_var"], ) if ctx_obj.json_output: @@ -497,7 +498,7 @@ async def logout(ctx_obj: ContextObj, provider: str, connection: str) -> None: """Log out of the specified PROVIDER connection.""" actx = await ctx_obj.initialize() await actx.runtime_client.logout(provider, connection) - audit.log("logout", provider=provider, connection=connection) + logger.info("client_event event=logout provider={} connection={}", provider, connection) if ctx_obj.json_output: ctx_obj.print_json({"status": "logged_out", "provider": provider, "connection": connection}) @@ -526,7 +527,7 @@ async def revoke(ctx_obj: ContextObj, provider: str) -> None: """Reset and delete all stored connections and secrets for PROVIDER.""" actx = await ctx_obj.initialize() await actx.runtime_client.revoke(provider) - audit.log("revoke", provider=provider, connection="all") + logger.info("client_event event=revoke provider={} connection=all", provider) if ctx_obj.json_output: ctx_obj.print_json({"status": "revoked", "provider": provider}) @@ -541,7 +542,7 @@ async def remove(ctx_obj: ContextObj, provider: str) -> None: """Permanently uninstall the specified custom PROVIDER definition.""" actx = await ctx_obj.initialize() await actx.runtime_client.remove(provider) - audit.log("remove", provider=provider, connection="all") + logger.info("client_event event=remove provider={} connection=all", provider) if ctx_obj.json_output: ctx_obj.print_json({"status": "removed", "provider": provider}) @@ -570,7 +571,12 @@ async def get(ctx_obj: ContextObj, provider: str, connection: str, field: str | if not require_os_auth("reveal secrets"): raise AuthenticationFailedError("Authentication failed or cancelled.") - audit.log("get", provider=provider, connection=connection, field=field or "all") + logger.info( + "client_event event=get provider={} connection={} field={}", + provider, + connection, + field or "all", + ) data = redact(record) if not show_secret else record.model_dump(mode="json") # Decouple from internal schema fields @@ -648,7 +654,12 @@ async def export(ctx_obj: ContextObj, provider: str | None, connection: str, exp actx = await ctx_obj.initialize() fmt = ExportFormat(export_format) output = await actx.runtime_client.export(provider, connection, format=fmt.value) - audit.log("export", provider=provider, connection=connection, format=fmt.value) + logger.info( + "client_event event=export provider={} connection={} format={}", + provider, + connection, + fmt.value, + ) if ctx_obj.json_output: # Call with format=json and parse the result to properly wrap with version info output_str = await actx.runtime_client.export(provider, connection, format="json") @@ -722,7 +733,7 @@ async def register(ctx_obj: ContextObj, path: str, force: bool, yes: bool) -> No await actx.runtime_client.register_provider(definition.model_dump(mode="json"), force=force) endpoints = [ep for _, ep, _ in endpoints_to_check] - audit.log("register", provider=definition.name, endpoints=endpoints) + logger.info("client_event event=register provider={} endpoints={}", definition.name, endpoints) if ctx_obj.json_output: ctx_obj.print_json({"status": "registered", "provider": definition.name}) @@ -821,7 +832,7 @@ async def profile_create(ctx_obj: ContextObj, handle: str | None) -> None: @auth_command async def profile_use(ctx_obj: ContextObj, handle: str) -> None: """Select the active local profile.""" - from authsome.identity import load_client_config, save_client_config + from authsome.cli.client_config import load_client_config, save_client_config from authsome.identity.keys import load_identity home = Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome"))) diff --git a/src/authsome/identity/__init__.py b/src/authsome/identity/__init__.py index 6deddea..633648d 100644 --- a/src/authsome/identity/__init__.py +++ b/src/authsome/identity/__init__.py @@ -4,7 +4,6 @@ from pathlib import Path -from authsome.identity.client_config import ClientConfig, client_config_path, load_client_config, save_client_config from authsome.identity.keys import ( IdentityMetadata, create_identity, @@ -29,8 +28,6 @@ async def current_from_home(home: Path) -> IdentityMetadata: __all__ = [ "IdentityMetadata", - "ClientConfig", - "client_config_path", "create_identity", "current_from_home", "ensure_local_identity", @@ -38,12 +35,10 @@ async def current_from_home(home: Path) -> IdentityMetadata: "identity_exists", "identity_key_path", "identity_metadata_path", - "load_client_config", "load_identity", "load_private_key", "mark_registered", "public_key_from_did_key", "public_key_to_did_key", "remove_legacy_default_identity", - "save_client_config", ] diff --git a/src/authsome/identity/keys.py b/src/authsome/identity/keys.py index 7f60342..4601ed2 100644 --- a/src/authsome/identity/keys.py +++ b/src/authsome/identity/keys.py @@ -13,6 +13,8 @@ from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey from pydantic import BaseModel, Field +from authsome.paths import get_client_home + _ED25519_MULTICODEC_PREFIX = b"\xed\x01" _DID_KEY_PREFIX = "did:key:z" _HANDLE_RE = re.compile(r"^[a-z0-9][a-z0-9-]*[a-z0-9]$") @@ -60,7 +62,7 @@ class IdentityMetadata(BaseModel): def identities_dir(home: Path) -> Path: - return home / "identities" + return get_client_home(home) / "identities" def identity_metadata_path(home: Path, handle: str) -> Path: @@ -145,7 +147,7 @@ def identity_exists(home: Path, handle: str) -> bool: def create_identity(home: Path, handle: str | None = None) -> IdentityMetadata: """Create a local identity and private key, returning existing metadata if present.""" - from authsome.identity.client_config import load_client_config, save_client_config + from authsome.cli.client_config import load_client_config, save_client_config resolved_handle = validate_handle(handle or _unique_handle(home)) if identity_exists(home, resolved_handle): @@ -210,7 +212,7 @@ def ensure_local_identity(home: Path, active_handle: str | None = None) -> Ident re-creation, because the old profile's credentials would become inaccessible with no explanation. """ - from authsome.identity.client_config import load_client_config + from authsome.cli.client_config import load_client_config remove_legacy_default_identity(home) if active_handle is None: diff --git a/src/authsome/identity/registry.py b/src/authsome/identity/registry.py index c7ff143..e262ba4 100644 --- a/src/authsome/identity/registry.py +++ b/src/authsome/identity/registry.py @@ -3,7 +3,6 @@ from __future__ import annotations import json -import os from datetime import UTC, datetime from pathlib import Path @@ -26,57 +25,63 @@ class IdentityRegistrationError(ValueError): class IdentityRegistry: - """JSON-backed authoritative registry for daemon identity handles.""" + """Filesystem-backed authoritative registry for daemon identity handles.""" - def __init__(self, server_home: Path) -> None: - self._path = server_home / "identity_registry.json" + def __init__(self, path: Path) -> None: + self._path = path - def register(self, *, handle: str, did: str) -> IdentityRegistration: + def _load_all(self) -> list[IdentityRegistration]: + try: + raw = json.loads(self._path.read_text(encoding="utf-8")) + except FileNotFoundError: + return [] + if not isinstance(raw, list): + return [] + registrations: list[IdentityRegistration] = [] + for item in raw: + try: + registrations.append(IdentityRegistration.model_validate(item)) + except Exception: + continue + return registrations + + def _save_all(self, registrations: list[IdentityRegistration]) -> None: + self._path.parent.mkdir(parents=True, exist_ok=True) + self._path.write_text( + json.dumps([registration.model_dump(mode="json") for registration in registrations], indent=2), + encoding="utf-8", + ) + + async def register(self, *, handle: str, did: str) -> IdentityRegistration: """Register a handle/DID binding, idempotent only for the same pair.""" handle = validate_handle(handle) public_key_from_did_key(did) - entries = self._load() - existing = entries.get(handle) + registrations = self._load_all() + + existing = next((registration for registration in registrations if registration.handle == handle), None) if existing is not None: - registration = IdentityRegistration.model_validate(existing) - if registration.did == did: - return registration + if existing.did == did: + return existing raise IdentityRegistrationError(f"Identity handle '{handle}' is already registered") - for registered_handle, raw in entries.items(): - registration = IdentityRegistration.model_validate(raw) + for registration in registrations: if registration.did == did: - raise IdentityRegistrationError(f"DID is already registered to identity handle '{registered_handle}'") + raise IdentityRegistrationError(f"DID is already registered to identity handle '{registration.handle}'") now = datetime.now(UTC) registration = IdentityRegistration(handle=handle, did=did, created_at=now, updated_at=now) - entries[handle] = registration.model_dump(mode="json") - self._save(entries) + registrations.append(registration) + self._save_all(registrations) return registration - def resolve(self, handle: str) -> IdentityRegistration | None: - entries = self._load() - raw = entries.get(handle) - if raw is None: - return None - return IdentityRegistration.model_validate(raw) + async def resolve(self, handle: str) -> IdentityRegistration | None: + for registration in self._load_all(): + if registration.handle == handle: + return registration + return None - def list_handles(self) -> list[str]: + async def list_handles(self) -> list[str]: """Return all registered identity handles.""" - return sorted(self._load().keys()) - - def _load(self) -> dict[str, dict[str, object]]: - try: - data = json.loads(self._path.read_text(encoding="utf-8")) - except FileNotFoundError: - return {} - if not isinstance(data, dict): - return {} - return {str(key): value for key, value in data.items() if isinstance(value, dict)} - - def _save(self, entries: dict[str, dict[str, object]]) -> None: - self._path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = self._path.with_suffix(".json.tmp") - tmp_path.write_text(json.dumps(entries, indent=2, sort_keys=True) + "\n", encoding="utf-8") - os.replace(tmp_path, self._path) + registrations = self._load_all() + return sorted(registration.handle for registration in registrations) diff --git a/src/authsome/paths.py b/src/authsome/paths.py new file mode 100644 index 0000000..ee056e9 --- /dev/null +++ b/src/authsome/paths.py @@ -0,0 +1,33 @@ +"""Filesystem layout helpers for Authsome.""" + +from __future__ import annotations + +import os +from pathlib import Path + + +def get_authsome_home(home: Path | None = None) -> Path: + """Return the root Authsome home directory.""" + if home is not None: + return home + return Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome"))) + + +def get_client_home(home: Path | None = None) -> Path: + """Return the client-owned Authsome directory.""" + return get_authsome_home(home) / "client" + + +def get_server_home(home: Path | None = None) -> Path: + """Return the server-owned Authsome directory.""" + return get_authsome_home(home) / "server" + + +def get_client_log_path(home: Path | None = None) -> Path: + """Return the default client log file path.""" + return get_client_home(home) / "logs" / "authsome.log" + + +def get_server_log_path(home: Path | None = None) -> Path: + """Return the default server log file path.""" + return get_server_home(home) / "logs" / "authsome.log" diff --git a/src/authsome/proxy/certs.py b/src/authsome/proxy/certs.py new file mode 100644 index 0000000..f05f7a2 --- /dev/null +++ b/src/authsome/proxy/certs.py @@ -0,0 +1,87 @@ +"""Local certificate trust helpers for the auth proxy.""" + +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + +from loguru import logger + +from authsome.cli.client_config import load_client_config, save_client_config + + +def ensure_local_proxy_ca(home: Path) -> None: + """Ensure local proxy trust setup has run at most once per client home.""" + config = load_client_config(home) + if config.proxy_ca_installed: + return + + if _ensure_macos_keychain_ca(): + config.proxy_ca_installed = True + save_client_config(home, config) + + +def _ensure_macos_keychain_ca() -> bool: + """Ensure the mitmproxy CA is generated and trusted in the macOS login keychain. + + Go's crypto/x509 on macOS uses the native Security framework and ignores + ``SSL_CERT_FILE``, so Go-based tools only trust the proxy CA once it is + added to the login keychain on the local machine. + """ + if sys.platform != "darwin": + return True + + keychain = Path.home() / "Library/Keychains/login.keychain-db" + if not keychain.exists(): + return True + + check = subprocess.run( + ["security", "find-certificate", "-c", "mitmproxy", str(keychain)], + capture_output=True, + text=True, + ) + if check.returncode == 0: + logger.debug("mitmproxy CA already present in macOS login keychain; skipping add") + return True + + confdir = Path.home() / ".mitmproxy" + ca_cert_path = confdir / "mitmproxy-ca-cert.pem" + if not ca_cert_path.exists(): + try: + from mitmproxy.certs import CertStore + + CertStore.from_store(confdir, "mitmproxy", 2048) + logger.debug("Generated mitmproxy CA certificate at {}", ca_cert_path) + except Exception as exc: + logger.debug("Failed to generate mitmproxy CA certificate: {}", exc) + return False + + if not ca_cert_path.exists(): + return False + + result = subprocess.run( + [ + "security", + "add-trusted-cert", + "-d", + "-r", + "trustRoot", + "-k", + str(keychain), + str(ca_cert_path), + ], + capture_output=True, + text=True, + timeout=60, + ) + if result.returncode == 0: + logger.debug("Added mitmproxy CA to macOS login keychain") + return True + + logger.warning( + "Could not add mitmproxy CA to macOS login keychain" + " (Go-based tools like gh/terraform/kubectl may fail with TLS errors): {}", + result.stderr.strip() or result.stdout.strip(), + ) + return False diff --git a/src/authsome/proxy/runner.py b/src/authsome/proxy/runner.py index dbb7088..c6821bb 100644 --- a/src/authsome/proxy/runner.py +++ b/src/authsome/proxy/runner.py @@ -10,6 +10,7 @@ from loguru import logger +from authsome.proxy.certs import ensure_local_proxy_ca from authsome.proxy.server import RunningProxy, start_proxy_server @@ -28,8 +29,9 @@ async def list_providers_by_source(self) -> Any: ... class ProxyRunner: """Launch a subprocess behind the Authsome local auth proxy.""" - def __init__(self, client: ProxyClient) -> None: + def __init__(self, client: ProxyClient, home: Path | None = None) -> None: self._client = client + self._home = home or Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome"))) async def run(self, command: list[str]) -> subprocess.CompletedProcess[str]: """Run *command* behind the auth-injecting proxy.""" @@ -72,6 +74,7 @@ async def run(self, command: list[str]) -> subprocess.CompletedProcess[str]: pass def _start_proxy(self) -> tuple[str, RunningProxy]: + ensure_local_proxy_ca(self._home) server = start_proxy_server(self._client) return server.url, server diff --git a/src/authsome/proxy/server.py b/src/authsome/proxy/server.py index ab0d01d..eb25b11 100644 --- a/src/authsome/proxy/server.py +++ b/src/authsome/proxy/server.py @@ -17,7 +17,6 @@ from mitmproxy.options import Options from mitmproxy.tools.dump import DumpMaster -from authsome import audit from authsome.proxy.router import RouteMatch, RouteResolution from authsome.utils import utc_now @@ -128,6 +127,8 @@ async def _build_routes( if hasattr(client, "proxy_routes"): try: route_data = await client.proxy_routes() + if not isinstance(route_data, dict): + raise TypeError("proxy_routes() must return a dict payload") for route in route_data.get("routes", []): route_match = RouteMatch(provider=route["provider"], connection=route.get("connection")) regex_pattern = _compile_host_regex(route["host_url"]) @@ -347,7 +348,13 @@ async def request(self, flow: http.HTTPFlow) -> None: if resolution.match is None: if resolution.miss_reason is not None: normalized_host = _normalize_host(flow.request.host) - audit.log("proxy_miss", host=normalized_host, reason=resolution.miss_reason) + logger.info( + "client_event event=proxy_miss host={} reason={} method={} path={}", + normalized_host, + resolution.miss_reason, + flow.request.method, + flow.request.path, + ) logger.error( "Proxy miss: host={} reason={} {} {}", normalized_host, @@ -371,13 +378,13 @@ async def request(self, flow: http.HTTPFlow) -> None: for key, value in headers.items(): flow.request.headers[key] = value - audit.log( - "proxy_inject", - provider=match.provider, - connection=match.connection, - host=_normalize_host(flow.request.host), - method=flow.request.method, - path=flow.request.path, + logger.info( + "client_event event=proxy_inject provider={} connection={} host={} method={} path={}", + match.provider, + match.connection, + _normalize_host(flow.request.host), + flow.request.method, + flow.request.path, ) async def _get_auth_headers(self, match: RouteMatch) -> dict[str, str]: diff --git a/src/authsome/server/app.py b/src/authsome/server/app.py index df40b12..3e03669 100644 --- a/src/authsome/server/app.py +++ b/src/authsome/server/app.py @@ -9,16 +9,20 @@ from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles +from authsome import audit from authsome.auth import AuthService from authsome.auth.sessions import AuthSessionStore from authsome.errors import AuthsomeError from authsome.identity.proof import ReplayCache from authsome.identity.registry import IdentityRegistrationError, IdentityRegistry +from authsome.paths import get_server_log_path from authsome.server.dependencies import ( + create_app_store, create_vault, get_deployment_mode, + get_identity_registry_path, get_server_base_url, - get_server_home, + load_server_config, load_ui_session_signing_secret, ) from authsome.server.routes.auth import router as auth_router @@ -34,7 +38,10 @@ @asynccontextmanager async def lifespan(app: FastAPI): """Manage daemon lifecycle.""" - app.state.vault = await create_vault() + app.state.store = await create_app_store() + app.state.server_config = load_server_config(app.state.store.home) + audit.setup(get_server_log_path(app.state.store.home)) + app.state.vault = await create_vault(app.state.store) app.state.auth_service = AuthService( vault=app.state.vault, identity="server", @@ -43,9 +50,11 @@ async def lifespan(app: FastAPI): app.state.auth_sessions = AuthSessionStore() app.state.ui_sessions = UiSessionStore(load_ui_session_signing_secret(app.state.vault.home)) app.state.proof_replay_cache = ReplayCache() - app.state.identity_registry = IdentityRegistry(get_server_home(app.state.vault.home)) + app.state.identity_registry = IdentityRegistry(get_identity_registry_path(app.state.store.home)) app.state.server_base_url = get_server_base_url() yield + audit.clear() + await app.state.store.close() def create_app() -> FastAPI: diff --git a/src/authsome/server/dependencies.py b/src/authsome/server/dependencies.py index d122085..e015869 100644 --- a/src/authsome/server/dependencies.py +++ b/src/authsome/server/dependencies.py @@ -9,9 +9,14 @@ if TYPE_CHECKING: from authsome.auth import AuthService + from authsome.store.interfaces import AppStore +from authsome.auth.models.config import ServerConfig from authsome.identity import current_from_home from authsome.identity.registry import IdentityRegistry +from authsome.paths import get_authsome_home as _get_authsome_home +from authsome.paths import get_server_home as _get_server_home +from authsome.paths import get_server_log_path as _get_server_log_path from authsome.server.urls import build_server_base_url from authsome.store.local import LocalAppStore from authsome.vault import Vault @@ -19,12 +24,27 @@ def get_authsome_home() -> Path: """Return the local Authsome home directory.""" - return Path(os.environ.get("AUTHSOME_HOME", str(Path.home() / ".authsome"))) + return _get_authsome_home() def get_server_home(home: Path | None = None) -> Path: """Return the daemon-owned state directory.""" - return (home or get_authsome_home()) / "server" + return _get_server_home(home) + + +def get_server_config_path(home: Path | None = None) -> Path: + """Return the daemon-owned config file path.""" + return get_server_home(home) / "config.json" + + +def get_server_log_path(home: Path | None = None) -> Path: + """Return the daemon-owned structured log path.""" + return _get_server_log_path(home) + + +def get_identity_registry_path(home: Path | None = None) -> Path: + """Return the daemon-owned identity registry file path.""" + return get_server_home(home) / "identity_registry.json" def get_ui_session_secret_path(home: Path | None = None) -> Path: @@ -43,12 +63,6 @@ def get_deployment_mode() -> str: return "hosted" if mode == "hosted" else "local" -def list_registered_identity_handles(home: Path | None = None) -> list[str]: - """Return identity handles registered with this daemon.""" - registry = IdentityRegistry(get_server_home(home)) - return registry.list_handles() - - def load_ui_session_signing_secret(home: Path | None = None) -> str: """Load or create the hosted UI session signing secret.""" path = get_ui_session_secret_path(home) @@ -68,20 +82,46 @@ async def get_local_ui_identity(home: Path | None = None) -> str: return identity.handle -async def create_vault(home: Path | None = None) -> Vault: - """Create the daemon vault without requiring caller identity files.""" - from authsome import audit +def load_server_config(home: Path | None = None) -> ServerConfig: + """Load daemon-owned server config, defaulting when absent or invalid.""" + path = get_server_config_path(home) + try: + return ServerConfig.model_validate_json(path.read_text(encoding="utf-8")) + except Exception: + config = ServerConfig() + save_server_config(config, home) + return config + + +def save_server_config(config: ServerConfig, home: Path | None = None) -> None: + """Persist daemon-owned server config.""" + path = get_server_config_path(home) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(config.model_dump_json(indent=2), encoding="utf-8") + +async def create_app_store(home: Path | None = None) -> AppStore: + """Create the daemon application store.""" resolved_home = home or get_authsome_home() - audit.setup(resolved_home / "audit.log") + load_server_config(resolved_home) app_store = LocalAppStore(resolved_home) await app_store.ensure_initialized() + return app_store + + +async def list_registered_identity_handles(home: Path | None = None) -> list[str]: + """Return identity handles registered with this daemon.""" + registry = IdentityRegistry(get_identity_registry_path(home)) + return await registry.list_handles() + - config = await app_store.get_config() - crypto_mode = config.encryption.mode if config.encryption else "local_key" +async def create_vault(app_store: AppStore) -> Vault: + """Create the daemon vault from an initialized application store.""" + resolved_home = app_store.home + config = load_server_config(resolved_home) return Vault( app_store=app_store, - crypto_mode=crypto_mode, + crypto_mode=config.encryption.mode, master_key_path=get_server_home(resolved_home) / "master.key", ) @@ -92,5 +132,6 @@ async def create_auth_service(home: Path | None = None, identity: str | None = N if not identity: raise ValueError("create_auth_service requires an explicit identity handle") - vault = await create_vault(home) + store = await create_app_store(home) + vault = await create_vault(store) return AuthService(vault=vault, identity=identity, deployment_mode=get_deployment_mode()) diff --git a/src/authsome/server/routes/_deps.py b/src/authsome/server/routes/_deps.py index ca81a8c..8a6db35 100644 --- a/src/authsome/server/routes/_deps.py +++ b/src/authsome/server/routes/_deps.py @@ -51,7 +51,7 @@ async def get_protected_auth_service(request: Request) -> AuthService: except (ProofValidationError, ValueError) as exc: raise HTTPException(status_code=401, detail=str(exc)) from exc - registration = request.app.state.identity_registry.resolve(claims.subject) + registration = await request.app.state.identity_registry.resolve(claims.subject) if registration is None: raise HTTPException(status_code=401, detail="Unknown identity handle") if registration.did != claims.issuer: diff --git a/src/authsome/server/routes/auth.py b/src/authsome/server/routes/auth.py index 05ddfbf..dfab819 100644 --- a/src/authsome/server/routes/auth.py +++ b/src/authsome/server/routes/auth.py @@ -44,6 +44,14 @@ async def _ensure_browser_session_identity(request: Request, session: AuthSessio return identity == session.identity +async def _load_session_or_404(sessions: AuthSessionStore, session_id: str) -> AuthSession: + """Return an auth session or raise the route-level not-found response.""" + try: + return await sessions.get(session_id) + except KeyError as exc: + raise HTTPException(status_code=404, detail="Authentication session not found") from exc + + @router.post("/sessions", response_model=AuthSessionResponse) async def start_session( body: StartAuthSessionRequest, @@ -54,7 +62,7 @@ async def start_session( ) -> AuthSessionResponse: definition = await auth.get_provider(body.provider) flow = FlowType(body.flow) if body.flow else definition.flow - session = sessions.create( + session = await sessions.create( provider=body.provider, identity=auth.identity, connection_name=body.connection, @@ -77,6 +85,7 @@ async def start_session( ): session.state = AuthSessionStatus.COMPLETED session.status_message = "Already connected" + await sessions.save(session) return _session_response(session, server_base_url) except Exception: pass @@ -85,6 +94,7 @@ async def start_session( if fields: session.state = AuthSessionStatus.WAITING_FOR_USER session.payload["input_fields"] = [_field_to_payload(field) for field in fields] + await sessions.save(session) return _session_response(session, server_base_url) await auth.begin_login_flow( @@ -96,7 +106,7 @@ async def start_session( if FlowType(session.flow_type) == FlowType.DEVICE_CODE: _update_device_code_expiry(sessions, session) background_tasks.add_task(auth.background_resume, session) - sessions.index_oauth_state(session) + await sessions.index_oauth_state(session) return _session_response(session, server_base_url) @@ -107,7 +117,7 @@ async def get_session( sessions: AuthSessionStore = Depends(get_auth_sessions), server_base_url: str = Depends(get_server_base_url), ) -> AuthSessionResponse: - session = sessions.get(session_id) + session = await _load_session_or_404(sessions, session_id) if session.identity != auth.identity: raise HTTPException(status_code=404, detail="Authentication session not found") return _session_response(session, server_base_url) @@ -121,7 +131,7 @@ async def resume_session( sessions: AuthSessionStore = Depends(get_auth_sessions), server_base_url: str = Depends(get_server_base_url), ) -> AuthSessionResponse: - session = sessions.get(session_id) + session = await _load_session_or_404(sessions, session_id) if session.identity != auth.identity: raise HTTPException(status_code=404, detail="Authentication session not found") try: @@ -131,9 +141,11 @@ async def resume_session( else: session.state = AuthSessionStatus.COMPLETED session.status_message = "Login successful" + await sessions.save(session) except Exception as exc: session.state = AuthSessionStatus.FAILED session.error_message = str(exc) + await sessions.save(session) raise return _session_response(session, server_base_url) @@ -147,7 +159,7 @@ async def oauth_callback( if not state: return HTMLResponse(pages.message_page("Authentication failed", "Missing OAuth state."), status_code=400) try: - session = sessions.get_by_oauth_state(state) + session = await sessions.get_by_oauth_state(state) except KeyError: return HTMLResponse( pages.message_page("Authentication session expired", "Please run authsome login again."), @@ -164,9 +176,11 @@ async def oauth_callback( await auth.resume_login_flow(session, callback_data) session.state = AuthSessionStatus.COMPLETED session.status_message = "Login successful" + await sessions.save(session) except Exception as exc: session.state = AuthSessionStatus.FAILED session.error_message = str(exc) + await sessions.save(session) return HTMLResponse(pages.message_page("Authentication failed", str(exc)), status_code=400) if return_url := session.payload.get("return_url"): return RedirectResponse(str(return_url), status_code=303) @@ -181,7 +195,7 @@ async def input_page( server_base_url: str = Depends(get_server_base_url), ) -> HTMLResponse: try: - session = sessions.get(session_id) + session = await sessions.get(session_id) except KeyError: return HTMLResponse( pages.message_page("Authentication session expired", "Please run authsome login again."), @@ -218,7 +232,7 @@ async def device_page( sessions: AuthSessionStore = Depends(get_auth_sessions), ) -> HTMLResponse: try: - session = sessions.get(session_id) + session = await sessions.get(session_id) except KeyError: return HTMLResponse( pages.message_page("Authentication session expired", "Please run authsome login again."), @@ -251,7 +265,13 @@ async def submit_input( sessions: AuthSessionStore = Depends(get_auth_sessions), server_base_url: str = Depends(get_server_base_url), ): - session = sessions.get(session_id) + try: + session = await sessions.get(session_id) + except KeyError: + return HTMLResponse( + pages.message_page("Authentication session expired", "Please run authsome login again."), + status_code=404, + ) if not await _ensure_browser_session_identity(request, session): return HTMLResponse( pages.message_page("Dashboard session expired", "Run 'authsome ui' to reopen the hosted dashboard."), @@ -268,6 +288,7 @@ async def submit_input( await auth.resume_login_flow(session, {}) session.state = AuthSessionStatus.COMPLETED session.status_message = "Login successful" + await sessions.save(session) if return_url := session.payload.get("return_url"): return RedirectResponse(str(return_url), status_code=303) return HTMLResponse(pages.message_page("Authentication successful", "You can close this window.")) @@ -283,13 +304,16 @@ async def submit_input( _update_device_code_expiry(sessions, session) background_tasks.add_task(auth.background_resume, session) if session.payload.get("user_code") and session.payload.get("verification_uri"): + await sessions.save(session) return RedirectResponse(url=build_device_url(server_base_url, session.session_id), status_code=303) - sessions.index_oauth_state(session) + await sessions.index_oauth_state(session) auth_url = session.payload.get("auth_url") if auth_url: + await sessions.save(session) return RedirectResponse(str(auth_url), status_code=303) + await sessions.save(session) return HTMLResponse(pages.message_page("Authentication started", "Return to your terminal to continue.")) diff --git a/src/authsome/server/routes/health.py b/src/authsome/server/routes/health.py index f7e01de..3232e84 100644 --- a/src/authsome/server/routes/health.py +++ b/src/authsome/server/routes/health.py @@ -29,25 +29,9 @@ async def ready(auth: AuthService = Depends(get_auth_service)) -> ReadyResponse: issues: list[str] = [] warnings: list[str] = [] - # 1. Config & Schema Version Check - try: - config = await auth.vault.get_config() - checks["config"] = "ok" - - expected_spec_version = current_spec_version() - if getattr(config, "spec_version", None) != expected_spec_version: - issues.append( - f"config: spec_version mismatch (got {config.spec_version}, expected {expected_spec_version})" - ) - checks["version_compatibility"] = "failed" - else: - checks["version_compatibility"] = "ok" - except Exception as exc: - checks["config"] = "failed" - checks["version_compatibility"] = "failed" - issues.append(f"config: {exc}") + checks["spec_version"] = str(current_spec_version()) - # 2. Active Identity Check + # 1. Active Identity Check try: await auth.get_identity(auth.identity) checks["identity"] = "ok" @@ -55,7 +39,7 @@ async def ready(auth: AuthService = Depends(get_auth_service)) -> ReadyResponse: checks["identity"] = "failed" issues.append(f"identity: {exc}") - # 3. Providers List Check + # 2. Providers List Check try: await auth.list_providers() checks["providers"] = "ok" @@ -63,7 +47,7 @@ async def ready(auth: AuthService = Depends(get_auth_service)) -> ReadyResponse: checks["providers"] = "failed" issues.append(f"providers: {exc}") - # 4. Connected Providers Check + # 3. Connected Providers Check try: conn_list = await auth.list_connections() checks["connections"] = "ok" @@ -74,7 +58,7 @@ async def ready(auth: AuthService = Depends(get_auth_service)) -> ReadyResponse: checks["connections"] = "failed" issues.append(f"connections: {exc}") - # 5. Vault Roundtrip & Store Integrity Check + # 4. Vault Roundtrip & Store Integrity Check try: await auth.vault.put("__ready_test__", "ok", collection=f"vault:{auth.identity}") value = await auth.vault.get("__ready_test__", collection=f"vault:{auth.identity}") @@ -109,8 +93,7 @@ async def whoami( auth: AuthService = Depends(get_protected_auth_service), server_base_url: str = Depends(get_server_base_url), ) -> dict[str, str]: - config = await auth.vault.get_config() - enc_mode = config.encryption.mode if config.encryption else "local_key" + enc_mode = request.app.state.server_config.encryption.mode if enc_mode == "local_key": enc_desc = f"Local Key ({auth.vault.home / 'server' / 'master.key'})" elif enc_mode == "keyring": diff --git a/src/authsome/server/routes/identities.py b/src/authsome/server/routes/identities.py index c9fc9b2..4858b95 100644 --- a/src/authsome/server/routes/identities.py +++ b/src/authsome/server/routes/identities.py @@ -18,7 +18,7 @@ class RegisterIdentityRequest(BaseModel): @router.post("/register") async def register_identity(body: RegisterIdentityRequest, request: Request) -> dict[str, str]: try: - registration = request.app.state.identity_registry.register(handle=body.handle, did=body.did) + registration = await request.app.state.identity_registry.register(handle=body.handle, did=body.did) except IdentityRegistrationError: raise except ValueError as exc: diff --git a/src/authsome/server/routes/ui.py b/src/authsome/server/routes/ui.py index 447e74d..859904d 100644 --- a/src/authsome/server/routes/ui.py +++ b/src/authsome/server/routes/ui.py @@ -387,7 +387,7 @@ async def connect_app( definition = await auth.get_provider(provider_name) flow = definition.flow - session = sessions.create( + session = await sessions.create( provider=provider_name, identity=auth.identity, connection_name=connection_name, @@ -404,6 +404,7 @@ async def connect_app( existing = await auth.get_connection(provider_name, connection_name) if auth._connection_is_valid(existing): session.status_message = "Already connected" + await sessions.save(session) return _redirect(request, f"/ui/apps/{provider_name}") except Exception: pass @@ -411,6 +412,7 @@ async def connect_app( fields = await auth.get_required_inputs(session) if fields: session.payload["input_fields"] = [field.model_dump(mode="json", exclude_none=True) for field in fields] + await sessions.save(session) return _redirect(request, build_auth_input_url(server_base_url, session.session_id)) await auth.begin_login_flow(session=session, force=force) @@ -418,12 +420,15 @@ async def connect_app( _update_device_code_expiry(sessions, session) background_tasks.add_task(auth.background_resume, session) if session.payload.get("user_code") and session.payload.get("verification_uri"): + await sessions.save(session) return _redirect(request, build_device_url(server_base_url, session.session_id)) - sessions.index_oauth_state(session) + await sessions.index_oauth_state(session) auth_url = session.payload.get("auth_url") if auth_url: + await sessions.save(session) return _redirect(request, str(auth_url)) + await sessions.save(session) return _redirect(request, f"/ui/apps/{provider_name}") diff --git a/src/authsome/store/interfaces.py b/src/authsome/store/interfaces.py index e9a1d84..0f2725b 100644 --- a/src/authsome/store/interfaces.py +++ b/src/authsome/store/interfaces.py @@ -1,24 +1,15 @@ -"""Unified storage interfaces for Authsome. - -The AppStore handles bootstrapping (config and initialization) and -exposes the underlying async KV backend for the Vault to wrap with -encryption. -""" +"""Unified storage interfaces for Authsome.""" from __future__ import annotations from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING from key_value.aio.protocols.key_value import AsyncKeyValue -if TYPE_CHECKING: - from authsome.auth.models.config import GlobalConfig - class AppStore(ABC): - """Storage backend — config + raw async KV access.""" + """Storage backend for the encrypted vault KV.""" @property @abstractmethod @@ -49,18 +40,6 @@ async def check_integrity(self) -> bool: """Perform a health check on the storage medium.""" ... - # ── Config (unencrypted — needed before crypto is available) ────────── - - @abstractmethod - async def get_config(self) -> GlobalConfig: - """Get global configuration.""" - ... - - @abstractmethod - async def save_config(self, config: GlobalConfig) -> None: - """Save global configuration.""" - ... - @abstractmethod async def close(self) -> None: """Close all underlying storage connections.""" diff --git a/src/authsome/store/local.py b/src/authsome/store/local.py index b28658a..3d4be9b 100644 --- a/src/authsome/store/local.py +++ b/src/authsome/store/local.py @@ -2,31 +2,24 @@ from __future__ import annotations -import json -import subprocess -import sys from pathlib import Path from key_value.aio.protocols.key_value import AsyncKeyValue from key_value.aio.stores.disk import DiskStore -from loguru import logger -from authsome.auth.models.config import GlobalConfig +from authsome.paths import get_server_home from authsome.store.interfaces import AppStore +_CONFIG_COLLECTION = "config" -class LocalAppStore(AppStore): - """Disk-backed AppStore using py-key-value-aio's DiskStore. - All data lives inside a single ``kv_store/`` directory managed by - diskcache. Swapping to a remote backend (e.g. PostgresStore) - requires only replacing the DiskStore constructor call. - """ +class LocalAppStore(AppStore): + """Disk-backed AppStore for the daemon vault KV.""" def __init__(self, home_dir: Path) -> None: self._home = home_dir self._home.mkdir(parents=True, exist_ok=True) - self._server_home = self._home / "server" + self._server_home = get_server_home(home_dir) self._server_home.mkdir(parents=True, exist_ok=True) self._store = DiskStore(directory=str(self._server_home / "kv_store")) @@ -45,13 +38,9 @@ def kv(self) -> AsyncKeyValue: # ── Initialization ──────────────────────────────────────────────────── async def ensure_initialized(self) -> None: - _ensure_macos_keychain_ca() - if await self._store.get("version", collection="config") is not None: - config = await self.get_config() - await self.save_config(config) + if await self._store.get("version", collection=_CONFIG_COLLECTION) is not None: return - await self._store.put("version", {"data": "1"}, collection="config") - await self.save_config(GlobalConfig()) + await self._store.put("version", {"data": "1"}, collection=_CONFIG_COLLECTION) async def is_healthy(self) -> bool: return True @@ -59,92 +48,7 @@ async def is_healthy(self) -> bool: async def check_integrity(self) -> bool: return True - # ── Config (unencrypted) ────────────────────────────────────────────── - - async def get_config(self) -> GlobalConfig: - val = await self._store.get("global", collection="config") - if not val: - return GlobalConfig() - try: - return GlobalConfig.model_validate_json(val["data"]) - except Exception as exc: - logger.warning("Failed to parse config, using defaults: {}", exc) - return GlobalConfig() - - async def save_config(self, config: GlobalConfig) -> None: - data = config.model_dump(mode="json") - await self._store.put("global", {"data": json.dumps(data, indent=2)}, collection="config") - async def close(self) -> None: - pass - - -def _ensure_macos_keychain_ca() -> None: - """Ensure the mitmproxy CA is generated and trusted in the macOS login keychain. - - Go's crypto/x509 on macOS uses the native Security framework and - ignores SSL_CERT_FILE, so the only reliable way to make Go binaries - (gh, terraform, kubectl, …) trust the mitmproxy CA is to add it to - the login keychain directly. - - The certificate is added persistently once to avoid repeated OS password - prompts. It will skip addition on subsequent calls if already present. - """ - if sys.platform != "darwin": - return - - keychain = Path.home() / "Library/Keychains/login.keychain-db" - if not keychain.exists(): - return - - # Avoid double-adding: if the user already has a cert with CN=mitmproxy - # in their keychain (e.g. from a manual mitmproxy install), don't touch it. - check = subprocess.run( - ["security", "find-certificate", "-c", "mitmproxy", str(keychain)], - capture_output=True, - text=True, - ) - if check.returncode == 0: - logger.debug("mitmproxy CA already present in macOS login keychain; skipping add") - return - - # Ensure CA certificate is generated so we can register it - confdir = Path.home() / ".mitmproxy" - ca_cert_path = confdir / "mitmproxy-ca-cert.pem" - if not ca_cert_path.exists(): - try: - from mitmproxy.certs import CertStore - - CertStore.from_store(confdir, "mitmproxy", 2048) - logger.debug("Generated mitmproxy CA certificate at {}", ca_cert_path) - except Exception as e: - logger.debug("Failed to generate mitmproxy CA certificate: {}", e) - return - - if not ca_cert_path.exists(): - return - - result = subprocess.run( - [ - "security", - "add-trusted-cert", - "-d", - "-r", - "trustRoot", - "-k", - str(keychain), - str(ca_cert_path), - ], - capture_output=True, - text=True, - timeout=60, - ) - if result.returncode == 0: - logger.debug("Added mitmproxy CA to macOS login keychain") - return - - logger.warning( - "Could not add mitmproxy CA to macOS login keychain" - " (Go-based tools like gh/terraform/kubectl may fail with TLS errors): {}", - result.stderr.strip() or result.stdout.strip(), - ) + close = getattr(self._store, "close", None) + if callable(close): + await close() diff --git a/src/authsome/vault/__init__.py b/src/authsome/vault/__init__.py index 0ba159e..3279d25 100644 --- a/src/authsome/vault/__init__.py +++ b/src/authsome/vault/__init__.py @@ -1,12 +1,4 @@ -"""Vault — encrypted key-value layer over AppStore. - -The Vault wraps an AppStore's async KV backend and encrypts every value -before writing and decrypts after reading. It uses collections to -logically separate different data types (profiles, providers, credentials). - -The Vault knows nothing about credential types, profiles, or providers. -All key schema decisions belong to the caller (AuthService). -""" +"""Vault — encrypted key-value layer over AppStore.""" from __future__ import annotations @@ -18,7 +10,6 @@ from authsome.store.interfaces import AppStore if TYPE_CHECKING: - from authsome.auth.models.config import GlobalConfig from authsome.vault.crypto import VaultCrypto @@ -27,8 +18,6 @@ class Vault: All values are encrypted at rest using AES-256-GCM. The master key is managed by the configured VaultCrypto backend (local file or OS keyring). - - Config is delegated unencrypted (needed before crypto is available). """ def __init__( @@ -56,14 +45,6 @@ def home(self) -> Path: """Base directory for the storage system.""" return self._app_store.home - # ── Config (delegated, unencrypted — bootstrap dependency) ──────────── - - async def get_config(self) -> GlobalConfig: - return await self._app_store.get_config() - - async def save_config(self, config: GlobalConfig) -> None: - await self._app_store.save_config(config) - # ── Index helpers ───────────────────────────────────────────────────── async def _get_index(self, collection: str) -> builtins.list[str]: diff --git a/tests/auth/test_models.py b/tests/auth/test_models.py index 1046e0a..35b5fd9 100644 --- a/tests/auth/test_models.py +++ b/tests/auth/test_models.py @@ -1,6 +1,6 @@ """Tests for authsome data models.""" -from authsome.auth.models.config import GlobalConfig, current_spec_version +from authsome.auth.models.config import current_spec_version from authsome.auth.models.connection import ( ConnectionRecord, ProviderClientRecord, @@ -35,25 +35,11 @@ def test_export_format_values(self) -> None: assert ExportFormat.JSON.value == "json" -class TestGlobalConfig: - """Global config model tests.""" +class TestSpecVersion: + """Spec version helper tests.""" - def test_defaults(self) -> None: - config = GlobalConfig() - assert config.spec_version == current_spec_version() - assert config.encryption is not None - assert config.encryption.mode == "local_key" - - def test_json_roundtrip(self) -> None: - config = GlobalConfig(spec_version=1) - json_str = config.model_dump_json() - restored = GlobalConfig.model_validate_json(json_str) - assert restored.spec_version == 1 - - def test_extra_fields_preserved(self) -> None: - config = GlobalConfig.model_validate({"spec_version": 1, "custom": "val"}) - dumped = config.model_dump() - assert dumped.get("custom") == "val" + def test_current_spec_version_is_int(self) -> None: + assert isinstance(current_spec_version(), int) class TestIdentityMetadata: diff --git a/tests/auth/test_service.py b/tests/auth/test_service.py index 577e8dd..99fa4be 100644 --- a/tests/auth/test_service.py +++ b/tests/auth/test_service.py @@ -23,7 +23,8 @@ class TestAuthServiceRefreshLogs: def audit_log(self, tmp_path: Path) -> Path: log_file = tmp_path / "audit.log" audit.setup(log_file) - return log_file + yield log_file + audit.clear() @pytest.fixture def service(self) -> AuthService: diff --git a/tests/auth/test_service_provider_clients.py b/tests/auth/test_service_provider_clients.py index 6f6b3be..13dd645 100644 --- a/tests/auth/test_service_provider_clients.py +++ b/tests/auth/test_service_provider_clients.py @@ -15,7 +15,7 @@ from authsome.errors import OperationNotAllowedError from authsome.identity.keys import create_identity from authsome.identity.registry import IdentityRegistry -from authsome.server.dependencies import create_vault, get_server_home +from authsome.server.dependencies import create_app_store, create_vault, get_identity_registry_path from authsome.utils import build_store_key @@ -250,11 +250,12 @@ async def test_hosted_resume_login_flow_rejects_dcr_client_persistence() -> None async def test_revoke_local_deletes_shared_client_and_all_identity_connections(tmp_path) -> None: first_identity = create_identity(tmp_path, "steady-wisely-boldly-0042") second_identity = create_identity(tmp_path, "rapid-brightly-firmly-0007") - registry = IdentityRegistry(get_server_home(tmp_path)) - registry.register(handle=first_identity.handle, did=first_identity.did) - registry.register(handle=second_identity.handle, did=second_identity.did) + store = await create_app_store(tmp_path) + registry = IdentityRegistry(get_identity_registry_path(tmp_path)) + await registry.register(handle=first_identity.handle, did=first_identity.did) + await registry.register(handle=second_identity.handle, did=second_identity.did) - vault = await create_vault(tmp_path) + vault = await create_vault(store) try: service = AuthService(vault, identity="steady-wisely-boldly-0042", deployment_mode="local") @@ -368,3 +369,4 @@ async def test_revoke_local_deletes_shared_client_and_all_identity_connections(t ) finally: await vault.close() + await store.close() diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index d4da104..9bee07c 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -57,8 +57,7 @@ def mock_client() -> AsyncMock: def _patch_runtime(mock_client: AsyncMock, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: """Replace resolve_runtime_client so CLI commands get the mock client. - Also patch audit.setup and audit.log to prevent real file writes, - and redirect AUTHSOME_HOME to a temporary directory. + Redirect AUTHSOME_HOME to a temporary directory. """ monkeypatch.setenv("AUTHSOME_HOME", str(tmp_path)) @@ -69,11 +68,8 @@ def _patch_runtime(mock_client: AsyncMock, monkeypatch: pytest.MonkeyPatch, tmp_ monkeypatch.setattr(dc, "resolve_runtime_client", mock.AsyncMock(return_value=mock_client)) import authsome.cli.context as context_mod - import authsome.cli.main as main_mod monkeypatch.setattr(context_mod, "resolve_runtime_client", mock.AsyncMock(return_value=mock_client)) - monkeypatch.setattr(main_mod.audit, "setup", lambda *a, **kw: None) - monkeypatch.setattr(main_mod.audit, "log", lambda *a, **kw: None) import webbrowser diff --git a/tests/cli/test_client_signing.py b/tests/cli/test_client_signing.py index cb1a26d..ccbf9d0 100644 --- a/tests/cli/test_client_signing.py +++ b/tests/cli/test_client_signing.py @@ -5,7 +5,7 @@ import pytest from authsome.cli.client import AuthsomeApiClient -from authsome.identity.client_config import ClientConfig, load_client_config, save_client_config +from authsome.cli.client_config import ClientConfig, load_client_config, save_client_config from authsome.identity.keys import create_identity, mark_registered diff --git a/tests/cli/test_identity.py b/tests/cli/test_identity.py index e30bdd3..60bc751 100644 --- a/tests/cli/test_identity.py +++ b/tests/cli/test_identity.py @@ -8,8 +8,8 @@ from click.testing import CliRunner +from authsome.cli.client_config import load_client_config from authsome.cli.main import cli -from authsome.identity.client_config import load_client_config from authsome.identity.keys import load_identity diff --git a/tests/cli/test_import_env.py b/tests/cli/test_import_env.py index d768f63..d35f467 100644 --- a/tests/cli/test_import_env.py +++ b/tests/cli/test_import_env.py @@ -86,6 +86,7 @@ def test_scan_prompts_and_skips_import_when_declined( } monkeypatch.setenv("OPENAI_API_KEY", "sk-test-value") monkeypatch.setattr("authsome.cli.main.click.confirm", lambda *args, **kwargs: False) + mock_client.get_connection.return_value = {} result = runner.invoke(cli, ["--log-file", "", "scan"]) assert result.exit_code == 0, result.output diff --git a/tests/cli/test_init.py b/tests/cli/test_init.py index 9a98bab..ce5b29c 100644 --- a/tests/cli/test_init.py +++ b/tests/cli/test_init.py @@ -9,9 +9,10 @@ from click.testing import CliRunner +from authsome import __version__ +from authsome.cli.client_config import ClientConfig, load_client_config, save_client_config from authsome.cli.main import cli -from authsome.identity import mark_registered, save_client_config -from authsome.identity.client_config import ClientConfig, load_client_config +from authsome.identity import mark_registered from authsome.identity.keys import create_identity from authsome.store.local import LocalAppStore @@ -21,7 +22,7 @@ def test_init_removes_legacy_default_state_and_registers_identity( mock_client: MagicMock, tmp_path: Path, ) -> None: - identities = tmp_path / "identities" + identities = tmp_path / "client" / "identities" identities.mkdir(parents=True) (identities / "default.json").write_text("{}", encoding="utf-8") (identities / "default.key").write_text("legacy\n", encoding="utf-8") @@ -39,6 +40,7 @@ def test_init_removes_legacy_default_state_and_registers_identity( mock_client.register_identity.assert_called_once_with(data["profile"], data["did"]) config_data = load_client_config(tmp_path) + assert config_data.version == __version__ assert config_data.active_identity == data["profile"] diff --git a/tests/cli/test_login.py b/tests/cli/test_login.py index ba9caff..e2c509b 100644 --- a/tests/cli/test_login.py +++ b/tests/cli/test_login.py @@ -1,7 +1,7 @@ """Tests for `authsome login`. Covers: session started path, session already completed, --force flag, -JSON output shape, and that audit.log is called. +JSON output shape. """ import json diff --git a/tests/common/test_audit.py b/tests/common/test_audit.py index a325579..8108fa5 100644 --- a/tests/common/test_audit.py +++ b/tests/common/test_audit.py @@ -1,61 +1,28 @@ -"""Tests for the AuditLogger.""" +"""Tests for daemon audit event models.""" -import json -from pathlib import Path +from authsome.audit import AuditEvent -from authsome import audit +def test_audit_event_captures_known_fields() -> None: + event = AuditEvent( + event="login", + provider="github", + connection="default", + identity="steady-wisely-boldly-0042", + status="success", + ) -def test_audit_logger_initialization(tmp_path: Path): - filepath = tmp_path / "audit.log" - audit.setup(filepath) - assert audit._logger_instance is not None - assert audit._logger_instance.filepath == filepath + assert event.event == "login" + assert event.provider == "github" + assert event.connection == "default" + assert event.identity == "steady-wisely-boldly-0042" + assert event.status == "success" -def test_audit_logger_writes_json_line(tmp_path: Path): - filepath = tmp_path / "audit.log" - audit.setup(filepath) - audit.log("test_event", provider="test_provider", status="success") +def test_audit_event_metadata_defaults_to_empty_mapping() -> None: + event = AuditEvent(event="proxy_miss") - assert filepath.exists() - lines = filepath.read_text(encoding="utf-8").strip().split("\n") - assert len(lines) == 1 - - event_data = json.loads(lines[0]) - assert "timestamp" in event_data - assert event_data["event"] == "test_event" - assert event_data["provider"] == "test_provider" - assert event_data["status"] == "success" - - -def test_audit_logger_filters_none_values(tmp_path: Path): - filepath = tmp_path / "audit.log" - audit.setup(filepath) - audit.log("test_event", provider="test_provider", missing=None) - - lines = filepath.read_text(encoding="utf-8").strip().split("\n") - event_data = json.loads(lines[0]) - assert "provider" in event_data - assert "missing" not in event_data - - -def test_audit_logger_creates_parent_directory(tmp_path: Path): - filepath = tmp_path / "nested" / "dir" / "audit.log" - audit.setup(filepath) - audit.log("test_event") - - assert filepath.exists() - assert filepath.parent.exists() - - -def test_audit_logger_graceful_failure(tmp_path: Path, monkeypatch): - filepath = tmp_path / "audit.log" - audit.setup(filepath) - - def mock_open(*args, **kwargs): - raise OSError("Permission denied") - - monkeypatch.setattr("builtins.open", mock_open) - # This should not raise an exception - audit.log("test_event") + assert event.metadata == {} + payload = event.model_dump(mode="json") + assert payload["event"] == "proxy_miss" + assert "timestamp" in payload diff --git a/tests/identity/test_identity.py b/tests/identity/test_identity.py index 00464f7..17109ab 100644 --- a/tests/identity/test_identity.py +++ b/tests/identity/test_identity.py @@ -2,8 +2,8 @@ import pytest -from authsome.identity import current_from_home, load_client_config, save_client_config -from authsome.identity.client_config import ClientConfig +from authsome.cli.client_config import ClientConfig, load_client_config, save_client_config +from authsome.identity import current_from_home from authsome.identity.keys import ( create_identity, ensure_local_identity, diff --git a/tests/proxy/test_proxy.py b/tests/proxy/test_proxy.py index 1ebce83..16baa10 100644 --- a/tests/proxy/test_proxy.py +++ b/tests/proxy/test_proxy.py @@ -480,22 +480,13 @@ async def test_addon_injects_headers_for_matched_request(self) -> None: flow = self._make_flow() auth.resolve_credentials.return_value = {"headers": {"Authorization": "Bearer sk-test"}, "expires_at": None} - with patch("authsome.proxy.server.audit.log") as log_mock: - addon, _router, patcher = self._make_addon(auth, RouteMatch(provider="openai", connection="default")) - try: - await addon.request(flow) - finally: - patcher.stop() + addon, _router, patcher = self._make_addon(auth, RouteMatch(provider="openai", connection="default")) + try: + await addon.request(flow) + finally: + patcher.stop() assert flow.request.headers["Authorization"] == "Bearer sk-test" - log_mock.assert_any_call( - "proxy_inject", - provider="openai", - connection="default", - host="api.openai.com", - method="GET", - path="/v1/responses", - ) @pytest.mark.asyncio async def test_addon_overwrites_existing_authorization_header(self) -> None: @@ -516,15 +507,13 @@ async def test_addon_skips_unmatched_request(self) -> None: auth = mock.AsyncMock() flow = self._make_flow(host="example.com", path="/") - with patch("authsome.proxy.server.audit.log") as log_mock: - addon, _router, patcher = self._make_addon(auth, None, miss_reason="no_match") - try: - await addon.request(flow) - finally: - patcher.stop() + addon, _router, patcher = self._make_addon(auth, None, miss_reason="no_match") + try: + await addon.request(flow) + finally: + patcher.stop() auth.resolve_credentials.assert_not_called() - log_mock.assert_called_once_with("proxy_miss", host="example.com", reason="no_match") @pytest.mark.asyncio async def test_addon_continues_on_header_retrieval_failure(self) -> None: @@ -565,15 +554,17 @@ async def test_runner_sets_proxy_environment(self, tmp_path: Path) -> None: from authsome.proxy.runner import ProxyRunner auth = await _make_auth(tmp_path) - runner = ProxyRunner(auth) + runner = ProxyRunner(auth, home=tmp_path) with patch("authsome.proxy.runner.subprocess.run") as run_mock: run_mock.return_value.returncode = 0 - with patch.object(runner, "_start_proxy", return_value=("http://127.0.0.1:8899", mock.Mock())): - with patch.object(runner, "_build_ca_bundle", return_value=Path("/tmp/fake-ca.pem")): - await runner.run(["python", "-c", "print('ok')"]) + with patch("authsome.proxy.runner.ensure_local_proxy_ca") as ensure_ca_mock: + with patch.object(runner, "_start_proxy", return_value=("http://127.0.0.1:8899", mock.Mock())): + with patch.object(runner, "_build_ca_bundle", return_value=Path("/tmp/fake-ca.pem")): + await runner.run(["python", "-c", "print('ok')"]) env = run_mock.call_args.kwargs["env"] + ensure_ca_mock.assert_not_called() assert env["HTTP_PROXY"] == "http://127.0.0.1:8899" assert env["HTTPS_PROXY"] == "http://127.0.0.1:8899" assert env["http_proxy"] == "http://127.0.0.1:8899" @@ -591,7 +582,7 @@ async def test_runner_injects_dummy_credentials_for_connected_providers(self, tm auth = await _make_auth(tmp_path) await _save_connection_record(auth, "openai", "sk-real-padded-for-regex-12") - runner = ProxyRunner(auth) + runner = ProxyRunner(auth, home=tmp_path) with patch("authsome.proxy.runner.subprocess.run") as run_mock: run_mock.return_value.returncode = 0 @@ -608,7 +599,7 @@ async def test_runner_stops_proxy_on_subprocess_failure(self, tmp_path: Path) -> from authsome.proxy.runner import ProxyRunner auth = await _make_auth(tmp_path) - runner = ProxyRunner(auth) + runner = ProxyRunner(auth, home=tmp_path) server = mock.Mock() with patch("authsome.proxy.runner.subprocess.run", side_effect=RuntimeError("boom")): @@ -619,6 +610,42 @@ async def test_runner_stops_proxy_on_subprocess_failure(self, tmp_path: Path) -> server.shutdown.assert_called_once() + def test_start_proxy_ensures_local_proxy_ca_once(self, tmp_path: Path) -> None: + from authsome.proxy.runner import ProxyRunner + + runner = ProxyRunner(mock.Mock(), home=tmp_path) + server = mock.Mock(url="http://127.0.0.1:8899") + + with patch("authsome.proxy.runner.ensure_local_proxy_ca") as ensure_ca_mock: + with patch("authsome.proxy.runner.start_proxy_server", return_value=server) as start_mock: + proxy_url, returned_server = runner._start_proxy() + + ensure_ca_mock.assert_called_once_with(tmp_path) + start_mock.assert_called_once_with(runner._client) + assert proxy_url == "http://127.0.0.1:8899" + assert returned_server is server + + def test_ensure_local_proxy_ca_sets_flag_after_success(self, tmp_path: Path) -> None: + from authsome.cli.client_config import load_client_config + from authsome.proxy.certs import ensure_local_proxy_ca + + with patch("authsome.proxy.certs._ensure_macos_keychain_ca", return_value=True) as ensure_ca_mock: + ensure_local_proxy_ca(tmp_path) + + ensure_ca_mock.assert_called_once_with() + assert load_client_config(tmp_path).proxy_ca_installed is True + + def test_ensure_local_proxy_ca_skips_repeat_prompt_once_flagged(self, tmp_path: Path) -> None: + from authsome.cli.client_config import ClientConfig, save_client_config + from authsome.proxy.certs import ensure_local_proxy_ca + + save_client_config(tmp_path, ClientConfig(proxy_ca_installed=True)) + + with patch("authsome.proxy.certs._ensure_macos_keychain_ca") as ensure_ca_mock: + ensure_local_proxy_ca(tmp_path) + + ensure_ca_mock.assert_not_called() + def test_runner_merges_existing_no_proxy(self, tmp_path: Path) -> None: from authsome.proxy.runner import ProxyRunner diff --git a/tests/server/test_auth_sessions.py b/tests/server/test_auth_sessions.py index 2ce6d0d..85a815a 100644 --- a/tests/server/test_auth_sessions.py +++ b/tests/server/test_auth_sessions.py @@ -1,5 +1,6 @@ """Session ownership tests for protected auth routes.""" +import asyncio from pathlib import Path from fastapi.testclient import TestClient @@ -45,11 +46,13 @@ def test_get_session_rejects_other_identity(monkeypatch, tmp_path: Path) -> None json={"handle": stranger.handle, "did": stranger.did}, ) assert stranger_registration.status_code == 200 - session = client.app.state.auth_sessions.create( - provider="github", - identity=owner.handle, - connection_name="default", - flow_type=FlowType.PKCE.value, + session = asyncio.run( + client.app.state.auth_sessions.create( + provider="github", + identity=owner.handle, + connection_name="default", + flow_type=FlowType.PKCE.value, + ) ) response = client.get( @@ -80,11 +83,13 @@ def test_resume_session_rejects_other_identity(monkeypatch, tmp_path: Path) -> N json={"handle": stranger.handle, "did": stranger.did}, ) assert stranger_registration.status_code == 200 - session = client.app.state.auth_sessions.create( - provider="github", - identity=owner.handle, - connection_name="default", - flow_type=FlowType.PKCE.value, + session = asyncio.run( + client.app.state.auth_sessions.create( + provider="github", + identity=owner.handle, + connection_name="default", + flow_type=FlowType.PKCE.value, + ) ) response = client.post( @@ -100,3 +105,31 @@ def test_resume_session_rejects_other_identity(monkeypatch, tmp_path: Path) -> N assert response.status_code == 401 assert response.json()["detail"] == "Proof JWT body hash does not match request" + + +def test_sessions_do_not_survive_app_recreation(monkeypatch, tmp_path: Path) -> None: + monkeypatch.setenv("AUTHSOME_HOME", str(tmp_path)) + owner = create_identity(tmp_path, "steady-wisely-boldly-0042") + session_id = "" + + with TestClient(create_app()) as first_client: + registration = first_client.post("/identities/register", json={"handle": owner.handle, "did": owner.did}) + assert registration.status_code == 200 + session = asyncio.run( + first_client.app.state.auth_sessions.create( + provider="github", + identity=owner.handle, + connection_name="default", + flow_type=FlowType.PKCE.value, + ) + ) + session_id = session.session_id + + with TestClient(create_app()) as second_client: + response = second_client.get( + f"/auth/sessions/{session_id}", + headers=_auth_header(tmp_path, "GET", f"/auth/sessions/{session_id}", handle=owner.handle), + ) + + assert response.status_code == 404 + assert response.json()["detail"] == "Authentication session not found" diff --git a/tests/server/test_ui_sessions.py b/tests/server/test_ui_sessions.py index 0dcb459..d4e9b1b 100644 --- a/tests/server/test_ui_sessions.py +++ b/tests/server/test_ui_sessions.py @@ -97,14 +97,17 @@ def test_hosted_ui_auth_input_requires_matching_browser_session(monkeypatch, tmp app = create_app() with TestClient(app) as client: _register_identity(client, tmp_path, "steady-wisely-boldly-0042") - session = client.app.state.auth_sessions.create( - provider="github", - identity="steady-wisely-boldly-0042", - connection_name="default", - flow_type="pkce", + session = asyncio.run( + client.app.state.auth_sessions.create( + provider="github", + identity="steady-wisely-boldly-0042", + connection_name="default", + flow_type="pkce", + ) ) session.payload["ui_session_required"] = True session.payload["input_fields"] = [{"name": "client_id", "label": "Client ID", "secret": False}] + asyncio.run(client.app.state.auth_sessions.save(session)) response = client.get(f"/auth/sessions/{session.session_id}/input")