diff --git a/vulnerabilities/api_v2.py b/vulnerabilities/api_v2.py
index ba41d0906..46ad2b80b 100644
--- a/vulnerabilities/api_v2.py
+++ b/vulnerabilities/api_v2.py
@@ -383,6 +383,10 @@ class PackageV2FilterSet(filters.FilterSet):
)
fixing_vulnerability = filters.CharFilter(field_name="fixing_vulnerabilities__vulnerability_id")
purl = filters.CharFilter(field_name="package_url")
+ is_vulnerable = filters.BooleanFilter(method="filter_is_vulnerable")
+
+ def filter_is_vulnerable(self, queryset, name, value):
+ return queryset.filter(is_vulnerable=value)
class AdvisoryPackageV2FilterSet(filters.FilterSet):
@@ -424,6 +428,7 @@ def get_queryset(self):
package_purls = self.request.query_params.getlist("purl")
affected_by_vulnerability = self.request.query_params.get("affected_by_vulnerability")
fixing_vulnerability = self.request.query_params.get("fixing_vulnerability")
+ is_vulnerable = self.request.query_params.get("is_vulnerable")
if package_purls:
queryset = queryset.filter(package_url__in=package_purls)
@@ -435,6 +440,10 @@ def get_queryset(self):
queryset = queryset.filter(
fixing_vulnerabilities__vulnerability_id=fixing_vulnerability
)
+ if is_vulnerable is not None:
+ queryset = queryset.with_is_vulnerable()
+ is_vulnerable = is_vulnerable.lower() == "true"
+ queryset = queryset.filter(is_vulnerable=is_vulnerable)
return queryset.with_is_vulnerable()
def list(self, request, *args, **kwargs):
diff --git a/vulnerabilities/forms.py b/vulnerabilities/forms.py
index 03829cd52..47399f70b 100644
--- a/vulnerabilities/forms.py
+++ b/vulnerabilities/forms.py
@@ -23,6 +23,14 @@ class PackageSearchForm(forms.Form):
attrs={"placeholder": "Package name, purl or purl fragment"},
),
)
+ vulnerable_only = forms.ChoiceField(
+ required=False,
+ choices=[
+ ("", "All Packages"),
+ ("true", "Vulnerable Only"),
+ ("false", "Non-Vulnerable Only"),
+ ],
+ )
class VulnerabilitySearchForm(forms.Form):
diff --git a/vulnerabilities/templates/packages.html b/vulnerabilities/templates/packages.html
index 1f7687429..bfc848baf 100644
--- a/vulnerabilities/templates/packages.html
+++ b/vulnerabilities/templates/packages.html
@@ -18,6 +18,37 @@
{{ page_obj.paginator.count|intcomma }} results
+
+
+
+
+
+
{% if is_paginated %}
{% include 'includes/pagination.html' with page_obj=page_obj %}
{% endif %}
@@ -81,4 +112,4 @@
{% endif %}
-{% endblock %}
+{% endblock %}
\ No newline at end of file
diff --git a/vulnerabilities/tests/test_api_v2.py b/vulnerabilities/tests/test_api_v2.py
index 662499ed9..138eb73f1 100644
--- a/vulnerabilities/tests/test_api_v2.py
+++ b/vulnerabilities/tests/test_api_v2.py
@@ -237,6 +237,27 @@ def test_list_packages(self):
all(vuln_id in response.data["results"]["vulnerabilities"] for vuln_id in package_vulns)
)
+ def test_filter_packages_by_vulnerability_status(self):
+ vulnerability = Vulnerability.objects.create(
+ vulnerability_id="VCID-FILTER", summary="Test vulnerability for is_vulnerable filter"
+ )
+ self.package1.affected_by_vulnerabilities.add(vulnerability)
+ url = reverse("package-v2-list")
+ response = self.client.get(url, {"is_vulnerable": "true"}, format="json")
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn("results", response.data)
+ self.assertIn("packages", response.data["results"])
+ package_purls = [pkg["purl"] for pkg in response.data["results"]["packages"]]
+ self.assertIn(self.package1.package_url, package_purls)
+ self.assertNotIn(self.package2.package_url, package_purls)
+ response = self.client.get(url, {"is_vulnerable": "false"}, format="json")
+ self.assertEqual(response.status_code, status.HTTP_200_OK)
+ self.assertIn("results", response.data)
+ self.assertIn("packages", response.data["results"])
+ package_purls = [pkg["purl"] for pkg in response.data["results"]["packages"]]
+ self.assertNotIn(self.package1.package_url, package_purls)
+ self.assertIn(self.package2.package_url, package_purls)
+
def test_filter_packages_by_purl(self):
"""
Test filtering packages by one or more PURLs.
diff --git a/vulnerabilities/tests/test_view.py b/vulnerabilities/tests/test_view.py
index 471e0bf43..83f76b0a1 100644
--- a/vulnerabilities/tests/test_view.py
+++ b/vulnerabilities/tests/test_view.py
@@ -77,6 +77,29 @@ def test_package_detail_view(self):
package = PackageDetails(kwargs={"purl": "pkg:nginx/nginx@1.0.15"}).get_object()
assert package.purl == "pkg:nginx/nginx@1.0.15"
+ def test_package_vulnerability_filter(self):
+ vulnerability = Vulnerability.objects.create(
+ vulnerability_id="VCID-TEST", summary="Test Vulnerability for filtering"
+ )
+ vulnerable_package = Package.objects.get(package_url="pkg:nginx/nginx@1.20.0")
+ AffectedByPackageRelatedVulnerability.objects.create(
+ vulnerability=vulnerability, package=vulnerable_package, created_by="test"
+ )
+ response = self.client.get("/packages/search?search=nginx&vulnerable_only=true")
+ self.assertEqual(response.status_code, 200)
+ self.assertIn(vulnerable_package.purl, str(response.content))
+ self.assertNotIn("pkg:nginx/nginx@1.21.0", str(response.content))
+
+ response = self.client.get("/packages/search?search=nginx&vulnerable_only=false")
+ self.assertEqual(response.status_code, 200)
+ self.assertNotIn(vulnerable_package.purl, str(response.content))
+ self.assertIn("pkg:nginx/nginx@1.21.0", str(response.content))
+
+ response = self.client.get("/packages/search?search=nginx")
+ self.assertEqual(response.status_code, 200)
+ self.assertIn(vulnerable_package.purl, str(response.content))
+ self.assertIn("pkg:nginx/nginx@1.21.0", str(response.content))
+
def test_package_view_with_purl_fragment(self):
qs = PackageSearch().get_queryset(query="nginx@1.0.15")
pkgs = list(qs)
diff --git a/vulnerabilities/views.py b/vulnerabilities/views.py
index f4cd99dbe..97a4a5b46 100644
--- a/vulnerabilities/views.py
+++ b/vulnerabilities/views.py
@@ -64,13 +64,27 @@ def get_queryset(self, query=None):
Make a best effort approach to find matching packages either based
on exact purl, partial purl or just name and namespace.
"""
- query = query or self.request.GET.get("search") or ""
- return (
+ if query is not None:
+ queryset = (
+ self.model.objects.search(query)
+ .with_vulnerability_counts()
+ .prefetch_related()
+ .order_by("package_url")
+ )
+ return queryset
+ query = self.request.GET.get("search") or ""
+ queryset = (
self.model.objects.search(query)
.with_vulnerability_counts()
.prefetch_related()
.order_by("package_url")
)
+ vulnerable_only = self.request.GET.get("vulnerable_only", "")
+ if vulnerable_only in ["true", "false"]:
+ queryset = queryset.with_is_vulnerable()
+ queryset = queryset.filter(is_vulnerable=vulnerable_only == "true")
+
+ return queryset
class PackageSearchV2(ListView):