1010:class:`Field` is a class describing the attr related fields.
1111
1212"""
13+ from math import isclose as math_isclose
1314from sys import version_info
1415from typing import (
1516 Any ,
@@ -65,6 +66,7 @@ def __init__(
6566 ) -> None :
6667 self .loader : _Callable
6768 self .dumper : _Callable
69+ self .is_close : Callable [[Any , Any ], bool ]
6870
6971 self .is_dynamic = is_dynamic
7072 self .default = default
@@ -95,6 +97,7 @@ def __init__(self, key: Optional[str]) -> None:
9597 self .loader : _Callable
9698 self .dumper : _Callable
9799 self .key = key
100+ self .is_close : _Callable
98101
99102
100103class AttrsMixin :
@@ -123,6 +126,8 @@ def __init_subclass__(cls) -> None:
123126 field = getattr (cls , name , None )
124127 if isinstance (field , Field ):
125128 field .loader , field .dumper = _get_operators (type_ )
129+ if cls .__name__ == "_LabelBase" :
130+ field .is_close = _get_isclose (type_ )
126131 if hasattr (field , "key_converter" ):
127132 field .key = field .key_converter (name )
128133 attrs_fields [name ] = field
@@ -201,6 +206,45 @@ def _dumps(self) -> Dict[str, Any]:
201206 _key_dumper (field .key , contents , field .dumper (value ))
202207 return contents
203208
209+ def isclose (self , other : object , * , rel_tol : float = 1e-09 , abs_tol : float = 0.0 ) -> bool :
210+ """Determine whether this instance is close to another in value.
211+
212+ Arguments:
213+ other: The other instance to compare.
214+ rel_tol: Maximum difference for being considered "close", relative to the
215+ magnitude of the input values
216+ abs_tol: Maximum difference for being considered "close", regardless of the
217+ magnitude of the input values
218+
219+ Raises:
220+ NotImplementedError: When the class type is not LabeledBox3D.
221+
222+ Returns:
223+ A bool value indicating whether this instance is close to another.
224+
225+ """
226+ if self .__class__ .__name__ != "LabeledBox3D" :
227+ raise NotImplementedError
228+
229+ if not isinstance (other , self .__class__ ) or self .__dict__ .keys () != other .__dict__ .keys ():
230+ return False
231+
232+ base = getattr (self , _ATTRS_BASE , None )
233+ result = True
234+ if base and hasattr (base , "is_close" ):
235+ result = result and base .is_close (self , other , rel_tol = rel_tol , abs_tol = abs_tol )
236+
237+ if not result :
238+ return result
239+
240+ for name , field in self ._attrs_fields .items ():
241+ if not hasattr (self , name ):
242+ continue
243+
244+ result = result and field .is_close (getattr (self , name ), getattr (other , name ))
245+
246+ return result
247+
204248
205249def attr (
206250 * ,
@@ -291,6 +335,38 @@ def _get_origin_in_3_6(annotation: Any) -> Any:
291335_get_origin = _get_origin_in_3_6 if version_info < (3 , 7 ) else _get_origin_in_3_7
292336
293337
338+ def _get_isclose (annotation : Any ) -> Callable [[Any , Any ], Any ]:
339+ """Get attr isclose methods by annotations.
340+
341+ AttrsMixin has three operating types which are classified by attr annotation.
342+ 1. builtin types, like str, int, None
343+ 2. tensorbay custom class, like tensorbay.label.Classification
344+ 3. tensorbay custom class list or NameList, like List[tensorbay.label.LabeledBox2D]
345+
346+ Arguments:
347+ annotation: Type of the attr.
348+
349+ Returns:
350+ The ``isclose`` methods of the annotation.
351+
352+ """
353+ origin = _get_origin (annotation )
354+ if isinstance (origin , type ) and issubclass (origin , Sequence ):
355+ type_ = annotation .__args__ [0 ]
356+ return lambda self , other : all (_get_isclose (type_ )(i , j ) for i , j in zip (self , other ))
357+
358+ type_ = annotation
359+
360+ mod = getattr (type_ , "__module__" , None )
361+ if mod in _BUILTINS :
362+ if type_ in (int , float ):
363+ return math_isclose
364+ if mod == "typing" :
365+ return origin .__eq__
366+ return type_ .__eq__ # type: ignore[no-any-return]
367+ return type_ .isclose # type: ignore[no-any-return]
368+
369+
294370def _get_operators (annotation : Any ) -> Tuple [_Callable , _Callable ]:
295371 """Get attr operating methods by annotations.
296372
@@ -315,6 +391,9 @@ def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]:
315391 sequence = None
316392 type_ = annotation
317393
394+ if annotation == Dict [str , Union [str , int , float , bool , List [Union [str , int , float , bool ]]]]:
395+ print ("yes" , origin , type_ , getattr (type_ , "__module__" , None ))
396+
318397 if {getattr (sequence , "__module__" , None ), getattr (type_ , "__module__" , None )} < _BUILTINS :
319398 return _builtin_operator , _builtin_operator
320399
0 commit comments