Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/src/dna/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
FindRequest,
GenerateNoteRequest,
GenerateNoteResponse,
LoginRequest,
LoginResponse,
PublishNotesRequest,
PublishNotesResponse,
SearchRequest,
Expand Down Expand Up @@ -78,6 +80,8 @@
"GenerateNoteResponse",
"SearchRequest",
"SearchResult",
"LoginRequest",
"LoginResponse",
"PublishNotesRequest",
"PublishNotesResponse",
"DraftNote",
Expand Down
17 changes: 17 additions & 0 deletions backend/src/dna/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,20 @@ class PublishNotesResponse(BaseModel):
skipped_count: int
failed_count: int
total: int


class LoginRequest(BaseModel):
"""Request model for user authentication."""

username: str = Field(description="ShotGrid login name")
password: str = Field(description="ShotGrid password")


class LoginResponse(BaseModel):
"""Response model for successful authentication."""

user_id: int = Field(description="ShotGrid user ID")
login: str = Field(description="ShotGrid login name")
name: str = Field(description="User display name")
email: str = Field(description="User email address")
session_token: str = Field(description="Session token for authenticated requests")
13 changes: 13 additions & 0 deletions backend/src/dna/prodtrack_providers/prodtrack_provider_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,19 @@ def get_versions_for_playlist(self, playlist_id: int) -> list["Version"]:
"""
raise NotImplementedError("Subclasses must implement this method.")

def authenticate_user(self, login: str, password: str) -> dict | None:
"""Authenticate a user with login credentials.

Args:
login: The user's login name.
password: The user's password.

Returns:
A dict with user_id, login, name, email, and session_token
if authentication succeeded. None if authentication failed.
"""
raise NotImplementedError("Subclasses must implement this method.")

def publish_note(
self,
version_id: int,
Expand Down
38 changes: 38 additions & 0 deletions backend/src/dna/prodtrack_providers/shotgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,44 @@ def update_note(
print(f"Error updating note {note_id}: {e}")
return False

def authenticate_user(self, login: str, password: str) -> dict | None:
"""Authenticate a user with ShotGrid credentials.

Uses ShotGrid's authenticate_human_user to validate the credentials.
If successful, retrieves a session token.

Args:
login: The user's ShotGrid login name.
password: The user's password.

