From 0d78417ce18d2a24fc6d5c4b0a3822614ff8f5bb Mon Sep 17 00:00:00 2001 From: aviralgarg05 Date: Tue, 20 Jan 2026 00:18:29 +0530 Subject: [PATCH 1/4] feat: Implement Shotgrid authentication with username/password Signed-off-by: aviralgarg05 --- .../prodtrack_provider_base.py | 13 ++- .../src/dna/prodtrack_providers/shotgrid.py | 82 +++++++++++++++- backend/src/main.py | 76 +++++++++++++-- .../app/src/components/ProjectSelector.tsx | 94 ++++++++++++++++--- frontend/packages/core/src/apiHandler.ts | 5 + frontend/packages/core/src/interfaces.ts | 10 ++ 6 files changed, 253 insertions(+), 27 deletions(-) diff --git a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py index d68b69d2..adda161c 100644 --- a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py +++ b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py @@ -51,6 +51,10 @@ def get_user_by_email(self, user_email: str) -> "User": """ raise NotImplementedError("Subclasses must implement this method.") + def get_user_by_login(self, login: str) -> "User": + """Get a user by their login/username.""" + raise NotImplementedError("Subclasses must implement this method.") + def get_projects_for_user(self, user_email: str) -> list["Project"]: """Get projects accessible by a user. @@ -84,12 +88,17 @@ def get_versions_for_playlist(self, playlist_id: int) -> list["Version"]: """ raise NotImplementedError("Subclasses must implement this method.") + @staticmethod + def authenticate_user(url: str, login: str, password: str) -> str: + """Authenticate a user and return a session token.""" + raise NotImplementedError("Subclasses must implement this method.") + -def get_prodtrack_provider() -> ProdtrackProviderBase: +def get_prodtrack_provider(session_token: str | None = None) -> ProdtrackProviderBase: """Get the production tracking provider.""" from dna.prodtrack_providers.shotgrid import ShotgridProvider provider_type = os.getenv("PRODTRACK_PROVIDER", "shotgrid") if provider_type == "shotgrid": - return ShotgridProvider() + return ShotgridProvider(session_token=session_token) raise ValueError(f"Unknown production tracking provider: {provider_type}") diff --git a/backend/src/dna/prodtrack_providers/shotgrid.py b/backend/src/dna/prodtrack_providers/shotgrid.py index a62237f2..6a0e772f 100644 --- a/backend/src/dna/prodtrack_providers/shotgrid.py +++ b/backend/src/dna/prodtrack_providers/shotgrid.py @@ -122,6 +122,7 @@ def __init__( url: Optional[str] = None, script_name: Optional[str] = None, api_key: Optional[str] = None, + session_token: Optional[str] = None, connect: bool = True, ): """Initialize the ShotGrid connection. @@ -130,17 +131,22 @@ def __init__( url: ShotGrid server URL. Defaults to SHOTGRID_URL env var. script_name: API script name. Defaults to SHOTGRID_SCRIPT_NAME env var. api_key: API key for authentication. Defaults to SHOTGRID_API_KEY env var. + session_token: Session token for user authentication. """ super().__init__() self.url = url or os.getenv("SHOTGRID_URL") self.script_name = script_name or os.getenv("SHOTGRID_SCRIPT_NAME") self.api_key = api_key or os.getenv("SHOTGRID_API_KEY") + self.session_token = session_token - if not all([self.url, self.script_name, self.api_key]): + if not self.url: + raise ValueError("ShotGrid URL not provided.") + + if not self.session_token and not (self.script_name and self.api_key): raise ValueError( - "ShotGrid credentials not provided. Set SHOTGRID_URL, " - "SHOTGRID_SCRIPT_NAME, and SHOTGRID_API_KEY environment variables." + "ShotGrid credentials not provided. Provide either session_token " + "or (script_name and api_key)." ) self.sg = None @@ -149,7 +155,11 @@ def __init__( def _connect(self): """Connect to ShotGrid.""" - self.sg = Shotgun(self.url, self.script_name, self.api_key) + if self.session_token: + # When using session token, we don't use script credentials + self.sg = Shotgun(self.url, session_token=self.session_token) + else: + self.sg = Shotgun(self.url, self.script_name, self.api_key) def _convert_sg_entity_to_dna_entity( self, @@ -398,6 +408,35 @@ def get_user_by_email(self, user_email: str) -> User: sg_user, entity_mapping, "user", resolve_links=False ) + def get_user_by_login(self, login: str) -> User: + """Get a user by their login/username. + + Args: + login: The login/username of the user + + Returns: + User entity + + Raises: + ValueError: If user is not found + """ + if not self.sg: + raise ValueError("Not connected to ShotGrid") + + sg_user = self.sg.find_one( + "HumanUser", + filters=[["login", "is", login]], + fields=["id", "name", "email", "login"], + ) + + if not sg_user: + raise ValueError(f"User not found: {login}") + + entity_mapping = FIELD_MAPPING["user"] + return self._convert_sg_entity_to_dna_entity( + sg_user, entity_mapping, "user", resolve_links=False + ) + def get_projects_for_user(self, user_email: str) -> list[Project]: """Get projects accessible by a user. @@ -538,6 +577,41 @@ def get_versions_for_playlist(self, playlist_id: int) -> list[Version]: return versions + @staticmethod + def authenticate_user(url: str, login: str, password: str) -> str: + """Authenticate a user with ShotGrid and return a session token. + + Args: + url: ShotGrid server URL + login: User login/username + password: User password + + Returns: + Session token string + + Raises: + ValueError: If authentication fails + """ + try: + # Shotgun.authenticate_human_user returns the user object, but we need the session_token. + # However, the standard way to get a token is to just create a connection which validates creds. + # But wait, Shotgun API structure specifically for auth: + # We can use the simple authentication helper or instantiate to get token. + # Actually, standard shotgun_api3 doesn't easily expose 'authenticate_human_user' to get a token string directly + # without internals. + # Let's instantiate a connection to verify and get session_token if available or standard auth flow. + # The pattern usually is: sg = Shotgun(url, login=login, password=password) then sg.get_session_token(). + + # Note: shotgun_api3 v3.3.0+ supports `login` and `password` in constructor for script-based auth, + # but for human user relying on session token: + + sg = Shotgun(url, login=login, password=password) + # This establishes connection. Now implementation detail: how to get the token? + # The 'get_session_token()' method provides it. + return sg.get_session_token() + except Exception as e: + raise ValueError(f"Authentication failed: {str(e)}") + def _get_dna_entity_type(sg_entity_type: str) -> str: """Get the DNA entity type from the ShotGrid entity type.""" for entity_type, entity_data in FIELD_MAPPING.items(): diff --git a/backend/src/main.py b/backend/src/main.py index 57678f4c..99e4e001 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -3,8 +3,9 @@ from functools import lru_cache from typing import Annotated, cast -from fastapi import Depends, FastAPI, HTTPException +from fastapi import Depends, FastAPI, HTTPException, Header from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel from dna.models import ( Asset, @@ -23,6 +24,13 @@ ProdtrackProviderBase, get_prodtrack_provider, ) +from dna.prodtrack_providers.shotgrid import ShotgridProvider + + +class LoginRequest(BaseModel): + username: str + password: str + # API metadata for Swagger documentation API_TITLE = "DNA Backend" @@ -47,6 +55,10 @@ # Define API tags for organizing endpoints tags_metadata = [ + { + "name": "Auth", + "description": "Authentication endpoints", + }, { "name": "Health", "description": "Health check and status endpoints", @@ -126,17 +138,69 @@ # ----------------------------------------------------------------------------- -@lru_cache -def get_prodtrack_provider_cached() -> ProdtrackProviderBase: - """Get or create the production tracking provider singleton.""" - return get_prodtrack_provider() +def get_token_header(authorization: Annotated[str | None, Header()] = None) -> str | None: + """Extract token from Authorization header.""" + if not authorization: + return None + if authorization.startswith("Bearer "): + return authorization.split(" ")[1] + return authorization + + +def get_prodtrack_provider_dep( + token: Annotated[str | None, Depends(get_token_header)], +) -> ProdtrackProviderBase: + """Get the production tracking provider with user session.""" + return get_prodtrack_provider(session_token=token) ProdtrackProviderDep = Annotated[ - ProdtrackProviderBase, Depends(get_prodtrack_provider_cached) + ProdtrackProviderBase, Depends(get_prodtrack_provider_dep) ] +# ----------------------------------------------------------------------------- +# Auth endpoints +# ----------------------------------------------------------------------------- + + +@app.post( + "/auth/login", + tags=["Auth"], + summary="Login to Production Tracking", + description="Authenticate with username and password to get a session token.", +) +async def login(request: LoginRequest): + """Login to ShotGrid.""" + try: + # We need a provider instance to access the static method if we want to keep it clean, + # or just import the class. We imported ShotgridProvider above. + # But we need the URL from the environment or default provider. + # Let's instantiate a default provider to get config, or just use the class method + # and assume env vars are set for URL if not passed? + # The static method requires URL. + + # Helper to get base URL + import os + url = os.getenv("SHOTGRID_URL") + if not url: + raise HTTPException(status_code=500, detail="SHOTGRID_URL not configured") + + token = ShotgridProvider.authenticate_user(url, request.username, request.password) + + # Create a provider with this token to fetch the user details (email) + provider = ShotgridProvider(url=url, session_token=token) + user = provider.get_user_by_login(request.username) + + if not user.email: + raise HTTPException(status_code=400, detail="User has no email address configured") + + return {"token": token, "email": user.email} + except ValueError as e: + raise HTTPException(status_code=401, detail=str(e)) + + + # ----------------------------------------------------------------------------- # Health endpoints # ----------------------------------------------------------------------------- diff --git a/frontend/packages/app/src/components/ProjectSelector.tsx b/frontend/packages/app/src/components/ProjectSelector.tsx index a3ba6fa2..c932549d 100644 --- a/frontend/packages/app/src/components/ProjectSelector.tsx +++ b/frontend/packages/app/src/components/ProjectSelector.tsx @@ -2,7 +2,7 @@ import { useState, useEffect } from 'react'; import styled from 'styled-components'; import { Button, Flex, Select, Spinner } from '@radix-ui/themes'; import { Playlist, Project } from '@dna/core'; -import { useGetProjectsForUser, useGetPlaylistsForProject } from '../api'; +import { useGetProjectsForUser, useGetPlaylistsForProject, apiHandler } from '../api'; import { Logo } from './Logo'; import { StyledTextField, @@ -13,6 +13,7 @@ import { export const STORAGE_KEYS = { USER_EMAIL: 'dna_user_email', PROJECT: 'dna_selected_project', + TOKEN: 'dna_user_token', }; interface ProjectSelectorProps { @@ -207,14 +208,42 @@ function clearStoredProject(): void { } } +function getStoredToken(): string | null { + try { + return localStorage.getItem(STORAGE_KEYS.TOKEN); + } catch { + return null; + } +} + +function saveToken(token: string): void { + try { + localStorage.setItem(STORAGE_KEYS.TOKEN, token); + } catch { + // Ignore storage errors + } +} + +function clearStoredToken(): void { + try { + localStorage.removeItem(STORAGE_KEYS.TOKEN); + } catch { + // Ignore storage errors + } +} + export function clearUserSession(): void { clearStoredEmail(); clearStoredProject(); + clearStoredToken(); } export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) { const [step, setStep] = useState('loading'); const [email, setEmail] = useState(''); + const [password, setPassword] = useState(''); + const [loginError, setLoginError] = useState(null); + const [isLoggingIn, setIsLoggingIn] = useState(false); const [submittedEmail, setSubmittedEmail] = useState(null); const [selectedProject, setSelectedProject] = useState(null); const [selectedPlaylistId, setSelectedPlaylistId] = useState(''); @@ -222,13 +251,19 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) { useEffect(() => { const storedEmail = getStoredEmail(); const storedProject = getStoredProject(); + const storedToken = getStoredToken(); - if (storedEmail && storedProject) { + // Restore session if token exists + if (storedToken && storedEmail) { + apiHandler.setUser({ id: '0', email: storedEmail, token: storedToken }); + } + + if (storedEmail && storedProject && storedToken) { setSubmittedEmail(storedEmail); setEmail(storedEmail); setSelectedProject(storedProject); setStep('playlist'); - } else if (storedEmail) { + } else if (storedEmail && storedToken) { setSubmittedEmail(storedEmail); setEmail(storedEmail); setStep('project'); @@ -251,13 +286,29 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) { error: playlistsError, } = useGetPlaylistsForProject(selectedProject?.id ?? null); - const handleEmailSubmit = (e: React.FormEvent) => { + const handleLogin = async (e: React.FormEvent) => { e.preventDefault(); - if (email.trim()) { - const trimmedEmail = email.trim(); - setSubmittedEmail(trimmedEmail); - saveEmail(trimmedEmail); - setStep('project'); + setLoginError(null); + setIsLoggingIn(true); + + try { + if (email.trim() && password.trim()) { + const { token, email: userEmail } = await apiHandler.login({ + username: email.trim(), + password: password.trim(), + }); + + apiHandler.setUser({ id: '0', email: userEmail, token }); // ID is not returned by login yet, using stub or need to update user interface/login response + + setSubmittedEmail(userEmail); + saveEmail(userEmail); + saveToken(token); + setStep('project'); + } + } catch (err: any) { + setLoginError(err.response?.data?.detail || 'Login failed. Please check your credentials.'); + } finally { + setIsLoggingIn(false); } }; @@ -329,26 +380,39 @@ export function ProjectSelector({ onSelectionComplete }: ProjectSelectorProps) { Dailies Notes Assistant {step === 'email' && ( - + - + setEmail(e.target.value)} required /> + + setPassword(e.target.value)} + required + /> + + {loginError && {loginError}} + )} diff --git a/frontend/packages/core/src/apiHandler.ts b/frontend/packages/core/src/apiHandler.ts index 1dae7ac0..6e2e3609 100644 --- a/frontend/packages/core/src/apiHandler.ts +++ b/frontend/packages/core/src/apiHandler.ts @@ -4,6 +4,7 @@ import { GetPlaylistsForProjectParams, GetVersionsForPlaylistParams, GetUserByEmailParams, + LoginParams, Playlist, Project, User as DNAUser, @@ -103,6 +104,10 @@ class ApiHandler { async getUserByEmail(params: GetUserByEmailParams): Promise { return this.get(`/users/${encodeURIComponent(params.userEmail)}`); } + + async login(params: LoginParams): Promise<{ token: string }> { + return this.post<{ token: string }>('/auth/login', params); + } } export const createApiHandler = (config: ApiHandlerConfig): ApiHandler => { diff --git a/frontend/packages/core/src/interfaces.ts b/frontend/packages/core/src/interfaces.ts index d54726f2..aaf60d2d 100644 --- a/frontend/packages/core/src/interfaces.ts +++ b/frontend/packages/core/src/interfaces.ts @@ -140,3 +140,13 @@ export interface GetVersionsForPlaylistParams { export interface GetUserByEmailParams { userEmail: string; } + +export interface LoginParams { + username: string; + password?: string; +} + +export interface AuthResponse { + token: string; + email: string; +} From 3f5130a1803bc183ba2015bd2c92c16fdf81957f Mon Sep 17 00:00:00 2001 From: aviralgarg05 Date: Tue, 20 Jan 2026 02:14:37 +0530 Subject: [PATCH 2/4] feat: Refactor `authenticate_user` to accept only username and password, return a dictionary containing the session token and user email, and simplify the `main` login endpoint. Signed-off-by: aviralgarg05 --- .../prodtrack_provider_base.py | 4 +- .../src/dna/prodtrack_providers/shotgrid.py | 43 ++++++++++--------- backend/src/main.py | 26 +---------- 3 files changed, 26 insertions(+), 47 deletions(-) diff --git a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py index adda161c..25682130 100644 --- a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py +++ b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py @@ -89,8 +89,8 @@ def get_versions_for_playlist(self, playlist_id: int) -> list["Version"]: raise NotImplementedError("Subclasses must implement this method.") @staticmethod - def authenticate_user(url: str, login: str, password: str) -> str: - """Authenticate a user and return a session token.""" + def authenticate_user(username: str, password: str) -> dict[str, Any]: + """Authenticate a user and return a session token and user info.""" raise NotImplementedError("Subclasses must implement this method.") diff --git a/backend/src/dna/prodtrack_providers/shotgrid.py b/backend/src/dna/prodtrack_providers/shotgrid.py index 6a0e772f..19e9e79f 100644 --- a/backend/src/dna/prodtrack_providers/shotgrid.py +++ b/backend/src/dna/prodtrack_providers/shotgrid.py @@ -578,37 +578,38 @@ def get_versions_for_playlist(self, playlist_id: int) -> list[Version]: @staticmethod - def authenticate_user(url: str, login: str, password: str) -> str: - """Authenticate a user with ShotGrid and return a session token. + def authenticate_user(username: str, password: str) -> dict[str, Any]: + """Authenticate a user with ShotGrid and return session token and user info. Args: - url: ShotGrid server URL - login: User login/username + username: User login/username password: User password Returns: - Session token string + Dictionary containing 'token' and 'email' Raises: ValueError: If authentication fails """ + url = os.getenv("SHOTGRID_URL") + if not url: + raise ValueError("SHOTGRID_URL not configured") + try: - # Shotgun.authenticate_human_user returns the user object, but we need the session_token. - # However, the standard way to get a token is to just create a connection which validates creds. - # But wait, Shotgun API structure specifically for auth: - # We can use the simple authentication helper or instantiate to get token. - # Actually, standard shotgun_api3 doesn't easily expose 'authenticate_human_user' to get a token string directly - # without internals. - # Let's instantiate a connection to verify and get session_token if available or standard auth flow. - # The pattern usually is: sg = Shotgun(url, login=login, password=password) then sg.get_session_token(). - - # Note: shotgun_api3 v3.3.0+ supports `login` and `password` in constructor for script-based auth, - # but for human user relying on session token: - - sg = Shotgun(url, login=login, password=password) - # This establishes connection. Now implementation detail: how to get the token? - # The 'get_session_token()' method provides it. - return sg.get_session_token() + # Initialize connection to verify credentials and get token + sg = Shotgun(url, login=username, password=password) + token = sg.get_session_token() + + # Create a provider instance with the new token to fetch user details + # This reuses the existing entity mapping logic + provider = ShotgridProvider(url=url, session_token=token) + user = provider.get_user_by_login(username) + + if not user.email: + raise ValueError("User has no email address configured") + + return {"token": token, "email": user.email} + except Exception as e: raise ValueError(f"Authentication failed: {str(e)}") diff --git a/backend/src/main.py b/backend/src/main.py index 99e4e001..8e29bb75 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -1,6 +1,6 @@ """FastAPI application entry point.""" -from functools import lru_cache + from typing import Annotated, cast from fastapi import Depends, FastAPI, HTTPException, Header @@ -173,29 +173,7 @@ def get_prodtrack_provider_dep( async def login(request: LoginRequest): """Login to ShotGrid.""" try: - # We need a provider instance to access the static method if we want to keep it clean, - # or just import the class. We imported ShotgridProvider above. - # But we need the URL from the environment or default provider. - # Let's instantiate a default provider to get config, or just use the class method - # and assume env vars are set for URL if not passed? - # The static method requires URL. - - # Helper to get base URL - import os - url = os.getenv("SHOTGRID_URL") - if not url: - raise HTTPException(status_code=500, detail="SHOTGRID_URL not configured") - - token = ShotgridProvider.authenticate_user(url, request.username, request.password) - - # Create a provider with this token to fetch the user details (email) - provider = ShotgridProvider(url=url, session_token=token) - user = provider.get_user_by_login(request.username) - - if not user.email: - raise HTTPException(status_code=400, detail="User has no email address configured") - - return {"token": token, "email": user.email} + return ShotgridProvider.authenticate_user(request.username, request.password) except ValueError as e: raise HTTPException(status_code=401, detail=str(e)) From a8dc9d4d26e3634db08960cef165ffa4a646a751 Mon Sep 17 00:00:00 2001 From: aviralgarg05 Date: Tue, 20 Jan 2026 02:17:19 +0530 Subject: [PATCH 3/4] test: Update tests to match refactored auth logic and fix broken imports Signed-off-by: aviralgarg05 --- backend/tests/test_main.py | 82 ++++++++++++++++++++++++++------------ 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index c7fb3ed1..7ab4fbf8 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -4,7 +4,7 @@ import pytest from fastapi.testclient import TestClient -from main import app, get_prodtrack_provider_cached +from main import app, get_prodtrack_provider_dep client = TestClient(app) @@ -46,7 +46,7 @@ def test_create_note_returns_201(self, mock_provider): project={"type": "Project", "id": 85}, ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -75,7 +75,7 @@ def test_create_note_with_links(self, mock_provider): project={"type": "Project", "id": 85}, ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -101,7 +101,7 @@ def test_create_note_with_links(self, mock_provider): def test_create_note_missing_project_returns_422(self, mock_provider): """Test that missing required project field returns 422.""" - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -133,7 +133,7 @@ def test_find_returns_200_with_results(self, mock_provider): Project(id=2, name="Project Two"), ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -159,7 +159,7 @@ def test_find_with_filters(self, mock_provider): mock_provider.find.return_value = [Shot(id=100, name="shot_010")] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -188,7 +188,7 @@ def test_find_with_uppercase_entity_type(self, mock_provider): mock_provider.find.return_value = [Project(id=1, name="Test")] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -206,7 +206,7 @@ def test_find_with_uppercase_entity_type(self, mock_provider): def test_find_unsupported_entity_type_returns_400(self, mock_provider): """Test that find returns 400 for unsupported entity types.""" - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -227,7 +227,7 @@ def test_find_returns_empty_list_when_no_results(self, mock_provider): """Test that find returns empty list when no entities match.""" mock_provider.find.return_value = [] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -249,7 +249,7 @@ def test_find_provider_error_returns_400(self, mock_provider): """Test that find returns 400 when provider raises ValueError.""" mock_provider.find.side_effect = ValueError("Unknown field 'bad_field'") - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -269,7 +269,7 @@ def test_find_provider_error_returns_400(self, mock_provider): def test_find_missing_entity_type_returns_422(self, mock_provider): """Test that find returns 422 when entity_type is missing.""" - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -288,7 +288,7 @@ def test_find_with_multiple_filters(self, mock_provider): mock_provider.find.return_value = [Version(id=1, name="v001", status="apr")] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -315,7 +315,7 @@ def test_find_default_filters_is_empty_list(self, mock_provider): mock_provider.find.return_value = [Project(id=1, name="Test")] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.post( @@ -394,7 +394,7 @@ def test_get_projects_for_user_returns_200_with_results(self, mock_provider): Project(id=2, name="Project Two"), ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/user/testuser") @@ -416,7 +416,7 @@ def test_get_projects_for_user_calls_provider_with_email(self, mock_provider): Project(id=1, name="Test Project") ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: client.get("/projects/user/jsmith@example.com") @@ -430,7 +430,7 @@ def test_get_projects_for_user_returns_empty_list(self, mock_provider): """Test that get_projects_for_user returns empty list when user has no projects.""" mock_provider.get_projects_for_user.return_value = [] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/user/newuser") @@ -446,7 +446,7 @@ def test_get_projects_for_user_returns_404_when_user_not_found(self, mock_provid "User not found: unknownuser" ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/user/unknownuser") @@ -474,7 +474,7 @@ def test_get_playlists_for_project_returns_200_with_results(self, mock_provider) Playlist(id=2, code="Final Review"), ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/42/playlists") @@ -498,7 +498,7 @@ def test_get_playlists_for_project_calls_provider_with_project_id( Playlist(id=1, code="Test Playlist") ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: client.get("/projects/123/playlists") @@ -510,7 +510,7 @@ def test_get_playlists_for_project_returns_empty_list(self, mock_provider): """Test that get_playlists_for_project returns empty list when no playlists.""" mock_provider.get_playlists_for_project.return_value = [] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/999/playlists") @@ -526,7 +526,7 @@ def test_get_playlists_for_project_returns_404_on_error(self, mock_provider): "Project not found" ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/projects/999/playlists") @@ -554,7 +554,7 @@ def test_get_versions_for_playlist_returns_200_with_results(self, mock_provider) Version(id=2, name="shot_020_v002", status="apr"), ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/playlists/42/versions") @@ -578,7 +578,7 @@ def test_get_versions_for_playlist_calls_provider_with_playlist_id( Version(id=1, name="v001") ] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: client.get("/playlists/123/versions") @@ -590,7 +590,7 @@ def test_get_versions_for_playlist_returns_empty_list(self, mock_provider): """Test that get_versions_for_playlist returns empty list when no versions.""" mock_provider.get_versions_for_playlist.return_value = [] - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/playlists/999/versions") @@ -606,7 +606,7 @@ def test_get_versions_for_playlist_returns_404_on_error(self, mock_provider): "Playlist not found" ) - app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider + app.dependency_overrides[get_prodtrack_provider_dep] = lambda: mock_provider try: response = client.get("/playlists/999/versions") @@ -615,3 +615,35 @@ def test_get_versions_for_playlist_returns_404_on_error(self, mock_provider): assert "Playlist not found" in data["detail"] finally: app.dependency_overrides.clear() + +class TestAuthEndpoints: + """Tests for authentication endpoints.""" + + def test_login_success(self): + """Test successful login returns token and email.""" + with mock.patch("main.ShotgridProvider.authenticate_user") as mock_auth: + mock_auth.return_value = {"token": "fake-token", "email": "test@example.com"} + + response = client.post( + "/auth/login", + json={"username": "testuser", "password": "password"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["token"] == "fake-token" + assert data["email"] == "test@example.com" + mock_auth.assert_called_once_with("testuser", "password") + + def test_login_failure(self): + """Test failed login returns 401.""" + with mock.patch("main.ShotgridProvider.authenticate_user") as mock_auth: + mock_auth.side_effect = ValueError("Invalid credentials") + + response = client.post( + "/auth/login", + json={"username": "testuser", "password": "wrong-password"}, + ) + + assert response.status_code == 401 + assert "Invalid credentials" in response.json()["detail"] From 3b04d3ed2f3dd6071baaecc4447d9567ebe24c8e Mon Sep 17 00:00:00 2001 From: aviralgarg05 Date: Tue, 20 Jan 2026 14:14:32 +0530 Subject: [PATCH 4/4] feat: Separate authentication logic into ShotgridAuthenticationProvider and decouple main.py Signed-off-by: aviralgarg05 --- .../prodtrack_provider_base.py | 17 ++++-- .../src/dna/prodtrack_providers/shotgrid.py | 36 ----------- .../dna/prodtrack_providers/shotgrid_auth.py | 59 +++++++++++++++++++ backend/src/main.py | 12 ++-- backend/tests/test_main.py | 18 +++--- 5 files changed, 89 insertions(+), 53 deletions(-) create mode 100644 backend/src/dna/prodtrack_providers/shotgrid_auth.py diff --git a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py index 25682130..83720a81 100644 --- a/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py +++ b/backend/src/dna/prodtrack_providers/prodtrack_provider_base.py @@ -88,10 +88,7 @@ def get_versions_for_playlist(self, playlist_id: int) -> list["Version"]: """ raise NotImplementedError("Subclasses must implement this method.") - @staticmethod - def authenticate_user(username: str, password: str) -> dict[str, Any]: - """Authenticate a user and return a session token and user info.""" - raise NotImplementedError("Subclasses must implement this method.") + # Removed authenticate_user static method as it is now handled by the module-level factory function below. def get_prodtrack_provider(session_token: str | None = None) -> ProdtrackProviderBase: @@ -102,3 +99,15 @@ def get_prodtrack_provider(session_token: str | None = None) -> ProdtrackProvide if provider_type == "shotgrid": return ShotgridProvider(session_token=session_token) raise ValueError(f"Unknown production tracking provider: {provider_type}") + + +def authenticate_user(username: str, password: str) -> dict[str, Any]: + """Authenticate a user using the configured provider.""" + provider_type = os.getenv("PRODTRACK_PROVIDER", "shotgrid") + + if provider_type == "shotgrid": + from dna.prodtrack_providers.shotgrid_auth import ShotgridAuthenticationProvider + + return ShotgridAuthenticationProvider.authenticate(username, password) + + raise ValueError(f"Unknown production tracking provider: {provider_type}") diff --git a/backend/src/dna/prodtrack_providers/shotgrid.py b/backend/src/dna/prodtrack_providers/shotgrid.py index 19e9e79f..e6c50f04 100644 --- a/backend/src/dna/prodtrack_providers/shotgrid.py +++ b/backend/src/dna/prodtrack_providers/shotgrid.py @@ -577,42 +577,6 @@ def get_versions_for_playlist(self, playlist_id: int) -> list[Version]: return versions - @staticmethod - def authenticate_user(username: str, password: str) -> dict[str, Any]: - """Authenticate a user with ShotGrid and return session token and user info. - - Args: - username: User login/username - password: User password - - Returns: - Dictionary containing 'token' and 'email' - - Raises: - ValueError: If authentication fails - """ - url = os.getenv("SHOTGRID_URL") - if not url: - raise ValueError("SHOTGRID_URL not configured") - - try: - # Initialize connection to verify credentials and get token - sg = Shotgun(url, login=username, password=password) - token = sg.get_session_token() - - # Create a provider instance with the new token to fetch user details - # This reuses the existing entity mapping logic - provider = ShotgridProvider(url=url, session_token=token) - user = provider.get_user_by_login(username) - - if not user.email: - raise ValueError("User has no email address configured") - - return {"token": token, "email": user.email} - - except Exception as e: - raise ValueError(f"Authentication failed: {str(e)}") - def _get_dna_entity_type(sg_entity_type: str) -> str: """Get the DNA entity type from the ShotGrid entity type.""" for entity_type, entity_data in FIELD_MAPPING.items(): diff --git a/backend/src/dna/prodtrack_providers/shotgrid_auth.py b/backend/src/dna/prodtrack_providers/shotgrid_auth.py new file mode 100644 index 00000000..93569441 --- /dev/null +++ b/backend/src/dna/prodtrack_providers/shotgrid_auth.py @@ -0,0 +1,59 @@ +"""ShotGrid authentication provider implementation.""" + +import os +from typing import Any, Optional + +from shotgun_api3 import Shotgun + + +class ShotgridAuthenticationProvider: + """Provider for ShotGrid authentication.""" + + @staticmethod + def authenticate( + username: str, password: str, provider_url: Optional[str] = None + ) -> dict[str, Any]: + """Authenticate a user with ShotGrid and return session token and user info. + + Args: + username: User login/username + password: User password + provider_url: Optional ShotGrid URL. If not provided, uses SHOTGRID_URL env var. + + Returns: + Dictionary containing 'token' and 'email' + + Raises: + ValueError: If authentication fails + """ + url = provider_url or os.getenv("SHOTGRID_URL") + if not url: + raise ValueError("SHOTGRID_URL not configured") + + try: + # Initialize connection to verify credentials and get token + sg = Shotgun(url, login=username, password=password) + token = sg.get_session_token() + + # We need to fetch the user details (email) + # We can use the same connection object 'sg' if it supports find/find_one after auth + # OR create a new connection with the token. + # Using the existing 'sg' instance is more efficient as it's already authenticated/connected (presumably). + # However, shotgun_api3 behaviour: 'sg' initialized with login/pass IS valid. + + # Implementation note: We need to find the user by login to get the email. + user_data = sg.find_one( + "HumanUser", filters=[["login", "is", username]], fields=["email"] + ) + + if not user_data: + raise ValueError(f"User not found: {username}") + + email = user_data.get("email") + if not email: + raise ValueError("User has no email address configured") + + return {"token": token, "email": email} + + except Exception as e: + raise ValueError(f"Authentication failed: {str(e)}") diff --git a/backend/src/main.py b/backend/src/main.py index 8e29bb75..eb2ddfb9 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -1,9 +1,8 @@ """FastAPI application entry point.""" - from typing import Annotated, cast -from fastapi import Depends, FastAPI, HTTPException, Header +from fastapi import Depends, FastAPI, Header, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel @@ -22,9 +21,9 @@ from dna.models.entity import ENTITY_MODELS, EntityBase from dna.prodtrack_providers.prodtrack_provider_base import ( ProdtrackProviderBase, + authenticate_user, get_prodtrack_provider, ) -from dna.prodtrack_providers.shotgrid import ShotgridProvider class LoginRequest(BaseModel): @@ -138,7 +137,9 @@ class LoginRequest(BaseModel): # ----------------------------------------------------------------------------- -def get_token_header(authorization: Annotated[str | None, Header()] = None) -> str | None: +def get_token_header( + authorization: Annotated[str | None, Header()] = None, +) -> str | None: """Extract token from Authorization header.""" if not authorization: return None @@ -173,12 +174,11 @@ def get_prodtrack_provider_dep( async def login(request: LoginRequest): """Login to ShotGrid.""" try: - return ShotgridProvider.authenticate_user(request.username, request.password) + return authenticate_user(request.username, request.password) except ValueError as e: raise HTTPException(status_code=401, detail=str(e)) - # ----------------------------------------------------------------------------- # Health endpoints # ----------------------------------------------------------------------------- diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 7ab4fbf8..8ee9aeb3 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -616,19 +616,23 @@ def test_get_versions_for_playlist_returns_404_on_error(self, mock_provider): finally: app.dependency_overrides.clear() + class TestAuthEndpoints: """Tests for authentication endpoints.""" def test_login_success(self): """Test successful login returns token and email.""" - with mock.patch("main.ShotgridProvider.authenticate_user") as mock_auth: - mock_auth.return_value = {"token": "fake-token", "email": "test@example.com"} - + with mock.patch("main.authenticate_user") as mock_auth: + mock_auth.return_value = { + "token": "fake-token", + "email": "test@example.com", + } + response = client.post( "/auth/login", json={"username": "testuser", "password": "password"}, ) - + assert response.status_code == 200 data = response.json() assert data["token"] == "fake-token" @@ -637,13 +641,13 @@ def test_login_success(self): def test_login_failure(self): """Test failed login returns 401.""" - with mock.patch("main.ShotgridProvider.authenticate_user") as mock_auth: + with mock.patch("main.authenticate_user") as mock_auth: mock_auth.side_effect = ValueError("Invalid credentials") - + response = client.post( "/auth/login", json={"username": "testuser", "password": "wrong-password"}, ) - + assert response.status_code == 401 assert "Invalid credentials" in response.json()["detail"]