Skip to content

Commit 5fe4fb7

Browse files
committed
feat: add "isclose" method for Vector, Box3D and LabeledBox3D
1 parent 34f4cd9 commit 5fe4fb7

10 files changed

Lines changed: 219 additions & 14 deletions

File tree

tensorbay/geometry/box.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,25 @@ def _line_intersect(length1: float, length2: float, midpoint_distance: float) ->
407407
intersect_length = min(line1_max, line2_max) - max(line1_min, line2_min)
408408
return intersect_length if intersect_length > 0 else 0
409409

410+
def _allclose(self, other: _B3, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
411+
"""Determine whether this 3D box is close to another in value.
412+
413+
Arguments:
414+
other: The other object to compare.
415+
rel_tol: Maximum difference for being considered "close", relative to the
416+
magnitude of the input values
417+
abs_tol: Maximum difference for being considered "close", regardless of the
418+
magnitude of the input values
419+
420+
Returns:
421+
A bool value indicating whether this vector is close to another.
422+
423+
"""
424+
# pylint: disable=protected-access
425+
return self._size._allclose(
426+
other.size, rel_tol=rel_tol, abs_tol=abs_tol
427+
) and self._transform._allclose(other.transform, rel_tol=rel_tol, abs_tol=abs_tol)
428+
410429
def _loads(self, contents: Dict[str, Dict[str, float]]) -> None:
411430
self._size = Vector3D.loads(contents["size"])
412431
self._transform = Transform3D.loads(contents)

tensorbay/geometry/tests/test_box.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
from quaternion import quaternion
88

9-
from ...utility import UserSequence
9+
from ...utility import UserSequence, allclose
1010
from .. import Box2D, Box3D, Transform3D, Vector2D, Vector3D
1111

1212
_DATA_2D = {"xmin": 1.0, "ymin": 2.0, "xmax": 3.0, "ymax": 4.0}
@@ -116,10 +116,13 @@ def test_rmul(self):
116116
assert box3d.__rmul__(transform) == Box3D(
117117
size=(1, 1, 1), translation=[2, 0, 0], rotation=quaternion(-1, 0, 0, 0)
118118
)
119-
assert box3d.__rmul__(quaternion_1) == Box3D(
120-
size=(1, 1, 1),
121-
translation=[1.7999999999999996, 2, 2.6],
122-
rotation=quaternion(-2, 1, 4, -3),
119+
assert allclose(
120+
box3d.__rmul__(quaternion_1),
121+
Box3D(
122+
size=(1, 1, 1),
123+
translation=[1.7999999999999996, 2, 2.6],
124+
rotation=quaternion(-2, 1, 4, -3),
125+
),
123126
)
124127
assert box3d.__rmul__(1) == NotImplemented
125128

tensorbay/geometry/tests/test_vector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def test_abs(self):
125125
assert abs(Vector(1, 1)) == 1.4142135623730951
126126
assert abs(Vector(1, 1, 1)) == 1.7320508075688772
127127

128+
def test__allclose(self):
129+
assert Vector(1, 2)._allclose(Vector2D(1.000000000001, 2))
130+
assert Vector(1, 2, 3)._allclose(Vector3D(1.000000000001, 2, 2.999999999996))
131+
assert not Vector(1, 2, 3)._allclose(Vector3D(1.100000000001, 2, 2.999999999996))
132+
128133
def test_repr_head(self):
129134
vector = Vector(1, 2)
130135
assert vector._repr_head() == "Vector2D(1, 2)"

tensorbay/geometry/transform.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323

2424
with warnings.catch_warnings():
2525
warnings.simplefilter("ignore")
26-
from quaternion import as_rotation_matrix, from_rotation_matrix, quaternion, rotate_vectors
26+
from quaternion import as_rotation_matrix, from_rotation_matrix
27+
from quaternion import isclose as quaternion_isclose
28+
from quaternion import quaternion, rotate_vectors
2729

2830
_T = TypeVar("_T", bound="Transform3D")
2931

@@ -154,6 +156,27 @@ def _mul_vector(self, other: Iterable[float]) -> Vector3D:
154156
# __radd__ is used to ensure the shape of the input object.
155157
return self._translation.__radd__(rotate_vectors(self._rotation, other))
156158

159+
def _allclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
160+
"""Determine whether this 3D transform is close to another in value.
161+
162+
Arguments:
163+
other: The other object to compare.
164+
rel_tol: Maximum difference for being considered "close", relative to the
165+
magnitude of the input values
166+
abs_tol: Maximum difference for being considered "close", regardless of the
167+
magnitude of the input values
168+
169+
Returns:
170+
A bool value indicating whether this vector is close to another.
171+
172+
"""
173+
if not isinstance(other, self.__class__):
174+
return False
175+
176+
return self._translation._allclose( # pylint: disable=protected-access
177+
other.translation, rel_tol=rel_tol, abs_tol=abs_tol
178+
) and all(quaternion_isclose(self._rotation, other.rotation))
179+
157180
def _loads(self, contents: Dict[str, Dict[str, float]]) -> None:
158181
self._translation = Vector3D.loads(contents["translation"])
159182
rotation_contents = contents["rotation"]

tensorbay/geometry/vector.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""
1616

1717
from itertools import zip_longest
18-
from math import hypot, sqrt
18+
from math import hypot
19+
from math import isclose as math_isclose
20+
from math import sqrt
1921
from sys import version_info
2022
from typing import Dict, Iterable, Optional, Sequence, Tuple, Type, TypeVar, Union
2123

@@ -180,6 +182,35 @@ def __abs__(self) -> float:
180182
def _repr_head(self) -> str:
181183
return f"{self.__class__.__name__}{self._data}"
182184

185+
def _allclose(
186+
self, other: Iterable[float], *, rel_tol: float = 1e-09, abs_tol: float = 0.0
187+
) -> bool:
188+
"""Determine whether this vector is close to another in value.
189+
190+
Arguments:
191+
other: The other object to compare.
192+
rel_tol: Maximum difference for being considered "close", relative to the
193+
magnitude of the input values
194+
abs_tol: Maximum difference for being considered "close", regardless of the
195+
magnitude of the input values
196+
197+
Raises:
198+
TypeError: When other have inconsistent dimension.
199+
200+
Returns:
201+
A bool value indicating whether this vector is close to another.
202+
203+
"""
204+
try:
205+
return all(
206+
math_isclose(i, j, rel_tol=rel_tol, abs_tol=abs_tol)
207+
for i, j in zip_longest(self._data, other)
208+
)
209+
except TypeError as error:
210+
raise TypeError(
211+
f"The other object must have the dimension of {self._DIMENSION}"
212+
) from error
213+
183214
@staticmethod
184215
def loads(contents: Dict[str, float]) -> _T:
185216
"""Loads a :class:`Vector` from a dict containing coordinates of the vector.

tensorbay/label/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ class _LabelBase(AttrsMixin, TypeMixin[LabelType], ReprMixin):
143143
_label_attrs: Tuple[str, ...] = ("category", "attributes", "instance")
144144

145145
_repr_attrs = _label_attrs
146+
_support_allclose = True
146147

147148
_AttributeType = Dict[str, Union[str, int, float, bool, List[Union[str, int, float, bool]]]]
148149

tensorbay/label/tests/test_label_box.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from quaternion import quaternion
88

99
from ...geometry import Transform3D, Vector3D
10+
from ...utility import allclose
1011
from .. import Box2DSubcatalog, Box3DSubcatalog, LabeledBox2D, LabeledBox3D
1112

1213

@@ -156,13 +157,16 @@ def test_rmul(self):
156157
instance="12345",
157158
)
158159

159-
assert labeledbox3d.__rmul__(quaternion_1) == LabeledBox3D(
160-
size=size,
161-
translation=[1.7999999999999996, 2, 2.6],
162-
rotation=[-2, 1, 4, -3],
163-
category="cat",
164-
attributes={"gender": "male"},
165-
instance="12345",
160+
assert allclose(
161+
labeledbox3d.__rmul__(quaternion_1),
162+
LabeledBox3D(
163+
size=size,
164+
translation=[1.7999999999999996, 2, 2.6],
165+
rotation=[-2, 1, 4, -3],
166+
category="cat",
167+
attributes={"gender": "male"},
168+
instance="12345",
169+
),
166170
)
167171

168172
assert labeledbox3d.__rmul__(1) == NotImplemented

tensorbay/utility/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
EqMixin,
1414
KwargsDeprecated,
1515
MatrixType,
16+
allclose,
1617
common_loads,
1718
locked,
1819
)
@@ -42,6 +43,7 @@
4243
"UserMutableMapping",
4344
"UserMutableSequence",
4445
"UserSequence",
46+
"allclose",
4547
"attr",
4648
"attr_base",
4749
"camel",

tensorbay/utility/attr.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
:class:`Field` is a class describing the attr related fields.
1111
1212
"""
13+
from math import isclose as math_isclose
1314
from sys import version_info
1415
from typing import (
1516
Any,
@@ -64,6 +65,7 @@ def __init__(
6465
) -> None:
6566
self.loader: _Callable
6667
self.dumper: _Callable
68+
self.allclose: Callable[[Any, Any, float, float], bool]
6769

6870
self.is_dynamic = is_dynamic
6971
self.default = default
@@ -94,6 +96,7 @@ def __init__(self, key: Optional[str]) -> None:
9496
self.loader: _Callable
9597
self.dumper: _Callable
9698
self.key = key
99+
self.allclose: Callable[[Any, Any, float, float], bool]
97100

98101

99102
class AttrsMixin:
@@ -109,6 +112,7 @@ class AttrsMixin:
109112

110113
def __init_subclass__(cls) -> None:
111114
type_ = cls.__annotations__.pop(_ATTRS_BASE, None)
115+
support_allclose = getattr(cls, "_support_allclose", False)
112116
if type_:
113117
cls._attrs_base.loader = type_._loads # pylint: disable=protected-access
114118
cls._attrs_base.dumper = type_.dumps
@@ -122,6 +126,8 @@ def __init_subclass__(cls) -> None:
122126
field = getattr(cls, name, None)
123127
if isinstance(field, Field):
124128
field.loader, field.dumper = _get_operators(type_)
129+
if support_allclose:
130+
field.allclose = _get_allclose(type_)
125131
if hasattr(field, "key_converter"):
126132
field.key = field.key_converter(name)
127133
attrs_fields[name] = field
@@ -200,6 +206,32 @@ def _dumps(self) -> Dict[str, Any]:
200206
_key_dumper(field.key, contents, field.dumper(value))
201207
return contents
202208

209+
def _allclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
210+
"""Determine whether this instance is close to another in value.
211+
212+
Arguments:
213+
other: The other instance to compare.
214+
rel_tol: Maximum difference for being considered "close", relative to the
215+
magnitude of the input values
216+
abs_tol: Maximum difference for being considered "close", regardless of the
217+
magnitude of the input values
218+
219+
Returns:
220+
A bool value indicating whether this instance is close to another.
221+
222+
"""
223+
result = True
224+
225+
for name, field in self._attrs_fields.items():
226+
if not hasattr(self, name):
227+
continue
228+
229+
result = result and field.allclose( # type: ignore[call-arg]
230+
getattr(self, name), getattr(other, name), rel_tol=rel_tol, abs_tol=abs_tol
231+
)
232+
233+
return result
234+
203235

204236
def attr(
205237
*,
@@ -292,6 +324,43 @@ def _get_origin_in_3_6(annotation: Any) -> Any:
292324
_get_origin = _get_origin_in_3_6 if version_info < (3, 7) else _get_origin_in_3_7
293325

294326

327+
def _get_allclose(annotation: Any) -> Callable[[Any, Any, float, float], bool]:
328+
"""Get attr allclose methods by annotations.
329+
330+
AttrsMixin has three operating types which are classified by attr annotation.
331+
1. builtin types, like str, int, None
332+
2. tensorbay custom class, like tensorbay.label.Classification
333+
3. tensorbay custom class list or NameList, like List[tensorbay.label.LabeledBox2D]
334+
335+
Arguments:
336+
annotation: Type of the attr.
337+
338+
Returns:
339+
The ``_allclose`` methods of the annotation.
340+
341+
"""
342+
origin = _get_origin(annotation)
343+
if isinstance(origin, type) and issubclass(origin, Sequence):
344+
type_ = annotation.__args__[0]
345+
return lambda self, other, rel_tol=1e-09, abs_tol=0.0: all( # type: ignore[misc]
346+
_get_allclose(type_)(i, j, rel_tol=rel_tol, abs_tol=abs_tol) # type: ignore[call-arg]
347+
for i, j in zip(self, other)
348+
)
349+
350+
type_ = annotation
351+
352+
mod = getattr(type_, "__module__", None)
353+
if mod in _BUILTINS:
354+
if type_ in (int, float):
355+
return math_isclose # type: ignore[return-value]
356+
return _eq_allclose # type: ignore[return-value]
357+
return type_._allclose # type: ignore[no-any-return] # pylint: disable=protected-access
358+
359+
360+
def _eq_allclose(object_1: object, object_2: object, **_: float) -> bool:
361+
return object_1 == object_2
362+
363+
295364
def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]:
296365
"""Get attr operating methods by annotations.
297366

tensorbay/utility/common.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Any, Callable, DefaultDict, Optional, Sequence, Tuple, Type, TypeVar, Union
2121

2222
import numpy as np
23+
from typing_extensions import Protocol
2324

2425
_T = TypeVar("_T")
2526
_Callable = TypeVar("_Callable", bound=Callable[..., Any])
@@ -45,6 +46,53 @@ def common_loads(object_class: Type[_T], contents: Any) -> _T:
4546
return obj
4647

4748

49+
class _A(Protocol): # pylint: disable=too-few-public-methods
50+
def _allclose(self, other: "_A", *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
51+
"""Tell if all the data is close to the other object.
52+
53+
Arguments:
54+
other: The other object.
55+
rel_tol: Maximum difference for being considered "close", relative to the
56+
magnitude of the input values
57+
abs_tol: Maximum difference for being considered "close", regardless of the
58+
magnitude of the input values
59+
"""
60+
61+
62+
def allclose(object_1: _A, object_2: _A, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
63+
"""Determine whether object_1 is close to object_2 in value.
64+
65+
Arguments:
66+
object_1: The first object to compare.
67+
object_2: The second object to compare.
68+
rel_tol: Maximum difference for being considered "close", relative to the
69+
magnitude of the input values
70+
abs_tol: Maximum difference for being considered "close", regardless of the
71+
magnitude of the input values
72+
73+
Returns:
74+
A bool value indicating whether object_1 is close to object_2.
75+
76+
"""
77+
try:
78+
# pylint: disable=protected-access
79+
if issubclass(object_1.__class__, object_2.__class__) and hasattr(object_1, "_allclose"):
80+
print(1)
81+
return object_1._allclose(object_2, rel_tol=rel_tol, abs_tol=abs_tol)
82+
if issubclass(object_2.__class__, object_1.__class__) and hasattr(object_2, "_allclose"):
83+
print(2)
84+
return object_2._allclose(object_1, rel_tol=rel_tol, abs_tol=abs_tol)
85+
86+
if hasattr(object_1, "_allclose"):
87+
print(3)
88+
return object_1._allclose(object_2, rel_tol=rel_tol, abs_tol=abs_tol)
89+
print(4)
90+
return object_2._allclose(object_1, rel_tol=rel_tol, abs_tol=abs_tol)
91+
except Exception: # pylint: disable=broad-except
92+
print(5)
93+
return False
94+
95+
4896
class EqMixin: # pylint: disable=too-few-public-methods
4997
"""A mixin class to support __eq__() method.
5098

0 commit comments

Comments
 (0)