Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 130 additions & 43 deletions src/backend/src/controller/directory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,44 @@

Reads provider configuration from ``app_settings``, dispatches to the
right concrete ``DirectoryProvider``, and caches search results in
memory for 5 minutes per (provider_type, query_shape) key. The cache
is per-instance and the manager is held as a singleton on
memory for 5 minutes per (provider_type, config_signature, query) key.
The cache is per-instance and the manager is held as a singleton on
``app.state``.

The manager itself is provider-agnostic: adding a new provider is a
matter of registering another class in ``_PROVIDER_REGISTRY``.
matter of:

1. Implementing ``DirectoryProvider`` in ``src.controller.directory_providers``.
2. Registering it in ``_PROVIDER_REGISTRY`` below.
3. (If the provider needs new settings) adding the keys to
``DirectoryProviderConfig`` and to the read/write paths here.

No changes to routes or models otherwise.
"""

import time
from threading import Lock
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple

from sqlalchemy.orm import Session

from src.common.logging import get_logger
from src.controller.directory_providers import (
DirectoryError,
DirectoryProvider,
DirectoryProviderConfig,
DirectoryProviderContext,
EntraIdProvider,
FileProvider,
LakebaseProvider,
)
from src.models.directory import (
DirectoryProviderType,
DirectoryStatus,
Principal,
SETTING_KEY_CONNECTION_NAME,
SETTING_KEY_FILE_PATH,
SETTING_KEY_LAKEBASE_TABLE,
SETTING_KEY_PROVIDER_TYPE,
)
from src.repositories.app_settings_repository import app_settings_repo
Expand All @@ -38,11 +51,27 @@
_DEFAULT_SEARCH_LIMIT = 20


# Provider registry. Adding a new provider requires only an entry here
# plus an implementation in src.controller.directory_providers; the
# manager, routes, and models stay untouched.
_PROVIDER_REGISTRY: Dict[str, Callable[[Any, str], DirectoryProvider]] = {
# Provider registry. Each factory takes (context, config) and returns
# a DirectoryProvider. Adding a new provider requires only an entry
# here plus an implementation in src.controller.directory_providers;
# routes and models stay untouched.
ProviderFactory = Callable[
[DirectoryProviderContext, DirectoryProviderConfig], DirectoryProvider
]
_PROVIDER_REGISTRY: Dict[str, ProviderFactory] = {
DirectoryProviderType.ENTRA.value: EntraIdProvider,
DirectoryProviderType.LAKEBASE.value: LakebaseProvider,
DirectoryProviderType.FILE.value: FileProvider,
}


# Each provider declares which settings keys are required so the
# manager can decide ``configured=True/False`` without instantiating
# the provider.
_REQUIRED_KEYS: Dict[str, Tuple[str, ...]] = {
DirectoryProviderType.ENTRA.value: (SETTING_KEY_CONNECTION_NAME,),
DirectoryProviderType.LAKEBASE.value: (SETTING_KEY_LAKEBASE_TABLE,),
DirectoryProviderType.FILE.value: (SETTING_KEY_FILE_PATH,),
}


Expand All @@ -54,30 +83,34 @@ class DirectoryManager:
"""

def __init__(self) -> None:
self._cache: Dict[Tuple[str, str, str, str], Tuple[float, List[Principal]]] = {}
self._cache: Dict[
Tuple[str, Tuple[Optional[str], ...], str, str],
Tuple[float, List[Principal]],
] = {}
self._lock = Lock()
# Track which (provider_type, connection_name) tuple the cache
# was filled for; flip => purge.
self._cache_keyed_on: Optional[Tuple[str, str]] = None
# Track which (provider_type, config_signature) tuple the
# cache was filled for; flip => purge.
self._cache_keyed_on: Optional[Tuple[str, Tuple[Optional[str], ...]]] = None

# ----- public API ---------------------------------------------------------

def get_status(self, db: Session) -> DirectoryStatus:
"""Return the live ``configured`` flag plus a redaction-safe summary."""

provider_type = app_settings_repo.get_by_key(db, SETTING_KEY_PROVIDER_TYPE)
connection_name = app_settings_repo.get_by_key(db, SETTING_KEY_CONNECTION_NAME)
configured = bool(provider_type) and bool(connection_name) and provider_type in _PROVIDER_REGISTRY
provider_type, config = self._read_settings(db)
configured = self._is_configured(provider_type, config)
return DirectoryStatus(
configured=configured,
provider_type=provider_type if provider_type else None,
connection_name=connection_name if connection_name else None,
provider_type=provider_type or None,
connection_name=config.connection_name,
lakebase_table=config.lakebase_table,
file_path=config.file_path,
)

