diff --git a/vulnerabilities/api_v2.py b/vulnerabilities/api_v2.py index ba41d0906..46ad2b80b 100644 --- a/vulnerabilities/api_v2.py +++ b/vulnerabilities/api_v2.py @@ -383,6 +383,10 @@ class PackageV2FilterSet(filters.FilterSet): ) fixing_vulnerability = filters.CharFilter(field_name="fixing_vulnerabilities__vulnerability_id") purl = filters.CharFilter(field_name="package_url") + is_vulnerable = filters.BooleanFilter(method="filter_is_vulnerable") + + def filter_is_vulnerable(self, queryset, name, value): + return queryset.filter(is_vulnerable=value) class AdvisoryPackageV2FilterSet(filters.FilterSet): @@ -424,6 +428,7 @@ def get_queryset(self): package_purls = self.request.query_params.getlist("purl") affected_by_vulnerability = self.request.query_params.get("affected_by_vulnerability") fixing_vulnerability = self.request.query_params.get("fixing_vulnerability") + is_vulnerable = self.request.query_params.get("is_vulnerable") if package_purls: queryset = queryset.filter(package_url__in=package_purls) @@ -435,6 +440,10 @@ def get_queryset(self): queryset = queryset.filter( fixing_vulnerabilities__vulnerability_id=fixing_vulnerability ) + if is_vulnerable is not None: + queryset = queryset.with_is_vulnerable() + is_vulnerable = is_vulnerable.lower() == "true" + queryset = queryset.filter(is_vulnerable=is_vulnerable) return queryset.with_is_vulnerable() def list(self, request, *args, **kwargs): diff --git a/vulnerabilities/forms.py b/vulnerabilities/forms.py index 03829cd52..47399f70b 100644 --- a/vulnerabilities/forms.py +++ b/vulnerabilities/forms.py @@ -23,6 +23,14 @@ class PackageSearchForm(forms.Form): attrs={"placeholder": "Package name, purl or purl fragment"}, ), ) + vulnerable_only = forms.ChoiceField( + required=False, + choices=[ + ("", "All Packages"), + ("true", "Vulnerable Only"), + ("false", "Non-Vulnerable Only"), + ], + ) class VulnerabilitySearchForm(forms.Form): diff --git a/vulnerabilities/templates/packages.html b/vulnerabilities/templates/packages.html index 1f7687429..bfc848baf 100644 --- a/vulnerabilities/templates/packages.html +++ b/vulnerabilities/templates/packages.html @@ -18,6 +18,37 @@
{{ page_obj.paginator.count|intcomma }} results
+ {% if is_paginated %} {% include 'includes/pagination.html' with page_obj=page_obj %} {% endif %} @@ -81,4 +112,4 @@ {% endif %} -{% endblock %} +{% endblock %} \ No newline at end of file diff --git a/vulnerabilities/tests/test_api_v2.py b/vulnerabilities/tests/test_api_v2.py index 662499ed9..138eb73f1 100644 --- a/vulnerabilities/tests/test_api_v2.py +++ b/vulnerabilities/tests/test_api_v2.py @@ -237,6 +237,27 @@ def test_list_packages(self): all(vuln_id in response.data["results"]["vulnerabilities"] for vuln_id in package_vulns) ) + def test_filter_packages_by_vulnerability_status(self): + vulnerability = Vulnerability.objects.create( + vulnerability_id="VCID-FILTER", summary="Test vulnerability for is_vulnerable filter" + ) + self.package1.affected_by_vulnerabilities.add(vulnerability) + url = reverse("package-v2-list") + response = self.client.get(url, {"is_vulnerable": "true"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("results", response.data) + self.assertIn("packages", response.data["results"]) + package_purls = [pkg["purl"] for pkg in response.data["results"]["packages"]] + self.assertIn(self.package1.package_url, package_purls) + self.assertNotIn(self.package2.package_url, package_purls) + response = self.client.get(url, {"is_vulnerable": "false"}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("results", response.data) + self.assertIn("packages", response.data["results"]) + package_purls = [pkg["purl"] for pkg in response.data["results"]["packages"]] + self.assertNotIn(self.package1.package_url, package_purls) + self.assertIn(self.package2.package_url, package_purls) + def test_filter_packages_by_purl(self): """ Test filtering packages by one or more PURLs. diff --git a/vulnerabilities/tests/test_view.py b/vulnerabilities/tests/test_view.py index 471e0bf43..83f76b0a1 100644 --- a/vulnerabilities/tests/test_view.py +++ b/vulnerabilities/tests/test_view.py @@ -77,6 +77,29 @@ def test_package_detail_view(self): package = PackageDetails(kwargs={"purl": "pkg:nginx/nginx@1.0.15"}).get_object() assert package.purl == "pkg:nginx/nginx@1.0.15" + def test_package_vulnerability_filter(self): + vulnerability = Vulnerability.objects.create( + vulnerability_id="VCID-TEST", summary="Test Vulnerability for filtering" + ) + vulnerable_package = Package.objects.get(package_url="pkg:nginx/nginx@1.20.0") + AffectedByPackageRelatedVulnerability.objects.create( + vulnerability=vulnerability, package=vulnerable_package, created_by="test" + ) + response = self.client.get("/packages/search?search=nginx&vulnerable_only=true") + self.assertEqual(response.status_code, 200) + self.assertIn(vulnerable_package.purl, str(response.content)) + self.assertNotIn("pkg:nginx/nginx@1.21.0", str(response.content)) + + response = self.client.get("/packages/search?search=nginx&vulnerable_only=false") + self.assertEqual(response.status_code, 200) + self.assertNotIn(vulnerable_package.purl, str(response.content)) + self.assertIn("pkg:nginx/nginx@1.21.0", str(response.content)) + + response = self.client.get("/packages/search?search=nginx") + self.assertEqual(response.status_code, 200) + self.assertIn(vulnerable_package.purl, str(response.content)) + self.assertIn("pkg:nginx/nginx@1.21.0", str(response.content)) + def test_package_view_with_purl_fragment(self): qs = PackageSearch().get_queryset(query="nginx@1.0.15") pkgs = list(qs) diff --git a/vulnerabilities/views.py b/vulnerabilities/views.py index f4cd99dbe..97a4a5b46 100644 --- a/vulnerabilities/views.py +++ b/vulnerabilities/views.py @@ -64,13 +64,27 @@ def get_queryset(self, query=None): Make a best effort approach to find matching packages either based on exact purl, partial purl or just name and namespace. """ - query = query or self.request.GET.get("search") or "" - return ( + if query is not None: + queryset = ( + self.model.objects.search(query) + .with_vulnerability_counts() + .prefetch_related() + .order_by("package_url") + ) + return queryset + query = self.request.GET.get("search") or "" + queryset = ( self.model.objects.search(query) .with_vulnerability_counts() .prefetch_related() .order_by("package_url") ) + vulnerable_only = self.request.GET.get("vulnerable_only", "") + if vulnerable_only in ["true", "false"]: + queryset = queryset.with_is_vulnerable() + queryset = queryset.filter(is_vulnerable=vulnerable_only == "true") + + return queryset class PackageSearchV2(ListView):