Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion server/app/component/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ async def auth(


async def auth_must(
token: str = Depends(oauth2_scheme),
token: str | None = Depends(oauth2_scheme),
session: Session = Depends(session),
) -> Auth:
if token is None:
raise TokenException(code.token_invalid, _("Authentication required"))
model = Auth.decode_token(token)
user = session.get(User, model.id)
model._user = user
Expand Down
15 changes: 12 additions & 3 deletions server/app/controller/chat/share_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sqlmodel import Session, asc, select
from starlette.responses import StreamingResponse

from app.component.auth import Auth, auth_must
from app.component.database import session
from app.model.chat.chat_history import ChatHistory
from app.model.chat.chat_share import (
Expand Down Expand Up @@ -116,12 +117,20 @@ async def event_generator():


@router.post("/share", name="Generate sharable link for a task(1 day expiration)")
def create_share_link(data: ChatShareIn):
def create_share_link(data: ChatShareIn, auth: Auth = Depends(auth_must)):
"""Generate sharing token with 1-day expiration for task."""
user_id = auth.user.id
try:
share_token = ChatShare.generate_token(data.task_id)
logger.info("Share link created", extra={"task_id": data.task_id, "token_prefix": share_token[:10]})
logger.info(
"Share link created",
extra={"user_id": user_id, "task_id": data.task_id, "token_prefix": share_token[:10]},
)
return {"share_token": share_token}
except Exception as e:
logger.error("Share link creation failed", extra={"task_id": data.task_id, "error": str(e)}, exc_info=True)
logger.error(
"Share link creation failed",
extra={"user_id": user_id, "task_id": data.task_id, "error": str(e)},
exc_info=True,
)
raise HTTPException(status_code=500, detail="Internal server error")
14 changes: 12 additions & 2 deletions server/app/controller/chat/snapshot_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ async def list_chat_snapshots(
camel_task_id: str | None = None,
browser_url: str | None = None,
session: Session = Depends(session),
auth: Auth = Depends(auth_must),
):
"""List chat snapshots with optional filtering."""
query = select(ChatSnapshot)
user_id = auth.user.id
query = select(ChatSnapshot).where(ChatSnapshot.user_id == user_id)
if api_task_id is not None:
query = query.where(ChatSnapshot.api_task_id == api_task_id)
if camel_task_id is not None:
Expand All @@ -45,7 +47,8 @@ async def list_chat_snapshots(

snapshots = session.exec(query).all()
logger.debug(
"Snapshots listed", extra={"api_task_id": api_task_id, "camel_task_id": camel_task_id, "count": len(snapshots)}
"Snapshots listed",
extra={"user_id": user_id, "api_task_id": api_task_id, "camel_task_id": camel_task_id, "count": len(snapshots)},
)
return snapshots

Expand All @@ -60,6 +63,13 @@ async def get_chat_snapshot(snapshot_id: int, session: Session = Depends(session
logger.warning("Snapshot not found", extra={"user_id": user_id, "snapshot_id": snapshot_id})
raise HTTPException(status_code=404, detail=_("Chat snapshot not found"))

if snapshot.user_id != user_id:
logger.warning(
"Unauthorized snapshot access",
extra={"user_id": user_id, "snapshot_id": snapshot_id, "owner_id": snapshot.user_id},
)
raise HTTPException(status_code=403, detail=_("You are not allowed to view this snapshot"))

logger.debug(
"Snapshot retrieved",
extra={"user_id": user_id, "snapshot_id": snapshot_id, "api_task_id": snapshot.api_task_id},
Expand Down
145 changes: 145 additions & 0 deletions server/tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========

import inspect

import pytest

from app.controller.chat.share_controller import (
create_share_link,
get_share_info,
share_playback,
)


class TestAuthMustNoneTokenHandling:
"""Tests for auth_must handling of None tokens.

When oauth2_scheme is configured with auto_error=False, it returns
None instead of raising 401 when no token is provided. auth_must
must explicitly handle this case instead of passing None to
jwt.decode() which produces an opaque DecodeError.
"""

def test_auth_must_has_none_type_annotation(self):
"""auth_must should accept Optional[str] since oauth2_scheme
may return None with auto_error=False."""
from app.component.auth import auth_must

sig = inspect.signature(auth_must)
token_param = sig.parameters["token"]
annotation = str(token_param.annotation)
# Should accept None (str | None or Optional[str])
assert "None" in annotation or "Optional" in annotation

def test_auth_must_raises_on_none_token(self):
"""auth_must should raise TokenException immediately when
token is None, not pass it to jwt.decode()."""
import asyncio
from unittest.mock import MagicMock, patch

from app.component.auth import auth_must
from app.exception.exception import TokenException

mock_session = MagicMock()

with pytest.raises(TokenException):
asyncio.run(auth_must(token=None, session=mock_session))

def test_auth_must_does_not_call_decode_on_none(self):
"""Verify jwt.decode is never called with None token."""
import asyncio
from unittest.mock import MagicMock, patch

from app.component.auth import auth_must

mock_session = MagicMock()

with patch("app.component.auth.Auth.decode_token") as mock_decode:
try:
asyncio.run(auth_must(token=None, session=mock_session))
except Exception:
pass
mock_decode.assert_not_called()


class TestSnapshotEndpointAuthRequirements:
"""Tests verifying that all snapshot CRUD endpoints require authentication.

The list endpoint was previously missing the auth dependency, allowing
unauthenticated users to enumerate all snapshots across all users.
"""

def test_list_snapshots_requires_auth_dependency(self):
"""GET /snapshots must include auth_must as a dependency."""
from app.controller.chat.snapshot_controller import list_chat_snapshots

sig = inspect.signature(list_chat_snapshots)
param_names = list(sig.parameters.keys())
assert "auth" in param_names, (
"list_chat_snapshots is missing the 'auth' parameter — "
"unauthenticated users can list all snapshots"
)

def test_get_snapshot_requires_auth_dependency(self):
"""GET /snapshots/{id} must include auth_must as a dependency."""
from app.controller.chat.snapshot_controller import get_chat_snapshot

sig = inspect.signature(get_chat_snapshot)
param_names = list(sig.parameters.keys())
assert "auth" in param_names

def test_create_snapshot_requires_auth_dependency(self):
"""POST /snapshots must include auth_must as a dependency."""
from app.controller.chat.snapshot_controller import create_chat_snapshot

sig = inspect.signature(create_chat_snapshot)
param_names = list(sig.parameters.keys())
assert "auth" in param_names

def test_update_snapshot_requires_auth_dependency(self):
"""PUT /snapshots/{id} must include auth_must as a dependency."""
from app.controller.chat.snapshot_controller import update_chat_snapshot

sig = inspect.signature(update_chat_snapshot)
param_names = list(sig.parameters.keys())
assert "auth" in param_names

def test_delete_snapshot_requires_auth_dependency(self):
"""DELETE /snapshots/{id} must include auth_must as a dependency."""
from app.controller.chat.snapshot_controller import delete_chat_snapshot

sig = inspect.signature(delete_chat_snapshot)
param_names = list(sig.parameters.keys())
assert "auth" in param_names


def test_create_share_link_requires_auth_dependency():
"""POST /share must include auth_must as a dependency."""
sig = inspect.signature(create_share_link)
param_names = list(sig.parameters.keys())
assert "auth" in param_names, (
"create_share_link is missing the 'auth' parameter — "
"unauthenticated users can generate share tokens"
)


def test_share_read_endpoints_remain_public():
"""GET /share/info and /share/playback should remain public
since they verify the share token itself."""
# These endpoints use the share token for auth, not user auth
info_params = list(inspect.signature(get_share_info).parameters.keys())
playback_params = list(inspect.signature(share_playback).parameters.keys())
assert "token" in info_params
assert "token" in playback_params