Returns:
A dict with user_id, login, name, email, and session_token
if authentication succeeded. None if authentication failed.
"""
if not self._sg:
raise ValueError("Not connected to ShotGrid")

user = self._sg.authenticate_human_user(login, password)
if not user:
return None

session_token = self._sg.get_session_token()

# Fetch full user details including email
sg_user = self._sg.find_one(
"HumanUser",
filters=[["id", "is", user["id"]]],
fields=["id", "name", "email", "login"],
)

return {
"user_id": user["id"],
"login": sg_user.get("login", login) if sg_user else login,
"name": user.get("name", ""),
"email": sg_user.get("email", "") if sg_user else "",
"session_token": session_token,
}

def publish_note(
self,
version_id: int,
Expand Down
37 changes: 37 additions & 0 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
FindRequest,
GenerateNoteRequest,
GenerateNoteResponse,
LoginRequest,
LoginResponse,
Note,
Platform,
Playlist,
Expand Down Expand Up @@ -138,6 +140,10 @@
"name": "User Settings",
"description": "Operations for managing user settings and preferences",
},
{
"name": "Authentication",
"description": "User authentication endpoints",
},
]

app = FastAPI(
Expand Down Expand Up @@ -512,6 +518,37 @@ async def search_entities(
raise HTTPException(status_code=400, detail=str(e))


# -----------------------------------------------------------------------------
# Authentication endpoints
# -----------------------------------------------------------------------------


@app.post(
"/auth/login",
tags=["Authentication"],
summary="Authenticate a user",
description="Authenticate a user with ShotGrid credentials and return a session token.",
response_model=LoginResponse,
)
async def login(request: LoginRequest, provider: ProdtrackProviderDep) -> LoginResponse:
"""Authenticate a user with username and password."""
try:
result = provider.authenticate_user(request.username, request.password)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e))

if result is None:
raise HTTPException(status_code=401, detail="Invalid username or password")

return LoginResponse(
user_id=result["user_id"],
login=result["login"],
name=result["name"],
email=result["email"],
session_token=result["session_token"],
)


# -----------------------------------------------------------------------------
# User endpoints
# -----------------------------------------------------------------------------
Expand Down
90 changes: 90 additions & 0 deletions backend/tests/providers/test_shotgrid_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,3 +1431,93 @@ def test_create_shallow_entity_for_non_playlist(self, shotgrid_provider):

assert result.id == 100
assert result.name == "shot_010"


# ============================================================================
# Authentication tests
# ============================================================================


class TestAuthenticateUser:
"""Tests for the authenticate_user method."""

@pytest.fixture
def shotgrid_provider(self):
sg_provider = ShotgridProvider(connect=False)
mock_sg = mock.MagicMock()
sg_provider.sg = mock_sg
return sg_provider

def test_authenticate_user_success(self, shotgrid_provider):
"""Test authenticate_user returns user info and session token."""
shotgrid_provider.sg.authenticate_human_user.return_value = {
"type": "HumanUser",
"id": 42,
"name": "John Smith",
}
shotgrid_provider.sg.get_session_token.return_value = "session_abc123"
shotgrid_provider.sg.find_one.return_value = {
"id": 42,
"name": "John Smith",
"email": "jsmith@example.com",
"login": "jsmith",
}

result = shotgrid_provider.authenticate_user("jsmith", "secret")

assert result is not None
assert result["user_id"] == 42
assert result["login"] == "jsmith"
assert result["name"] == "John Smith"
assert result["email"] == "jsmith@example.com"
assert result["session_token"] == "session_abc123"

shotgrid_provider.sg.authenticate_human_user.assert_called_once_with(
"jsmith", "secret"
)
shotgrid_provider.sg.get_session_token.assert_called_once()

def test_authenticate_user_invalid_credentials(self, shotgrid_provider):
"""Test authenticate_user returns None for invalid credentials."""
shotgrid_provider.sg.authenticate_human_user.return_value = None

result = shotgrid_provider.authenticate_user("baduser", "badpass")

assert result is None
shotgrid_provider.sg.get_session_token.assert_not_called()

def test_authenticate_user_not_connected(self):
"""Test authenticate_user raises error when not connected."""
provider = ShotgridProvider(
url="https://test.shotgunstudio.com",
script_name="test_script",
api_key="test_key",
connect=False,
)
with pytest.raises(ValueError, match="Not connected to ShotGrid"):
provider.authenticate_user("jsmith", "secret")

def test_authenticate_user_base_raises_not_implemented(self):
"""Test that ProdtrackProviderBase.authenticate_user raises NotImplementedError."""
provider = ProdtrackProviderBase()
with pytest.raises(NotImplementedError, match="Subclasses must implement"):
provider.authenticate_user("user", "pass")

def test_authenticate_user_without_full_user_details(self, shotgrid_provider):
"""Test authenticate_user handles missing user details gracefully."""
shotgrid_provider.sg.authenticate_human_user.return_value = {
"type": "HumanUser",
"id": 99,
"name": "Minimal User",
}
shotgrid_provider.sg.get_session_token.return_value = "token_xyz"
shotgrid_provider.sg.find_one.return_value = None

result = shotgrid_provider.authenticate_user("minuser", "pass123")

assert result is not None
assert result["user_id"] == 99
assert result["login"] == "minuser"
assert result["name"] == "Minimal User"
assert result["email"] == ""
assert result["session_token"] == "token_xyz"
138 changes: 138 additions & 0 deletions backend/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,3 +1114,141 @@ def test_generate_note_returns_400_on_error(
assert "DB Error" in data["detail"]
finally:
app.dependency_overrides.clear()


class TestAuthLoginEndpoint:
"""Tests for POST /auth/login endpoint."""

@pytest.fixture
def mock_provider(self):
"""Create a mock ShotGrid provider."""
return mock.MagicMock()

def test_login_returns_200_with_valid_credentials(self, mock_provider):
"""Test that login returns 200 with valid credentials."""
mock_provider.authenticate_user.return_value = {
"user_id": 42,
"login": "jsmith",
"name": "John Smith",
"email": "jsmith@example.com",
"session_token": "abc123token",
}

app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider

try:
response = client.post(
"/auth/login",
json={
"username": "jsmith",
"password": "secret",
},
)
assert response.status_code == 200
data = response.json()
assert data["user_id"] == 42
assert data["login"] == "jsmith"
assert data["name"] == "John Smith"
assert data["email"] == "jsmith@example.com"
assert data["session_token"] == "abc123token"
finally:
app.dependency_overrides.clear()

def test_login_calls_provider_with_correct_args(self, mock_provider):
"""Test that login passes correct arguments to provider."""
mock_provider.authenticate_user.return_value = {
"user_id": 1,
"login": "testuser",
"name": "Test User",
"email": "test@example.com",
"session_token": "token123",
}

app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider

try:
client.post(
"/auth/login",
json={
"username": "testuser",
"password": "testpass",
},
)

mock_provider.authenticate_user.assert_called_once_with(
"testuser", "testpass"
)
finally:
app.dependency_overrides.clear()

def test_login_returns_401_with_invalid_credentials(self, mock_provider):
"""Test that login returns 401 with invalid credentials."""
mock_provider.authenticate_user.return_value = None

app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider

try:
response = client.post(
"/auth/login",
json={
"username": "baduser",
"password": "badpass",
},
)
assert response.status_code == 401
data = response.json()
assert "Invalid username or password" in data["detail"]
finally:
app.dependency_overrides.clear()

def test_login_returns_422_with_missing_username(self, mock_provider):
"""Test that login returns 422 when username is missing."""
app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider

try:
response = client.post(
"/auth/login",
json={
"password": "secret",
},
)
assert response.status_code == 422
finally:
app.dependency_overrides.clear()

def test_login_returns_422_with_missing_password(self, mock_provider):
"""Test that login returns 422 when password is missing."""
app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider

try:
response = client.post(
"/auth/login",
json={
"username": "jsmith",
},
)
assert response.status_code == 422
finally:
app.dependency_overrides.clear()

def test_login_returns_500_when_provider_raises_error(self, mock_provider):
"""Test that login returns 500 when provider raises ValueError."""
mock_provider.authenticate_user.side_effect = ValueError(
"Not connected to ShotGrid"
)

app.dependency_overrides[get_prodtrack_provider_cached] = lambda: mock_provider

try:
response = client.post(
"/auth/login",
json={
"username": "jsmith",
"password": "secret",
},
)
assert response.status_code == 500
data = response.json()
assert "Not connected to ShotGrid" in data["detail"]
finally:
app.dependency_overrides.clear()
Loading