99
1010import dataclasses
1111import datetime
12+ import functools
1213import logging
1314import os
1415import shutil
4647logger = logging .getLogger (__name__ )
4748
4849
49- @dataclasses .dataclass (order = True )
50+ @dataclasses .dataclass (eq = True )
51+ @functools .total_ordering
5052class VulnerabilitySeverity :
5153 # FIXME: this should be named scoring_system, like in the model
5254 system : ScoringSystem
@@ -55,15 +57,26 @@ class VulnerabilitySeverity:
5557 published_at : Optional [datetime .datetime ] = None
5658
5759 def to_dict (self ):
58- published_at_dict = (
59- {"published_at" : self .published_at .isoformat ()} if self .published_at else {}
60- )
61- return {
60+ data = {
6261 "system" : self .system .identifier ,
6362 "value" : self .value ,
6463 "scoring_elements" : self .scoring_elements ,
65- ** published_at_dict ,
6664 }
65+ if self .published_at :
66+ if isinstance (self .published_at , datetime .datetime ):
67+ data ["published_at" ] = self .published_at .isoformat ()
68+ else :
69+ data ["published_at" ] = self .published_at
70+ return data
71+
72+ def __lt__ (self , other ):
73+ if not isinstance (other , VulnerabilitySeverity ):
74+ return NotImplemented
75+ return self ._cmp_key () < other ._cmp_key ()
76+
77+ # TODO: Add cache
78+ def _cmp_key (self ):
79+ return (self .system .identifier , self .value , self .scoring_elements , self .published_at )
6780
6881 @classmethod
6982 def from_dict (cls , severity : dict ):
@@ -79,7 +92,8 @@ def from_dict(cls, severity: dict):
7992 )
8093
8194
82- @dataclasses .dataclass (order = True )
95+ @dataclasses .dataclass (eq = True )
96+ @functools .total_ordering
8397class Reference :
8498 reference_id : str = ""
8599 reference_type : str = ""
@@ -90,27 +104,28 @@ def __post_init__(self):
90104 if not self .url :
91105 raise TypeError ("Reference must have a url" )
92106
93- def normalized (self ):
94- severities = sorted ( self . severities )
95- return Reference (
96- reference_id = self . reference_id ,
97- url = self . url ,
98- severities = severities ,
99- reference_type = self . reference_type ,
100- )
107+ def __lt__ (self , other ):
108+ if not isinstance ( other , Reference ):
109+ return NotImplemented
110+ return self . _cmp_key () < other . _cmp_key ()
111+
112+ # TODO: Add cache
113+ def _cmp_key ( self ):
114+ return ( self . reference_id , self . reference_type , self . url , tuple ( self . severities ) )
101115
102116 def to_dict (self ):
117+ """Return a normalized dictionary representation"""
103118 return {
104119 "reference_id" : self .reference_id ,
105120 "reference_type" : self .reference_type ,
106121 "url" : self .url ,
107- "severities" : [severity .to_dict () for severity in self .severities ],
122+ "severities" : [severity .to_dict () for severity in sorted ( self .severities ) ],
108123 }
109124
110125 @classmethod
111126 def from_dict (cls , ref : dict ):
112127 return cls (
113- reference_id = ref ["reference_id" ],
128+ reference_id = str ( ref ["reference_id" ]) ,
114129 reference_type = ref .get ("reference_type" ) or "" ,
115130 url = ref ["url" ],
116131 severities = [
@@ -140,7 +155,8 @@ class NoAffectedPackages(Exception):
140155 """
141156
142157
143- @dataclasses .dataclass (order = True , frozen = True )
158+ @functools .total_ordering
159+ @dataclasses .dataclass (eq = True )
144160class AffectedPackage :
145161 """
146162 Relate a Package URL with a range of affected versions and a fixed version.
@@ -170,6 +186,19 @@ def get_fixed_purl(self):
170186 raise ValueError (f"Affected Package { self .package !r} does not have a fixed version" )
171187 return update_purl_version (purl = self .package , version = str (self .fixed_version ))
172188
189+ def __lt__ (self , other ):
190+ if not isinstance (other , AffectedPackage ):
191+ return NotImplemented
192+ return self ._cmp_key () < other ._cmp_key ()
193+
194+ # TODO: Add cache
195+ def _cmp_key (self ):
196+ return (
197+ str (self .package ),
198+ str (self .affected_version_range or "" ),
199+ str (self .fixed_version or "" ),
200+ )
201+
173202 @classmethod
174203 def merge (
175204 cls , affected_packages : Iterable
0 commit comments