|
8 | 8 | # |
9 | 9 |
|
10 | 10 | from django.db.models import Prefetch |
| 11 | +from django.test import TestCase |
11 | 12 | from django.urls import reverse |
| 13 | +from django.utils import timezone |
12 | 14 | from packageurl import PackageURL |
13 | 15 | from rest_framework import status |
14 | 16 | from rest_framework.test import APIClient |
15 | 17 | from rest_framework.test import APITestCase |
16 | 18 |
|
17 | 19 | from vulnerabilities.api_v2 import PackageV2Serializer |
18 | 20 | from vulnerabilities.api_v2 import VulnerabilityListSerializer |
| 21 | +from vulnerabilities.models import Advisory |
19 | 22 | from vulnerabilities.models import Alias |
20 | 23 | from vulnerabilities.models import ApiUser |
21 | 24 | from vulnerabilities.models import Package |
@@ -662,3 +665,59 @@ def test_lookup_with_invalid_purl_format(self): |
662 | 665 | self.assertEqual(response.status_code, status.HTTP_200_OK) |
663 | 666 | # No packages or vulnerabilities should be returned |
664 | 667 | self.assertEqual(len(response.data), 0) |
| 668 | + |
| 669 | + |
| 670 | +class AdvisoryAPITest(TestCase): |
| 671 | + def setUp(self): |
| 672 | + self.user = ApiUser.objects.create_api_user(username="test@test.com") |
| 673 | + self.auth = f"Token {self.user.auth_token.key}" |
| 674 | + self.client = APIClient(enforce_csrf_checks=True) |
| 675 | + self.client.credentials(HTTP_AUTHORIZATION=self.auth) |
| 676 | + |
| 677 | + self.now = timezone.now() |
| 678 | + self.advisories = [] |
| 679 | + for i in range(10): |
| 680 | + advisory = Advisory.objects.create( |
| 681 | + aliases=[f"CVE-2020-{i}"], |
| 682 | + summary=f"Test Advisory {i}", |
| 683 | + affected_packages=[{"package_url": f"pkg:npm/package{i}@1.0.0"}], |
| 684 | + references=[{"url": f"https://example.com/vuln/{i}"}], |
| 685 | + date_published=self.now, |
| 686 | + date_collected=self.now, |
| 687 | + created_by="test_importer", |
| 688 | + url=f"https://example.com/{i}", |
| 689 | + ) |
| 690 | + self.advisories.append(advisory) |
| 691 | + |
| 692 | + def test_advisory_list(self): |
| 693 | + with self.assertNumQueries(5): # save + auth + count + data + release |
| 694 | + response = self.client.get("/api/v2/advisories/", format="json") |
| 695 | + self.assertEqual(200, response.status_code) |
| 696 | + data = response.json() |
| 697 | + self.assertEqual(10, data["count"]) |
| 698 | + self.assertEqual(10, len(data["results"])) |
| 699 | + |
| 700 | + first_result = data["results"][0] |
| 701 | + expected_fields = { |
| 702 | + "aliases", |
| 703 | + "summary", |
| 704 | + "affected_packages", |
| 705 | + "references", |
| 706 | + "date_published", |
| 707 | + "url", |
| 708 | + } |
| 709 | + self.assertEqual(expected_fields, set(first_result.keys())) |
| 710 | + |
| 711 | + def test_advisory_pagination(self): |
| 712 | + with self.assertNumQueries(5): |
| 713 | + response = self.client.get("/api/v2/advisories/?page_size=5", format="json") |
| 714 | + self.assertEqual(200, response.status_code) |
| 715 | + data = response.json() |
| 716 | + self.assertEqual(10, data["count"]) |
| 717 | + self.assertEqual(5, len(data["results"])) |
| 718 | + self.assertIsNotNone(data["next"]) |
| 719 | + self.assertIsNone(data["previous"]) |
| 720 | + |
| 721 | + def test_advisory_invalid_page(self): |
| 722 | + response = self.client.get("/api/v2/advisories/?page=999", format="json") |
| 723 | + self.assertEqual(404, response.status_code) |
0 commit comments