Skip to content
Merged
50 changes: 50 additions & 0 deletions vulnerabilities/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from django import forms
from django.contrib import admin
from django.contrib.admin.widgets import FilteredSelectMultiple
from django.contrib.auth.admin import GroupAdmin as BasicGroupAdmin
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.core.validators import validate_email

from vulnerabilities.models import ApiUser
Expand Down Expand Up @@ -97,3 +101,49 @@ def get_form(self, request, obj=None, **kwargs):
defaults["form"] = self.add_form
defaults.update(kwargs)
return super().get_form(request, obj, **defaults)


class GroupWithUsersForm(forms.ModelForm):
users = forms.ModelMultipleChoiceField(
queryset=User.objects.all(),
required=False,
widget=FilteredSelectMultiple("Users", is_stacked=False),
label="Users",
)

class Meta:
model = Group
fields = "__all__"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fields["users"].label_from_instance = lambda user: (
f"{user.username} | {user.email}" if user.email else user.username
)
if self.instance.pk:
self.fields["users"].initial = self.instance.user_set.all()

def save(self, commit=True):
group = super().save(commit=commit)
self.save_m2m()
group.user_set.set(self.cleaned_data["users"])
return group


admin.site.unregister(Group)


@admin.register(Group)
class GroupAdmin(admin.ModelAdmin):
form = GroupWithUsersForm
search_fields = ("name",)
ordering = ("name",)
filter_horizontal = ("permissions",)

def formfield_for_manytomany(self, db_field, request=None, **kwargs):
if db_field.name == "permissions":
qs = kwargs.get("queryset", db_field.remote_field.model.objects)
# Avoid a major performance hit resolving permission names which
# triggers a content_type load:
kwargs["queryset"] = qs.select_related("content_type")
return super().formfield_for_manytomany(db_field, request=request, **kwargs)
7 changes: 3 additions & 4 deletions vulnerabilities/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.throttling import AnonRateThrottle

from vulnerabilities.models import Alias
from vulnerabilities.models import Exploit
Expand All @@ -34,7 +33,7 @@
from vulnerabilities.models import get_purl_query_lookups
from vulnerabilities.severity_systems import EPSS
from vulnerabilities.severity_systems import SCORING_SYSTEMS
from vulnerabilities.throttling import StaffUserRateThrottle
from vulnerabilities.throttling import PermissionBasedUserRateThrottle
from vulnerabilities.utils import get_severity_range


Expand Down Expand Up @@ -471,7 +470,7 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = PackageSerializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = PackageFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle]

def get_queryset(self):
return super().get_queryset().with_is_vulnerable()
Expand Down Expand Up @@ -688,7 +687,7 @@ def get_queryset(self):
serializer_class = VulnerabilitySerializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = VulnerabilityFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle]


class CPEFilterSet(filters.FilterSet):
Expand Down
11 changes: 5 additions & 6 deletions vulnerabilities/api_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from rest_framework.serializers import ModelSerializer
from rest_framework.serializers import Serializer
from rest_framework.serializers import ValidationError
from rest_framework.throttling import AnonRateThrottle

from vulnerabilities.api import BaseResourceSerializer
from vulnerabilities.models import Exploit
Expand All @@ -33,7 +32,7 @@
from vulnerabilities.models import VulnerabilitySeverity
from vulnerabilities.models import Weakness
from vulnerabilities.models import get_purl_query_lookups
from vulnerabilities.throttling import StaffUserRateThrottle
from vulnerabilities.throttling import PermissionBasedUserRateThrottle


class SerializerExcludeFieldsMixin:
Expand Down Expand Up @@ -259,7 +258,7 @@ class V2PackageViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field = "purl"
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = V2PackageFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle]

def get_queryset(self):
return super().get_queryset().with_is_vulnerable().prefetch_related("vulnerabilities")
Expand Down Expand Up @@ -345,7 +344,7 @@ class VulnerabilityViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field = "vulnerability_id"
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = V2VulnerabilityFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle]

def get_queryset(self):
"""
Expand Down Expand Up @@ -381,7 +380,7 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet):
).distinct()
serializer_class = V2VulnerabilitySerializer
filter_backends = (filters.DjangoFilterBackend,)
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle]
filterset_class = CPEFilterSet

@action(detail=False, methods=["post"])
Expand Down Expand Up @@ -420,4 +419,4 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = V2VulnerabilitySerializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = AliasFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle]
6 changes: 6 additions & 0 deletions vulnerabilities/api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from rest_framework.permissions import BasePermission
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.throttling import AnonRateThrottle

from vulnerabilities.models import AdvisoryReference
from vulnerabilities.models import AdvisorySeverity
Expand All @@ -38,6 +39,7 @@
from vulnerabilities.models import VulnerabilityReference
from vulnerabilities.models import VulnerabilitySeverity
from vulnerabilities.models import Weakness
from vulnerabilities.throttling import PermissionBasedUserRateThrottle


class WeaknessV2Serializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -199,6 +201,7 @@ class VulnerabilityV2ViewSet(viewsets.ReadOnlyModelViewSet):
queryset = Vulnerability.objects.all()
serializer_class = VulnerabilityV2Serializer
lookup_field = "vulnerability_id"
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]

def get_queryset(self):
queryset = super().get_queryset()
Expand Down Expand Up @@ -394,6 +397,7 @@ class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = PackageV2Serializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = PackageV2FilterSet
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]

def get_queryset(self):
queryset = super().get_queryset()
Expand Down Expand Up @@ -721,6 +725,7 @@ class CodeFixViewSet(viewsets.ReadOnlyModelViewSet):

queryset = CodeFix.objects.all()
serializer_class = CodeFixSerializer
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]

def get_queryset(self):
"""
Expand Down Expand Up @@ -863,6 +868,7 @@ class PipelineScheduleV2ViewSet(CreateListRetrieveUpdateViewSet):
serializer_class = PipelineScheduleAPISerializer
lookup_field = "pipeline_id"
lookup_value_regex = r"[\w.]+"
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]

