Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/backend/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
readiness_routes,
suggestion_routes,
certification_levels_routes,
directory_routes,
)

from src.common.database import init_db, get_session_factory, SQLAlchemySession
Expand Down Expand Up @@ -383,6 +384,7 @@ async def shutdown_event():
self_service_routes.register_routes(app)
workflows_routes.register_routes(app)
settings_routes.register_routes(app)
directory_routes.register_routes(app)
connection_routes.register_routes(app)
schema_import_routes.register_routes(app)

Expand Down
3 changes: 3 additions & 0 deletions src/backend/src/common/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from src.controller.business_owners_manager import BusinessOwnersManager
from src.controller.delivery_methods_manager import DeliveryMethodsManager
from src.controller.ontology_generator_manager import OntologyGeneratorManager
from src.controller.directory_manager import DirectoryManager

# Import base dependencies
from src.common.database import get_session_factory # Import the factory function
Expand Down Expand Up @@ -55,6 +56,7 @@
get_business_owners_manager,
get_delivery_methods_manager,
get_ontology_generator_manager,
get_directory_manager,
)
# Import workspace client getter separately as it might be structured differently
from src.common.workspace_client import get_workspace_client_dependency # Fixed to use proper wrapper
Expand Down Expand Up @@ -132,6 +134,7 @@ async def get_current_user(user_details: UserInfo = Depends(get_user_details_fro
BusinessOwnersManagerDep = Annotated[BusinessOwnersManager, Depends(get_business_owners_manager)]
DeliveryMethodsManagerDep = Annotated[DeliveryMethodsManager, Depends(get_delivery_methods_manager)]
OntologyGeneratorManagerDep = Annotated[OntologyGeneratorManager, Depends(get_ontology_generator_manager)]
DirectoryManagerDep = Annotated[DirectoryManager, Depends(get_directory_manager)]

# Permission Checker Dependency
PermissionCheckerDep = AuthorizationManagerDep
Expand Down
8 changes: 8 additions & 0 deletions src/backend/src/common/manager_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from src.controller.business_owners_manager import BusinessOwnersManager
from src.controller.delivery_methods_manager import DeliveryMethodsManager
from src.controller.ontology_generator_manager import OntologyGeneratorManager
from src.controller.directory_manager import DirectoryManager

# Import other dependencies needed by these providers
from src.common.database import get_db
Expand Down Expand Up @@ -221,6 +222,13 @@ def get_ontology_generator_manager(request: Request) -> OntologyGeneratorManager
raise HTTPException(status_code=503, detail="Ontology Generator service not configured.")
return manager

def get_directory_manager(request: Request) -> DirectoryManager:
manager = getattr(request.app.state, "directory_manager", None)
if not manager:
logger.critical("DirectoryManager not found in application state!")
raise HTTPException(status_code=503, detail="Directory service not configured.")
return manager

# --- Add other manager getters if needed --- #
# Example:
# def get_data_products_manager(request: Request) -> DataProductsManager:
Expand Down
47 changes: 47 additions & 0 deletions src/backend/src/common/uc_connections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Shared helpers for working with Unity Catalog Connections.

Today these are HTTP-type connections used by workflow webhook steps
and by the Directory layer. Centralising the listing logic so both
routes return the same shape and behave identically on SDK errors.
"""

from typing import Any, Dict, List

from src.common.logging import get_logger

logger = get_logger(__name__)


def list_http_connections(ws_client: Any) -> List[Dict[str, Any]]:
"""Return all HTTP-type Unity Catalog connections.

Output shape matches the legacy ``/api/workflows/http-connections``
response: ``[{name, connection_type, comment, owner, created_at,
updated_at}, ...]``. On SDK errors an empty list is returned and a
warning is logged -- the workspace may simply not expose HTTP
connections yet, and surfacing the error as 500 was deemed worse
than returning the no-op set.
"""

try:
connections: List[Dict[str, Any]] = []
for conn in ws_client.connections.list():
# ``ConnectionType.HTTP`` is not present in every SDK
# version, so match the existing string-based check.
conn_type = str(conn.connection_type) if conn.connection_type else ""
if "HTTP" not in conn_type.upper():
continue
connections.append(
{
"name": conn.name,
"connection_type": conn_type,
"comment": conn.comment,
"owner": conn.owner,
"created_at": conn.created_at,
"updated_at": conn.updated_at,
}
)
return connections
except Exception as exc:
logger.warning(f"Failed to list HTTP connections: {exc}")
return []
219 changes: 219 additions & 0 deletions src/backend/src/controller/directory_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""Directory layer manager.

Reads provider configuration from ``app_settings``, dispatches to the
right concrete ``DirectoryProvider``, and caches search results in
memory for 5 minutes per (provider_type, query_shape) key. The cache
is per-instance and the manager is held as a singleton on
``app.state``.

The manager itself is provider-agnostic: adding a new provider is a
matter of registering another class in ``_PROVIDER_REGISTRY``.
"""

import time
from threading import Lock
from typing import Any, Callable, Dict, List, Optional, Tuple

from sqlalchemy.orm import Session

from src.common.logging import get_logger
from src.controller.directory_providers import (
DirectoryError,
DirectoryProvider,
EntraIdProvider,
)
from src.models.directory import (
DirectoryProviderType,
DirectoryStatus,
Principal,
SETTING_KEY_CONNECTION_NAME,
SETTING_KEY_PROVIDER_TYPE,
)
from src.repositories.app_settings_repository import app_settings_repo

logger = get_logger(__name__)


_CACHE_TTL_SECONDS = 5 * 60
_DEFAULT_SEARCH_LIMIT = 20


# Provider registry. Adding a new provider requires only an entry here
# plus an implementation in src.controller.directory_providers; the
# manager, routes, and models stay untouched.
_PROVIDER_REGISTRY: Dict[str, Callable[[Any, str], DirectoryProvider]] = {
DirectoryProviderType.ENTRA.value: EntraIdProvider,
}


class DirectoryManager:
"""Stateless dispatcher + per-instance TTL cache.

All methods are safe to call from concurrent request handlers; the
internal cache is guarded by a lock.
"""

def __init__(self) -> None:
self._cache: Dict[Tuple[str, str, str, str], Tuple[float, List[Principal]]] = {}
self._lock = Lock()
# Track which (provider_type, connection_name) tuple the cache
# was filled for; flip => purge.
self._cache_keyed_on: Optional[Tuple[str, str]] = None

# ----- public API ---------------------------------------------------------

def get_status(self, db: Session) -> DirectoryStatus:
"""Return the live ``configured`` flag plus a redaction-safe summary."""

provider_type = app_settings_repo.get_by_key(db, SETTING_KEY_PROVIDER_TYPE)
connection_name = app_settings_repo.get_by_key(db, SETTING_KEY_CONNECTION_NAME)
configured = bool(provider_type) and bool(connection_name) and provider_type in _PROVIDER_REGISTRY
return DirectoryStatus(
configured=configured,
provider_type=provider_type if provider_type else None,
connection_name=connection_name if connection_name else None,
)

def search(
self,
db: Session,
ws_client: Any,
query: str,
types: List[str],
limit: int = _DEFAULT_SEARCH_LIMIT,
) -> List[Principal]:
"""Return up to ``limit`` principals matching ``query`` across ``types``.

``types`` may include any combination of ``"user"`` and
``"group"``. Results are de-duplicated by ``(type, id)`` to
survive partial cache hits. Returns an empty list when the
directory is not configured.
"""

provider_type, connection_name = self._read_settings(db)
if not provider_type or not connection_name:
return []

self._invalidate_if_keyed_changed(provider_type, connection_name)

wanted = {t for t in types if t in {"user", "group"}} or {"user", "group"}
results: List[Principal] = []
seen: set = set()

provider = self._build_provider(provider_type, ws_client, connection_name)

if "user" in wanted:
for p in self._cached(provider_type, connection_name, "user", query, limit,
lambda: provider.search_users(query, limit)):
key = (p.type, p.id)
if key not in seen:
seen.add(key)
results.append(p)

if "group" in wanted:
for p in self._cached(provider_type, connection_name, "group", query, limit,
lambda: provider.search_groups(query, limit)):
key = (p.type, p.id)
if key not in seen:
seen.add(key)
results.append(p)

# Honour the caller's overall limit even after cross-type merge.
return results[:limit]

def test(self, db: Session, ws_client: Any) -> None:
"""Probe the configured provider. Raises ``DirectoryError`` if unhealthy."""

provider_type, connection_name = self._read_settings(db)
if not provider_type:
raise DirectoryError("Directory provider is not configured")
if not connection_name:
raise DirectoryError("UC HTTP connection name is not configured")
provider = self._build_provider(provider_type, ws_client, connection_name)
provider.test()

def invalidate_cache(self) -> None:
"""Drop all cached results. Call after any setting change."""

with self._lock:
self._cache.clear()
self._cache_keyed_on = None

# ----- internals ----------------------------------------------------------

def _read_settings(self, db: Session) -> Tuple[Optional[str], Optional[str]]:
provider_type = app_settings_repo.get_by_key(db, SETTING_KEY_PROVIDER_TYPE)
connection_name = app_settings_repo.get_by_key(db, SETTING_KEY_CONNECTION_NAME)
return (provider_type or None), (connection_name or None)

def _build_provider(
self,
provider_type: str,
ws_client: Any,
connection_name: str,
) -> DirectoryProvider:
factory = _PROVIDER_REGISTRY.get(provider_type)
if factory is None:
raise DirectoryError(
f"Unknown directory provider type: {provider_type!r}"
)
return factory(ws_client, connection_name)

def _invalidate_if_keyed_changed(self, provider_type: str, connection_name: str) -> None:
with self._lock:
current = (provider_type, connection_name)
if self._cache_keyed_on is not None and self._cache_keyed_on != current:
self._cache.clear()
self._cache_keyed_on = current

def _cached(
self,
provider_type: str,
connection_name: str,
kind: str,
query: str,
limit: int,
loader: Callable[[], List[Principal]],
) -> List[Principal]:
# Normalise the query so capitalisation / surrounding whitespace
# doesn't bypass the cache.
cache_key = (provider_type, connection_name, kind, f"{query.strip().lower()}|{limit}")
now = time.monotonic()
with self._lock:
entry = self._cache.get(cache_key)
if entry and (now - entry[0]) < _CACHE_TTL_SECONDS:
return entry[1]

try:
values = loader()
except DirectoryError:
raise
except Exception as exc:
raise DirectoryError(f"Directory lookup failed: {exc}") from exc

with self._lock:
self._cache[cache_key] = (time.monotonic(), values)
return values


# Re-export for routes that only need the registry knowledge.
SUPPORTED_PROVIDER_TYPES: List[str] = list(_PROVIDER_REGISTRY.keys())


def register_provider(
provider_type: str,
factory: Callable[[Any, str], DirectoryProvider],
) -> None:
"""Register an additional provider implementation at runtime.

Used by tests to inject stub providers without touching the
production registry.
"""

_PROVIDER_REGISTRY[provider_type] = factory


def unregister_provider(provider_type: str) -> None:
"""Inverse of :func:`register_provider`, primarily for test teardown."""

_PROVIDER_REGISTRY.pop(provider_type, None)
25 changes: 25 additions & 0 deletions src/backend/src/controller/directory_providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Directory provider plug-ins.

Each concrete provider talks to its IdP exclusively via a Unity Catalog
HTTP Connection so UC owns OAuth2 client-credentials acquisition,
caching, and refresh. The app stores no client secret and no token
cache.

Field mapping (Graph ``userPrincipalName`` vs Okta ``profile.login`` vs
...) lives entirely inside each provider; the manager and routes only
ever see normalised ``Principal`` instances.
"""

from src.controller.directory_providers.base import (
DirectoryError,
DirectoryProvider,
)
from src.controller.directory_providers.entra_id_provider import (
EntraIdProvider,
)

__all__ = [
"DirectoryError",
"DirectoryProvider",
"EntraIdProvider",
]
Loading