1010:class:`Field` is a class describing the attr related fields.
1111
1212"""
13+ from math import isclose as math_isclose
1314from sys import version_info
1415from typing import (
1516 Any ,
@@ -64,6 +65,7 @@ def __init__(
6465 ) -> None :
6566 self .loader : _Callable
6667 self .dumper : _Callable
68+ self .allclose : Callable [[Any , Any , float , float ], bool ]
6769
6870 self .is_dynamic = is_dynamic
6971 self .default = default
@@ -94,6 +96,7 @@ def __init__(self, key: Optional[str]) -> None:
9496 self .loader : _Callable
9597 self .dumper : _Callable
9698 self .key = key
99+ self .allclose : Callable [[Any , Any , float , float ], bool ]
97100
98101
99102class AttrsMixin :
@@ -109,6 +112,7 @@ class AttrsMixin:
109112
110113 def __init_subclass__ (cls ) -> None :
111114 type_ = cls .__annotations__ .pop (_ATTRS_BASE , None )
115+ support_allclose = getattr (cls , "_support_allclose" , False )
112116 if type_ :
113117 cls ._attrs_base .loader = type_ ._loads # pylint: disable=protected-access
114118 cls ._attrs_base .dumper = type_ .dumps
@@ -122,6 +126,8 @@ def __init_subclass__(cls) -> None:
122126 field = getattr (cls , name , None )
123127 if isinstance (field , Field ):
124128 field .loader , field .dumper = _get_operators (type_ )
129+ if support_allclose :
130+ field .allclose = _get_allclose (type_ )
125131 if hasattr (field , "key_converter" ):
126132 field .key = field .key_converter (name )
127133 attrs_fields [name ] = field
@@ -200,6 +206,32 @@ def _dumps(self) -> Dict[str, Any]:
200206 _key_dumper (field .key , contents , field .dumper (value ))
201207 return contents
202208
209+ def _allclose (self , other : object , * , rel_tol : float = 1e-09 , abs_tol : float = 0.0 ) -> bool :
210+ """Determine whether this instance is close to another in value.
211+
212+ Arguments:
213+ other: The other instance to compare.
214+ rel_tol: Maximum difference for being considered "close", relative to the
215+ magnitude of the input values
216+ abs_tol: Maximum difference for being considered "close", regardless of the
217+ magnitude of the input values
218+
219+ Returns:
220+ A bool value indicating whether this instance is close to another.
221+
222+ """
223+ result = True
224+
225+ for name , field in self ._attrs_fields .items ():
226+ if not hasattr (self , name ):
227+ continue
228+
229+ result = result and field .allclose ( # type: ignore[call-arg]
230+ getattr (self , name ), getattr (other , name ), rel_tol = rel_tol , abs_tol = abs_tol
231+ )
232+
233+ return result
234+
203235
204236def attr (
205237 * ,
@@ -292,6 +324,43 @@ def _get_origin_in_3_6(annotation: Any) -> Any:
292324_get_origin = _get_origin_in_3_6 if version_info < (3 , 7 ) else _get_origin_in_3_7
293325
294326
327+ def _get_allclose (annotation : Any ) -> Callable [[Any , Any , float , float ], bool ]:
328+ """Get attr allclose methods by annotations.
329+
330+ AttrsMixin has three operating types which are classified by attr annotation.
331+ 1. builtin types, like str, int, None
332+ 2. tensorbay custom class, like tensorbay.label.Classification
333+ 3. tensorbay custom class list or NameList, like List[tensorbay.label.LabeledBox2D]
334+
335+ Arguments:
336+ annotation: Type of the attr.
337+
338+ Returns:
339+ The ``_allclose`` methods of the annotation.
340+
341+ """
342+ origin = _get_origin (annotation )
343+ if isinstance (origin , type ) and issubclass (origin , Sequence ):
344+ type_ = annotation .__args__ [0 ]
345+ return lambda self , other , rel_tol = 1e-09 , abs_tol = 0.0 : all ( # type: ignore[misc]
346+ _get_allclose (type_ )(i , j , rel_tol = rel_tol , abs_tol = abs_tol ) # type: ignore[call-arg]
347+ for i , j in zip (self , other )
348+ )
349+
350+ type_ = annotation
351+
352+ mod = getattr (type_ , "__module__" , None )
353+ if mod in _BUILTINS :
354+ if type_ in (int , float ):
355+ return math_isclose # type: ignore[return-value]
356+ return _eq_allclose # type: ignore[return-value]
357+ return type_ ._allclose # type: ignore[no-any-return] # pylint: disable=protected-access
358+
359+
360+ def _eq_allclose (object_1 : object , object_2 : object , ** _ : float ) -> bool :
361+ return object_1 == object_2
362+
363+
295364def _get_operators (annotation : Any ) -> Tuple [_Callable , _Callable ]:
296365 """Get attr operating methods by annotations.
297366
0 commit comments