def get_serializer_class(self):
if self.action == "create":
Expand Down
36 changes: 36 additions & 0 deletions vulnerabilities/migrations/0095_alter_apiuser_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Generated by Django 4.2.22 on 2025-07-01 11:59

from django.db import migrations


class Migration(migrations.Migration):

dependencies = [
("vulnerabilities", "0094_advisoryalias_advisoryreference_advisoryseverity_and_more"),
]

operations = [
migrations.AlterModelOptions(
name="apiuser",
options={
"permissions": [
(
"throttle_3_unrestricted",
"Can make unlimited API requests without any throttling limits",
),
(
"throttle_2_high",
"Can make high number of API requests with minimal throttling",
),
(
"throttle_1_medium",
"Can make medium number of API requests with standard throttling",
),
(
"throttle_0_low",
"Can make low number of API requests with strict throttling",
),
]
},
),
]
23 changes: 20 additions & 3 deletions vulnerabilities/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from cwe2.mappings import xml_database_path
from cwe2.weakness import Weakness as DBWeakness
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.contrib.auth.models import UserManager
from django.core import exceptions
from django.core.exceptions import ValidationError
Expand Down Expand Up @@ -1489,14 +1490,30 @@ def _validate_username(self, email):


class ApiUser(UserModel):
"""
A User proxy model to facilitate simplified admin API user creation.
"""
"""A User proxy model to facilitate simplified admin API user creation."""

objects = ApiUserManager()

class Meta:
proxy = True
permissions = [
(
"throttle_3_unrestricted",
"Can make unlimited API requests without any throttling limits",
),
(
"throttle_2_high",
"Can make high number of API requests with minimal throttling",
),
(
"throttle_1_medium",
"Can make medium number of API requests with standard throttling",
),
(
"throttle_0_low",
"Can make low number of API requests with strict throttling",
),
]


class ChangeLog(models.Model):
Expand Down
25 changes: 11 additions & 14 deletions vulnerabilities/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
from urllib.parse import quote

from django.core.cache import cache
from django.test import TestCase
from django.test import TransactionTestCase
from django.test.client import RequestFactory
Expand Down Expand Up @@ -452,10 +453,8 @@ def add_aliases(vuln, aliases):

class APIPerformanceTest(TestCase):
def setUp(self):
self.user = ApiUser.objects.create_api_user(username="e@mail.com")
self.auth = f"Token {self.user.auth_token.key}"
cache.clear()
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth)

# This setup creates the following data:
# vulnerabilities: vul1, vul2, vul3
Expand Down Expand Up @@ -503,7 +502,7 @@ def setUp(self):
set_as_fixing(package=self.pkg_2_13_2, vulnerability=self.vul1)

def test_api_packages_all_num_queries(self):
with self.assertNumQueries(4):
with self.assertNumQueries(3):
# There are 4 queries:
# 1. SAVEPOINT
# 2. Authenticating user
Expand All @@ -519,22 +518,22 @@ def test_api_packages_all_num_queries(self):
]

def test_api_packages_single_num_queries(self):
with self.assertNumQueries(8):
with self.assertNumQueries(7):
self.csrf_client.get(f"/api/packages/{self.pkg_2_14_0_rc1.id}", format="json")

def test_api_packages_single_with_purl_in_query_num_queries(self):
with self.assertNumQueries(9):
with self.assertNumQueries(8):
self.csrf_client.get(f"/api/packages/?purl={self.pkg_2_14_0_rc1.purl}", format="json")

def test_api_packages_single_with_purl_no_version_in_query_num_queries(self):
with self.assertNumQueries(64):
with self.assertNumQueries(63):
self.csrf_client.get(
f"/api/packages/?purl=pkg:maven/com.fasterxml.jackson.core/jackson-databind",
format="json",
)

def test_api_packages_bulk_search(self):
with self.assertNumQueries(45):
with self.assertNumQueries(44):
packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1]
purls = [p.purl for p in packages]

Expand All @@ -547,7 +546,7 @@ def test_api_packages_bulk_search(self):
).json()

def test_api_packages_with_lookup(self):
with self.assertNumQueries(14):
with self.assertNumQueries(13):
data = {"purl": self.pkg_2_12_6.purl}

resp = self.csrf_client.post(
Expand All @@ -557,7 +556,7 @@ def test_api_packages_with_lookup(self):
).json()

def test_api_packages_bulk_lookup(self):
with self.assertNumQueries(45):
with self.assertNumQueries(44):
packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1]
purls = [p.purl for p in packages]

Expand All @@ -572,10 +571,8 @@ def test_api_packages_bulk_lookup(self):

class APITestCasePackage(TestCase):
def setUp(self):
self.user = ApiUser.objects.create_api_user(username="e@mail.com")
self.auth = f"Token {self.user.auth_token.key}"
cache.clear()
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth)

# This setup creates the following data:
# vulnerabilities: vul1, vul2, vul3
Expand Down Expand Up @@ -766,7 +763,7 @@ def test_api_with_wrong_namespace_filter(self):
self.assertEqual(response["count"], 0)

def test_api_with_all_vulnerable_packages(self):
with self.assertNumQueries(4):
with self.assertNumQueries(3):
# There are 4 queries:
# 1. SAVEPOINT
# 2. Authenticating user
Expand Down
Loading