From 42673a76e181d08c32f23d8bec9c28f5a3e20028 Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Wed, 11 Jun 2025 14:03:38 +0530 Subject: [PATCH 01/11] Throttle API users based on user group Signed-off-by: Keshav Priyadarshi --- vulnerabilities/api.py | 6 +++--- vulnerabilities/api_extension.py | 10 +++++----- vulnerabilities/models.py | 5 +++++ vulnerabilities/throttling.py | 25 +++++++++++++++++++------ vulnerablecode/settings.py | 17 ++++++++++++----- 5 files changed, 44 insertions(+), 19 deletions(-) diff --git a/vulnerabilities/api.py b/vulnerabilities/api.py index 1fd480ce9..d23dd7adb 100644 --- a/vulnerabilities/api.py +++ b/vulnerabilities/api.py @@ -34,7 +34,7 @@ from vulnerabilities.models import get_purl_query_lookups from vulnerabilities.severity_systems import EPSS from vulnerabilities.severity_systems import SCORING_SYSTEMS -from vulnerabilities.throttling import StaffUserRateThrottle +from vulnerabilities.throttling import GroupUserRateThrottle from vulnerabilities.utils import get_severity_range @@ -471,7 +471,7 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = PackageSerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = PackageFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [AnonRateThrottle, GroupUserRateThrottle] def get_queryset(self): return super().get_queryset().with_is_vulnerable() @@ -688,7 +688,7 @@ def get_queryset(self): serializer_class = VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = VulnerabilityFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [AnonRateThrottle, GroupUserRateThrottle] class CPEFilterSet(filters.FilterSet): diff --git a/vulnerabilities/api_extension.py b/vulnerabilities/api_extension.py index 7a13baf42..df765137c 100644 --- a/vulnerabilities/api_extension.py +++ b/vulnerabilities/api_extension.py @@ -33,7 +33,7 @@ from vulnerabilities.models import VulnerabilitySeverity from vulnerabilities.models import Weakness from vulnerabilities.models import get_purl_query_lookups -from vulnerabilities.throttling import StaffUserRateThrottle +from vulnerabilities.throttling import GroupUserRateThrottle class SerializerExcludeFieldsMixin: @@ -259,7 +259,7 @@ class V2PackageViewSet(viewsets.ReadOnlyModelViewSet): lookup_field = "purl" filter_backends = (filters.DjangoFilterBackend,) filterset_class = V2PackageFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [GroupUserRateThrottle, AnonRateThrottle] def get_queryset(self): return super().get_queryset().with_is_vulnerable().prefetch_related("vulnerabilities") @@ -345,7 +345,7 @@ class VulnerabilityViewSet(viewsets.ReadOnlyModelViewSet): lookup_field = "vulnerability_id" filter_backends = (filters.DjangoFilterBackend,) filterset_class = V2VulnerabilityFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [GroupUserRateThrottle, AnonRateThrottle] def get_queryset(self): """ @@ -381,7 +381,7 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet): ).distinct() serializer_class = V2VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [GroupUserRateThrottle, AnonRateThrottle] filterset_class = CPEFilterSet @action(detail=False, methods=["post"]) @@ -420,4 +420,4 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = V2VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = AliasFilterSet - throttle_classes = [StaffUserRateThrottle, AnonRateThrottle] + throttle_classes = [GroupUserRateThrottle, AnonRateThrottle] diff --git a/vulnerabilities/models.py b/vulnerabilities/models.py index ab01010d7..7a2705e16 100644 --- a/vulnerabilities/models.py +++ b/vulnerabilities/models.py @@ -28,6 +28,7 @@ from cwe2.mappings import xml_database_path from cwe2.weakness import Weakness as DBWeakness from django.contrib.auth import get_user_model +from django.contrib.auth.models import Group from django.contrib.auth.models import UserManager from django.core import exceptions from django.core.exceptions import ValidationError @@ -1472,6 +1473,10 @@ def create_api_user(self, username, first_name="", last_name="", **extra_fields) user.set_unusable_password() user.save() + # Assign the default basic group + default_group, _ = Group.objects.get_or_create(name="silver") + user.groups.add(default_group) + Token._default_manager.get_or_create(user=user) return user diff --git a/vulnerabilities/throttling.py b/vulnerabilities/throttling.py index 99b1d7756..ce3a17176 100644 --- a/vulnerabilities/throttling.py +++ b/vulnerabilities/throttling.py @@ -6,18 +6,31 @@ # See https://github.com/aboutcode-org/vulnerablecode for support or download. # See https://aboutcode.org for more information about nexB OSS projects. # + from rest_framework.exceptions import Throttled from rest_framework.throttling import UserRateThrottle from rest_framework.views import exception_handler -class StaffUserRateThrottle(UserRateThrottle): +class GroupUserRateThrottle(UserRateThrottle): + scope = "bronze" + def allow_request(self, request, view): - """ - Do not apply throttling for superusers and admins. - """ - if request.user.is_superuser or request.user.is_staff: - return True + user = request.user + + if user and user.is_authenticated: + if user.is_superuser or user.is_staff: + return True + + user_groups = user.groups.all() + if any([group.name == "gold" for group in user_groups]): + return True + + if any([group.name == "silver" for group in user_groups]): + self.scope = "silver" + + self.rate = self.THROTTLE_RATES.get(self.scope) + self.num_requests, self.duration = self.parse_rate(self.rate) return super().allow_request(request, view) diff --git a/vulnerablecode/settings.py b/vulnerablecode/settings.py index 6040f99b9..2db44bee8 100644 --- a/vulnerablecode/settings.py +++ b/vulnerablecode/settings.py @@ -190,12 +190,20 @@ LOGIN_REDIRECT_URL = "/" LOGOUT_REDIRECT_URL = "/" -REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {"anon": "3600/hour", "user": "10800/hour"} +REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = { + # No throttling for users in gold group. + "silver": "10800/hour", + "bronze": "7200/hour", + "anon": "3600/hour", +} if IS_TESTS: VULNERABLECODEIO_REQUIRE_AUTHENTICATION = False - REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {"anon": "10/day", "user": "20/day"} - + REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = { + "silver": "20/day", + "bronze": "15/day", + "anon": "10/day", + } USE_L10N = True @@ -235,9 +243,8 @@ "rest_framework.filters.SearchFilter", ), "DEFAULT_THROTTLE_CLASSES": [ - "vulnerabilities.throttling.StaffUserRateThrottle", + "vulnerabilities.throttling.GroupUserRateThrottle", "rest_framework.throttling.AnonRateThrottle", - "rest_framework.throttling.UserRateThrottle", ], "DEFAULT_THROTTLE_RATES": REST_FRAMEWORK_DEFAULT_THROTTLE_RATES, "EXCEPTION_HANDLER": "vulnerabilities.throttling.throttled_exception_handler", From 677ff99dd87b9c92327aa199d5728625bfff0c26 Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Wed, 11 Jun 2025 14:12:46 +0530 Subject: [PATCH 02/11] Add test for group based throttling Signed-off-by: Keshav Priyadarshi --- vulnerabilities/tests/test_api.py | 4 +- vulnerabilities/tests/test_api_v2.py | 4 +- vulnerabilities/tests/test_throttling.py | 58 +++++++++++++++++++----- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/vulnerabilities/tests/test_api.py b/vulnerabilities/tests/test_api.py index a5f80aa06..bad51a121 100644 --- a/vulnerabilities/tests/test_api.py +++ b/vulnerabilities/tests/test_api.py @@ -452,7 +452,7 @@ def add_aliases(vuln, aliases): class APIPerformanceTest(TestCase): def setUp(self): - self.user = ApiUser.objects.create_api_user(username="e@mail.com") + self.user = ApiUser.objects.create_api_user(username="e@mail.com", is_staff=True) self.auth = f"Token {self.user.auth_token.key}" self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) @@ -572,7 +572,7 @@ def test_api_packages_bulk_lookup(self): class APITestCasePackage(TestCase): def setUp(self): - self.user = ApiUser.objects.create_api_user(username="e@mail.com") + self.user = ApiUser.objects.create_api_user(username="e@mail.com", is_staff=True) self.auth = f"Token {self.user.auth_token.key}" self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) diff --git a/vulnerabilities/tests/test_api_v2.py b/vulnerabilities/tests/test_api_v2.py index 071a4450c..ff7f53bdf 100644 --- a/vulnerabilities/tests/test_api_v2.py +++ b/vulnerabilities/tests/test_api_v2.py @@ -61,7 +61,7 @@ def setUp(self): ) self.reference2.vulnerabilities.add(self.vuln2) - self.user = ApiUser.objects.create_api_user(username="e@mail.com") + self.user = ApiUser.objects.create_api_user(username="e@mail.com", is_staff=True) self.auth = f"Token {self.user.auth_token.key}" self.client = APIClient(enforce_csrf_checks=True) self.client.credentials(HTTP_AUTHORIZATION=self.auth) @@ -210,7 +210,7 @@ def setUp(self): self.package1.affected_by_vulnerabilities.add(self.vuln1) self.package2.fixing_vulnerabilities.add(self.vuln2) - self.user = ApiUser.objects.create_api_user(username="e@mail.com") + self.user = ApiUser.objects.create_api_user(username="e@mail.com", is_staff=True) self.auth = f"Token {self.user.auth_token.key}" self.client = APIClient(enforce_csrf_checks=True) self.client.credentials(HTTP_AUTHORIZATION=self.auth) diff --git a/vulnerabilities/tests/test_throttling.py b/vulnerabilities/tests/test_throttling.py index 174761045..8be404db2 100644 --- a/vulnerabilities/tests/test_throttling.py +++ b/vulnerabilities/tests/test_throttling.py @@ -9,6 +9,7 @@ import json +from django.contrib.auth.models import Group from django.core.cache import cache from rest_framework.test import APIClient from rest_framework.test import APITestCase @@ -16,18 +17,35 @@ from vulnerabilities.models import ApiUser -class ThrottleApiTests(APITestCase): +class GroupUserRateThrottleApiTests(APITestCase): def setUp(self): # Reset the api throttling to properly test the rate limit on anon users. # DRF stores throttling state in cache, clear cache to reset throttling. # See https://www.django-rest-framework.org/api-guide/throttling/#setting-up-the-cache cache.clear() - # create a basic user - self.user = ApiUser.objects.create_api_user(username="e@mail.com") - self.auth = f"Token {self.user.auth_token.key}" - self.csrf_client = APIClient(enforce_csrf_checks=True) - self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) + # User in bronze group + self.bronze_user = ApiUser.objects.create_api_user(username="bronze@mail.com") + bronze, _ = Group.objects.get_or_create(name="bronze") + self.bronze_user.groups.clear() + self.bronze_user.groups.add(bronze) + self.bronze_auth = f"Token {self.bronze_user.auth_token.key}" + self.bronze_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.bronze_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.bronze_auth) + + # User in silver group (default group for api user) + self.silver_user = ApiUser.objects.create_api_user(username="silver@mail.com") + self.silver_auth = f"Token {self.silver_user.auth_token.key}" + self.silver_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.silver_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.silver_auth) + + # User in gold group + self.gold_user = ApiUser.objects.create_api_user(username="gold@mail.com") + gold, _ = Group.objects.get_or_create(name="gold") + self.gold_user.groups.add(gold) + self.gold_auth = f"Token {self.gold_user.auth_token.key}" + self.gold_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.gold_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.gold_auth) # create a staff user self.staff_user = ApiUser.objects.create_api_user(username="staff@mail.com", is_staff=True) @@ -39,16 +57,34 @@ def setUp(self): self.csrf_client_anon_1 = APIClient(enforce_csrf_checks=True) def test_package_endpoint_throttling(self): - for i in range(0, 20): - response = self.csrf_client.get("/api/packages") + for i in range(0, 15): + response = self.bronze_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, 200) - response = self.staff_csrf_client.get("/api/packages") + + response = self.bronze_user_csrf_client.get("/api/packages") + # 429 - too many requests for bronze user + self.assertEqual(response.status_code, 429) + + for i in range(0, 20): + response = self.silver_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, 200) - response = self.csrf_client.get("/api/packages") - # 429 - too many requests for basic user + response = self.silver_user_csrf_client.get("/api/packages") + # 429 - too many requests for silver user self.assertEqual(response.status_code, 429) + for i in range(0, 30): + response = self.gold_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, 200) + + response = self.gold_user_csrf_client.get("/api/packages", format="json") + # 200 - gold user can access API unlimited times + self.assertEqual(response.status_code, 200) + + for i in range(0, 30): + response = self.staff_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, 200) + response = self.staff_csrf_client.get("/api/packages", format="json") # 200 - staff user can access API unlimited times self.assertEqual(response.status_code, 200) From 8e9607b8b79aa03868f0f9b248c42cf70a762fbb Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Fri, 13 Jun 2025 17:30:07 +0530 Subject: [PATCH 03/11] Throttle API requests based on user permissions Signed-off-by: Keshav Priyadarshi --- vulnerabilities/api.py | 6 ++--- vulnerabilities/api_extension.py | 10 ++++---- .../migrations/0093_alter_apiuser_options.py | 23 +++++++++++++++++++ vulnerabilities/models.py | 13 +++++------ vulnerabilities/throttling.py | 20 ++++++---------- vulnerablecode/settings.py | 15 +++--------- 6 files changed, 47 insertions(+), 40 deletions(-) create mode 100644 vulnerabilities/migrations/0093_alter_apiuser_options.py diff --git a/vulnerabilities/api.py b/vulnerabilities/api.py index d23dd7adb..50403583d 100644 --- a/vulnerabilities/api.py +++ b/vulnerabilities/api.py @@ -34,7 +34,7 @@ from vulnerabilities.models import get_purl_query_lookups from vulnerabilities.severity_systems import EPSS from vulnerabilities.severity_systems import SCORING_SYSTEMS -from vulnerabilities.throttling import GroupUserRateThrottle +from vulnerabilities.throttling import PermissionBasedUserRateThrottle from vulnerabilities.utils import get_severity_range @@ -471,7 +471,7 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = PackageSerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = PackageFilterSet - throttle_classes = [AnonRateThrottle, GroupUserRateThrottle] + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] def get_queryset(self): return super().get_queryset().with_is_vulnerable() @@ -688,7 +688,7 @@ def get_queryset(self): serializer_class = VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = VulnerabilityFilterSet - throttle_classes = [AnonRateThrottle, GroupUserRateThrottle] + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] class CPEFilterSet(filters.FilterSet): diff --git a/vulnerabilities/api_extension.py b/vulnerabilities/api_extension.py index df765137c..89ee644bf 100644 --- a/vulnerabilities/api_extension.py +++ b/vulnerabilities/api_extension.py @@ -33,7 +33,7 @@ from vulnerabilities.models import VulnerabilitySeverity from vulnerabilities.models import Weakness from vulnerabilities.models import get_purl_query_lookups -from vulnerabilities.throttling import GroupUserRateThrottle +from vulnerabilities.throttling import PermissionBasedUserRateThrottle class SerializerExcludeFieldsMixin: @@ -259,7 +259,7 @@ class V2PackageViewSet(viewsets.ReadOnlyModelViewSet): lookup_field = "purl" filter_backends = (filters.DjangoFilterBackend,) filterset_class = V2PackageFilterSet - throttle_classes = [GroupUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle] def get_queryset(self): return super().get_queryset().with_is_vulnerable().prefetch_related("vulnerabilities") @@ -345,7 +345,7 @@ class VulnerabilityViewSet(viewsets.ReadOnlyModelViewSet): lookup_field = "vulnerability_id" filter_backends = (filters.DjangoFilterBackend,) filterset_class = V2VulnerabilityFilterSet - throttle_classes = [GroupUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle] def get_queryset(self): """ @@ -381,7 +381,7 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet): ).distinct() serializer_class = V2VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) - throttle_classes = [GroupUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle] filterset_class = CPEFilterSet @action(detail=False, methods=["post"]) @@ -420,4 +420,4 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = V2VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = AliasFilterSet - throttle_classes = [GroupUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle] diff --git a/vulnerabilities/migrations/0093_alter_apiuser_options.py b/vulnerabilities/migrations/0093_alter_apiuser_options.py new file mode 100644 index 000000000..771a3779b --- /dev/null +++ b/vulnerabilities/migrations/0093_alter_apiuser_options.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.22 on 2025-06-13 08:07 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("vulnerabilities", "0092_pipelineschedule_pipelinerun"), + ] + + operations = [ + migrations.AlterModelOptions( + name="apiuser", + options={ + "permissions": [ + ("throttle_unrestricted", "Exempt from API throttling limits"), + ("throttle_18000_hour", "Can make 18000 API requests per hour"), + ("throttle_14400_hour", "Can make 14400 API requests per hour"), + ] + }, + ), + ] diff --git a/vulnerabilities/models.py b/vulnerabilities/models.py index 7a2705e16..6a792e80b 100644 --- a/vulnerabilities/models.py +++ b/vulnerabilities/models.py @@ -1473,10 +1473,6 @@ def create_api_user(self, username, first_name="", last_name="", **extra_fields) user.set_unusable_password() user.save() - # Assign the default basic group - default_group, _ = Group.objects.get_or_create(name="silver") - user.groups.add(default_group) - Token._default_manager.get_or_create(user=user) return user @@ -1494,14 +1490,17 @@ def _validate_username(self, email): class ApiUser(UserModel): - """ - A User proxy model to facilitate simplified admin API user creation. - """ + """A User proxy model to facilitate simplified admin API user creation.""" objects = ApiUserManager() class Meta: proxy = True + permissions = [ + ("throttle_unrestricted", "Exempt from API throttling limits"), + ("throttle_18000_hour", "Can make 18000 API requests per hour"), + ("throttle_14400_hour", "Can make 14400 API requests per hour"), + ] class ChangeLog(models.Model): diff --git a/vulnerabilities/throttling.py b/vulnerabilities/throttling.py index ce3a17176..d6b0840eb 100644 --- a/vulnerabilities/throttling.py +++ b/vulnerabilities/throttling.py @@ -12,25 +12,19 @@ from rest_framework.views import exception_handler -class GroupUserRateThrottle(UserRateThrottle): - scope = "bronze" - +class PermissionBasedUserRateThrottle(UserRateThrottle): def allow_request(self, request, view): user = request.user if user and user.is_authenticated: - if user.is_superuser or user.is_staff: - return True - - user_groups = user.groups.all() - if any([group.name == "gold" for group in user_groups]): + if user.has_perm("vulnerabilities.throttle_unrestricted"): return True + elif user.has_perm("vulnerabilities.throttle_18000_hour"): + self.rate = "18000/hour" + elif user.has_perm("vulnerabilities.throttle_14400_hour"): + self.rate = "14400/hour" - if any([group.name == "silver" for group in user_groups]): - self.scope = "silver" - - self.rate = self.THROTTLE_RATES.get(self.scope) - self.num_requests, self.duration = self.parse_rate(self.rate) + self.num_requests, self.duration = self.parse_rate(self.rate) return super().allow_request(request, view) diff --git a/vulnerablecode/settings.py b/vulnerablecode/settings.py index 2db44bee8..63810397c 100644 --- a/vulnerablecode/settings.py +++ b/vulnerablecode/settings.py @@ -190,20 +190,11 @@ LOGIN_REDIRECT_URL = "/" LOGOUT_REDIRECT_URL = "/" -REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = { - # No throttling for users in gold group. - "silver": "10800/hour", - "bronze": "7200/hour", - "anon": "3600/hour", -} +REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {"anon": "3600/hour", "user": "10800/hour"} + if IS_TESTS: VULNERABLECODEIO_REQUIRE_AUTHENTICATION = False - REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = { - "silver": "20/day", - "bronze": "15/day", - "anon": "10/day", - } USE_L10N = True @@ -243,7 +234,7 @@ "rest_framework.filters.SearchFilter", ), "DEFAULT_THROTTLE_CLASSES": [ - "vulnerabilities.throttling.GroupUserRateThrottle", + "vulnerabilities.throttling.PermissionBasedUserRateThrottle", "rest_framework.throttling.AnonRateThrottle", ], "DEFAULT_THROTTLE_RATES": REST_FRAMEWORK_DEFAULT_THROTTLE_RATES, From dba5b6c6c4b583f77bbb2cb81bda9a6df4d05123 Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Fri, 13 Jun 2025 17:37:56 +0530 Subject: [PATCH 04/11] Enable throttling for v2 API endpoint Signed-off-by: Keshav Priyadarshi --- vulnerabilities/api_v2.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vulnerabilities/api_v2.py b/vulnerabilities/api_v2.py index 4915dda63..e9f967b79 100644 --- a/vulnerabilities/api_v2.py +++ b/vulnerabilities/api_v2.py @@ -23,6 +23,7 @@ from rest_framework.permissions import BasePermission from rest_framework.response import Response from rest_framework.reverse import reverse +from rest_framework.throttling import AnonRateThrottle from vulnerabilities.models import AdvisoryReference from vulnerabilities.models import AdvisorySeverity @@ -38,6 +39,7 @@ from vulnerabilities.models import VulnerabilityReference from vulnerabilities.models import VulnerabilitySeverity from vulnerabilities.models import Weakness +from vulnerabilities.throttling import PermissionBasedUserRateThrottle class WeaknessV2Serializer(serializers.ModelSerializer): @@ -199,6 +201,7 @@ class VulnerabilityV2ViewSet(viewsets.ReadOnlyModelViewSet): queryset = Vulnerability.objects.all() serializer_class = VulnerabilityV2Serializer lookup_field = "vulnerability_id" + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] def get_queryset(self): queryset = super().get_queryset() @@ -394,6 +397,7 @@ class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = PackageV2Serializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = PackageV2FilterSet + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] def get_queryset(self): queryset = super().get_queryset() @@ -721,6 +725,7 @@ class CodeFixViewSet(viewsets.ReadOnlyModelViewSet): queryset = CodeFix.objects.all() serializer_class = CodeFixSerializer + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] def get_queryset(self): """ @@ -863,6 +868,7 @@ class PipelineScheduleV2ViewSet(CreateListRetrieveUpdateViewSet): serializer_class = PipelineScheduleAPISerializer lookup_field = "pipeline_id" lookup_value_regex = r"[\w.]+" + throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] def get_serializer_class(self): if self.action == "create": From 68a375c84636c642f02e4c83fa081e4b32eaa144 Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Fri, 13 Jun 2025 17:41:01 +0530 Subject: [PATCH 05/11] Add tests for user permission based API throttling Signed-off-by: Keshav Priyadarshi --- vulnerabilities/tests/test_api.py | 25 ++-- vulnerabilities/tests/test_api_v2.py | 62 ++++---- vulnerabilities/tests/test_throttling.py | 173 ++++++++++++++--------- 3 files changed, 150 insertions(+), 110 deletions(-) diff --git a/vulnerabilities/tests/test_api.py b/vulnerabilities/tests/test_api.py index bad51a121..9ed647099 100644 --- a/vulnerabilities/tests/test_api.py +++ b/vulnerabilities/tests/test_api.py @@ -11,6 +11,7 @@ import os from urllib.parse import quote +from django.core.cache import cache from django.test import TestCase from django.test import TransactionTestCase from django.test.client import RequestFactory @@ -452,10 +453,8 @@ def add_aliases(vuln, aliases): class APIPerformanceTest(TestCase): def setUp(self): - self.user = ApiUser.objects.create_api_user(username="e@mail.com", is_staff=True) - self.auth = f"Token {self.user.auth_token.key}" + cache.clear() self.csrf_client = APIClient(enforce_csrf_checks=True) - self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) # This setup creates the following data: # vulnerabilities: vul1, vul2, vul3 @@ -503,7 +502,7 @@ def setUp(self): set_as_fixing(package=self.pkg_2_13_2, vulnerability=self.vul1) def test_api_packages_all_num_queries(self): - with self.assertNumQueries(4): + with self.assertNumQueries(3): # There are 4 queries: # 1. SAVEPOINT # 2. Authenticating user @@ -519,22 +518,22 @@ def test_api_packages_all_num_queries(self): ] def test_api_packages_single_num_queries(self): - with self.assertNumQueries(8): + with self.assertNumQueries(7): self.csrf_client.get(f"/api/packages/{self.pkg_2_14_0_rc1.id}", format="json") def test_api_packages_single_with_purl_in_query_num_queries(self): - with self.assertNumQueries(9): + with self.assertNumQueries(8): self.csrf_client.get(f"/api/packages/?purl={self.pkg_2_14_0_rc1.purl}", format="json") def test_api_packages_single_with_purl_no_version_in_query_num_queries(self): - with self.assertNumQueries(64): + with self.assertNumQueries(63): self.csrf_client.get( f"/api/packages/?purl=pkg:maven/com.fasterxml.jackson.core/jackson-databind", format="json", ) def test_api_packages_bulk_search(self): - with self.assertNumQueries(45): + with self.assertNumQueries(44): packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1] purls = [p.purl for p in packages] @@ -547,7 +546,7 @@ def test_api_packages_bulk_search(self): ).json() def test_api_packages_with_lookup(self): - with self.assertNumQueries(14): + with self.assertNumQueries(13): data = {"purl": self.pkg_2_12_6.purl} resp = self.csrf_client.post( @@ -557,7 +556,7 @@ def test_api_packages_with_lookup(self): ).json() def test_api_packages_bulk_lookup(self): - with self.assertNumQueries(45): + with self.assertNumQueries(44): packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1] purls = [p.purl for p in packages] @@ -572,10 +571,8 @@ def test_api_packages_bulk_lookup(self): class APITestCasePackage(TestCase): def setUp(self): - self.user = ApiUser.objects.create_api_user(username="e@mail.com", is_staff=True) - self.auth = f"Token {self.user.auth_token.key}" + cache.clear() self.csrf_client = APIClient(enforce_csrf_checks=True) - self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth) # This setup creates the following data: # vulnerabilities: vul1, vul2, vul3 @@ -766,7 +763,7 @@ def test_api_with_wrong_namespace_filter(self): self.assertEqual(response["count"], 0) def test_api_with_all_vulnerable_packages(self): - with self.assertNumQueries(4): + with self.assertNumQueries(3): # There are 4 queries: # 1. SAVEPOINT # 2. Authenticating user diff --git a/vulnerabilities/tests/test_api_v2.py b/vulnerabilities/tests/test_api_v2.py index ff7f53bdf..6bdfa77f8 100644 --- a/vulnerabilities/tests/test_api_v2.py +++ b/vulnerabilities/tests/test_api_v2.py @@ -61,10 +61,8 @@ def setUp(self): ) self.reference2.vulnerabilities.add(self.vuln2) - self.user = ApiUser.objects.create_api_user(username="e@mail.com", is_staff=True) - self.auth = f"Token {self.user.auth_token.key}" + cache.clear() self.client = APIClient(enforce_csrf_checks=True) - self.client.credentials(HTTP_AUTHORIZATION=self.auth) def test_list_vulnerabilities(self): """ @@ -73,7 +71,7 @@ def test_list_vulnerabilities(self): """ url = reverse("vulnerability-v2-list") response = self.client.get(url, format="json") - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("results", response.data) @@ -88,7 +86,7 @@ def test_retrieve_vulnerability_detail(self): Test retrieving vulnerability details by vulnerability_id. """ url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-1234"}) - with self.assertNumQueries(8): + with self.assertNumQueries(7): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["vulnerability_id"], "VCID-1234") @@ -102,7 +100,7 @@ def test_filter_vulnerability_by_vulnerability_id(self): Test filtering vulnerabilities by vulnerability_id. """ url = reverse("vulnerability-v2-list") - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.get(url, {"vulnerability_id": "VCID-1234"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["vulnerability_id"], "VCID-1234") @@ -112,7 +110,7 @@ def test_filter_vulnerability_by_alias(self): Test filtering vulnerabilities by alias. """ url = reverse("vulnerability-v2-list") - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get(url, {"alias": "CVE-2021-5678"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("results", response.data) @@ -127,7 +125,7 @@ def test_filter_vulnerabilities_multiple_ids(self): Test filtering vulnerabilities by multiple vulnerability_ids. """ url = reverse("vulnerability-v2-list") - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get( url, {"vulnerability_id": ["VCID-1234", "VCID-5678"]}, format="json" ) @@ -139,7 +137,7 @@ def test_filter_vulnerabilities_multiple_aliases(self): Test filtering vulnerabilities by multiple aliases. """ url = reverse("vulnerability-v2-list") - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get( url, {"alias": ["CVE-2021-1234", "CVE-2021-5678"]}, format="json" ) @@ -152,7 +150,7 @@ def test_invalid_vulnerability_id(self): Should return 404 Not Found. """ url = reverse("vulnerability-v2-detail", kwargs={"vulnerability_id": "VCID-9999"}) - with self.assertNumQueries(5): + with self.assertNumQueries(4): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) @@ -210,10 +208,8 @@ def setUp(self): self.package1.affected_by_vulnerabilities.add(self.vuln1) self.package2.fixing_vulnerabilities.add(self.vuln2) - self.user = ApiUser.objects.create_api_user(username="e@mail.com", is_staff=True) - self.auth = f"Token {self.user.auth_token.key}" + cache.clear() self.client = APIClient(enforce_csrf_checks=True) - self.client.credentials(HTTP_AUTHORIZATION=self.auth) def test_list_packages(self): """ @@ -221,7 +217,7 @@ def test_list_packages(self): Should return a list of packages with their details and associated vulnerabilities. """ url = reverse("package-v2-list") - with self.assertNumQueries(32): + with self.assertNumQueries(31): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("results", response.data) @@ -243,7 +239,7 @@ def test_filter_packages_by_purl(self): Test filtering packages by one or more PURLs. """ url = reverse("package-v2-list") - with self.assertNumQueries(20): + with self.assertNumQueries(19): response = self.client.get(url, {"purl": "pkg:pypi/django@3.2"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["packages"]), 1) @@ -254,7 +250,7 @@ def test_filter_packages_by_affected_vulnerability(self): Test filtering packages by affected_by_vulnerability. """ url = reverse("package-v2-list") - with self.assertNumQueries(20): + with self.assertNumQueries(19): response = self.client.get( url, {"affected_by_vulnerability": "VCID-1234"}, format="json" ) @@ -267,7 +263,7 @@ def test_filter_packages_by_fixing_vulnerability(self): Test filtering packages by fixing_vulnerability. """ url = reverse("package-v2-list") - with self.assertNumQueries(18): + with self.assertNumQueries(17): response = self.client.get(url, {"fixing_vulnerability": "VCID-5678"}, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["results"]["packages"]), 1) @@ -356,7 +352,7 @@ def test_invalid_vulnerability_filter(self): Should return an empty list. """ url = reverse("package-v2-list") - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.get( url, {"affected_by_vulnerability": "VCID-9999"}, format="json" ) @@ -369,7 +365,7 @@ def test_invalid_purl_filter(self): Should return an empty list. """ url = reverse("package-v2-list") - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.get( url, {"purl": "pkg:nonexistent/package@1.0.0"}, format="json" ) @@ -421,7 +417,7 @@ def test_bulk_lookup_with_valid_purls(self): """ url = reverse("package-v2-bulk-lookup") data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]} - with self.assertNumQueries(28): + with self.assertNumQueries(27): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("packages", response.data) @@ -446,7 +442,7 @@ def test_bulk_lookup_with_invalid_purls(self): """ url = reverse("package-v2-bulk-lookup") data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]} - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since the packages don't exist, the response should be empty @@ -460,7 +456,7 @@ def test_bulk_lookup_with_empty_purls(self): """ url = reverse("package-v2-bulk-lookup") data = {"purls": []} - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) @@ -474,7 +470,7 @@ def test_bulk_search_with_valid_purls(self): """ url = reverse("package-v2-bulk-search") data = {"purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"]} - with self.assertNumQueries(28): + with self.assertNumQueries(27): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("packages", response.data) @@ -502,7 +498,7 @@ def test_bulk_search_with_purl_only_true(self): "purls": ["pkg:pypi/django@3.2", "pkg:npm/lodash@4.17.20"], "purl_only": True, } - with self.assertNumQueries(17): + with self.assertNumQueries(16): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since purl_only=True, response should be a list of PURLs @@ -529,7 +525,7 @@ def test_bulk_search_with_plain_purl_true(self): "purls": ["pkg:pypi/django@3.2", "pkg:pypi/django@3.2?extension=tar.gz"], "plain_purl": True, } - with self.assertNumQueries(16): + with self.assertNumQueries(15): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("packages", response.data) @@ -550,7 +546,7 @@ def test_bulk_search_with_purl_only_and_plain_purl_true(self): "purl_only": True, "plain_purl": True, } - with self.assertNumQueries(11): + with self.assertNumQueries(10): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Response should be a list of plain PURLs @@ -566,7 +562,7 @@ def test_bulk_search_with_invalid_purls(self): """ url = reverse("package-v2-bulk-search") data = {"purls": ["pkg:pypi/nonexistent@1.0.0", "pkg:npm/unknown@0.0.1"]} - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since the packages don't exist, the response should be empty @@ -580,7 +576,7 @@ def test_bulk_search_with_empty_purls(self): """ url = reverse("package-v2-bulk-search") data = {"purls": []} - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) @@ -592,7 +588,7 @@ def test_all_vulnerable_packages(self): Test the 'all' endpoint that returns all vulnerable package URLs. """ url = reverse("package-v2-all") - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # Since package1 is vulnerable, it should be returned @@ -606,7 +602,7 @@ def test_lookup_with_valid_purl(self): """ url = reverse("package-v2-lookup") data = {"purl": "pkg:pypi/django@3.2"} - with self.assertNumQueries(13): + with self.assertNumQueries(12): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(1, len(response.data)) @@ -635,7 +631,7 @@ def test_lookup_with_invalid_purl(self): """ url = reverse("package-v2-lookup") data = {"purl": "pkg:pypi/nonexistent@1.0.0"} - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # No packages or vulnerabilities should be returned @@ -648,7 +644,7 @@ def test_lookup_with_missing_purl(self): """ url = reverse("package-v2-lookup") data = {} - with self.assertNumQueries(3): + with self.assertNumQueries(2): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("error", response.data) @@ -662,7 +658,7 @@ def test_lookup_with_invalid_purl_format(self): """ url = reverse("package-v2-lookup") data = {"purl": "invalid_purl_format"} - with self.assertNumQueries(4): + with self.assertNumQueries(3): response = self.client.post(url, data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) # No packages or vulnerabilities should be returned diff --git a/vulnerabilities/tests/test_throttling.py b/vulnerabilities/tests/test_throttling.py index 8be404db2..62fbde4ef 100644 --- a/vulnerabilities/tests/test_throttling.py +++ b/vulnerabilities/tests/test_throttling.py @@ -9,94 +9,141 @@ import json -from django.contrib.auth.models import Group +from django.contrib.auth.models import Permission from django.core.cache import cache +from rest_framework import status from rest_framework.test import APIClient from rest_framework.test import APITestCase +from rest_framework.throttling import AnonRateThrottle +from vulnerabilities.api import PermissionBasedUserRateThrottle from vulnerabilities.models import ApiUser -class GroupUserRateThrottleApiTests(APITestCase): +def simulate_throttle_usage( + url, + client, + mock_use_count, + throttle_cls=PermissionBasedUserRateThrottle, +): + throttle = throttle_cls() + request = client.get(url).wsgi_request + + if cache_key := throttle.get_cache_key(request, view=None): + now = throttle.timer() + cache.set(cache_key, [now] * mock_use_count) + + +class PermissionBasedRateThrottleApiTests(APITestCase): def setUp(self): # Reset the api throttling to properly test the rate limit on anon users. # DRF stores throttling state in cache, clear cache to reset throttling. # See https://www.django-rest-framework.org/api-guide/throttling/#setting-up-the-cache cache.clear() - # User in bronze group - self.bronze_user = ApiUser.objects.create_api_user(username="bronze@mail.com") - bronze, _ = Group.objects.get_or_create(name="bronze") - self.bronze_user.groups.clear() - self.bronze_user.groups.add(bronze) - self.bronze_auth = f"Token {self.bronze_user.auth_token.key}" - self.bronze_user_csrf_client = APIClient(enforce_csrf_checks=True) - self.bronze_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.bronze_auth) - - # User in silver group (default group for api user) - self.silver_user = ApiUser.objects.create_api_user(username="silver@mail.com") - self.silver_auth = f"Token {self.silver_user.auth_token.key}" - self.silver_user_csrf_client = APIClient(enforce_csrf_checks=True) - self.silver_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.silver_auth) - - # User in gold group - self.gold_user = ApiUser.objects.create_api_user(username="gold@mail.com") - gold, _ = Group.objects.get_or_create(name="gold") - self.gold_user.groups.add(gold) - self.gold_auth = f"Token {self.gold_user.auth_token.key}" - self.gold_user_csrf_client = APIClient(enforce_csrf_checks=True) - self.gold_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.gold_auth) - - # create a staff user - self.staff_user = ApiUser.objects.create_api_user(username="staff@mail.com", is_staff=True) - self.staff_auth = f"Token {self.staff_user.auth_token.key}" - self.staff_csrf_client = APIClient(enforce_csrf_checks=True) - self.staff_csrf_client.credentials(HTTP_AUTHORIZATION=self.staff_auth) + permission_14400 = Permission.objects.get(codename="throttle_14400_hour") + permission_18000 = Permission.objects.get(codename="throttle_18000_hour") + permission_unrestricted = Permission.objects.get(codename="throttle_unrestricted") + + # basic user without any special throttling perm + self.basic_user = ApiUser.objects.create_api_user(username="a@mail.com") + self.basic_user_auth = f"Token {self.basic_user.auth_token.key}" + self.basic_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.basic_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.basic_user_auth) + + # 14400/hour permission + self.th_14400_user = ApiUser.objects.create_api_user(username="b@mail.com") + self.th_14400_user.user_permissions.add(permission_14400) + self.th_14400_user_auth = f"Token {self.th_14400_user.auth_token.key}" + self.th_14400_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_14400_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_14400_user_auth) + + # 18000/hour permission + self.th_18000_user = ApiUser.objects.create_api_user(username="c@mail.com") + self.th_18000_user.user_permissions.add(permission_18000) + self.th_18000_user_auth = f"Token {self.th_18000_user.auth_token.key}" + self.th_18000_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_18000_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_18000_user_auth) + + # unrestricted throttling perm + self.th_unrestricted_user = ApiUser.objects.create_api_user(username="d@mail.com") + self.th_unrestricted_user.user_permissions.add(permission_unrestricted) + self.th_unrestricted_user_auth = f"Token {self.th_unrestricted_user.auth_token.key}" + self.th_unrestricted_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_unrestricted_user_csrf_client.credentials( + HTTP_AUTHORIZATION=self.th_unrestricted_user_auth + ) self.csrf_client_anon = APIClient(enforce_csrf_checks=True) self.csrf_client_anon_1 = APIClient(enforce_csrf_checks=True) - def test_package_endpoint_throttling(self): - for i in range(0, 15): - response = self.bronze_user_csrf_client.get("/api/packages") - self.assertEqual(response.status_code, 200) + def test_basic_user_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.basic_user_csrf_client, + mock_use_count=10799, + ) - response = self.bronze_user_csrf_client.get("/api/packages") - # 429 - too many requests for bronze user - self.assertEqual(response.status_code, 429) + response = self.basic_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) - for i in range(0, 20): - response = self.silver_user_csrf_client.get("/api/packages") - self.assertEqual(response.status_code, 200) + # exhausted 10800/hr allowed requests for basic user. + response = self.basic_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) - response = self.silver_user_csrf_client.get("/api/packages") - # 429 - too many requests for silver user - self.assertEqual(response.status_code, 429) + def test_user_with_14400_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_14400_user_csrf_client, + mock_use_count=14399, + ) - for i in range(0, 30): - response = self.gold_user_csrf_client.get("/api/packages") - self.assertEqual(response.status_code, 200) + response = self.th_14400_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) - response = self.gold_user_csrf_client.get("/api/packages", format="json") - # 200 - gold user can access API unlimited times - self.assertEqual(response.status_code, 200) + # exhausted 14400/hr allowed requests for user with 14400 perm. + response = self.th_14400_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) - for i in range(0, 30): - response = self.staff_csrf_client.get("/api/packages") - self.assertEqual(response.status_code, 200) + def test_user_with_18000_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_18000_user_csrf_client, + mock_use_count=17999, + ) - response = self.staff_csrf_client.get("/api/packages", format="json") - # 200 - staff user can access API unlimited times - self.assertEqual(response.status_code, 200) + response = self.th_18000_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) - # A anonymous user can only access /packages endpoint 10 times a day - for _i in range(0, 10): - response = self.csrf_client_anon.get("/api/packages") - self.assertEqual(response.status_code, 200) + # exhausted 18000/hr allowed requests for user with 18000 perm. + response = self.th_18000_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + + def test_user_with_unrestricted_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_unrestricted_user_csrf_client, + mock_use_count=20000, + ) + + # no throttling for user with unrestricted perm. + response = self.th_unrestricted_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + def test_anon_throttling(self): + simulate_throttle_usage( + throttle_cls=AnonRateThrottle, + url="/api/packages", + client=self.csrf_client_anon, + mock_use_count=3599, + ) response = self.csrf_client_anon.get("/api/packages") - # 429 - too many requests for anon user - self.assertEqual(response.status_code, 429) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # exhausted 3600/hr allowed requests for anon. + response = self.csrf_client_anon.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertEqual( response.data.get("message"), "Your request has been throttled. Please contact support@nexb.com", @@ -104,7 +151,7 @@ def test_package_endpoint_throttling(self): response = self.csrf_client_anon.get("/api/vulnerabilities") # 429 - too many requests for anon user - self.assertEqual(response.status_code, 429) + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertEqual( response.data.get("message"), "Your request has been throttled. Please contact support@nexb.com", @@ -116,7 +163,7 @@ def test_package_endpoint_throttling(self): "/api/packages/bulk_search", data=data, content_type="application/json" ) # 429 - too many requests for anon user - self.assertEqual(response.status_code, 429) + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertEqual( response.data.get("message"), "Your request has been throttled. Please contact support@nexb.com", From f3c04370e5d1428484ffd7957384f44f2762f509 Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Fri, 13 Jun 2025 17:42:15 +0530 Subject: [PATCH 06/11] Enable admin login page Signed-off-by: Keshav Priyadarshi --- vulnerablecode/urls.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vulnerablecode/urls.py b/vulnerablecode/urls.py index 245b8e917..81ba5cadb 100644 --- a/vulnerablecode/urls.py +++ b/vulnerablecode/urls.py @@ -171,10 +171,10 @@ def __init__(self, *args, **kwargs): TemplateView.as_view(template_name="tos.html"), name="api_tos", ), - # path( - # "admin/", - # admin.site.urls, - # ), + path( + "admin/", + admin.site.urls, + ), ] if DEBUG: From 94dd104efa747545c23c526b37f413f1fa9add53 Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Fri, 13 Jun 2025 18:34:49 +0530 Subject: [PATCH 07/11] Add perm to demote user to anon throttle rate Signed-off-by: Keshav Priyadarshi --- .../migrations/0093_alter_apiuser_options.py | 9 +++---- vulnerabilities/models.py | 7 +++--- vulnerabilities/tests/test_throttling.py | 24 ++++++++++++++++++- vulnerabilities/throttling.py | 8 +++++++ 4 files changed, 40 insertions(+), 8 deletions(-) diff --git a/vulnerabilities/migrations/0093_alter_apiuser_options.py b/vulnerabilities/migrations/0093_alter_apiuser_options.py index 771a3779b..9709439cc 100644 --- a/vulnerabilities/migrations/0093_alter_apiuser_options.py +++ b/vulnerabilities/migrations/0093_alter_apiuser_options.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.22 on 2025-06-13 08:07 +# Generated by Django 4.2.22 on 2025-06-13 12:44 from django.db import migrations @@ -14,9 +14,10 @@ class Migration(migrations.Migration): name="apiuser", options={ "permissions": [ - ("throttle_unrestricted", "Exempt from API throttling limits"), - ("throttle_18000_hour", "Can make 18000 API requests per hour"), - ("throttle_14400_hour", "Can make 14400 API requests per hour"), + ("throttle_unrestricted", "Can make api requests without throttling limits"), + ("throttle_18000_hour", "Can make 18000 api requests per hour"), + ("throttle_14400_hour", "Can make 14400 api requests per hour"), + ("throttle_3600_hour", "Can make 3600 api requests per hour"), ] }, ), diff --git a/vulnerabilities/models.py b/vulnerabilities/models.py index 6a792e80b..777781e40 100644 --- a/vulnerabilities/models.py +++ b/vulnerabilities/models.py @@ -1497,9 +1497,10 @@ class ApiUser(UserModel): class Meta: proxy = True permissions = [ - ("throttle_unrestricted", "Exempt from API throttling limits"), - ("throttle_18000_hour", "Can make 18000 API requests per hour"), - ("throttle_14400_hour", "Can make 14400 API requests per hour"), + ("throttle_unrestricted", "Can make api requests without throttling limits"), + ("throttle_18000_hour", "Can make 18000 api requests per hour"), + ("throttle_14400_hour", "Can make 14400 api requests per hour"), + ("throttle_3600_hour", "Can make 3600 api requests per hour"), ] diff --git a/vulnerabilities/tests/test_throttling.py b/vulnerabilities/tests/test_throttling.py index 62fbde4ef..d89b69c11 100644 --- a/vulnerabilities/tests/test_throttling.py +++ b/vulnerabilities/tests/test_throttling.py @@ -41,10 +41,18 @@ def setUp(self): # See https://www.django-rest-framework.org/api-guide/throttling/#setting-up-the-cache cache.clear() + permission_3600 = Permission.objects.get(codename="throttle_3600_hour") permission_14400 = Permission.objects.get(codename="throttle_14400_hour") permission_18000 = Permission.objects.get(codename="throttle_18000_hour") permission_unrestricted = Permission.objects.get(codename="throttle_unrestricted") + # user with 3600/hour permission + self.th_3600_user = ApiUser.objects.create_api_user(username="z@mail.com") + self.th_3600_user.user_permissions.add(permission_3600) + self.th_3600_user_auth = f"Token {self.th_3600_user.auth_token.key}" + self.th_3600_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_3600_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_3600_user_auth) + # basic user without any special throttling perm self.basic_user = ApiUser.objects.create_api_user(username="a@mail.com") self.basic_user_auth = f"Token {self.basic_user.auth_token.key}" @@ -77,6 +85,20 @@ def setUp(self): self.csrf_client_anon = APIClient(enforce_csrf_checks=True) self.csrf_client_anon_1 = APIClient(enforce_csrf_checks=True) + def test_user_with_3600_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_3600_user_csrf_client, + mock_use_count=3599, + ) + + response = self.th_3600_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + # exhausted 3600/hr allowed requests. + response = self.th_3600_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) + def test_basic_user_throttling(self): simulate_throttle_usage( url="/api/packages", @@ -87,7 +109,7 @@ def test_basic_user_throttling(self): response = self.basic_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_200_OK) - # exhausted 10800/hr allowed requests for basic user. + # exhausted 10800/hr allowed requests. response = self.basic_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) diff --git a/vulnerabilities/throttling.py b/vulnerabilities/throttling.py index d6b0840eb..fd96acdef 100644 --- a/vulnerabilities/throttling.py +++ b/vulnerabilities/throttling.py @@ -13,6 +13,12 @@ class PermissionBasedUserRateThrottle(UserRateThrottle): + """ + Throttle authenticated users based on their assigned permissions. + If no throttling permission is assigned, default to rate for `user` + scope provided via `DEFAULT_THROTTLE_RATES` in settings.py. + """ + def allow_request(self, request, view): user = request.user @@ -23,6 +29,8 @@ def allow_request(self, request, view): self.rate = "18000/hour" elif user.has_perm("vulnerabilities.throttle_14400_hour"): self.rate = "14400/hour" + elif user.has_perm("vulnerabilities.throttle_3600_hour"): + self.rate = "3600/hour" self.num_requests, self.duration = self.parse_rate(self.rate) From 2f03f11503f6f5997e80af57cd2b87d18cf0cdd6 Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Sat, 21 Jun 2025 00:38:57 +0530 Subject: [PATCH 08/11] Add custom group admin with user selection Signed-off-by: Keshav Priyadarshi --- vulnerabilities/admin.py | 50 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/vulnerabilities/admin.py b/vulnerabilities/admin.py index eecef0276..176e3c0c0 100644 --- a/vulnerabilities/admin.py +++ b/vulnerabilities/admin.py @@ -9,6 +9,10 @@ from django import forms from django.contrib import admin +from django.contrib.admin.widgets import FilteredSelectMultiple +from django.contrib.auth.admin import GroupAdmin as BasicGroupAdmin +from django.contrib.auth.models import Group +from django.contrib.auth.models import User from django.core.validators import validate_email from vulnerabilities.models import ApiUser @@ -97,3 +101,49 @@ def get_form(self, request, obj=None, **kwargs): defaults["form"] = self.add_form defaults.update(kwargs) return super().get_form(request, obj, **defaults) + + +class GroupWithUsersForm(forms.ModelForm): + users = forms.ModelMultipleChoiceField( + queryset=User.objects.all(), + required=False, + widget=FilteredSelectMultiple("Users", is_stacked=False), + label="Users", + ) + + class Meta: + model = Group + fields = "__all__" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.fields["users"].label_from_instance = lambda user: ( + f"{user.username} | {user.email}" if user.email else user.username + ) + if self.instance.pk: + self.fields["users"].initial = self.instance.user_set.all() + + def save(self, commit=True): + group = super().save(commit=commit) + self.save_m2m() + group.user_set.set(self.cleaned_data["users"]) + return group + + +admin.site.unregister(Group) + + +@admin.register(Group) +class GroupAdmin(admin.ModelAdmin): + form = GroupWithUsersForm + search_fields = ("name",) + ordering = ("name",) + filter_horizontal = ("permissions",) + + def formfield_for_manytomany(self, db_field, request=None, **kwargs): + if db_field.name == "permissions": + qs = kwargs.get("queryset", db_field.remote_field.model.objects) + # Avoid a major performance hit resolving permission names which + # triggers a content_type load: + kwargs["queryset"] = qs.select_related("content_type") + return super().formfield_for_manytomany(db_field, request=request, **kwargs) From 103afbbb2e0f6566fe9c061a4380658ca97000dc Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Thu, 26 Jun 2025 14:20:32 +0530 Subject: [PATCH 09/11] Add high, medium, and low tier based throttling permissions Signed-off-by: Keshav Priyadarshi --- vulnerabilities/api.py | 5 +- vulnerabilities/api_extension.py | 9 +- .../migrations/0093_alter_apiuser_options.py | 22 +++-- vulnerabilities/models.py | 20 ++++- vulnerabilities/tests/test_throttling.py | 88 +++++++++---------- vulnerabilities/throttling.py | 40 ++++++--- vulnerablecode/settings.py | 13 ++- 7 files changed, 118 insertions(+), 79 deletions(-) diff --git a/vulnerabilities/api.py b/vulnerabilities/api.py index 50403583d..d994b297d 100644 --- a/vulnerabilities/api.py +++ b/vulnerabilities/api.py @@ -22,7 +22,6 @@ from rest_framework import viewsets from rest_framework.decorators import action from rest_framework.response import Response -from rest_framework.throttling import AnonRateThrottle from vulnerabilities.models import Alias from vulnerabilities.models import Exploit @@ -471,7 +470,7 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = PackageSerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = PackageFilterSet - throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] def get_queryset(self): return super().get_queryset().with_is_vulnerable() @@ -688,7 +687,7 @@ def get_queryset(self): serializer_class = VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = VulnerabilityFilterSet - throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] class CPEFilterSet(filters.FilterSet): diff --git a/vulnerabilities/api_extension.py b/vulnerabilities/api_extension.py index 89ee644bf..01d98ca99 100644 --- a/vulnerabilities/api_extension.py +++ b/vulnerabilities/api_extension.py @@ -23,7 +23,6 @@ from rest_framework.serializers import ModelSerializer from rest_framework.serializers import Serializer from rest_framework.serializers import ValidationError -from rest_framework.throttling import AnonRateThrottle from vulnerabilities.api import BaseResourceSerializer from vulnerabilities.models import Exploit @@ -259,7 +258,7 @@ class V2PackageViewSet(viewsets.ReadOnlyModelViewSet): lookup_field = "purl" filter_backends = (filters.DjangoFilterBackend,) filterset_class = V2PackageFilterSet - throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] def get_queryset(self): return super().get_queryset().with_is_vulnerable().prefetch_related("vulnerabilities") @@ -345,7 +344,7 @@ class VulnerabilityViewSet(viewsets.ReadOnlyModelViewSet): lookup_field = "vulnerability_id" filter_backends = (filters.DjangoFilterBackend,) filterset_class = V2VulnerabilityFilterSet - throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] def get_queryset(self): """ @@ -381,7 +380,7 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet): ).distinct() serializer_class = V2VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) - throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] filterset_class = CPEFilterSet @action(detail=False, methods=["post"]) @@ -420,4 +419,4 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet): serializer_class = V2VulnerabilitySerializer filter_backends = (filters.DjangoFilterBackend,) filterset_class = AliasFilterSet - throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle] + throttle_classes = [PermissionBasedUserRateThrottle] diff --git a/vulnerabilities/migrations/0093_alter_apiuser_options.py b/vulnerabilities/migrations/0093_alter_apiuser_options.py index 9709439cc..61c5b183d 100644 --- a/vulnerabilities/migrations/0093_alter_apiuser_options.py +++ b/vulnerabilities/migrations/0093_alter_apiuser_options.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.22 on 2025-06-13 12:44 +# Generated by Django 4.2.22 on 2025-06-25 18:56 from django.db import migrations @@ -14,10 +14,22 @@ class Migration(migrations.Migration): name="apiuser", options={ "permissions": [ - ("throttle_unrestricted", "Can make api requests without throttling limits"), - ("throttle_18000_hour", "Can make 18000 api requests per hour"), - ("throttle_14400_hour", "Can make 14400 api requests per hour"), - ("throttle_3600_hour", "Can make 3600 api requests per hour"), + ( + "throttle_3_unrestricted", + "Can make unlimited API requests without any throttling limits", + ), + ( + "throttle_2_high", + "Can make high number of API requests with minimal throttling", + ), + ( + "throttle_1_medium", + "Can make medium number of API requests with standard throttling", + ), + ( + "throttle_0_low", + "Can make low number of API requests with strict throttling", + ), ] }, ), diff --git a/vulnerabilities/models.py b/vulnerabilities/models.py index 777781e40..e1c656e2b 100644 --- a/vulnerabilities/models.py +++ b/vulnerabilities/models.py @@ -1497,10 +1497,22 @@ class ApiUser(UserModel): class Meta: proxy = True permissions = [ - ("throttle_unrestricted", "Can make api requests without throttling limits"), - ("throttle_18000_hour", "Can make 18000 api requests per hour"), - ("throttle_14400_hour", "Can make 14400 api requests per hour"), - ("throttle_3600_hour", "Can make 3600 api requests per hour"), + ( + "throttle_3_unrestricted", + "Can make unlimited API requests without any throttling limits", + ), + ( + "throttle_2_high", + "Can make high number of API requests with minimal throttling", + ), + ( + "throttle_1_medium", + "Can make medium number of API requests with standard throttling", + ), + ( + "throttle_0_low", + "Can make low number of API requests with strict throttling", + ), ] diff --git a/vulnerabilities/tests/test_throttling.py b/vulnerabilities/tests/test_throttling.py index d89b69c11..25af231f8 100644 --- a/vulnerabilities/tests/test_throttling.py +++ b/vulnerabilities/tests/test_throttling.py @@ -14,22 +14,17 @@ from rest_framework import status from rest_framework.test import APIClient from rest_framework.test import APITestCase -from rest_framework.throttling import AnonRateThrottle from vulnerabilities.api import PermissionBasedUserRateThrottle from vulnerabilities.models import ApiUser -def simulate_throttle_usage( - url, - client, - mock_use_count, - throttle_cls=PermissionBasedUserRateThrottle, -): - throttle = throttle_cls() +def simulate_throttle_usage(url, client, mock_use_count): + throttle = PermissionBasedUserRateThrottle() request = client.get(url).wsgi_request if cache_key := throttle.get_cache_key(request, view=None): + print(cache_key) now = throttle.timer() cache.set(cache_key, [now] * mock_use_count) @@ -41,17 +36,17 @@ def setUp(self): # See https://www.django-rest-framework.org/api-guide/throttling/#setting-up-the-cache cache.clear() - permission_3600 = Permission.objects.get(codename="throttle_3600_hour") - permission_14400 = Permission.objects.get(codename="throttle_14400_hour") - permission_18000 = Permission.objects.get(codename="throttle_18000_hour") - permission_unrestricted = Permission.objects.get(codename="throttle_unrestricted") + permission_low = Permission.objects.get(codename="throttle_0_low") + permission_medium = Permission.objects.get(codename="throttle_1_medium") + permission_high = Permission.objects.get(codename="throttle_2_high") + permission_unrestricted = Permission.objects.get(codename="throttle_3_unrestricted") - # user with 3600/hour permission - self.th_3600_user = ApiUser.objects.create_api_user(username="z@mail.com") - self.th_3600_user.user_permissions.add(permission_3600) - self.th_3600_user_auth = f"Token {self.th_3600_user.auth_token.key}" - self.th_3600_user_csrf_client = APIClient(enforce_csrf_checks=True) - self.th_3600_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_3600_user_auth) + # user with low permission + self.th_low_user = ApiUser.objects.create_api_user(username="z@mail.com") + self.th_low_user.user_permissions.add(permission_low) + self.th_low_user_auth = f"Token {self.th_low_user.auth_token.key}" + self.th_low_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_low_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_low_user_auth) # basic user without any special throttling perm self.basic_user = ApiUser.objects.create_api_user(username="a@mail.com") @@ -59,19 +54,19 @@ def setUp(self): self.basic_user_csrf_client = APIClient(enforce_csrf_checks=True) self.basic_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.basic_user_auth) - # 14400/hour permission - self.th_14400_user = ApiUser.objects.create_api_user(username="b@mail.com") - self.th_14400_user.user_permissions.add(permission_14400) - self.th_14400_user_auth = f"Token {self.th_14400_user.auth_token.key}" - self.th_14400_user_csrf_client = APIClient(enforce_csrf_checks=True) - self.th_14400_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_14400_user_auth) + # medium permission + self.th_medium_user = ApiUser.objects.create_api_user(username="b@mail.com") + self.th_medium_user.user_permissions.add(permission_medium) + self.th_medium_user_auth = f"Token {self.th_medium_user.auth_token.key}" + self.th_medium_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_medium_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_medium_user_auth) - # 18000/hour permission - self.th_18000_user = ApiUser.objects.create_api_user(username="c@mail.com") - self.th_18000_user.user_permissions.add(permission_18000) - self.th_18000_user_auth = f"Token {self.th_18000_user.auth_token.key}" - self.th_18000_user_csrf_client = APIClient(enforce_csrf_checks=True) - self.th_18000_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_18000_user_auth) + # high permission + self.th_high_user = ApiUser.objects.create_api_user(username="c@mail.com") + self.th_high_user.user_permissions.add(permission_high) + self.th_high_user_auth = f"Token {self.th_high_user.auth_token.key}" + self.th_high_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_high_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_high_user_auth) # unrestricted throttling perm self.th_unrestricted_user = ApiUser.objects.create_api_user(username="d@mail.com") @@ -85,60 +80,60 @@ def setUp(self): self.csrf_client_anon = APIClient(enforce_csrf_checks=True) self.csrf_client_anon_1 = APIClient(enforce_csrf_checks=True) - def test_user_with_3600_perm_throttling(self): + def test_user_with_low_perm_throttling(self): simulate_throttle_usage( url="/api/packages", - client=self.th_3600_user_csrf_client, - mock_use_count=3599, + client=self.th_low_user_csrf_client, + mock_use_count=10799, ) - response = self.th_3600_user_csrf_client.get("/api/packages") + response = self.th_low_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_200_OK) - # exhausted 3600/hr allowed requests. - response = self.th_3600_user_csrf_client.get("/api/packages") + # exhausted 10800/hr allowed requests. + response = self.th_low_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) def test_basic_user_throttling(self): simulate_throttle_usage( url="/api/packages", client=self.basic_user_csrf_client, - mock_use_count=10799, + mock_use_count=14399, ) response = self.basic_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_200_OK) - # exhausted 10800/hr allowed requests. + # exhausted 14400/hr allowed requests. response = self.basic_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) - def test_user_with_14400_perm_throttling(self): + def test_user_with_medium_perm_throttling(self): simulate_throttle_usage( url="/api/packages", - client=self.th_14400_user_csrf_client, + client=self.th_medium_user_csrf_client, mock_use_count=14399, ) - response = self.th_14400_user_csrf_client.get("/api/packages") + response = self.th_medium_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_200_OK) # exhausted 14400/hr allowed requests for user with 14400 perm. - response = self.th_14400_user_csrf_client.get("/api/packages") + response = self.th_medium_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) - def test_user_with_18000_perm_throttling(self): + def test_user_with_high_perm_throttling(self): simulate_throttle_usage( url="/api/packages", - client=self.th_18000_user_csrf_client, + client=self.th_high_user_csrf_client, mock_use_count=17999, ) - response = self.th_18000_user_csrf_client.get("/api/packages") + response = self.th_high_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_200_OK) # exhausted 18000/hr allowed requests for user with 18000 perm. - response = self.th_18000_user_csrf_client.get("/api/packages") + response = self.th_high_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) def test_user_with_unrestricted_perm_throttling(self): @@ -154,7 +149,6 @@ def test_user_with_unrestricted_perm_throttling(self): def test_anon_throttling(self): simulate_throttle_usage( - throttle_cls=AnonRateThrottle, url="/api/packages", client=self.csrf_client_anon, mock_use_count=3599, diff --git a/vulnerabilities/throttling.py b/vulnerabilities/throttling.py index fd96acdef..e14c1a1c0 100644 --- a/vulnerabilities/throttling.py +++ b/vulnerabilities/throttling.py @@ -7,6 +7,7 @@ # See https://aboutcode.org for more information about nexB OSS projects. # +from django.core.exceptions import ImproperlyConfigured from rest_framework.exceptions import Throttled from rest_framework.throttling import UserRateThrottle from rest_framework.views import exception_handler @@ -14,28 +15,41 @@ class PermissionBasedUserRateThrottle(UserRateThrottle): """ - Throttle authenticated users based on their assigned permissions. - If no throttling permission is assigned, default to rate for `user` - scope provided via `DEFAULT_THROTTLE_RATES` in settings.py. + Throttles authenticated users based on their assigned permissions. + If no throttling permission is assigned, defaults to `medium` throttling + for authenticated users and `anon` for unauthenticated users. """ + def __init__(self): + pass + def allow_request(self, request, view): user = request.user + throttling_tier = "medium" - if user and user.is_authenticated: - if user.has_perm("vulnerabilities.throttle_unrestricted"): - return True - elif user.has_perm("vulnerabilities.throttle_18000_hour"): - self.rate = "18000/hour" - elif user.has_perm("vulnerabilities.throttle_14400_hour"): - self.rate = "14400/hour" - elif user.has_perm("vulnerabilities.throttle_3600_hour"): - self.rate = "3600/hour" + if not user or not user.is_authenticated: + throttling_tier = "anon" + elif user.has_perm("vulnerabilities.throttle_3_unrestricted"): + return True + elif user.has_perm("vulnerabilities.throttle_2_high"): + throttling_tier = "high" + elif user.has_perm("vulnerabilities.throttle_1_medium"): + throttling_tier = "medium" + elif user.has_perm("vulnerabilities.throttle_0_low"): + throttling_tier = "low" - self.num_requests, self.duration = self.parse_rate(self.rate) + self.rate = self.get_throttle_rate(throttling_tier) + self.num_requests, self.duration = self.parse_rate(self.rate) return super().allow_request(request, view) + def get_throttle_rate(self, tier): + try: + return self.THROTTLE_RATES[tier] + except KeyError: + msg = f"No throttle rate set for {tier}." + raise ImproperlyConfigured(msg) + def throttled_exception_handler(exception, context): """ diff --git a/vulnerablecode/settings.py b/vulnerablecode/settings.py index 63810397c..b1a51c0a7 100644 --- a/vulnerablecode/settings.py +++ b/vulnerablecode/settings.py @@ -190,7 +190,17 @@ LOGIN_REDIRECT_URL = "/" LOGOUT_REDIRECT_URL = "/" -REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = {"anon": "3600/hour", "user": "10800/hour"} +THROTTLE_RATES_ANON = env.str("THROTTLE_RATES_ANON", default="3600/hour") +THROTTLE_RATES_USER_HIGH = env.str("THROTTLE_RATES_USER_HIGH", default="18000/hour") +THROTTLE_RATES_USER_MEDIUM = env.str("THROTTLE_RATES_USER_MEDIUM", default="14400/hour") +THROTTLE_RATES_USER_LOW = env.str("THROTTLE_RATES_USER_LOW", default="10800/hour") + +REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = { + "anon": THROTTLE_RATES_ANON, + "low": THROTTLE_RATES_USER_LOW, + "medium": THROTTLE_RATES_USER_MEDIUM, + "high": THROTTLE_RATES_USER_HIGH, +} if IS_TESTS: @@ -235,7 +245,6 @@ ), "DEFAULT_THROTTLE_CLASSES": [ "vulnerabilities.throttling.PermissionBasedUserRateThrottle", - "rest_framework.throttling.AnonRateThrottle", ], "DEFAULT_THROTTLE_RATES": REST_FRAMEWORK_DEFAULT_THROTTLE_RATES, "EXCEPTION_HANDLER": "vulnerabilities.throttling.throttled_exception_handler", From 722935f3b39bcba4abee0b848afe21f94c8b1abf Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Thu, 26 Jun 2025 15:01:54 +0530 Subject: [PATCH 10/11] Test throttling behavior for user in group Signed-off-by: Keshav Priyadarshi --- vulnerabilities/tests/test_throttling.py | 22 ++++++++++++++++++++++ vulnerablecode/settings.py | 16 ++++++++-------- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/vulnerabilities/tests/test_throttling.py b/vulnerabilities/tests/test_throttling.py index 25af231f8..4ff83c70d 100644 --- a/vulnerabilities/tests/test_throttling.py +++ b/vulnerabilities/tests/test_throttling.py @@ -9,6 +9,7 @@ import json +from django.contrib.auth.models import Group from django.contrib.auth.models import Permission from django.core.cache import cache from rest_framework import status @@ -77,6 +78,16 @@ def setUp(self): HTTP_AUTHORIZATION=self.th_unrestricted_user_auth ) + # unrestricted throttling for group user + group, _ = Group.objects.get_or_create(name="Test Unrestricted") + group.permissions.add(permission_unrestricted) + + self.th_group_user = ApiUser.objects.create_api_user(username="g@mail.com") + self.th_group_user.groups.add(group) + self.th_group_user_auth = f"Token {self.th_group_user.auth_token.key}" + self.th_group_user_csrf_client = APIClient(enforce_csrf_checks=True) + self.th_group_user_csrf_client.credentials(HTTP_AUTHORIZATION=self.th_group_user_auth) + self.csrf_client_anon = APIClient(enforce_csrf_checks=True) self.csrf_client_anon_1 = APIClient(enforce_csrf_checks=True) @@ -147,6 +158,17 @@ def test_user_with_unrestricted_perm_throttling(self): response = self.th_unrestricted_user_csrf_client.get("/api/packages") self.assertEqual(response.status_code, status.HTTP_200_OK) + def test_user_in_group_with_unrestricted_perm_throttling(self): + simulate_throttle_usage( + url="/api/packages", + client=self.th_group_user_csrf_client, + mock_use_count=20000, + ) + + # no throttling for user in group with unrestricted perm. + response = self.th_group_user_csrf_client.get("/api/packages") + self.assertEqual(response.status_code, status.HTTP_200_OK) + def test_anon_throttling(self): simulate_throttle_usage( url="/api/packages", diff --git a/vulnerablecode/settings.py b/vulnerablecode/settings.py index b1a51c0a7..0d4207b23 100644 --- a/vulnerablecode/settings.py +++ b/vulnerablecode/settings.py @@ -190,16 +190,16 @@ LOGIN_REDIRECT_URL = "/" LOGOUT_REDIRECT_URL = "/" -THROTTLE_RATES_ANON = env.str("THROTTLE_RATES_ANON", default="3600/hour") -THROTTLE_RATES_USER_HIGH = env.str("THROTTLE_RATES_USER_HIGH", default="18000/hour") -THROTTLE_RATES_USER_MEDIUM = env.str("THROTTLE_RATES_USER_MEDIUM", default="14400/hour") -THROTTLE_RATES_USER_LOW = env.str("THROTTLE_RATES_USER_LOW", default="10800/hour") +THROTTLE_RATE_ANON = env.str("THROTTLE_RATE_ANON", default="3600/hour") +THROTTLE_RATE_USER_HIGH = env.str("THROTTLE_RATE_USER_HIGH", default="18000/hour") +THROTTLE_RATE_USER_MEDIUM = env.str("THROTTLE_RATE_USER_MEDIUM", default="14400/hour") +THROTTLE_RATE_USER_LOW = env.str("THROTTLE_RATE_USER_LOW", default="10800/hour") REST_FRAMEWORK_DEFAULT_THROTTLE_RATES = { - "anon": THROTTLE_RATES_ANON, - "low": THROTTLE_RATES_USER_LOW, - "medium": THROTTLE_RATES_USER_MEDIUM, - "high": THROTTLE_RATES_USER_HIGH, + "anon": THROTTLE_RATE_ANON, + "low": THROTTLE_RATE_USER_LOW, + "medium": THROTTLE_RATE_USER_MEDIUM, + "high": THROTTLE_RATE_USER_HIGH, } From be5edc35298dda9a2fd9eb5b540708909099afae Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Tue, 1 Jul 2025 17:45:45 +0530 Subject: [PATCH 11/11] Resolve migration conflicts Signed-off-by: Keshav Priyadarshi --- ...alter_apiuser_options.py => 0095_alter_apiuser_options.py} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename vulnerabilities/migrations/{0093_alter_apiuser_options.py => 0095_alter_apiuser_options.py} (87%) diff --git a/vulnerabilities/migrations/0093_alter_apiuser_options.py b/vulnerabilities/migrations/0095_alter_apiuser_options.py similarity index 87% rename from vulnerabilities/migrations/0093_alter_apiuser_options.py rename to vulnerabilities/migrations/0095_alter_apiuser_options.py index 61c5b183d..2f30298a4 100644 --- a/vulnerabilities/migrations/0093_alter_apiuser_options.py +++ b/vulnerabilities/migrations/0095_alter_apiuser_options.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.22 on 2025-06-25 18:56 +# Generated by Django 4.2.22 on 2025-07-01 11:59 from django.db import migrations @@ -6,7 +6,7 @@ class Migration(migrations.Migration): dependencies = [ - ("vulnerabilities", "0092_pipelineschedule_pipelinerun"), + ("vulnerabilities", "0094_advisoryalias_advisoryreference_advisoryseverity_and_more"), ] operations = [