diff --git a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py index 13782d69..1b0fc965 100644 --- a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py +++ b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py @@ -4,6 +4,15 @@ if TYPE_CHECKING: from dna.models.entity import EntityBase, Playlist, Project, User, Version +AUTH_MODE_PASSWORDLESS = "passwordless" +AUTH_MODE_SELF_HOSTED = "self_hosted" +AUTH_MODE_SSO = "sso" +SUPPORTED_SHOTGRID_AUTH_MODES = { + AUTH_MODE_PASSWORDLESS, + AUTH_MODE_SELF_HOSTED, + AUTH_MODE_SSO, +} + class UserNotFoundError(Exception): """Raised when a user is not found in the production tracking system.""" @@ -11,6 +20,12 @@ class UserNotFoundError(Exception): pass +class AuthModeNotImplementedError(Exception): + """Raised when an authentication mode is recognized but not implemented.""" + + pass + + class ProdtrackProviderBase: def __init__(self): pass @@ -88,6 +103,10 @@ def get_user_by_email(self, user_email: str) -> "User": """ raise NotImplementedError("Subclasses must implement this method.") + def get_user_by_login(self, login: str) -> "User": + """Get a user by their login/username.""" + raise NotImplementedError("Subclasses must implement this method.") + def get_projects_for_user(self, user_email: str) -> list["Project"]: """Get projects accessible by a user. @@ -149,11 +168,58 @@ def publish_note( raise NotImplementedError("Subclasses must implement this method.") -def get_prodtrack_provider() -> ProdtrackProviderBase: +def get_shotgrid_auth_mode() -> str: + """Get configured ShotGrid auth mode. + + Supported values: + - passwordless + - self_hosted + - sso + """ + auth_mode = os.getenv("SHOTGRID_AUTH_MODE", AUTH_MODE_PASSWORDLESS).strip().lower() + if auth_mode not in SUPPORTED_SHOTGRID_AUTH_MODES: + supported = ", ".join(sorted(SUPPORTED_SHOTGRID_AUTH_MODES)) + raise ValueError( + f"Invalid SHOTGRID_AUTH_MODE '{auth_mode}'. Supported values: {supported}" + ) + return auth_mode + + +def get_prodtrack_provider(session_token: str | None = None) -> ProdtrackProviderBase: """Get the production tracking provider.""" from dna.prodtrack_providers.shotgrid import ShotgridProvider provider_type = os.getenv("PRODTRACK_PROVIDER", "shotgrid") if provider_type == "shotgrid": - return ShotgridProvider() + return ShotgridProvider(session_token=session_token) raise ValueError(f"Unknown production tracking provider: {provider_type}") + + +def authenticate_user(username: str, password: str | None = None) -> dict[str, Any]: + """Authenticate a user using the configured provider and auth mode.""" + provider_type = os.getenv("PRODTRACK_PROVIDER", "shotgrid") + if provider_type != "shotgrid": + raise ValueError(f"Unknown production tracking provider: {provider_type}") + + from dna.prodtrack_providers.shotgrid_auth import ShotgridAuthenticationProvider + + auth_mode = get_shotgrid_auth_mode() + if auth_mode == AUTH_MODE_PASSWORDLESS: + result = ShotgridAuthenticationProvider.authenticate_passwordless(username) + result["mode"] = AUTH_MODE_PASSWORDLESS + return result + + if auth_mode == AUTH_MODE_SELF_HOSTED: + if not password: + raise ValueError("Password is required for self-hosted authentication.") + result = ShotgridAuthenticationProvider.authenticate(username, password) + result["mode"] = AUTH_MODE_SELF_HOSTED + return result + + if auth_mode == AUTH_MODE_SSO: + raise AuthModeNotImplementedError( + "SSO authentication mode is not implemented yet." + ) + + # Defensive fallback; get_shotgrid_auth_mode() validates supported values. + raise ValueError(f"Unsupported SHOTGRID_AUTH_MODE: {auth_mode}") diff --git a/backend/src/dna/prodtrack_providers/shotgrid.py b/backend/src/dna/prodtrack_providers/shotgrid.py index 75e91143..ecf9aee7 100644 --- a/backend/src/dna/prodtrack_providers/shotgrid.py +++ b/backend/src/dna/prodtrack_providers/shotgrid.py @@ -126,6 +126,7 @@ def __init__( url: Optional[str] = None, script_name: Optional[str] = None, api_key: Optional[str] = None, + session_token: Optional[str] = None, sudo_user: Optional[str] = None, connect: bool = True, ): @@ -135,6 +136,7 @@ def __init__( url: ShotGrid server URL. Defaults to SHOTGRID_URL env var. script_name: API script name. Defaults to SHOTGRID_SCRIPT_NAME env var. api_key: API key for authentication. Defaults to SHOTGRID_API_KEY env var. + session_token: Optional user session token. sudo_user: Optional user login to perform actions as. connect: Whether to connect immediately. """ @@ -143,12 +145,16 @@ def __init__( self.url = url or os.getenv("SHOTGRID_URL") self.script_name = script_name or os.getenv("SHOTGRID_SCRIPT_NAME") self.api_key = api_key or os.getenv("SHOTGRID_API_KEY") + self.session_token = session_token self.sudo_user = sudo_user or os.getenv("SHOTGRID_SUDO_USER") - if not all([self.url, self.script_name, self.api_key]): + if not self.url: + raise ValueError("ShotGrid URL not provided.") + + if not self.session_token and not (self.script_name and self.api_key): raise ValueError( - "ShotGrid credentials not provided. Set SHOTGRID_URL, " - "SHOTGRID_SCRIPT_NAME, and SHOTGRID_API_KEY environment variables." + "ShotGrid credentials not provided. Provide either session_token " + "or (script_name and api_key)." ) self.sg = None @@ -163,7 +169,10 @@ def connect(self, sudo_user: Optional[str] = None): sudo_user: Optional user login to perform actions as. If provided, overrides the instance's sudo_user. """ - # Close existing connection if any (though Shotgun API doesn't really require explicit close) + if self.session_token: + self.sg = Shotgun(self.url, session_token=self.session_token) + return + self.sg = Shotgun( self.url, self.script_name, @@ -177,6 +186,8 @@ def set_sudo_user(self, sudo_user: str): Args: sudo_user: The user login to perform actions as. """ + if self.session_token: + raise ValueError("Cannot set sudo user when using a session token.") self.sudo_user = sudo_user self.connect() @@ -189,6 +200,9 @@ def sudo(self, user_login: str): Args: user_login: The user login to perform actions as. """ + if self.session_token: + raise ValueError("Cannot use sudo context when using a session token.") + original_connection = self._sudo_connection try: # Create a temporary connection for this user @@ -456,7 +470,7 @@ def search( List of lightweight entity representations with type, id, name, and type-specific fields (email for users, description for shots/assets/versions) """ - if not self.sg: + if not self._sg: raise ValueError("Not connected to ShotGrid") results = [] @@ -501,7 +515,7 @@ def search( ) # Query ShotGrid directly with minimal fields for performance - sg_results = self.sg.find( + sg_results = self._sg.find( sg_entity_type, filters=sg_filters, fields=sg_fields, @@ -568,6 +582,35 @@ def get_user_by_email(self, user_email: str) -> User: sg_user, entity_mapping, "user", resolve_links=False ) + def get_user_by_login(self, login: str) -> User: + """Get a user by their login/username. + + Args: + login: The login/username of the user + + Returns: + User entity + + Raises: + ValueError: If user is not found + """ + if not self._sg: + raise ValueError("Not connected to ShotGrid") + + sg_user = self._sg.find_one( + "HumanUser", + filters=[["login", "is", login]], + fields=["id", "name", "email", "login"], + ) + + if not sg_user: + raise ValueError(f"User not found: {login}") + + entity_mapping = FIELD_MAPPING["user"] + return self._convert_sg_entity_to_dna_entity( + sg_user, entity_mapping, "user", resolve_links=False + ) + def get_projects_for_user(self, user_email: str) -> list[Project]: """Get projects accessible by a user. @@ -800,7 +843,7 @@ def publish_note( f"Author not found in ShotGrid: {author_email}" ) from e - if author_login: + if author_login and not self.session_token: with self.sudo(author_login): result = self._sg.create("Note", note_data) else: diff --git a/backend/src/dna/prodtrack_providers/shotgrid_auth.py b/backend/src/dna/prodtrack_providers/shotgrid_auth.py new file mode 100644 index 00000000..22205252 --- /dev/null +++ b/backend/src/dna/prodtrack_providers/shotgrid_auth.py @@ -0,0 +1,71 @@ +"""ShotGrid authentication provider implementation.""" + +import os +from typing import Any, Optional + +from shotgun_api3 import Shotgun + +from dna.prodtrack_providers.shotgrid import ShotgridProvider + + +class ShotgridAuthenticationProvider: + """Provider for ShotGrid authentication.""" + + @staticmethod + def authenticate( + username: str, password: str, provider_url: Optional[str] = None + ) -> dict[str, Any]: + """Authenticate a user with ShotGrid and return session token and user info. + + Args: + username: User login/username + password: User password + provider_url: Optional ShotGrid URL. If not provided, uses SHOTGRID_URL env var. + + Returns: + Dictionary containing 'token' and 'email' + + Raises: + ValueError: If authentication fails + """ + url = provider_url or os.getenv("SHOTGRID_URL") + if not url: + raise ValueError("SHOTGRID_URL not configured") + + try: + sg = Shotgun(url, login=username, password=password) + token = sg.get_session_token() + user_data = sg.find_one( + "HumanUser", filters=[["login", "is", username]], fields=["email"] + ) + + if not user_data: + raise ValueError(f"User not found: {username}") + + email = user_data.get("email") + if not email: + raise ValueError("User has no email address configured") + + return {"token": token, "email": email} + except ValueError: + raise + except Exception as e: + raise ValueError(f"Authentication failed: {str(e)}") from e + + @staticmethod + def authenticate_passwordless(username: str) -> dict[str, Any]: + """Resolve user identity without password and without issuing a token.""" + provider = ShotgridProvider() + + try: + if "@" in username: + user = provider.get_user_by_email(username) + else: + user = provider.get_user_by_login(username) + except ValueError as exc: + raise ValueError(f"Authentication failed: {str(exc)}") from exc + + if not user.email: + raise ValueError("User has no email address configured") + + return {"token": None, "email": user.email} diff --git a/backend/src/main.py b/backend/src/main.py index 81dfd77a..5a51e0c0 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -1,11 +1,18 @@ """FastAPI application entry point.""" -import os from functools import lru_cache from typing import Annotated, Optional, cast -from fastapi import Depends, FastAPI, HTTPException, WebSocket, WebSocketDisconnect +from fastapi import ( + Depends, + FastAPI, + Header, + HTTPException, + WebSocket, + WebSocketDisconnect, +) from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel from dna.events import EventType, get_event_publisher from dna.llm_providers.default_prompt import DEFAULT_PROMPT @@ -42,7 +49,9 @@ ) from dna.models.entity import ENTITY_MODELS, EntityBase from dna.prodtrack_providers.prodtrack_provider_base import ( + AuthModeNotImplementedError, ProdtrackProviderBase, + authenticate_user, get_prodtrack_provider, ) from dna.storage_providers.storage_provider_base import ( @@ -55,6 +64,12 @@ ) from dna.transcription_service import TranscriptionService, get_transcription_service + +class LoginRequest(BaseModel): + username: str + password: str | None = None + + # API metadata for Swagger documentation API_TITLE = "DNA Backend" API_DESCRIPTION = """ @@ -78,6 +93,10 @@ # Define API tags for organizing endpoints tags_metadata = [ + { + "name": "Auth", + "description": "Authentication endpoints", + }, { "name": "Health", "description": "Health check and status endpoints", @@ -169,12 +188,35 @@ # ----------------------------------------------------------------------------- +def get_token_header( + authorization: Annotated[str | None, Header()] = None, +) -> str | None: + """Extract token from Authorization header.""" + if not authorization: + return None + if authorization.startswith("Bearer "): + return authorization.split(" ")[1] + return authorization + + @lru_cache def get_prodtrack_provider_cached() -> ProdtrackProviderBase: """Get or create the production tracking provider singleton.""" return get_prodtrack_provider() +def get_prodtrack_provider_dep( + token: Annotated[str | None, Depends(get_token_header)], + cached_provider: Annotated[ + ProdtrackProviderBase, Depends(get_prodtrack_provider_cached) + ], +) -> ProdtrackProviderBase: + """Get the production tracking provider with user session.""" + if token: + return get_prodtrack_provider(session_token=token) + return cached_provider + + @lru_cache def get_storage_provider_cached() -> StorageProviderBase: """Get or create the storage provider singleton.""" @@ -194,7 +236,7 @@ def get_llm_provider_cached() -> LLMProviderBase: ProdtrackProviderDep = Annotated[ - ProdtrackProviderBase, Depends(get_prodtrack_provider_cached) + ProdtrackProviderBase, Depends(get_prodtrack_provider_dep) ] StorageProviderDep = Annotated[ @@ -239,6 +281,27 @@ async def shutdown_event(): await service.close() +# ----------------------------------------------------------------------------- +# Auth endpoints +# ----------------------------------------------------------------------------- + + +@app.post( + "/auth/login", + tags=["Auth"], + summary="Login to Production Tracking", + description="Authenticate with username and password to get a session token.", +) +async def login(request: LoginRequest): + """Login to ShotGrid.""" + try: + return authenticate_user(request.username, request.password) + except AuthModeNotImplementedError as e: + raise HTTPException(status_code=501, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=401, detail=str(e)) + + # ----------------------------------------------------------------------------- # Health endpoints # ----------------------------------------------------------------------------- diff --git a/backend/tests/providers/test_shotgrid_provider.py b/backend/tests/providers/test_shotgrid_provider.py index 1423149f..9214c688 100644 --- a/backend/tests/providers/test_shotgrid_provider.py +++ b/backend/tests/providers/test_shotgrid_provider.py @@ -51,7 +51,10 @@ def test_get_version(shotgrid_provider): def test_missing_credentials_raises_error(): """Test that missing credentials raises ValueError.""" with mock.patch.dict("os.environ", {}, clear=True): - with pytest.raises(ValueError, match="ShotGrid credentials not provided"): + with pytest.raises( + ValueError, + match="ShotGrid credentials not provided|ShotGrid URL not provided", + ): ShotgridProvider(url=None, script_name=None, api_key=None, connect=False) diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 3672b894..94b76c43 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -8,6 +8,7 @@ app, get_llm_provider_cached, get_prodtrack_provider_cached, + get_prodtrack_provider_dep, get_storage_provider_cached, ) @@ -51,7 +52,7 @@ def test_create_note_returns_201(self, mock_provider): project={"type": "Project", "id": 85}, ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -80,7 +81,7 @@ def test_create_note_with_links(self, mock_provider): project={"type": "Project", "id": 85}, ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -106,7 +107,7 @@ def test_create_note_with_links(self, mock_provider): def test_create_note_missing_project_returns_422(self, mock_provider): """Test that missing required project field returns 422.""" - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -138,7 +139,7 @@ def test_find_returns_200_with_results(self, mock_provider): Project(id=2, name="Project Two"), ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -164,7 +165,7 @@ def test_find_with_filters(self, mock_provider): mock_provider.find.return_value = [Shot(id=100, name="shot_010")] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -193,7 +194,7 @@ def test_find_with_uppercase_entity_type(self, mock_provider): mock_provider.find.return_value = [Project(id=1, name="Test")] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -211,7 +212,7 @@ def test_find_with_uppercase_entity_type(self, mock_provider): def test_find_unsupported_entity_type_returns_400(self, mock_provider): """Test that find returns 400 for unsupported entity types.""" - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -232,7 +233,7 @@ def test_find_returns_empty_list_when_no_results(self, mock_provider): """Test that find returns empty list when no entities match.""" mock_provider.find.return_value = [] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -254,7 +255,7 @@ def test_find_provider_error_returns_400(self, mock_provider): """Test that find returns 400 when provider raises ValueError.""" mock_provider.find.side_effect = ValueError("Unknown field 'bad_field'") - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -274,7 +275,7 @@ def test_find_provider_error_returns_400(self, mock_provider): def test_find_missing_entity_type_returns_422(self, mock_provider): """Test that find returns 422 when entity_type is missing.""" - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -293,7 +294,7 @@ def test_find_with_multiple_filters(self, mock_provider): mock_provider.find.return_value = [Version(id=1, name="v001", status="apr")] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -320,7 +321,7 @@ def test_find_default_filters_is_empty_list(self, mock_provider): mock_provider.find.return_value = [Project(id=1, name="Test")] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -684,7 +685,7 @@ def test_get_projects_for_user_returns_200_with_results(self, mock_provider): Project(id=2, name="Project Two"), ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/user/testuser") @@ -706,7 +707,7 @@ def test_get_projects_for_user_calls_provider_with_email(self, mock_provider): Project(id=1, name="Test Project") ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: client.get("/projects/user/jsmith@example.com") @@ -720,7 +721,7 @@ def test_get_projects_for_user_returns_empty_list(self, mock_provider): """Test that get_projects_for_user returns empty list when user has no projects.""" mock_provider.get_projects_for_user.return_value = [] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/user/newuser") @@ -736,7 +737,7 @@ def test_get_projects_for_user_returns_404_when_user_not_found(self, mock_provid "User not found: unknownuser" ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/user/unknownuser") @@ -764,7 +765,7 @@ def test_get_playlists_for_project_returns_200_with_results(self, mock_provider) Playlist(id=2, code="Final Review"), ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/42/playlists") @@ -788,7 +789,7 @@ def test_get_playlists_for_project_calls_provider_with_project_id( Playlist(id=1, code="Test Playlist") ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: client.get("/projects/123/playlists") @@ -800,7 +801,7 @@ def test_get_playlists_for_project_returns_empty_list(self, mock_provider): """Test that get_playlists_for_project returns empty list when no playlists.""" mock_provider.get_playlists_for_project.return_value = [] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/999/playlists") @@ -816,7 +817,7 @@ def test_get_playlists_for_project_returns_404_on_error(self, mock_provider): "Project not found" ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/999/playlists") @@ -844,7 +845,7 @@ def test_get_versions_for_playlist_returns_200_with_results(self, mock_provider) Version(id=2, name="shot_020_v002", status="apr"), ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/playlists/42/versions") @@ -868,7 +869,7 @@ def test_get_versions_for_playlist_calls_provider_with_playlist_id( Version(id=1, name="v001") ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: client.get("/playlists/123/versions") @@ -880,7 +881,7 @@ def test_get_versions_for_playlist_returns_empty_list(self, mock_provider): """Test that get_versions_for_playlist returns empty list when no versions.""" mock_provider.get_versions_for_playlist.return_value = [] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/playlists/999/versions") @@ -896,7 +897,7 @@ def test_get_versions_for_playlist_returns_404_on_error(self, mock_provider): "Playlist not found" ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/playlists/999/versions") @@ -907,6 +908,61 @@ def test_get_versions_for_playlist_returns_404_on_error(self, mock_provider): app.dependency_overrides.clear() +class TestAuthEndpoints: + """Tests for authentication endpoints.""" + + def test_login_success(self): + """Test successful login returns token and email.""" + with mock.patch("main.authenticate_user") as mock_auth: + mock_auth.return_value = { + "token": "fake-token", + "email": "test@example.com", + } + + response = client.post( + "/auth/login", + json={"username": "testuser", "password": "password"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["token"] == "fake-token" + assert data["email"] == "test@example.com" + mock_auth.assert_called_once_with("testuser", "password") + + def test_login_failure(self): + """Test failed login returns 401.""" + with mock.patch("main.authenticate_user") as mock_auth: + mock_auth.side_effect = ValueError("Invalid credentials") + + response = client.post( + "/auth/login", + json={"username": "testuser", "password": "wrong-password"}, + ) + + assert response.status_code == 401 + assert "Invalid credentials" in response.json()["detail"] + + def test_login_not_implemented_mode(self): + """Test unimplemented auth mode returns 501.""" + from dna.prodtrack_providers.prodtrack_provider_base import ( + AuthModeNotImplementedError, + ) + + with mock.patch("main.authenticate_user") as mock_auth: + mock_auth.side_effect = AuthModeNotImplementedError( + "SSO authentication mode is not implemented yet." + ) + + response = client.post( + "/auth/login", + json={"username": "testuser", "password": "password"}, + ) + + assert response.status_code == 501 + assert "not implemented" in response.json()["detail"].lower() + + class TestGenerateNoteEndpoint: """Tests for POST /generate-note endpoint.""" diff --git a/backend/tests/test_providers_base.py b/backend/tests/test_providers_base.py index 2edb038c..e2260983 100644 --- a/backend/tests/test_providers_base.py +++ b/backend/tests/test_providers_base.py @@ -1,8 +1,16 @@ """Tests for base provider classes and additional coverage.""" +from unittest import mock + import pytest from dna.llm_providers.llm_provider_base import LLMProviderBase +from dna.prodtrack_providers.prodtrack_provider_base import ( + AUTH_MODE_PASSWORDLESS, + AuthModeNotImplementedError, + authenticate_user, + get_shotgrid_auth_mode, +) from dna.transcription_providers.transcription_provider_base import ( TranscriptionProviderBase, ) @@ -43,3 +51,59 @@ def test_init_exists(self): """Test that TranscriptionProviderBase can be instantiated.""" provider = TranscriptionProviderBase() assert provider is not None + + +class TestShotgridAuthModes: + """Tests for auth mode selection helpers.""" + + def test_get_shotgrid_auth_mode_defaults_to_passwordless(self): + with mock.patch.dict("os.environ", {}, clear=True): + assert get_shotgrid_auth_mode() == AUTH_MODE_PASSWORDLESS + + def test_get_shotgrid_auth_mode_rejects_invalid_mode(self): + with mock.patch.dict("os.environ", {"SHOTGRID_AUTH_MODE": "invalid-mode"}): + with pytest.raises(ValueError, match="Invalid SHOTGRID_AUTH_MODE"): + get_shotgrid_auth_mode() + + def test_authenticate_user_passwordless(self): + with mock.patch.dict( + "os.environ", + {"PRODTRACK_PROVIDER": "shotgrid", "SHOTGRID_AUTH_MODE": "passwordless"}, + ): + with mock.patch( + "dna.prodtrack_providers.shotgrid_auth.ShotgridAuthenticationProvider.authenticate_passwordless" + ) as mock_auth: + mock_auth.return_value = {"token": None, "email": "test@example.com"} + result = authenticate_user("test@example.com") + + assert result["email"] == "test@example.com" + assert result["token"] is None + assert result["mode"] == "passwordless" + mock_auth.assert_called_once_with("test@example.com") + + def test_authenticate_user_self_hosted(self): + with mock.patch.dict( + "os.environ", + {"PRODTRACK_PROVIDER": "shotgrid", "SHOTGRID_AUTH_MODE": "self_hosted"}, + ): + with mock.patch( + "dna.prodtrack_providers.shotgrid_auth.ShotgridAuthenticationProvider.authenticate" + ) as mock_auth: + mock_auth.return_value = { + "token": "session-token", + "email": "test@example.com", + } + result = authenticate_user("testuser", "password") + + assert result["email"] == "test@example.com" + assert result["token"] == "session-token" + assert result["mode"] == "self_hosted" + mock_auth.assert_called_once_with("testuser", "password") + + def test_authenticate_user_sso_not_implemented(self): + with mock.patch.dict( + "os.environ", + {"PRODTRACK_PROVIDER": "shotgrid", "SHOTGRID_AUTH_MODE": "sso"}, + ): + with pytest.raises(AuthModeNotImplementedError, match="not implemented"): + authenticate_user("testuser", "password") diff --git a/backend/tests/test_shotgrid_provider.py b/backend/tests/test_shotgrid_provider.py index 89e90b95..a27a504b 100644 --- a/backend/tests/test_shotgrid_provider.py +++ b/backend/tests/test_shotgrid_provider.py @@ -70,6 +70,37 @@ def test_init_with_sudo_user(self, mock_shotgun): ) assert provider.sudo_user == "admin" + def test_init_with_session_token(self, mock_shotgun): + """Test that __init__ supports session token authentication.""" + with mock.patch.dict( + os.environ, + { + "SHOTGRID_URL": "https://test.shotgunstudio.com", + "SHOTGRID_SCRIPT_NAME": "test_script", + "SHOTGRID_API_KEY": "test_key", + }, + ): + provider = ShotgridProvider(session_token="session-token") + mock_shotgun.assert_called_once_with( + "https://test.shotgunstudio.com", + session_token="session-token", + ) + assert provider.session_token == "session-token" + + def test_set_sudo_user_raises_with_session_token(self, mock_shotgun): + """Test sudo override is rejected in session token mode.""" + with mock.patch.dict( + os.environ, + { + "SHOTGRID_URL": "https://test.shotgunstudio.com", + "SHOTGRID_SCRIPT_NAME": "test_script", + "SHOTGRID_API_KEY": "test_key", + }, + ): + provider = ShotgridProvider(session_token="session-token") + with pytest.raises(ValueError, match="session token"): + provider.set_sudo_user("admin") + def test_set_sudo_user_reconnects(self, provider, mock_shotgun): """Test that set_sudo_user updates sudo_user and reconnects.""" # Reset mock to clear init call diff --git a/frontend/packages/app/src/App.test.tsx b/frontend/packages/app/src/App.test.tsx index da9603ca..b0d889f8 100644 --- a/frontend/packages/app/src/App.test.tsx +++ b/frontend/packages/app/src/App.test.tsx @@ -11,6 +11,6 @@ describe('App', () => { it('should render the project selector initially', () => { render(); expect(screen.getByText('Welcome to DNA')).toBeInTheDocument(); - expect(screen.getByPlaceholderText('you@example.com')).toBeInTheDocument(); + expect(screen.getByPlaceholderText('username')).toBeInTheDocument(); }); }); diff --git a/frontend/packages/app/src/components/ProjectSelector.tsx b/frontend/packages/app/src/components/ProjectSelector.tsx index e1437e3d..dee9db06 100644 --- a/frontend/packages/app/src/components/ProjectSelector.tsx +++ b/frontend/packages/app/src/components/ProjectSelector.tsx @@ -2,7 +2,11 @@ import { useState, useEffect } from 'react'; import styled from 'styled-components'; import { Button, Flex, Select, Spinner } from '@radix-ui/themes'; import { Playlist, Project } from '@dna/core'; -import { useGetProjectsForUser, useGetPlaylistsForProject } from '../api'; +import { + useGetProjectsForUser, + useGetPlaylistsForProject, + apiHandler, +} from '../api'; import { Logo } from './Logo'; import { StyledTextField, @@ -13,8 +17,12 @@ import { export const STORAGE_KEYS = { USER_EMAIL: 'dna_user_email', PROJECT: 'dna_selected_project', + TOKEN: 'dna_user_token', + AUTH_MODE: 'dna_auth_mode', }; +type AuthMode = 'passwordless' | 'self_hosted' | 'sso'; + interface ProjectSelectorProps { onSelectionComplete: ( project: Project, @@ -208,14 +216,71 @@ function clearStoredProject(): void { } } +function getStoredToken(): string | null { + try { + return localStorage.getItem(STORAGE_KEYS.TOKEN); + } catch { + return null; + } +} + +function saveToken(token: string): void { + try { + localStorage.setItem(STORAGE_KEYS.TOKEN, token); + } catch { + // Ignore storage errors + } +} + +function clearStoredToken(): void { + try { + localStorage.removeItem(STORAGE_KEYS.TOKEN); + } catch { + // Ignore storage errors + } +} + +function getStoredAuthMode(): AuthMode | null { + try { + const mode = localStorage.getItem(STORAGE_KEYS.AUTH_MODE); + if (mode === 'passwordless' || mode === 'self_hosted' || mode === 'sso') { + return mode; + } + return null; + } catch { + return null; + } +} + +function saveAuthMode(mode: AuthMode): void { + try { + localStorage.setItem(STORAGE_KEYS.AUTH_MODE, mode); + } catch { + // Ignore storage errors + } +} + +function clearStoredAuthMode(): void { + try { + localStorage.removeItem(STORAGE_KEYS.AUTH_MODE); + } catch { + // Ignore storage errors + } +} + export function clearUserSession(): void { clearStoredEmail(); clearStoredProject(); + clearStoredToken(); + clearStoredAuthMode(); } export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) { const [step, setStep] = useState('loading'); const [email, setEmail] = useState(''); + const [password, setPassword] = useState(''); + const [loginError, setLoginError] = useState(null); + const [isLoggingIn, setIsLoggingIn] = useState(false); const [submittedEmail, setSubmittedEmail] = useState(null); const [selectedProject, setSelectedProject] = useState(null); const [selectedPlaylistId, setSelectedPlaylistId] = useState(''); @@ -223,13 +288,24 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) { useEffect(() => { const storedEmail = getStoredEmail(); const storedProject = getStoredProject(); + const storedToken = getStoredToken(); + const storedAuthMode = getStoredAuthMode(); + + if (storedToken && storedEmail) { + apiHandler.setUser({ id: '0', email: storedEmail, token: storedToken }); + } else { + apiHandler.setUser(null); + } - if (storedEmail && storedProject) { + const tokenRequired = storedAuthMode !== 'passwordless'; + const hasValidAuth = storedToken || !tokenRequired; + + if (storedEmail && storedProject && hasValidAuth) { setSubmittedEmail(storedEmail); setEmail(storedEmail); setSelectedProject(storedProject); setStep('playlist'); - } else if (storedEmail) { + } else if (storedEmail && hasValidAuth) { setSubmittedEmail(storedEmail); setEmail(storedEmail); setStep('project'); @@ -252,13 +328,44 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) { error: playlistsError, } = useGetPlaylistsForProject(selectedProject?.id ?? null); - const handleEmailSubmit = (e: React.FormEvent) => { + const handleLogin = async (e: React.FormEvent) => { e.preventDefault(); - if (email.trim()) { - const trimmedEmail = email.trim(); - setSubmittedEmail(trimmedEmail); - saveEmail(trimmedEmail); - setStep('project'); + setLoginError(null); + setIsLoggingIn(true); + + try { + if (email.trim()) { + const { + token, + email: userEmail, + mode, + } = await apiHandler.login({ + username: email.trim(), + password: password.trim() || undefined, + }); + + const authMode = mode || (token ? 'self_hosted' : 'passwordless'); + saveAuthMode(authMode); + saveEmail(userEmail); + + if (token) { + apiHandler.setUser({ id: '0', email: userEmail, token }); + saveToken(token); + } else { + apiHandler.setUser(null); + clearStoredToken(); + } + + setSubmittedEmail(userEmail); + setStep('project'); + } + } catch (err: any) { + setLoginError( + err.response?.data?.detail || + 'Login failed. Please check your credentials.' + ); + } finally { + setIsLoggingIn(false); } }; @@ -290,6 +397,11 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) { const handleBackToEmail = () => { clearStoredEmail(); clearStoredProject(); + clearStoredToken(); + clearStoredAuthMode(); + apiHandler.setUser(null); + setPassword(''); + setLoginError(null); setSubmittedEmail(null); setSelectedProject(null); setSelectedPlaylistId(''); @@ -330,26 +442,38 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) { Dailies Notes Assistant {step === 'email' && ( - + - + setEmail(e.target.value)} required /> + + setPassword(e.target.value)} + /> + + {loginError && {loginError}} + )} diff --git a/frontend/packages/core/src/apiHandler.test.ts b/frontend/packages/core/src/apiHandler.test.ts index 39b1efb9..4dc66c6d 100644 --- a/frontend/packages/core/src/apiHandler.test.ts +++ b/frontend/packages/core/src/apiHandler.test.ts @@ -675,4 +675,47 @@ describe('ApiHandler', () => { ).rejects.toThrow('Server error'); }); }); + + describe('login', () => { + it('should post credentials to auth endpoint', async () => { + const api = createApiHandler({ baseURL: 'http://localhost:8000' }); + const authResponse = { + token: 'session-token', + email: 'user@example.com', + mode: 'self_hosted', + }; + mockAxiosInstance.post.mockResolvedValue({ data: authResponse }); + + const result = await api.login({ + username: 'user@example.com', + password: 'secret', + }); + + expect(mockAxiosInstance.post).toHaveBeenCalledWith( + '/auth/login', + { username: 'user@example.com', password: 'secret' }, + undefined + ); + expect(result).toEqual(authResponse); + }); + + it('should support passwordless login request', async () => { + const api = createApiHandler({ baseURL: 'http://localhost:8000' }); + const authResponse = { + token: null, + email: 'user@example.com', + mode: 'passwordless', + }; + mockAxiosInstance.post.mockResolvedValue({ data: authResponse }); + + const result = await api.login({ username: 'user@example.com' }); + + expect(mockAxiosInstance.post).toHaveBeenCalledWith( + '/auth/login', + { username: 'user@example.com' }, + undefined + ); + expect(result).toEqual(authResponse); + }); + }); }); diff --git a/frontend/packages/core/src/apiHandler.ts b/frontend/packages/core/src/apiHandler.ts index 667bba57..5f53974b 100644 --- a/frontend/packages/core/src/apiHandler.ts +++ b/frontend/packages/core/src/apiHandler.ts @@ -4,6 +4,8 @@ import { GetPlaylistsForProjectParams, GetVersionsForPlaylistParams, GetUserByEmailParams, + LoginParams, + AuthResponse, GetDraftNoteParams, GetAllDraftNotesParams, UpsertDraftNoteParams, @@ -146,6 +148,10 @@ class ApiHandler { return this.get(`/users/${encodeURIComponent(params.userEmail)}`); } + async login(params: LoginParams): Promise { + return this.post('/auth/login', params); + } + async getDraftNote(params: GetDraftNoteParams): Promise { return this.get( `/playlists/${params.playlistId}/versions/${params.versionId}/draft-notes/${encodeURIComponent(params.userEmail)}` @@ -262,7 +268,9 @@ class ApiHandler { return this.get(`/playlists/${playlistId}/draft-notes`); } - async publishNotes(params: PublishNotesParams): Promise { + async publishNotes( + params: PublishNotesParams + ): Promise { return this.post( `/playlists/${params.playlistId}/publish-notes`, params.request diff --git a/frontend/packages/core/src/interfaces.ts b/frontend/packages/core/src/interfaces.ts index 5e2c88ef..a9901da9 100644 --- a/frontend/packages/core/src/interfaces.ts +++ b/frontend/packages/core/src/interfaces.ts @@ -141,6 +141,17 @@ export interface GetUserByEmailParams { userEmail: string; } +export interface LoginParams { + username: string; + password?: string; +} + +export interface AuthResponse { + token?: string | null; + email: string; + mode?: 'passwordless' | 'self_hosted' | 'sso'; +} + export interface DraftNoteLink { entity_type: string; entity_id: number;