def search(
self,
db: Session,
ws_client: Any,
ctx: DirectoryProviderContext,
query: str,
types: List[str],
limit: int = _DEFAULT_SEARCH_LIMIT,
Expand All @@ -90,29 +123,35 @@ def search(
directory is not configured.
"""

provider_type, connection_name = self._read_settings(db)
if not provider_type or not connection_name:
provider_type, config = self._read_settings(db)
if not self._is_configured(provider_type, config):
return []

self._invalidate_if_keyed_changed(provider_type, connection_name)
assert provider_type is not None # narrowed by _is_configured
signature = config.signature()
self._invalidate_if_keyed_changed(provider_type, signature)

wanted = {t for t in types if t in {"user", "group"}} or {"user", "group"}
results: List[Principal] = []
seen: set = set()

provider = self._build_provider(provider_type, ws_client, connection_name)
provider = self._build_provider(provider_type, ctx, config)

if "user" in wanted:
for p in self._cached(provider_type, connection_name, "user", query, limit,
lambda: provider.search_users(query, limit)):
for p in self._cached(
provider_type, signature, "user", query, limit,
lambda: provider.search_users(query, limit),
):
key = (p.type, p.id)
if key not in seen:
seen.add(key)
results.append(p)

if "group" in wanted:
for p in self._cached(provider_type, connection_name, "group", query, limit,
lambda: provider.search_groups(query, limit)):
for p in self._cached(
provider_type, signature, "group", query, limit,
lambda: provider.search_groups(query, limit),
):
key = (p.type, p.id)
if key not in seen:
seen.add(key)
Expand All @@ -121,15 +160,19 @@ def search(
# Honour the caller's overall limit even after cross-type merge.
return results[:limit]

def test(self, db: Session, ws_client: Any) -> None:
def test(self, db: Session, ctx: DirectoryProviderContext) -> None:
"""Probe the configured provider. Raises ``DirectoryError`` if unhealthy."""

provider_type, connection_name = self._read_settings(db)
provider_type, config = self._read_settings(db)
if not provider_type:
raise DirectoryError("Directory provider is not configured")
if not connection_name:
raise DirectoryError("UC HTTP connection name is not configured")
provider = self._build_provider(provider_type, ws_client, connection_name)
if not self._is_configured(provider_type, config):
missing = _REQUIRED_KEYS.get(provider_type, ())
raise DirectoryError(
f"Provider {provider_type!r} is missing required setting(s): "
f"{', '.join(missing)}"
)
provider = self._build_provider(provider_type, ctx, config)
provider.test()

def invalidate_cache(self) -> None:
Expand All @@ -141,43 +184,80 @@ def invalidate_cache(self) -> None:

# ----- internals ----------------------------------------------------------

def _read_settings(self, db: Session) -> Tuple[Optional[str], Optional[str]]:
def _read_settings(
self, db: Session,
) -> Tuple[Optional[str], DirectoryProviderConfig]:
provider_type = app_settings_repo.get_by_key(db, SETTING_KEY_PROVIDER_TYPE)
connection_name = app_settings_repo.get_by_key(db, SETTING_KEY_CONNECTION_NAME)
return (provider_type or None), (connection_name or None)
config = DirectoryProviderConfig(
connection_name=app_settings_repo.get_by_key(db, SETTING_KEY_CONNECTION_NAME) or None,
lakebase_table=app_settings_repo.get_by_key(db, SETTING_KEY_LAKEBASE_TABLE) or None,
file_path=app_settings_repo.get_by_key(db, SETTING_KEY_FILE_PATH) or None,
)
return (provider_type or None), config

def _is_configured(
self, provider_type: Optional[str], config: DirectoryProviderConfig,
) -> bool:
if not provider_type or provider_type not in _PROVIDER_REGISTRY:
return False
required = _REQUIRED_KEYS.get(provider_type, ())
# Translate setting keys to config-field names.
key_to_field = {
SETTING_KEY_CONNECTION_NAME: "connection_name",
SETTING_KEY_LAKEBASE_TABLE: "lakebase_table",
SETTING_KEY_FILE_PATH: "file_path",
}
for key in required:
field = key_to_field.get(key)
if field is None:
# Defensive: an unknown required key means the
# registry is misconfigured at code level. Treat as
# not-configured rather than crash.
logger.warning(
"Provider %s declares unknown required setting key %s",
provider_type, key,
)
return False
if not getattr(config, field):
return False
return True

def _build_provider(
self,
provider_type: str,
ws_client: Any,
connection_name: str,
ctx: DirectoryProviderContext,
config: DirectoryProviderConfig,
) -> DirectoryProvider:
factory = _PROVIDER_REGISTRY.get(provider_type)
if factory is None:
raise DirectoryError(
f"Unknown directory provider type: {provider_type!r}"
)
return factory(ws_client, connection_name)
return factory(ctx, config)

def _invalidate_if_keyed_changed(self, provider_type: str, connection_name: str) -> None:
def _invalidate_if_keyed_changed(
self,
provider_type: str,
signature: Tuple[Optional[str], ...],
) -> None:
with self._lock:
current = (provider_type, connection_name)
current = (provider_type, signature)
if self._cache_keyed_on is not None and self._cache_keyed_on != current:
self._cache.clear()
self._cache_keyed_on = current

def _cached(
self,
provider_type: str,
connection_name: str,
signature: Tuple[Optional[str], ...],
kind: str,
query: str,
limit: int,
loader: Callable[[], List[Principal]],
) -> List[Principal]:
# Normalise the query so capitalisation / surrounding whitespace
# doesn't bypass the cache.
cache_key = (provider_type, connection_name, kind, f"{query.strip().lower()}|{limit}")
cache_key = (provider_type, signature, kind, f"{query.strip().lower()}|{limit}")
now = time.monotonic()
with self._lock:
entry = self._cache.get(cache_key)
Expand All @@ -202,18 +282,25 @@ def _cached(

def register_provider(
provider_type: str,
factory: Callable[[Any, str], DirectoryProvider],
factory: ProviderFactory,
*,
required_keys: Tuple[str, ...] = (),
) -> None:
"""Register an additional provider implementation at runtime.

Used by tests to inject stub providers without touching the
production registry.
production registry. ``required_keys`` mirrors the production
convention so the manager can compute ``configured`` correctly
for the stub too; default empty means "no required settings".
"""

_PROVIDER_REGISTRY[provider_type] = factory
if required_keys:
_REQUIRED_KEYS[provider_type] = required_keys


def unregister_provider(provider_type: str) -> None:
"""Inverse of :func:`register_provider`, primarily for test teardown."""

_PROVIDER_REGISTRY.pop(provider_type, None)
_REQUIRED_KEYS.pop(provider_type, None)
30 changes: 23 additions & 7 deletions src/backend/src/controller/directory_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
"""Directory provider plug-ins.

Each concrete provider talks to its IdP exclusively via a Unity Catalog
HTTP Connection so UC owns OAuth2 client-credentials acquisition,
caching, and refresh. The app stores no client secret and no token
cache.
Concrete providers shipped with the app:

Field mapping (Graph ``userPrincipalName`` vs Okta ``profile.login`` vs
...) lives entirely inside each provider; the manager and routes only
ever see normalised ``Principal`` instances.
- ``EntraIdProvider`` — Microsoft Entra ID via Microsoft Graph (UC HTTP)
- ``LakebaseProvider`` — Postgres table backed (the app's own Lakebase DB)
- ``FileProvider`` — Local CSV file (primarily for tests / demos)

Each provider receives a typed ``DirectoryProviderContext`` (transport
handles: SDK workspace client, SQLAlchemy engine, ...) and a
``DirectoryProviderConfig`` (all directory settings in one bag). The
provider reads only the fields relevant to its type and raises
``DirectoryError`` on missing / invalid required values.

Field mapping lives entirely inside each provider; the manager and
routes only ever see normalised ``Principal`` instances.
"""

from src.controller.directory_providers.base import (
DirectoryError,
DirectoryProvider,
DirectoryProviderConfig,
DirectoryProviderContext,
)
from src.controller.directory_providers.entra_id_provider import (
EntraIdProvider,
)
from src.controller.directory_providers.file_provider import FileProvider
from src.controller.directory_providers.lakebase_provider import (
LakebaseProvider,
)

__all__ = [
"DirectoryError",
"DirectoryProvider",
"DirectoryProviderConfig",
"DirectoryProviderContext",
"EntraIdProvider",
"FileProvider",
"LakebaseProvider",
]
Loading