diff --git a/vulnerabilities/admin.py b/vulnerabilities/admin.py index eecef0276..176e3c0c0 100644 --- a/vulnerabilities/admin.py +++ b/vulnerabilities/admin.py @@ -9,6 +9,10 @@ from django import forms from django.contrib import admin +from django.contrib.admin.widgets import FilteredSelectMultiple +from django.contrib.auth.admin import GroupAdmin as BasicGroupAdmin +from django.contrib.auth.models import Group +from django.contrib.auth.models import User from django.core.validators import validate_email from vulnerabilities.models import ApiUser @@ -97,3 +101,49 @@ def get_form(self, request, obj=None, **kwargs): defaults["form"] = self.add_form defaults.update(kwargs) return super().get_form(request, obj, **defaults) + + +class GroupWithUsersForm(forms.ModelForm): + users = forms.ModelMultipleChoiceField( + queryset=User.objects.all(), + required=False, + widget=FilteredSelectMultiple("Users", is_stacked=False), + label="Users", + ) + + class Meta: + model = Group + fields = "__all__" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields["users"].label_from_instance = lambda user: ( + f"{user.username} | {user.email}" if user.email else user.username + ) + if self.instance.pk: + self.fields["users"].initial = self.instance.user_set.all() + + def save(self, commit=True): + group = super().save(commit=commit) + self.save_m2m() + group.user_set.set(self.cleaned_data["users"]) + return group + + +admin.site.unregister(Group) + + +@admin.register(Group) +class GroupAdmin(admin.ModelAdmin): + form = GroupWithUsersForm + search_fields = ("name",) + ordering = ("name",) + filter_horizontal = ("permissions",) + + def formfield_for_manytomany(self, db_field, request=None, **kwargs): + if db_field.name == "permissions": + qs = kwargs.get("queryset", db_field.remote_field.model.objects) + # Avoid a major performance hit resolving permission names which + # triggers a content_type load: + kwargs["queryset"] = qs.select_related("content_type") + return super().formfield_for_manytomany(db_field, request=request, **kwargs) diff --git a/vulnerabilities/api.py b/vulnerabilities/api.py index 1fd480ce9..d994b297d 100644 --- a/vulnerabilities/api.py +++ b/vulnerabilities/api.py @@ -22,7 +22,6 @@ from rest_framework import viewsets from rest_framework.decorators import action from rest_framework.response import Response -from rest_framework.throttling import AnonRateThrottle from vulnerabilities.models import Alias from vulnerabilities.models import Exploit @@ -34,7 +33,7 @@ from vulnerabilities.models import get_purl_query_lookups from vulnerabilities.severity_systems import EPSS from vulnerabilities.severity_systems import SCORING_SYSTEMS -from vulnerabilities.throttling import StaffUserRateThrottle +from vulnerabilities.throttling import PermissionBasedUserRateThrottle from vulnerabilities.utils import get_severity_range @@ -471,7 +470,7 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = PackageSerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = PackageFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] def get_queryset(self): return super().get_queryset().with_is_vulnerable() @@ -688,7 +687,7 @@ def get_queryset(self): serializer_class = VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = VulnerabilityFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] class CPEFilterSet(filters.FilterSet): diff --git a/vulnerabilities/api_extension.py b/vulnerabilities/api_extension.py index 7a13baf42..01d98ca99 100644 --- a/vulnerabilities/api_extension.py +++ b/vulnerabilities/api_extension.py @@ -23,7 +23,6 @@ from rest_framework.serializers import ModelSerializer from rest_framework.serializers import Serializer from rest_framework.serializers import ValidationError -from rest_framework.throttling import AnonRateThrottle from vulnerabilities.api import BaseResourceSerializer from vulnerabilities.models import Exploit @@ -33,7 +32,7 @@ from vulnerabilities.models import VulnerabilitySeverity from vulnerabilities.models import Weakness from vulnerabilities.models import get_purl_query_lookups -from vulnerabilities.throttling import StaffUserRateThrottle +from vulnerabilities.throttling import PermissionBasedUserRateThrottle class SerializerExcludeFieldsMixin: @@ -259,7 +258,7 @@ class V2PackageViewSet(viewsets.ReadOnlyModelViewSet): lookup_field = "purl" filter_backends = (filters.DjangoFilterBackend,) filterset_class = V2PackageFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] def get_queryset(self): return super().get_queryset().with_is_vulnerable().prefetch_related("vulnerabilities") @@ -345,7 +344,7 @@ class VulnerabilityViewSet(viewsets.ReadOnlyModelViewSet): lookup_field = "vulnerability_id" filter_backends = (filters.DjangoFilterBackend,) filterset_class = V2VulnerabilityFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] def get_queryset(self): """ @@ -381,7 +380,7 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet): ).distinct() serializer_class = V2VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] filterset_class = CPEFilterSet @action(detail=False, methods=["post"]) @@ -420,4 +419,4 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = V2VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = AliasFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] diff --git a/vulnerabilities/api_v2.py b/vulnerabilities/api_v2.py index 4915dda63..e9f967b79 100644 --- a/vulnerabilities/api_v2.py +++ b/vulnerabilities/api_v2.py @@ -23,6 +23,7 @@ from rest_framework.permissions import BasePermission from rest_framework.response import Response from rest_framework.reverse import reverse +from rest_framework.throttling import AnonRateThrottle from vulnerabilities.models import AdvisoryReference from vulnerabilities.models import AdvisorySeverity @@ -38,6 +39,7 @@ from vulnerabilities.models import VulnerabilityReference from vulnerabilities.models import VulnerabilitySeverity from vulnerabilities.models import Weakness +from vulnerabilities.throttling import PermissionBasedUserRateThrottle class WeaknessV2Serializer(serializers.ModelSerializer): @@ -199,6 +201,7 @@ class VulnerabilityV2ViewSet(viewsets.ReadOnlyModelViewSet): queryset = Vulnerability.objects.all() serializer_class = VulnerabilityV2Serializer lookup_field = "vulnerability_id" + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] def get_queryset(self): queryset = super().get_queryset() @@ -394,6 +397,7 @@ class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = PackageV2Serializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = PackageV2FilterSet + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] def get_queryset(self): queryset = super().get_queryset() @@ -721,6 +725,7 @@ class CodeFixViewSet(viewsets.ReadOnlyModelViewSet): queryset = CodeFix.objects.all() serializer_class = CodeFixSerializer + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] def get_queryset(self): """ @@ -863,6 +868,7 @@ class PipelineScheduleV2ViewSet(CreateListRetrieveUpdateViewSet): serializer_class = PipelineScheduleAPISerializer lookup_field = "pipeline_id" lookup_value_regex = r"[\w.]+" + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] def get_serializer_class(self): if self.action == "create": diff --git a/vulnerabilities/migrations/0095_alter_apiuser_options.py b/vulnerabilities/migrations/0095_alter_apiuser_options.py new file mode 100644 index 000000000..2f30298a4 --- /dev/null +++ b/vulnerabilities/migrations/0095_alter_apiuser_options.py @@ -0,0 +1,36 @@ +# Generated by Django 4.2.22 on 2025-07-01 11:59 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("vulnerabilities", "0094_advisoryalias_advisoryreference_advisoryseverity_and_more"), + ] + + operations = [ + migrations.AlterModelOptions( + name="apiuser", + options={ + "permissions": [ + ( + "throttle_3_unrestricted", + "Can make unlimited API requests without any throttling limits", + ), + ( + "throttle_2_high", + "Can make high number of API requests with minimal throttling", + ), + ( + "throttle_1_medium", + "Can make medium number of API requests with standard throttling", + ), + ( + "throttle_0_low", + "Can make low number of API requests with strict throttling", + ), + ] + }, + ), + ] diff --git a/vulnerabilities/models.py b/vulnerabilities/models.py index ab01010d7..e1c656e2b 100644 --- a/vulnerabilities/models.py +++ b/vulnerabilities/models.py @@ -28,6 +28,7 @@ from cwe2.mappings import xml_database_path from cwe2.weakness import Weakness as DBWeakness from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.contrib.auth.models import UserManager from django.core import exceptions from django.core.exceptions import ValidationError @@ -1489,14 +1490,30 @@ def _validate_username(self, email): class ApiUser(UserModel): - """ - A User proxy model to facilitate simplified admin API user creation. - """ + """A User proxy model to facilitate simplified admin API user creation.""" objects = ApiUserManager() class Meta: proxy = True + permissions = [ + ( + "throttle_3_unrestricted", + "Can make unlimited API requests without any throttling limits", + ), + ( + "throttle_2_high", + "Can make high number of API requests with minimal throttling", + ), + ( + "throttle_1_medium", + "Can make medium number of API requests with standard throttling", + ), + ( + "throttle_0_low", + "Can make low number of API requests with strict throttling", + ), + ] class ChangeLog(models.Model): diff --git a/vulnerabilities/tests/test_api.py b/vulnerabilities/tests/test_api.py index a5f80aa06..9ed647099 100644 --- a/vulnerabilities/tests/test_api.py +++ b/vulnerabilities/tests/test_api.py @@ -11,6 +11,7 @@ import os from urllib.parse import quote +from django.core.cache import cache from django.test import TestCase from django.test import TransactionTestCase from django.test.client import RequestFactory @@ -452,10 +453,8 @@ def add_aliases(vuln, aliases): class APIPerformanceTest(TestCase): def setUp(self): - self.user = ApiUser.objects.create_api_user(username="e@mail.com") - self.auth = f"Token {self.user.auth_token.key}" + cache.clear() self.csrf_client = APIClient(enforce_csrf_checks=True) - self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) # This setup creates the following data: # vulnerabilities: vul1, vul2, vul3 @@ -503,7 +502,7 @@ def setUp(self): set_as_fixing(package=self.pkg_2_13_2, vulnerability=self.vul1) def test_api_packages_all_num_queries(self): - with self.assertNumQueries(4): + with self.assertNumQueries(3): # There are 4 queries: # 1. SAVEPOINT # 2. Authenticating user @@ -519,22 +518,22 @@ def test_api_packages_all_num_queries(self): ] def test_api_packages_single_num_queries(self): - with self.assertNumQueries(8): + with self.assertNumQueries(7): self.csrf_client.get(f"/api/packages/{self.pkg_2_14_0_rc1.id}", format="json") def test_api_packages_single_with_purl_in_query_num_queries(self): - with self.assertNumQueries(9): + with self.assertNumQueries(8): self.csrf_client.get(f"/api/packages/?purl={self.pkg_2_14_0_rc1.purl}", format="json") def test_api_packages_single_with_purl_no_version_in_query_num_queries(self): - with self.assertNumQueries(64): + with self.assertNumQueries(63): self.csrf_client.get( f"/api/packages/?purl=pkg:maven/com.fasterxml.jackson.core/jackson-databind", format="json", ) def test_api_packages_bulk_search(self): - with self.assertNumQueries(45): + with self.assertNumQueries(44): packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1] purls = [p.purl for p in packages] @@ -547,7 +546,7 @@ def test_api_packages_bulk_search(self): ).json() def test_api_packages_with_lookup(self): - with self.assertNumQueries(14): + with self.assertNumQueries(13): data = {"purl": self.pkg_2_12_6.purl} resp = self.csrf_client.post( @@ -557,7 +556,7 @@ def test_api_packages_with_lookup(self): ).json() def test_api_packages_bulk_lookup(self): - with self.assertNumQueries(45): + with self.assertNumQueries(44): packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1] purls = [p.purl for p in packages] @@ -572,10 +571,8 @@ def test_api_packages_bulk_lookup(self): class APITestCasePackage(TestCase): def setUp(self): - self.user = ApiUser.objects.create_api_user(username="e@mail.com") - self.auth = f"Token {self.user.auth_token.key}" + cache.clear() self.csrf_client = APIClient(enforce_csrf_checks=True) - self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) # This setup creates the following data: # vulnerabilities: vul1, vul2, vul3 @@ -766,7 +763,7 @@ def test_api_with_wrong_namespace_filter(self): self.assertEqual(response["count"], 0) def test_api_with_all_vulnerable_packages(self): - with self.assertNumQueries(4): + with self.assertNumQueries(3): # There are 4 queries: # 1. SAVEPOINT # 2. Authenticating user diff --git a/vulnerabilities/tests/test_api_v2.py b/vulnerabilities/tests/test_api_v2.py index 071a4450c..6bdfa77f8 100644 --- a/vulnerabilities/tests/test_api_v2.py +++ b/vulnerabilities/tests/test_api_v2.py @@ -61,10 +61,8 @@ def setUp(self): ) self.reference2.vulnerabilities.add(self.vuln2) - self.user = ApiUser.objects.create_api_user(username="e@mail.com") - self.auth = f"Token {self.user.auth_token.key}" + cache.clear() self.client = APIClient(enforce_csrf_checks=True) - self.client.credentials(HTTP_AUTHORIZATION=self.auth) def test_list_vulnerabilities(self): """ @@ -73,7 +71,7 @@ def test_list_vulnerabilities(self): """ url = reverse("vulnerability-v2-list") response = self.client.get(url, format="json") - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("results", response.data) @@ -88,7 +86,7 @@ def test_retrieve_vulnerability_detail(self): Test retrieving vulnerability details by vulnerability_id. """ url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-1234"}) - with self.assertNumQueries(8): + with self.assertNumQueries(7): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["vulnerability_id"], "VCID-1234") @@ -102,7 +100,7 @@ def test_filter_vulnerability_by_vulnerability_id(self): Test filtering vulnerabilities by vulnerability_id. """ url = reverse("vulnerability-v2-list") - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.get(url, {"vulnerability_id": "VCID-1234"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["vulnerability_id"], "VCID-1234") @@ -112,7 +110,7 @@ def test_filter_vulnerability_by_alias(self): Test filtering vulnerabilities by alias. """ url = reverse("vulnerability-v2-list") - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get(url, {"alias": "CVE-2021-5678"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("results", response.data) @@ -127,7 +125,7 @@ def test_filter_vulnerabilities_multiple_ids(self): Test filtering vulnerabilities by multiple vulnerability_ids. """ url = reverse("vulnerability-v2-list") - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get( url, {"vulnerability_id": ["VCID-1234", "VCID-5678"]}, format="json" ) @@ -139,7 +137,7 @@ def test_filter_vulnerabilities_multiple_aliases(self): Test filtering vulnerabilities by multiple aliases. """ url = reverse("vulnerability-v2-list") - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get( url, {"alias": ["CVE-2021-1234", "CVE-2021-5678"]}, format="json" ) @@ -152,7 +150,7 @@ def test_invalid_vulnerability_id(self): Should return 404 Not Found. """ url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-9999"}) - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @@ -210,10 +208,8 @@ def setUp(self): self.package1.affected_by_vulnerabilities.add(self.vuln1) self.package2.fixing_vulnerabilities.add(self.vuln2) - self.user = ApiUser.objects.create_api_user(username="e@mail.com") - self.auth = f"Token {self.user.auth_token.key}" + cache.clear() self.client = APIClient(enforce_csrf_checks=True) - self.client.credentials(HTTP_AUTHORIZATION=self.auth) def test_list_packages(self): """ @@ -221,7 +217,7 @@ def test_list_packages(self): Should return a list of packages with their details and associated vulnerabilities. """ url = reverse("package-v2-list") - with self.assertNumQueries(32): + with self.assertNumQueries(31): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("results", response.data) @@ -243,7 +239,7 @@ def test_filter_packages_by_purl(self): Test filtering packages by one or more PURLs. """ url = reverse("package-v2-list") - with self.assertNumQueries(20): + with self.assertNumQueries(19): response = self.client.get(url, {"purl": "pkg:pypi/django@3.2"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["packages"]), 1) @@ -254,7 +250,7 @@ def test_filter_packages_by_affected_vulnerability(self): Test filtering packages by affected_by_vulnerability. """ url = reverse("package-v2-list") - with self.assertNumQueries(20): + with self.assertNumQueries(19): response = self.client.get( url, {"affected_by_vulnerability": "VCID-1234"}, format="json" ) @@ -267,7 +263,7 @@ def test_filter_packages_by_fixing_vulnerability(self): Test filtering packages by fixing_vulnerability. """ url = reverse("package-v2-list") - with self.assertNumQueries(18): + with self.assertNumQueries(17): response = self.client.get(url, {"fixing_vulnerability": "VCID-5678"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["packages"]), 1) @@ -356,7 +352,7 @@ def test_invalid_vulnerability_filter(self): Should return an empty list. """ url = reverse("package-v2-list") - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.get( url, {"affected_by_vulnerability": "VCID-9999"}, format="json" ) @@ -369,7 +365,7 @@ def test_invalid_purl_filter(self): Should return an empty list. """ url = reverse("package-v2-list") - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.get( url, {"purl": "pkg:nonexistent/package@1.0.0"}, format="json" ) @@ -421,7 +417,7 @@ def test_bulk_lookup_with_valid_purls(self): """ url = reverse("package-v2-bulk-lookup") data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]} - with self.assertNumQueries(28): + with self.assertNumQueries(27): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("packages", response.data) @@ -446,7 +442,7 @@ def test_bulk_lookup_with_invalid_purls(self): """ url = reverse("package-v2-bulk-lookup") data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]} - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since the packages don't exist, the response should be empty @@ -460,7 +456,7 @@ def test_bulk_lookup_with_empty_purls(self): """ url = reverse("package-v2-bulk-lookup") data = {"purls": []} - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) @@ -474,7 +470,7 @@ def test_bulk_search_with_valid_purls(self): """ url = reverse("package-v2-bulk-search") data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]} - with self.assertNumQueries(28): + with self.assertNumQueries(27): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("packages", response.data) @@ -502,7 +498,7 @@ def test_bulk_search_with_purl_only_true(self): "purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"], "purl_only": True, } - with self.assertNumQueries(17): + with self.assertNumQueries(16): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since purl_only=True, response should be a list of PURLs @@ -529,7 +525,7 @@ def test_bulk_search_with_plain_purl_true(self): "purls": ["pkg:pypi/django@3.2", "pkg:pypi/django@3.2?extension=tar.gz"], "plain_purl": True, } - with self.assertNumQueries(16): + with self.assertNumQueries(15): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("packages", response.data) @@ -550,7 +546,7 @@ def test_bulk_search_with_purl_only_and_plain_purl_true(self): "purl_only": True, "plain_purl": True, } - with self.assertNumQueries(11): + with self.assertNumQueries(10): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Response should be a list of plain PURLs @@ -566,7 +562,7 @@ def test_bulk_search_with_invalid_purls(self): """ url = reverse("package-v2-bulk-search") data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]} - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since the packages don't exist, the response should be empty @@ -580,7 +576,7 @@ def test_bulk_search_with_empty_purls(self): """ url = reverse("package-v2-bulk-search") data = {"purls": []} - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) @@ -592,7 +588,7 @@ def test_all_vulnerable_packages(self): Test the 'all' endpoint that returns all vulnerable package URLs. """ url = reverse("package-v2-all") - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since package1 is vulnerable, it should be returned @@ -606,7 +602,7 @@ def test_lookup_with_valid_purl(self): """ url = reverse("package-v2-lookup") data = {"purl": "pkg:pypi/django@3.2"} - with self.assertNumQueries(13): + with self.assertNumQueries(12): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(1, len(response.data)) @@ -635,7 +631,7 @@ def test_lookup_with_invalid_purl(self): """ url = reverse("package-v2-lookup") data = {"purl": "pkg:pypi/nonexistent@1.0.0"} - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # No packages or vulnerabilities should be returned @@ -648,7 +644,7 @@ def test_lookup_with_missing_purl(self): """ url = reverse("package-v2-lookup") data = {} - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) @@ -662,7 +658,7 @@ def test_lookup_with_invalid_purl_format(self): """ url = reverse("package-v2-lookup") data = {"purl": "invalid_purl_format"} - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # No packages or vulnerabilities should be returned diff --git a/vulnerabilities/tests/test_throttling.py b/vulnerabilities/tests/test_throttling.py index 174761045..4ff83c70d 100644 --- a/vulnerabilities/tests/test_throttling.py +++ b/vulnerabilities/tests/test_throttling.py @@ -9,58 +9,179 @@ import json +from django.contrib.auth.models import Group +from django.contrib.auth.models import Permission from django.core.cache import cache +from rest_framework import status from rest_framework.test import APIClient from rest_framework.test import APITestCase +from vulnerabilities.api import PermissionBasedUserRateThrottle from vulnerabilities.models import ApiUser -class ThrottleApiTests(APITestCase): +def simulate_throttle_usage(url, client, mock_use_count): + throttle = PermissionBasedUserRateThrottle() + request = client.get(url).wsgi_request + + if cache_key := throttle.get_cache_key(request, view=None): + print(cache_key) + now = throttle.timer() + cache.set(cache_key, [now] * mock_use_count) + + +class PermissionBasedRateThrottleApiTests(APITestCase): def setUp(self): # Reset the api throttling to properly test the rate limit on anon users. # DRF stores throttling state in cache, clear cache to reset throttling. # See https://www.django-rest-framework.org/api-guide/throttling/#setting-up-the-cache cache.clear() - # create a basic user - self.user = ApiUser.objects.create_api_user(username="e@mail.com") - self.auth = f"Token {self.user.auth_token.key}" - self.csrf_client = APIClient(enforce_csrf_checks=True) - self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) + permission_low = Permission.objects.get(codename="throttle_0_low") + permission_medium = Permission.objects.get(codename="throttle_1_medium") + permission_high = Permission.objects.get(codename="throttle_2_high") + permission_unrestricted = Permission.objects.get(codename="throttle_3_unrestricted") + + # user with low permission + self.th_low_user = ApiUser.objects.create_api_user(username="z@mail.com") + self.th_low_user.user_permissions.add(permission_low) + self.th_low_user_auth = f"Token {self.th_low_user.auth_token.key}" + self.th_low_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_low_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_low_user_auth) + + # basic user without any special throttling perm + self.basic_user = ApiUser.objects.create_api_user(username="a@mail.com") + self.basic_user_auth = f"Token {self.basic_user.auth_token.key}" + self.basic_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.basic_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.basic_user_auth) + + # medium permission + self.th_medium_user = ApiUser.objects.create_api_user(username="b@mail.com") + self.th_medium_user.user_permissions.add(permission_medium) + self.th_medium_user_auth = f"Token {self.th_medium_user.auth_token.key}" + self.th_medium_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_medium_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_medium_user_auth) + + # high permission + self.th_high_user = ApiUser.objects.create_api_user(username="c@mail.com") + self.th_high_user.user_permissions.add(permission_high) + self.th_high_user_auth = f"Token {self.th_high_user.auth_token.key}" + self.th_high_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_high_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_high_user_auth) + + # unrestricted throttling perm + self.th_unrestricted_user = ApiUser.objects.create_api_user(username="d@mail.com") + self.th_unrestricted_user.user_permissions.add(permission_unrestricted) + self.th_unrestricted_user_auth = f"Token {self.th_unrestricted_user.auth_token.key}" + self.th_unrestricted_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_unrestricted_user_csrf_client.credentials( + HTTP_AUTHORIZATION=self.th_unrestricted_user_auth + ) - # create a staff user - self.staff_user = ApiUser.objects.create_api_user(username="staff@mail.com", is_staff=True) - self.staff_auth = f"Token {self.staff_user.auth_token.key}" - self.staff_csrf_client = APIClient(enforce_csrf_checks=True) - self.staff_csrf_client.credentials(HTTP_AUTHORIZATION=self.staff_auth) + # unrestricted throttling for group user + group, _ = Group.objects.get_or_create(name="Test Unrestricted") + group.permissions.add(permission_unrestricted) + + self.th_group_user = ApiUser.objects.create_api_user(username="g@mail.com") + self.th_group_user.groups.add(group) + self.th_group_user_auth = f"Token {self.th_group_user.auth_token.key}" + self.th_group_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_group_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_group_user_auth) self.csrf_client_anon = APIClient(enforce_csrf_checks=True) self.csrf_client_anon_1 = APIClient(enforce_csrf_checks=True) - def test_package_endpoint_throttling(self): - for i in range(0, 20): - response = self.csrf_client.get("/api/packages") - self.assertEqual(response.status_code, 200) - response = self.staff_csrf_client.get("/api/packages") - self.assertEqual(response.status_code, 200) + def test_user_with_low_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_low_user_csrf_client, + mock_use_count=10799, + ) + + response = self.th_low_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # exhausted 10800/hr allowed requests. + response = self.th_low_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + + def test_basic_user_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.basic_user_csrf_client, + mock_use_count=14399, + ) + + response = self.basic_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # exhausted 14400/hr allowed requests. + response = self.basic_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + + def test_user_with_medium_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_medium_user_csrf_client, + mock_use_count=14399, + ) + + response = self.th_medium_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # exhausted 14400/hr allowed requests for user with 14400 perm. + response = self.th_medium_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + + def test_user_with_high_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_high_user_csrf_client, + mock_use_count=17999, + ) + + response = self.th_high_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # exhausted 18000/hr allowed requests for user with 18000 perm. + response = self.th_high_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + + def test_user_with_unrestricted_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_unrestricted_user_csrf_client, + mock_use_count=20000, + ) + + # no throttling for user with unrestricted perm. + response = self.th_unrestricted_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self.csrf_client.get("/api/packages") - # 429 - too many requests for basic user - self.assertEqual(response.status_code, 429) + def test_user_in_group_with_unrestricted_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_group_user_csrf_client, + mock_use_count=20000, + ) - response = self.staff_csrf_client.get("/api/packages", format="json") - # 200 - staff user can access API unlimited times - self.assertEqual(response.status_code, 200) + # no throttling for user in group with unrestricted perm. + response = self.th_group_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) - # A anonymous user can only access /packages endpoint 10 times a day - for _i in range(0, 10): - response = self.csrf_client_anon.get("/api/packages") - self.assertEqual(response.status_code, 200) + def test_anon_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.csrf_client_anon, + mock_use_count=3599, + ) response = self.csrf_client_anon.get("/api/packages") - # 429 - too many requests for anon user - self.assertEqual(response.status_code, 429) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # exhausted 3600/hr allowed requests for anon. + response = self.csrf_client_anon.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertEqual( response.data.get("message"), "Your request has been throttled. Please contact support@nexb.com", @@ -68,7 +189,7 @@ def test_package_endpoint_throttling(self): response = self.csrf_client_anon.get("/api/vulnerabilities") # 429 - too many requests for anon user - self.assertEqual(response.status_code, 429) + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertEqual( response.data.get("message"), "Your request has been throttled. Please contact support@nexb.com", @@ -80,7 +201,7 @@ def test_package_endpoint_throttling(self): "/api/packages/bulk_search", data=data, content_type="application/json" ) # 429 - too many requests for anon user - self.assertEqual(response.status_code, 429) + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertEqual( response.data.get("message"), "Your request has been throttled. Please contact support@nexb.com", diff --git a/vulnerabilities/throttling.py b/vulnerabilities/throttling.py index 99b1d7756..e14c1a1c0 100644 --- a/vulnerabilities/throttling.py +++ b/vulnerabilities/throttling.py @@ -6,21 +6,50 @@ # See https://github.com/aboutcode-org/vulnerablecode for support or download. # See https://aboutcode.org for more information about nexB OSS projects. # + +from django.core.exceptions import ImproperlyConfigured from rest_framework.exceptions import Throttled from rest_framework.throttling import UserRateThrottle from rest_framework.views import exception_handler -class StaffUserRateThrottle(UserRateThrottle): +class PermissionBasedUserRateThrottle(UserRateThrottle): + """ + Throttles authenticated users based on their assigned permissions. + If no throttling permission is assigned, defaults to `medium` throttling + for authenticated users and `anon` for unauthenticated users. + """ + + def __init__(self): + pass + def allow_request(self, request, view): - """ - Do not apply throttling for superusers and admins. - """ - if request.user.is_superuser or request.user.is_staff: + user = request.user + throttling_tier = "medium" + + if not user or not user.is_authenticated: + throttling_tier = "anon" + elif user.has_perm("vulnerabilities.throttle_3_unrestricted"): return True + elif user.has_perm("vulnerabilities.throttle_2_high"): + throttling_tier = "high" + elif user.has_perm("vulnerabilities.throttle_1_medium"): + throttling_tier = "medium" + elif user.has_perm("vulnerabilities.throttle_0_low"): + throttling_tier = "low" + + self.rate = self.get_throttle_rate(throttling_tier) + self.num_requests, self.duration = self.parse_rate(self.rate) return super().allow_request(request, view) + def get_throttle_rate(self, tier): + try: + return self.THROTTLE_RATES[tier] + except KeyError: + msg = f"No throttle rate set for {tier}." + raise ImproperlyConfigured(msg) + def throttled_exception_handler(exception, context): """ diff --git a/vulnerablecode/settings.py b/vulnerablecode/settings.py index 6040f99b9..0d4207b23 100644 --- a/vulnerablecode/settings.py +++ b/vulnerablecode/settings.py @@ -190,12 +190,21 @@ LOGIN_REDIRECT_URL = "/" LOGOUT_REDIRECT_URL = "/" -REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {"anon": "3600/hour", "user": "10800/hour"} +THROTTLE_RATE_ANON = env.str("THROTTLE_RATE_ANON", default="3600/hour") +THROTTLE_RATE_USER_HIGH = env.str("THROTTLE_RATE_USER_HIGH", default="18000/hour") +THROTTLE_RATE_USER_MEDIUM = env.str("THROTTLE_RATE_USER_MEDIUM", default="14400/hour") +THROTTLE_RATE_USER_LOW = env.str("THROTTLE_RATE_USER_LOW", default="10800/hour") + +REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = { + "anon": THROTTLE_RATE_ANON, + "low": THROTTLE_RATE_USER_LOW, + "medium": THROTTLE_RATE_USER_MEDIUM, + "high": THROTTLE_RATE_USER_HIGH, +} + if IS_TESTS: VULNERABLECODEIO_REQUIRE_AUTHENTICATION = False - REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {"anon": "10/day", "user": "20/day"} - USE_L10N = True @@ -235,9 +244,7 @@ "rest_framework.filters.SearchFilter", ), "DEFAULT_THROTTLE_CLASSES": [ - "vulnerabilities.throttling.StaffUserRateThrottle", - "rest_framework.throttling.AnonRateThrottle", - "rest_framework.throttling.UserRateThrottle", + "vulnerabilities.throttling.PermissionBasedUserRateThrottle", ], "DEFAULT_THROTTLE_RATES": REST_FRAMEWORK_DEFAULT_THROTTLE_RATES, "EXCEPTION_HANDLER": "vulnerabilities.throttling.throttled_exception_handler", diff --git a/vulnerablecode/urls.py b/vulnerablecode/urls.py index 245b8e917..81ba5cadb 100644 --- a/vulnerablecode/urls.py +++ b/vulnerablecode/urls.py @@ -171,10 +171,10 @@ def __init__(self, *args, **kwargs): TemplateView.as_view(template_name="tos.html"), name="api_tos", ), - # path( - # "admin/", - # admin.site.urls, - # ), + path( + "admin/", + admin.site.urls, + ), ] if DEBUG: