diff --git a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py
index 13782d69..1b0fc965 100644
--- a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py
+++ b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py
@@ -4,6 +4,15 @@
if TYPE_CHECKING:
from dna.models.entity import EntityBase, Playlist, Project, User, Version
+AUTH_MODE_PASSWORDLESS = "passwordless"
+AUTH_MODE_SELF_HOSTED = "self_hosted"
+AUTH_MODE_SSO = "sso"
+SUPPORTED_SHOTGRID_AUTH_MODES = {
+ AUTH_MODE_PASSWORDLESS,
+ AUTH_MODE_SELF_HOSTED,
+ AUTH_MODE_SSO,
+}
+
class UserNotFoundError(Exception):
"""Raised when a user is not found in the production tracking system."""
@@ -11,6 +20,12 @@ class UserNotFoundError(Exception):
pass
+class AuthModeNotImplementedError(Exception):
+ """Raised when an authentication mode is recognized but not implemented."""
+
+ pass
+
+
class ProdtrackProviderBase:
def __init__(self):
pass
@@ -88,6 +103,10 @@ def get_user_by_email(self, user_email: str) -> "User":
"""
raise NotImplementedError("Subclasses must implement this method.")
+ def get_user_by_login(self, login: str) -> "User":
+ """Get a user by their login/username."""
+ raise NotImplementedError("Subclasses must implement this method.")
+
def get_projects_for_user(self, user_email: str) -> list["Project"]:
"""Get projects accessible by a user.
@@ -149,11 +168,58 @@ def publish_note(
raise NotImplementedError("Subclasses must implement this method.")
-def get_prodtrack_provider() -> ProdtrackProviderBase:
+def get_shotgrid_auth_mode() -> str:
+ """Get configured ShotGrid auth mode.
+
+ Supported values:
+ - passwordless
+ - self_hosted
+ - sso
+ """
+ auth_mode = os.getenv("SHOTGRID_AUTH_MODE", AUTH_MODE_PASSWORDLESS).strip().lower()
+ if auth_mode not in SUPPORTED_SHOTGRID_AUTH_MODES:
+ supported = ", ".join(sorted(SUPPORTED_SHOTGRID_AUTH_MODES))
+ raise ValueError(
+ f"Invalid SHOTGRID_AUTH_MODE '{auth_mode}'. Supported values: {supported}"
+ )
+ return auth_mode
+
+
+def get_prodtrack_provider(session_token: str | None = None) -> ProdtrackProviderBase:
"""Get the production tracking provider."""
from dna.prodtrack_providers.shotgrid import ShotgridProvider
provider_type = os.getenv("PRODTRACK_PROVIDER", "shotgrid")
if provider_type == "shotgrid":
- return ShotgridProvider()
+ return ShotgridProvider(session_token=session_token)
raise ValueError(f"Unknown production tracking provider: {provider_type}")
+
+
+def authenticate_user(username: str, password: str | None = None) -> dict[str, Any]:
+ """Authenticate a user using the configured provider and auth mode."""
+ provider_type = os.getenv("PRODTRACK_PROVIDER", "shotgrid")
+ if provider_type != "shotgrid":
+ raise ValueError(f"Unknown production tracking provider: {provider_type}")
+
+ from dna.prodtrack_providers.shotgrid_auth import ShotgridAuthenticationProvider
+
+ auth_mode = get_shotgrid_auth_mode()
+ if auth_mode == AUTH_MODE_PASSWORDLESS:
+ result = ShotgridAuthenticationProvider.authenticate_passwordless(username)
+ result["mode"] = AUTH_MODE_PASSWORDLESS
+ return result
+
+ if auth_mode == AUTH_MODE_SELF_HOSTED:
+ if not password:
+ raise ValueError("Password is required for self-hosted authentication.")
+ result = ShotgridAuthenticationProvider.authenticate(username, password)
+ result["mode"] = AUTH_MODE_SELF_HOSTED
+ return result
+
+ if auth_mode == AUTH_MODE_SSO:
+ raise AuthModeNotImplementedError(
+ "SSO authentication mode is not implemented yet."
+ )
+
+ # Defensive fallback; get_shotgrid_auth_mode() validates supported values.
+ raise ValueError(f"Unsupported SHOTGRID_AUTH_MODE: {auth_mode}")
diff --git a/backend/src/dna/prodtrack_providers/shotgrid.py b/backend/src/dna/prodtrack_providers/shotgrid.py
index 75e91143..ecf9aee7 100644
--- a/backend/src/dna/prodtrack_providers/shotgrid.py
+++ b/backend/src/dna/prodtrack_providers/shotgrid.py
@@ -126,6 +126,7 @@ def __init__(
url: Optional[str] = None,
script_name: Optional[str] = None,
api_key: Optional[str] = None,
+ session_token: Optional[str] = None,
sudo_user: Optional[str] = None,
connect: bool = True,
):
@@ -135,6 +136,7 @@ def __init__(
url: ShotGrid server URL. Defaults to SHOTGRID_URL env var.
script_name: API script name. Defaults to SHOTGRID_SCRIPT_NAME env var.
api_key: API key for authentication. Defaults to SHOTGRID_API_KEY env var.
+ session_token: Optional user session token.
sudo_user: Optional user login to perform actions as.
connect: Whether to connect immediately.
"""
@@ -143,12 +145,16 @@ def __init__(
self.url = url or os.getenv("SHOTGRID_URL")
self.script_name = script_name or os.getenv("SHOTGRID_SCRIPT_NAME")
self.api_key = api_key or os.getenv("SHOTGRID_API_KEY")
+ self.session_token = session_token
self.sudo_user = sudo_user or os.getenv("SHOTGRID_SUDO_USER")
- if not all([self.url, self.script_name, self.api_key]):
+ if not self.url:
+ raise ValueError("ShotGrid URL not provided.")
+
+ if not self.session_token and not (self.script_name and self.api_key):
raise ValueError(
- "ShotGrid credentials not provided. Set SHOTGRID_URL, "
- "SHOTGRID_SCRIPT_NAME, and SHOTGRID_API_KEY environment variables."
+ "ShotGrid credentials not provided. Provide either session_token "
+ "or (script_name and api_key)."
)
self.sg = None
@@ -163,7 +169,10 @@ def connect(self, sudo_user: Optional[str] = None):
sudo_user: Optional user login to perform actions as.
If provided, overrides the instance's sudo_user.
"""
- # Close existing connection if any (though Shotgun API doesn't really require explicit close)
+ if self.session_token:
+ self.sg = Shotgun(self.url, session_token=self.session_token)
+ return
+
self.sg = Shotgun(
self.url,
self.script_name,
@@ -177,6 +186,8 @@ def set_sudo_user(self, sudo_user: str):
Args:
sudo_user: The user login to perform actions as.
"""
+ if self.session_token:
+ raise ValueError("Cannot set sudo user when using a session token.")
self.sudo_user = sudo_user
self.connect()
@@ -189,6 +200,9 @@ def sudo(self, user_login: str):
Args:
user_login: The user login to perform actions as.
"""
+ if self.session_token:
+ raise ValueError("Cannot use sudo context when using a session token.")
+
original_connection = self._sudo_connection
try:
# Create a temporary connection for this user
@@ -456,7 +470,7 @@ def search(
List of lightweight entity representations with type, id, name, and
type-specific fields (email for users, description for shots/assets/versions)
"""
- if not self.sg:
+ if not self._sg:
raise ValueError("Not connected to ShotGrid")
results = []
@@ -501,7 +515,7 @@ def search(
)
# Query ShotGrid directly with minimal fields for performance
- sg_results = self.sg.find(
+ sg_results = self._sg.find(
sg_entity_type,
filters=sg_filters,
fields=sg_fields,
@@ -568,6 +582,35 @@ def get_user_by_email(self, user_email: str) -> User:
sg_user, entity_mapping, "user", resolve_links=False
)
+ def get_user_by_login(self, login: str) -> User:
+ """Get a user by their login/username.
+
+ Args:
+ login: The login/username of the user
+
+ Returns:
+ User entity
+
+ Raises:
+ ValueError: If user is not found
+ """
+ if not self._sg:
+ raise ValueError("Not connected to ShotGrid")
+
+ sg_user = self._sg.find_one(
+ "HumanUser",
+ filters=[["login", "is", login]],
+ fields=["id", "name", "email", "login"],
+ )
+
+ if not sg_user:
+ raise ValueError(f"User not found: {login}")
+
+ entity_mapping = FIELD_MAPPING["user"]
+ return self._convert_sg_entity_to_dna_entity(
+ sg_user, entity_mapping, "user", resolve_links=False
+ )
+
def get_projects_for_user(self, user_email: str) -> list[Project]:
"""Get projects accessible by a user.
@@ -800,7 +843,7 @@ def publish_note(
f"Author not found in ShotGrid: {author_email}"
) from e
- if author_login:
+ if author_login and not self.session_token:
with self.sudo(author_login):
result = self._sg.create("Note", note_data)
else:
diff --git a/backend/src/dna/prodtrack_providers/shotgrid_auth.py b/backend/src/dna/prodtrack_providers/shotgrid_auth.py
new file mode 100644
index 00000000..22205252
--- /dev/null
+++ b/backend/src/dna/prodtrack_providers/shotgrid_auth.py
@@ -0,0 +1,71 @@
+"""ShotGrid authentication provider implementation."""
+
+import os
+from typing import Any, Optional
+
+from shotgun_api3 import Shotgun
+
+from dna.prodtrack_providers.shotgrid import ShotgridProvider
+
+
+class ShotgridAuthenticationProvider:
+ """Provider for ShotGrid authentication."""
+
+ @staticmethod
+ def authenticate(
+ username: str, password: str, provider_url: Optional[str] = None
+ ) -> dict[str, Any]:
+ """Authenticate a user with ShotGrid and return session token and user info.
+
+ Args:
+ username: User login/username
+ password: User password
+ provider_url: Optional ShotGrid URL. If not provided, uses SHOTGRID_URL env var.
+
+ Returns:
+ Dictionary containing 'token' and 'email'
+
+ Raises:
+ ValueError: If authentication fails
+ """
+ url = provider_url or os.getenv("SHOTGRID_URL")
+ if not url:
+ raise ValueError("SHOTGRID_URL not configured")
+
+ try:
+ sg = Shotgun(url, login=username, password=password)
+ token = sg.get_session_token()
+ user_data = sg.find_one(
+ "HumanUser", filters=[["login", "is", username]], fields=["email"]
+ )
+
+ if not user_data:
+ raise ValueError(f"User not found: {username}")
+
+ email = user_data.get("email")
+ if not email:
+ raise ValueError("User has no email address configured")
+
+ return {"token": token, "email": email}
+ except ValueError:
+ raise
+ except Exception as e:
+ raise ValueError(f"Authentication failed: {str(e)}") from e
+
+ @staticmethod
+ def authenticate_passwordless(username: str) -> dict[str, Any]:
+ """Resolve user identity without password and without issuing a token."""
+ provider = ShotgridProvider()
+
+ try:
+ if "@" in username:
+ user = provider.get_user_by_email(username)
+ else:
+ user = provider.get_user_by_login(username)
+ except ValueError as exc:
+ raise ValueError(f"Authentication failed: {str(exc)}") from exc
+
+ if not user.email:
+ raise ValueError("User has no email address configured")
+
+ return {"token": None, "email": user.email}
diff --git a/backend/src/main.py b/backend/src/main.py
index 81dfd77a..5a51e0c0 100644
--- a/backend/src/main.py
+++ b/backend/src/main.py
@@ -1,11 +1,18 @@
"""FastAPI application entry point."""
-import os
from functools import lru_cache
from typing import Annotated, Optional, cast
-from fastapi import Depends, FastAPI, HTTPException, WebSocket, WebSocketDisconnect
+from fastapi import (
+ Depends,
+ FastAPI,
+ Header,
+ HTTPException,
+ WebSocket,
+ WebSocketDisconnect,
+)
from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
from dna.events import EventType, get_event_publisher
from dna.llm_providers.default_prompt import DEFAULT_PROMPT
@@ -42,7 +49,9 @@
)
from dna.models.entity import ENTITY_MODELS, EntityBase
from dna.prodtrack_providers.prodtrack_provider_base import (
+ AuthModeNotImplementedError,
ProdtrackProviderBase,
+ authenticate_user,
get_prodtrack_provider,
)
from dna.storage_providers.storage_provider_base import (
@@ -55,6 +64,12 @@
)
from dna.transcription_service import TranscriptionService, get_transcription_service
+
+class LoginRequest(BaseModel):
+ username: str
+ password: str | None = None
+
+
# API metadata for Swagger documentation
API_TITLE = "DNA Backend"
API_DESCRIPTION = """
@@ -78,6 +93,10 @@
# Define API tags for organizing endpoints
tags_metadata = [
+ {
+ "name": "Auth",
+ "description": "Authentication endpoints",
+ },
{
"name": "Health",
"description": "Health check and status endpoints",
@@ -169,12 +188,35 @@
# -----------------------------------------------------------------------------
+def get_token_header(
+ authorization: Annotated[str | None, Header()] = None,
+) -> str | None:
+ """Extract token from Authorization header."""
+ if not authorization:
+ return None
+ if authorization.startswith("Bearer "):
+ return authorization.split(" ")[1]
+ return authorization
+
+
@lru_cache
def get_prodtrack_provider_cached() -> ProdtrackProviderBase:
"""Get or create the production tracking provider singleton."""
return get_prodtrack_provider()
+def get_prodtrack_provider_dep(
+ token: Annotated[str | None, Depends(get_token_header)],
+ cached_provider: Annotated[
+ ProdtrackProviderBase, Depends(get_prodtrack_provider_cached)
+ ],
+) -> ProdtrackProviderBase:
+ """Get the production tracking provider with user session."""
+ if token:
+ return get_prodtrack_provider(session_token=token)
+ return cached_provider
+
+
@lru_cache
def get_storage_provider_cached() -> StorageProviderBase:
"""Get or create the storage provider singleton."""
@@ -194,7 +236,7 @@ def get_llm_provider_cached() -> LLMProviderBase:
ProdtrackProviderDep = Annotated[
- ProdtrackProviderBase, Depends(get_prodtrack_provider_cached)
+ ProdtrackProviderBase, Depends(get_prodtrack_provider_dep)
]
StorageProviderDep = Annotated[
@@ -239,6 +281,27 @@ async def shutdown_event():
await service.close()
+# -----------------------------------------------------------------------------
+# Auth endpoints
+# -----------------------------------------------------------------------------
+
+
+@app.post(
+ "/auth/login",
+ tags=["Auth"],
+ summary="Login to Production Tracking",
+ description="Authenticate with username and password to get a session token.",
+)
+async def login(request: LoginRequest):
+ """Login to ShotGrid."""
+ try:
+ return authenticate_user(request.username, request.password)
+ except AuthModeNotImplementedError as e:
+ raise HTTPException(status_code=501, detail=str(e))
+ except ValueError as e:
+ raise HTTPException(status_code=401, detail=str(e))
+
+
# -----------------------------------------------------------------------------
# Health endpoints
# -----------------------------------------------------------------------------
diff --git a/backend/tests/providers/test_shotgrid_provider.py b/backend/tests/providers/test_shotgrid_provider.py
index 1423149f..9214c688 100644
--- a/backend/tests/providers/test_shotgrid_provider.py
+++ b/backend/tests/providers/test_shotgrid_provider.py
@@ -51,7 +51,10 @@ def test_get_version(shotgrid_provider):
def test_missing_credentials_raises_error():
"""Test that missing credentials raises ValueError."""
with mock.patch.dict("os.environ", {}, clear=True):
- with pytest.raises(ValueError, match="ShotGrid credentials not provided"):
+ with pytest.raises(
+ ValueError,
+ match="ShotGrid credentials not provided|ShotGrid URL not provided",
+ ):
ShotgridProvider(url=None, script_name=None, api_key=None, connect=False)
diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py
index 3672b894..94b76c43 100644
--- a/backend/tests/test_main.py
+++ b/backend/tests/test_main.py
@@ -8,6 +8,7 @@
app,
get_llm_provider_cached,
get_prodtrack_provider_cached,
+ get_prodtrack_provider_dep,
get_storage_provider_cached,
)
@@ -51,7 +52,7 @@ def test_create_note_returns_201(self, mock_provider):
project={"type": "Project", "id": 85},
)
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -80,7 +81,7 @@ def test_create_note_with_links(self, mock_provider):
project={"type": "Project", "id": 85},
)
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -106,7 +107,7 @@ def test_create_note_with_links(self, mock_provider):
def test_create_note_missing_project_returns_422(self, mock_provider):
"""Test that missing required project field returns 422."""
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -138,7 +139,7 @@ def test_find_returns_200_with_results(self, mock_provider):
Project(id=2, name="Project Two"),
]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -164,7 +165,7 @@ def test_find_with_filters(self, mock_provider):
mock_provider.find.return_value = [Shot(id=100, name="shot_010")]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -193,7 +194,7 @@ def test_find_with_uppercase_entity_type(self, mock_provider):
mock_provider.find.return_value = [Project(id=1, name="Test")]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -211,7 +212,7 @@ def test_find_with_uppercase_entity_type(self, mock_provider):
def test_find_unsupported_entity_type_returns_400(self, mock_provider):
"""Test that find returns 400 for unsupported entity types."""
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -232,7 +233,7 @@ def test_find_returns_empty_list_when_no_results(self, mock_provider):
"""Test that find returns empty list when no entities match."""
mock_provider.find.return_value = []
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -254,7 +255,7 @@ def test_find_provider_error_returns_400(self, mock_provider):
"""Test that find returns 400 when provider raises ValueError."""
mock_provider.find.side_effect = ValueError("Unknown field 'bad_field'")
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -274,7 +275,7 @@ def test_find_provider_error_returns_400(self, mock_provider):
def test_find_missing_entity_type_returns_422(self, mock_provider):
"""Test that find returns 422 when entity_type is missing."""
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -293,7 +294,7 @@ def test_find_with_multiple_filters(self, mock_provider):
mock_provider.find.return_value = [Version(id=1, name="v001", status="apr")]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -320,7 +321,7 @@ def test_find_default_filters_is_empty_list(self, mock_provider):
mock_provider.find.return_value = [Project(id=1, name="Test")]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.post(
@@ -684,7 +685,7 @@ def test_get_projects_for_user_returns_200_with_results(self, mock_provider):
Project(id=2, name="Project Two"),
]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.get("/projects/user/testuser")
@@ -706,7 +707,7 @@ def test_get_projects_for_user_calls_provider_with_email(self, mock_provider):
Project(id=1, name="Test Project")
]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
client.get("/projects/user/jsmith@example.com")
@@ -720,7 +721,7 @@ def test_get_projects_for_user_returns_empty_list(self, mock_provider):
"""Test that get_projects_for_user returns empty list when user has no projects."""
mock_provider.get_projects_for_user.return_value = []
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.get("/projects/user/newuser")
@@ -736,7 +737,7 @@ def test_get_projects_for_user_returns_404_when_user_not_found(self, mock_provid
"User not found: unknownuser"
)
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.get("/projects/user/unknownuser")
@@ -764,7 +765,7 @@ def test_get_playlists_for_project_returns_200_with_results(self, mock_provider)
Playlist(id=2, code="Final Review"),
]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.get("/projects/42/playlists")
@@ -788,7 +789,7 @@ def test_get_playlists_for_project_calls_provider_with_project_id(
Playlist(id=1, code="Test Playlist")
]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
client.get("/projects/123/playlists")
@@ -800,7 +801,7 @@ def test_get_playlists_for_project_returns_empty_list(self, mock_provider):
"""Test that get_playlists_for_project returns empty list when no playlists."""
mock_provider.get_playlists_for_project.return_value = []
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.get("/projects/999/playlists")
@@ -816,7 +817,7 @@ def test_get_playlists_for_project_returns_404_on_error(self, mock_provider):
"Project not found"
)
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.get("/projects/999/playlists")
@@ -844,7 +845,7 @@ def test_get_versions_for_playlist_returns_200_with_results(self, mock_provider)
Version(id=2, name="shot_020_v002", status="apr"),
]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.get("/playlists/42/versions")
@@ -868,7 +869,7 @@ def test_get_versions_for_playlist_calls_provider_with_playlist_id(
Version(id=1, name="v001")
]
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
client.get("/playlists/123/versions")
@@ -880,7 +881,7 @@ def test_get_versions_for_playlist_returns_empty_list(self, mock_provider):
"""Test that get_versions_for_playlist returns empty list when no versions."""
mock_provider.get_versions_for_playlist.return_value = []
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.get("/playlists/999/versions")
@@ -896,7 +897,7 @@ def test_get_versions_for_playlist_returns_404_on_error(self, mock_provider):
"Playlist not found"
)
- app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider
+ app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider
try:
response = client.get("/playlists/999/versions")
@@ -907,6 +908,61 @@ def test_get_versions_for_playlist_returns_404_on_error(self, mock_provider):
app.dependency_overrides.clear()
+class TestAuthEndpoints:
+ """Tests for authentication endpoints."""
+
+ def test_login_success(self):
+ """Test successful login returns token and email."""
+ with mock.patch("main.authenticate_user") as mock_auth:
+ mock_auth.return_value = {
+ "token": "fake-token",
+ "email": "test@example.com",
+ }
+
+ response = client.post(
+ "/auth/login",
+ json={"username": "testuser", "password": "password"},
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+ assert data["token"] == "fake-token"
+ assert data["email"] == "test@example.com"
+ mock_auth.assert_called_once_with("testuser", "password")
+
+ def test_login_failure(self):
+ """Test failed login returns 401."""
+ with mock.patch("main.authenticate_user") as mock_auth:
+ mock_auth.side_effect = ValueError("Invalid credentials")
+
+ response = client.post(
+ "/auth/login",
+ json={"username": "testuser", "password": "wrong-password"},
+ )
+
+ assert response.status_code == 401
+ assert "Invalid credentials" in response.json()["detail"]
+
+ def test_login_not_implemented_mode(self):
+ """Test unimplemented auth mode returns 501."""
+ from dna.prodtrack_providers.prodtrack_provider_base import (
+ AuthModeNotImplementedError,
+ )
+
+ with mock.patch("main.authenticate_user") as mock_auth:
+ mock_auth.side_effect = AuthModeNotImplementedError(
+ "SSO authentication mode is not implemented yet."
+ )
+
+ response = client.post(
+ "/auth/login",
+ json={"username": "testuser", "password": "password"},
+ )
+
+ assert response.status_code == 501
+ assert "not implemented" in response.json()["detail"].lower()
+
+
class TestGenerateNoteEndpoint:
"""Tests for POST /generate-note endpoint."""
diff --git a/backend/tests/test_providers_base.py b/backend/tests/test_providers_base.py
index 2edb038c..e2260983 100644
--- a/backend/tests/test_providers_base.py
+++ b/backend/tests/test_providers_base.py
@@ -1,8 +1,16 @@
"""Tests for base provider classes and additional coverage."""
+from unittest import mock
+
import pytest
from dna.llm_providers.llm_provider_base import LLMProviderBase
+from dna.prodtrack_providers.prodtrack_provider_base import (
+ AUTH_MODE_PASSWORDLESS,
+ AuthModeNotImplementedError,
+ authenticate_user,
+ get_shotgrid_auth_mode,
+)
from dna.transcription_providers.transcription_provider_base import (
TranscriptionProviderBase,
)
@@ -43,3 +51,59 @@ def test_init_exists(self):
"""Test that TranscriptionProviderBase can be instantiated."""
provider = TranscriptionProviderBase()
assert provider is not None
+
+
+class TestShotgridAuthModes:
+ """Tests for auth mode selection helpers."""
+
+ def test_get_shotgrid_auth_mode_defaults_to_passwordless(self):
+ with mock.patch.dict("os.environ", {}, clear=True):
+ assert get_shotgrid_auth_mode() == AUTH_MODE_PASSWORDLESS
+
+ def test_get_shotgrid_auth_mode_rejects_invalid_mode(self):
+ with mock.patch.dict("os.environ", {"SHOTGRID_AUTH_MODE": "invalid-mode"}):
+ with pytest.raises(ValueError, match="Invalid SHOTGRID_AUTH_MODE"):
+ get_shotgrid_auth_mode()
+
+ def test_authenticate_user_passwordless(self):
+ with mock.patch.dict(
+ "os.environ",
+ {"PRODTRACK_PROVIDER": "shotgrid", "SHOTGRID_AUTH_MODE": "passwordless"},
+ ):
+ with mock.patch(
+ "dna.prodtrack_providers.shotgrid_auth.ShotgridAuthenticationProvider.authenticate_passwordless"
+ ) as mock_auth:
+ mock_auth.return_value = {"token": None, "email": "test@example.com"}
+ result = authenticate_user("test@example.com")
+
+ assert result["email"] == "test@example.com"
+ assert result["token"] is None
+ assert result["mode"] == "passwordless"
+ mock_auth.assert_called_once_with("test@example.com")
+
+ def test_authenticate_user_self_hosted(self):
+ with mock.patch.dict(
+ "os.environ",
+ {"PRODTRACK_PROVIDER": "shotgrid", "SHOTGRID_AUTH_MODE": "self_hosted"},
+ ):
+ with mock.patch(
+ "dna.prodtrack_providers.shotgrid_auth.ShotgridAuthenticationProvider.authenticate"
+ ) as mock_auth:
+ mock_auth.return_value = {
+ "token": "session-token",
+ "email": "test@example.com",
+ }
+ result = authenticate_user("testuser", "password")
+
+ assert result["email"] == "test@example.com"
+ assert result["token"] == "session-token"
+ assert result["mode"] == "self_hosted"
+ mock_auth.assert_called_once_with("testuser", "password")
+
+ def test_authenticate_user_sso_not_implemented(self):
+ with mock.patch.dict(
+ "os.environ",
+ {"PRODTRACK_PROVIDER": "shotgrid", "SHOTGRID_AUTH_MODE": "sso"},
+ ):
+ with pytest.raises(AuthModeNotImplementedError, match="not implemented"):
+ authenticate_user("testuser", "password")
diff --git a/backend/tests/test_shotgrid_provider.py b/backend/tests/test_shotgrid_provider.py
index 89e90b95..a27a504b 100644
--- a/backend/tests/test_shotgrid_provider.py
+++ b/backend/tests/test_shotgrid_provider.py
@@ -70,6 +70,37 @@ def test_init_with_sudo_user(self, mock_shotgun):
)
assert provider.sudo_user == "admin"
+ def test_init_with_session_token(self, mock_shotgun):
+ """Test that __init__ supports session token authentication."""
+ with mock.patch.dict(
+ os.environ,
+ {
+ "SHOTGRID_URL": "https://test.shotgunstudio.com",
+ "SHOTGRID_SCRIPT_NAME": "test_script",
+ "SHOTGRID_API_KEY": "test_key",
+ },
+ ):
+ provider = ShotgridProvider(session_token="session-token")
+ mock_shotgun.assert_called_once_with(
+ "https://test.shotgunstudio.com",
+ session_token="session-token",
+ )
+ assert provider.session_token == "session-token"
+
+ def test_set_sudo_user_raises_with_session_token(self, mock_shotgun):
+ """Test sudo override is rejected in session token mode."""
+ with mock.patch.dict(
+ os.environ,
+ {
+ "SHOTGRID_URL": "https://test.shotgunstudio.com",
+ "SHOTGRID_SCRIPT_NAME": "test_script",
+ "SHOTGRID_API_KEY": "test_key",
+ },
+ ):
+ provider = ShotgridProvider(session_token="session-token")
+ with pytest.raises(ValueError, match="session token"):
+ provider.set_sudo_user("admin")
+
def test_set_sudo_user_reconnects(self, provider, mock_shotgun):
"""Test that set_sudo_user updates sudo_user and reconnects."""
# Reset mock to clear init call
diff --git a/frontend/packages/app/src/App.test.tsx b/frontend/packages/app/src/App.test.tsx
index da9603ca..b0d889f8 100644
--- a/frontend/packages/app/src/App.test.tsx
+++ b/frontend/packages/app/src/App.test.tsx
@@ -11,6 +11,6 @@ describe('App', () => {
it('should render the project selector initially', () => {
render();
expect(screen.getByText('Welcome to DNA')).toBeInTheDocument();
- expect(screen.getByPlaceholderText('you@example.com')).toBeInTheDocument();
+ expect(screen.getByPlaceholderText('username')).toBeInTheDocument();
});
});
diff --git a/frontend/packages/app/src/components/ProjectSelector.tsx b/frontend/packages/app/src/components/ProjectSelector.tsx
index e1437e3d..dee9db06 100644
--- a/frontend/packages/app/src/components/ProjectSelector.tsx
+++ b/frontend/packages/app/src/components/ProjectSelector.tsx
@@ -2,7 +2,11 @@ import { useState, useEffect } from 'react';
import styled from 'styled-components';
import { Button, Flex, Select, Spinner } from '@radix-ui/themes';
import { Playlist, Project } from '@dna/core';
-import { useGetProjectsForUser, useGetPlaylistsForProject } from '../api';
+import {
+ useGetProjectsForUser,
+ useGetPlaylistsForProject,
+ apiHandler,
+} from '../api';
import { Logo } from './Logo';
import {
StyledTextField,
@@ -13,8 +17,12 @@ import {
export const STORAGE_KEYS = {
USER_EMAIL: 'dna_user_email',
PROJECT: 'dna_selected_project',
+ TOKEN: 'dna_user_token',
+ AUTH_MODE: 'dna_auth_mode',
};
+type AuthMode = 'passwordless' | 'self_hosted' | 'sso';
+
interface ProjectSelectorProps {
onSelectionComplete: (
project: Project,
@@ -208,14 +216,71 @@ function clearStoredProject(): void {
}
}
+function getStoredToken(): string | null {
+ try {
+ return localStorage.getItem(STORAGE_KEYS.TOKEN);
+ } catch {
+ return null;
+ }
+}
+
+function saveToken(token: string): void {
+ try {
+ localStorage.setItem(STORAGE_KEYS.TOKEN, token);
+ } catch {
+ // Ignore storage errors
+ }
+}
+
+function clearStoredToken(): void {
+ try {
+ localStorage.removeItem(STORAGE_KEYS.TOKEN);
+ } catch {
+ // Ignore storage errors
+ }
+}
+
+function getStoredAuthMode(): AuthMode | null {
+ try {
+ const mode = localStorage.getItem(STORAGE_KEYS.AUTH_MODE);
+ if (mode === 'passwordless' || mode === 'self_hosted' || mode === 'sso') {
+ return mode;
+ }
+ return null;
+ } catch {
+ return null;
+ }
+}
+
+function saveAuthMode(mode: AuthMode): void {
+ try {
+ localStorage.setItem(STORAGE_KEYS.AUTH_MODE, mode);
+ } catch {
+ // Ignore storage errors
+ }
+}
+
+function clearStoredAuthMode(): void {
+ try {
+ localStorage.removeItem(STORAGE_KEYS.AUTH_MODE);
+ } catch {
+ // Ignore storage errors
+ }
+}
+
export function clearUserSession(): void {
clearStoredEmail();
clearStoredProject();
+ clearStoredToken();
+ clearStoredAuthMode();
}
export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) {
const [step, setStep] = useState('loading');
const [email, setEmail] = useState('');
+ const [password, setPassword] = useState('');
+ const [loginError, setLoginError] = useState(null);
+ const [isLoggingIn, setIsLoggingIn] = useState(false);
const [submittedEmail, setSubmittedEmail] = useState(null);
const [selectedProject, setSelectedProject] = useState(null);
const [selectedPlaylistId, setSelectedPlaylistId] = useState('');
@@ -223,13 +288,24 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) {
useEffect(() => {
const storedEmail = getStoredEmail();
const storedProject = getStoredProject();
+ const storedToken = getStoredToken();
+ const storedAuthMode = getStoredAuthMode();
+
+ if (storedToken && storedEmail) {
+ apiHandler.setUser({ id: '0', email: storedEmail, token: storedToken });
+ } else {
+ apiHandler.setUser(null);
+ }
- if (storedEmail && storedProject) {
+ const tokenRequired = storedAuthMode !== 'passwordless';
+ const hasValidAuth = storedToken || !tokenRequired;
+
+ if (storedEmail && storedProject && hasValidAuth) {
setSubmittedEmail(storedEmail);
setEmail(storedEmail);
setSelectedProject(storedProject);
setStep('playlist');
- } else if (storedEmail) {
+ } else if (storedEmail && hasValidAuth) {
setSubmittedEmail(storedEmail);
setEmail(storedEmail);
setStep('project');
@@ -252,13 +328,44 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) {
error: playlistsError,
} = useGetPlaylistsForProject(selectedProject?.id ?? null);
- const handleEmailSubmit = (e: React.FormEvent) => {
+ const handleLogin = async (e: React.FormEvent) => {
e.preventDefault();
- if (email.trim()) {
- const trimmedEmail = email.trim();
- setSubmittedEmail(trimmedEmail);
- saveEmail(trimmedEmail);
- setStep('project');
+ setLoginError(null);
+ setIsLoggingIn(true);
+
+ try {
+ if (email.trim()) {
+ const {
+ token,
+ email: userEmail,
+ mode,
+ } = await apiHandler.login({
+ username: email.trim(),
+ password: password.trim() || undefined,
+ });
+
+ const authMode = mode || (token ? 'self_hosted' : 'passwordless');
+ saveAuthMode(authMode);
+ saveEmail(userEmail);
+
+ if (token) {
+ apiHandler.setUser({ id: '0', email: userEmail, token });
+ saveToken(token);
+ } else {
+ apiHandler.setUser(null);
+ clearStoredToken();
+ }
+
+ setSubmittedEmail(userEmail);
+ setStep('project');
+ }
+ } catch (err: any) {
+ setLoginError(
+ err.response?.data?.detail ||
+ 'Login failed. Please check your credentials.'
+ );
+ } finally {
+ setIsLoggingIn(false);
}
};
@@ -290,6 +397,11 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) {
const handleBackToEmail = () => {
clearStoredEmail();
clearStoredProject();
+ clearStoredToken();
+ clearStoredAuthMode();
+ apiHandler.setUser(null);
+ setPassword('');
+ setLoginError(null);
setSubmittedEmail(null);
setSelectedProject(null);
setSelectedPlaylistId('');
@@ -330,26 +442,38 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) {
Dailies Notes Assistant
{step === 'email' && (
-
+
-
+
setEmail(e.target.value)}
required
/>
+
+ setPassword(e.target.value)}
+ />
+
+ {loginError && {loginError}}
+
)}
diff --git a/frontend/packages/core/src/apiHandler.test.ts b/frontend/packages/core/src/apiHandler.test.ts
index 39b1efb9..4dc66c6d 100644
--- a/frontend/packages/core/src/apiHandler.test.ts
+++ b/frontend/packages/core/src/apiHandler.test.ts
@@ -675,4 +675,47 @@ describe('ApiHandler', () => {
).rejects.toThrow('Server error');
});
});
+
+ describe('login', () => {
+ it('should post credentials to auth endpoint', async () => {
+ const api = createApiHandler({ baseURL: 'http://localhost:8000' });
+ const authResponse = {
+ token: 'session-token',
+ email: 'user@example.com',
+ mode: 'self_hosted',
+ };
+ mockAxiosInstance.post.mockResolvedValue({ data: authResponse });
+
+ const result = await api.login({
+ username: 'user@example.com',
+ password: 'secret',
+ });
+
+ expect(mockAxiosInstance.post).toHaveBeenCalledWith(
+ '/auth/login',
+ { username: 'user@example.com', password: 'secret' },
+ undefined
+ );
+ expect(result).toEqual(authResponse);
+ });
+
+ it('should support passwordless login request', async () => {
+ const api = createApiHandler({ baseURL: 'http://localhost:8000' });
+ const authResponse = {
+ token: null,
+ email: 'user@example.com',
+ mode: 'passwordless',
+ };
+ mockAxiosInstance.post.mockResolvedValue({ data: authResponse });
+
+ const result = await api.login({ username: 'user@example.com' });
+
+ expect(mockAxiosInstance.post).toHaveBeenCalledWith(
+ '/auth/login',
+ { username: 'user@example.com' },
+ undefined
+ );
+ expect(result).toEqual(authResponse);
+ });
+ });
});
diff --git a/frontend/packages/core/src/apiHandler.ts b/frontend/packages/core/src/apiHandler.ts
index 667bba57..5f53974b 100644
--- a/frontend/packages/core/src/apiHandler.ts
+++ b/frontend/packages/core/src/apiHandler.ts
@@ -4,6 +4,8 @@ import {
GetPlaylistsForProjectParams,
GetVersionsForPlaylistParams,
GetUserByEmailParams,
+ LoginParams,
+ AuthResponse,
GetDraftNoteParams,
GetAllDraftNotesParams,
UpsertDraftNoteParams,
@@ -146,6 +148,10 @@ class ApiHandler {
return this.get(`/users/${encodeURIComponent(params.userEmail)}`);
}
+ async login(params: LoginParams): Promise {
+ return this.post('/auth/login', params);
+ }
+
async getDraftNote(params: GetDraftNoteParams): Promise {
return this.get(
`/playlists/${params.playlistId}/versions/${params.versionId}/draft-notes/${encodeURIComponent(params.userEmail)}`
@@ -262,7 +268,9 @@ class ApiHandler {
return this.get(`/playlists/${playlistId}/draft-notes`);
}
- async publishNotes(params: PublishNotesParams): Promise {
+ async publishNotes(
+ params: PublishNotesParams
+ ): Promise {
return this.post(
`/playlists/${params.playlistId}/publish-notes`,
params.request
diff --git a/frontend/packages/core/src/interfaces.ts b/frontend/packages/core/src/interfaces.ts
index 5e2c88ef..a9901da9 100644
--- a/frontend/packages/core/src/interfaces.ts
+++ b/frontend/packages/core/src/interfaces.ts
@@ -141,6 +141,17 @@ export interface GetUserByEmailParams {
userEmail: string;
}
+export interface LoginParams {
+ username: string;
+ password?: string;
+}
+
+export interface AuthResponse {
+ token?: string | null;
+ email: string;
+ mode?: 'passwordless' | 'self_hosted' | 'sso';
+}
+
export interface DraftNoteLink {
entity_type: string;
entity_id: number;