diff --git a/backend/api/views/lockbox.py b/backend/api/views/lockbox.py index 6f68d30bb..38a89e451 100644 --- a/backend/api/views/lockbox.py +++ b/backend/api/views/lockbox.py @@ -5,11 +5,12 @@ Lockbox, ) +from django.db import transaction from django.http import HttpResponse from django.views.decorators.csrf import csrf_exempt from django.utils import timezone -from django.db.models import Q +from django.db.models import Q, F from rest_framework.permissions import AllowAny from rest_framework.response import Response @@ -39,25 +40,18 @@ def dispatch(self, request, *args, **kwargs): def get(self, request, box_id): try: - box = Lockbox.objects.get( - Q(id=box_id) - & (Q(expires_at__gte=timezone.now()) | Q(expires_at__isnull=True)) - ) - if box.allowed_views is None or box.views < box.allowed_views: - serializer = LockboxSerializer(box) - return Response(serializer.data, status=status.HTTP_200_OK) - else: - return HttpResponse(status=status.HTTP_403_FORBIDDEN) - - except Lockbox.DoesNotExist: - return HttpResponse(status=status.HTTP_404_NOT_FOUND) - - def put(self, request, box_id): - try: - box = Lockbox.objects.get(id=box_id) - box.views += 1 - box.save() - return HttpResponse(status=status.HTTP_200_OK) + with transaction.atomic(): + box = Lockbox.objects.select_for_update().get( + Q(id=box_id) + & (Q(expires_at__gte=timezone.now()) | Q(expires_at__isnull=True)) + ) + if box.allowed_views is None or box.views < box.allowed_views: + serializer = LockboxSerializer(box) + # Atomically increment view count on read + Lockbox.objects.filter(id=box_id).update(views=F("views") + 1) + return Response(serializer.data, status=status.HTTP_200_OK) + else: + return HttpResponse(status=status.HTTP_403_FORBIDDEN) except Lockbox.DoesNotExist: return HttpResponse(status=status.HTTP_404_NOT_FOUND) diff --git a/backend/tests/api/views/__init__.py b/backend/tests/api/views/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/tests/api/views/test_lockbox.py b/backend/tests/api/views/test_lockbox.py new file mode 100644 index 000000000..e59d1b0f1 --- /dev/null +++ b/backend/tests/api/views/test_lockbox.py @@ -0,0 +1,128 @@ +import pytest +from unittest.mock import MagicMock, patch +from django.test import RequestFactory +from api.views.lockbox import LockboxView +from api.models import Lockbox as RealLockbox + + +class TestLockboxViewCounting: + """Tests that lockbox view counting is enforced server-side.""" + + @patch("api.views.lockbox.transaction") + @patch("api.views.lockbox.Lockbox") + @patch("api.views.lockbox.LockboxSerializer") + def test_get_increments_view_count_atomically( + self, MockSerializer, MockLockbox, mock_transaction + ): + """GET should atomically increment the view count.""" + MockLockbox.DoesNotExist = RealLockbox.DoesNotExist + mock_transaction.atomic.return_value.__enter__ = MagicMock() + mock_transaction.atomic.return_value.__exit__ = MagicMock(return_value=False) + + mock_box = MagicMock() + mock_box.id = "box-123" + mock_box.allowed_views = 3 + mock_box.views = 0 + + mock_qs = MagicMock() + mock_qs.get.return_value = mock_box + MockLockbox.objects.select_for_update.return_value = mock_qs + + mock_filter_qs = MagicMock() + MockLockbox.objects.filter.return_value = mock_filter_qs + + MockSerializer.return_value = MagicMock(data={"data": "encrypted"}) + + factory = RequestFactory() + request = factory.get("/lockbox/box-123") + + view = LockboxView() + response = view.get(request, "box-123") + + assert response.status_code == 200 + MockLockbox.objects.filter.assert_called_once_with(id="box-123") + mock_filter_qs.update.assert_called_once() + + @patch("api.views.lockbox.transaction") + @patch("api.views.lockbox.Lockbox") + def test_get_rejects_when_view_limit_reached(self, MockLockbox, mock_transaction): + """GET should return 403 when allowed_views is exhausted.""" + MockLockbox.DoesNotExist = RealLockbox.DoesNotExist + mock_transaction.atomic.return_value.__enter__ = MagicMock() + mock_transaction.atomic.return_value.__exit__ = MagicMock(return_value=False) + + mock_box = MagicMock() + mock_box.id = "box-123" + mock_box.allowed_views = 1 + mock_box.views = 1 + + mock_qs = MagicMock() + mock_qs.get.return_value = mock_box + MockLockbox.objects.select_for_update.return_value = mock_qs + + factory = RequestFactory() + request = factory.get("/lockbox/box-123") + + view = LockboxView() + response = view.get(request, "box-123") + + assert response.status_code == 403 + + @patch("api.views.lockbox.transaction") + @patch("api.views.lockbox.Lockbox") + @patch("api.views.lockbox.LockboxSerializer") + def test_get_allows_unlimited_views_when_allowed_views_is_none( + self, MockSerializer, MockLockbox, mock_transaction + ): + """GET should allow reads when allowed_views is None (unlimited).""" + MockLockbox.DoesNotExist = RealLockbox.DoesNotExist + mock_transaction.atomic.return_value.__enter__ = MagicMock() + mock_transaction.atomic.return_value.__exit__ = MagicMock(return_value=False) + + mock_box = MagicMock() + mock_box.id = "box-123" + mock_box.allowed_views = None + mock_box.views = 100 + + mock_qs = MagicMock() + mock_qs.get.return_value = mock_box + MockLockbox.objects.select_for_update.return_value = mock_qs + + mock_filter_qs = MagicMock() + MockLockbox.objects.filter.return_value = mock_filter_qs + + MockSerializer.return_value = MagicMock(data={"data": "encrypted"}) + + factory = RequestFactory() + request = factory.get("/lockbox/box-123") + + view = LockboxView() + response = view.get(request, "box-123") + + assert response.status_code == 200 + + @patch("api.views.lockbox.transaction") + @patch("api.views.lockbox.Lockbox") + def test_get_returns_404_for_nonexistent_box(self, MockLockbox, mock_transaction): + """GET should return 404 for missing lockboxes.""" + MockLockbox.DoesNotExist = RealLockbox.DoesNotExist + mock_transaction.atomic.return_value.__enter__ = MagicMock() + mock_transaction.atomic.return_value.__exit__ = MagicMock(return_value=False) + + mock_qs = MagicMock() + mock_qs.get.side_effect = RealLockbox.DoesNotExist + MockLockbox.objects.select_for_update.return_value = mock_qs + + factory = RequestFactory() + request = factory.get("/lockbox/missing-id") + + view = LockboxView() + response = view.get(request, "missing-id") + + assert response.status_code == 404 + + def test_no_put_method(self): + """LockboxView should not have a PUT method.""" + assert not hasattr(LockboxView, "put") or not callable( + getattr(LockboxView, "put", None) + ) diff --git a/frontend/components/lockbox/LockboxViewer.tsx b/frontend/components/lockbox/LockboxViewer.tsx index e9f28eaa5..187ebadb1 100644 --- a/frontend/components/lockbox/LockboxViewer.tsx +++ b/frontend/components/lockbox/LockboxViewer.tsx @@ -1,7 +1,7 @@ 'use client' import { LockboxType } from '@/apollo/graphql' -import { boxExpiryString, updateBoxViewCount } from '@/utils/lockbox' +import { boxExpiryString } from '@/utils/lockbox' import { useEffect, useState } from 'react' import { Button } from '../common/Button' import CopyButton from '../common/CopyButton' @@ -36,8 +36,6 @@ export const LockboxViewer = (props: { box: LockboxType }) => { } catch (err) { toast.error('Something wrong opening this box. Please check the link and try again!') } - - updateBoxViewCount(box.id) } return ( diff --git a/frontend/utils/lockbox.ts b/frontend/utils/lockbox.ts index 7d7714667..ebd723dda 100644 --- a/frontend/utils/lockbox.ts +++ b/frontend/utils/lockbox.ts @@ -12,16 +12,6 @@ export const getBox = async (boxId: string) => { return res.json() } -export const updateBoxViewCount = async (boxId: string) => { - const res = await fetch(`${process.env.NEXT_PUBLIC_BACKEND_API_BASE}/lockbox/${boxId}`, { - method: 'PUT', - headers: { - 'Content-Type': 'application/json', - }, - credentials: 'omit', - }) -} - export const boxExpiryString = (expiresAt?: number, allowedViews?: Maybe) => { if (!expiresAt && !allowedViews) { return 'This box will never expire'