Skip to content

Commit d430f9e

Browse files
committed
feat: add "isclose" method for Vector, Box3D and LabeledBox3D
1 parent a4f2b66 commit d430f9e

7 files changed

Lines changed: 177 additions & 13 deletions

File tree

tensorbay/geometry/box.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,3 +564,24 @@ def dumps(self) -> Dict[str, Dict[str, float]]:
564564
contents = self._transform.dumps()
565565
contents["size"] = self.size.dumps()
566566
return contents
567+
568+
def isclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
569+
"""Determine whether this 3D box is close to another in value.
570+
571+
Arguments:
572+
other: The other object to compare.
573+
rel_tol: Maximum difference for being considered "close", relative to the
574+
magnitude of the input values
575+
abs_tol: Maximum difference for being considered "close", regardless of the
576+
magnitude of the input values
577+
578+
Returns:
579+
A bool value indicating whether this vector is close to another.
580+
581+
"""
582+
if not isinstance(other, self.__class__):
583+
return False
584+
585+
return self._size.isclose(
586+
other.size, rel_tol=rel_tol, abs_tol=abs_tol
587+
) and self._transform.isclose(other.transform, rel_tol=rel_tol, abs_tol=abs_tol)

tensorbay/geometry/tests/test_box.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,12 @@ def test_rmul(self):
116116
assert box3d.__rmul__(transform) == Box3D(
117117
size=(1, 1, 1), translation=[2, 0, 0], rotation=quaternion(-1, 0, 0, 0)
118118
)
119-
assert box3d.__rmul__(quaternion_1) == Box3D(
120-
size=(1, 1, 1),
121-
translation=[1.7999999999999996, 2, 2.6],
122-
rotation=quaternion(-2, 1, 4, -3),
119+
assert box3d.__rmul__(quaternion_1).isclose(
120+
Box3D(
121+
size=(1, 1, 1),
122+
translation=[1.7999999999999996, 2, 2.6],
123+
rotation=quaternion(-2, 1, 4, -3),
124+
)
123125
)
124126
assert box3d.__rmul__(1) == NotImplemented
125127

tensorbay/geometry/tests/test_vector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ def test_abs(self):
125125
assert abs(Vector(1, 1)) == 1.4142135623730951
126126
assert abs(Vector(1, 1, 1)) == 1.7320508075688772
127127

128+
def test_isclose(self):
129+
assert Vector(1, 2).isclose(Vector2D(1.000000000001, 2))
130+
assert Vector(1, 2, 3).isclose(Vector3D(1.000000000001, 2, 2.999999999996))
131+
assert not Vector(1, 2, 3).isclose(Vector3D(1.100000000001, 2, 2.999999999996))
132+
128133
def test_repr_head(self):
129134
vector = Vector(1, 2)
130135
assert vector._repr_head() == "Vector2D(1, 2)"

tensorbay/geometry/transform.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515

1616
import warnings
17+
from math import isclose as math_isclose
1718
from typing import Dict, Iterable, Optional, Type, TypeVar, Union, overload
1819

1920
import numpy as np
@@ -23,7 +24,9 @@
2324

2425
with warnings.catch_warnings():
2526
warnings.simplefilter("ignore")
26-
from quaternion import as_rotation_matrix, from_rotation_matrix, quaternion, rotate_vectors
27+
from quaternion import as_rotation_matrix, from_rotation_matrix
28+
from quaternion import isclose as quaternion_isclose
29+
from quaternion import quaternion, rotate_vectors
2730

2831
_T = TypeVar("_T", bound="Transform3D")
2932

@@ -323,3 +326,24 @@ def inverse(self: _T) -> _T:
323326
rotation = self._rotation.inverse()
324327
translation = Vector3D(*rotate_vectors(rotation, -self._translation))
325328
return self._create(translation, rotation)
329+
330+
def isclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
331+
"""Determine whether this 3D transform is close to another in value.
332+
333+
Arguments:
334+
other: The other object to compare.
335+
rel_tol: Maximum difference for being considered "close", relative to the
336+
magnitude of the input values
337+
abs_tol: Maximum difference for being considered "close", regardless of the
338+
magnitude of the input values
339+
340+
Returns:
341+
A bool value indicating whether this vector is close to another.
342+
343+
"""
344+
if not isinstance(other, self.__class__):
345+
return False
346+
347+
return self._translation.isclose(
348+
other.translation, rel_tol=rel_tol, abs_tol=abs_tol
349+
) and all(quaternion_isclose(self._rotation, other.rotation))

tensorbay/geometry/vector.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
"""
1616

1717
from itertools import zip_longest
18-
from math import hypot, sqrt
18+
from math import hypot
19+
from math import isclose as math_isclose
20+
from math import sqrt
1921
from sys import version_info
2022
from typing import Dict, Iterable, Optional, Sequence, Tuple, Type, TypeVar, Union
2123

@@ -204,6 +206,35 @@ def loads(contents: Dict[str, float]) -> _T:
204206
return Vector3D.loads(contents)
205207
return Vector2D.loads(contents)
206208

209+
def isclose(
210+
self, other: Iterable[float], *, rel_tol: float = 1e-09, abs_tol: float = 0.0
211+
) -> bool:
212+
"""Determine whether this vector is close to another in value.
213+
214+
Arguments:
215+
other: The other object to compare.
216+
rel_tol: Maximum difference for being considered "close", relative to the
217+
magnitude of the input values
218+
abs_tol: Maximum difference for being considered "close", regardless of the
219+
magnitude of the input values
220+
221+
Raises:
222+
TypeError: When other have inconsistent dimension.
223+
224+
Returns:
225+
A bool value indicating whether this vector is close to another.
226+
227+
"""
228+
try:
229+
return all(
230+
math_isclose(i, j, rel_tol=rel_tol, abs_tol=abs_tol)
231+
for i, j in zip_longest(self._data, other)
232+
)
233+
except TypeError as error:
234+
raise TypeError(
235+
f"The other object must have the dimension of {self._DIMENSION}"
236+
) from error
237+
207238

208239
class Vector2D(Vector):
209240
"""This class defines the concept of Vector2D.

tensorbay/label/tests/test_label_box.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,15 @@ def test_rmul(self):
156156
instance="12345",
157157
)
158158

159-
assert labeledbox3d.__rmul__(quaternion_1) == LabeledBox3D(
160-
size=size,
161-
translation=[1.7999999999999996, 2, 2.6],
162-
rotation=[-2, 1, 4, -3],
163-
category="cat",
164-
attributes={"gender": "male"},
165-
instance="12345",
159+
assert labeledbox3d.__rmul__(quaternion_1).isclose(
160+
LabeledBox3D(
161+
size=size,
162+
translation=[1.7999999999999996, 2, 2.6],
163+
rotation=[-2, 1, 4, -3],
164+
category="cat",
165+
attributes={"gender": "male"},
166+
instance="12345",
167+
)
166168
)
167169

168170
assert labeledbox3d.__rmul__(1) == NotImplemented

tensorbay/utility/attr.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
:class:`Field` is a class describing the attr related fields.
1111
1212
"""
13+
from math import isclose as math_isclose
1314
from sys import version_info
1415
from typing import (
1516
Any,
@@ -65,6 +66,7 @@ def __init__(
6566
) -> None:
6667
self.loader: _Callable
6768
self.dumper: _Callable
69+
self.is_close: Callable[[Any, Any], bool]
6870

6971
self.is_dynamic = is_dynamic
7072
self.default = default
@@ -95,6 +97,7 @@ def __init__(self, key: Optional[str]) -> None:
9597
self.loader: _Callable
9698
self.dumper: _Callable
9799
self.key = key
100+
self.is_close: _Callable
98101

99102

100103
class AttrsMixin:
@@ -123,6 +126,8 @@ def __init_subclass__(cls) -> None:
123126
field = getattr(cls, name, None)
124127
if isinstance(field, Field):
125128
field.loader, field.dumper = _get_operators(type_)
129+
if cls.__name__ == "_LabelBase":
130+
field.is_close = _get_isclose(type_)
126131
if hasattr(field, "key_converter"):
127132
field.key = field.key_converter(name)
128133
attrs_fields[name] = field
@@ -201,6 +206,45 @@ def _dumps(self) -> Dict[str, Any]:
201206
_key_dumper(field.key, contents, field.dumper(value))
202207
return contents
203208

209+
def isclose(self, other: object, *, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool:
210+
"""Determine whether this instance is close to another in value.
211+
212+
Arguments:
213+
other: The other instance to compare.
214+
rel_tol: Maximum difference for being considered "close", relative to the
215+
magnitude of the input values
216+
abs_tol: Maximum difference for being considered "close", regardless of the
217+
magnitude of the input values
218+
219+
Raises:
220+
NotImplementedError: When the class type is not LabeledBox3D.
221+
222+
Returns:
223+
A bool value indicating whether this instance is close to another.
224+
225+
"""
226+
if self.__class__.__name__ != "LabeledBox3D":
227+
raise NotImplementedError
228+
229+
if not isinstance(other, self.__class__) or self.__dict__.keys() != other.__dict__.keys():
230+
return False
231+
232+
base = getattr(self, _ATTRS_BASE, None)
233+
result = True
234+
if base and hasattr(base, "is_close"):
235+
result = result and base.is_close(self, other, rel_tol=rel_tol, abs_tol=abs_tol)
236+
237+
if not result:
238+
return result
239+
240+
for name, field in self._attrs_fields.items():
241+
if not hasattr(self, name):
242+
continue
243+
244+
result = result and field.is_close(getattr(self, name), getattr(other, name))
245+
246+
return result
247+
204248

205249
def attr(
206250
*,
@@ -291,6 +335,38 @@ def _get_origin_in_3_6(annotation: Any) -> Any:
291335
_get_origin = _get_origin_in_3_6 if version_info < (3, 7) else _get_origin_in_3_7
292336

293337

338+
def _get_isclose(annotation: Any) -> Callable[[Any, Any], Any]:
339+
"""Get attr isclose methods by annotations.
340+
341+
AttrsMixin has three operating types which are classified by attr annotation.
342+
1. builtin types, like str, int, None
343+
2. tensorbay custom class, like tensorbay.label.Classification
344+
3. tensorbay custom class list or NameList, like List[tensorbay.label.LabeledBox2D]
345+
346+
Arguments:
347+
annotation: Type of the attr.
348+
349+
Returns:
350+
The ``isclose`` methods of the annotation.
351+
352+
"""
353+
origin = _get_origin(annotation)
354+
if isinstance(origin, type) and issubclass(origin, Sequence):
355+
type_ = annotation.__args__[0]
356+
return lambda self, other: all(_get_isclose(type_)(i, j) for i, j in zip(self, other))
357+
358+
type_ = annotation
359+
360+
mod = getattr(type_, "__module__", None)
361+
if mod in _BUILTINS:
362+
if type_ in (int, float):
363+
return math_isclose
364+
if mod == "typing":
365+
return origin.__eq__
366+
return type_.__eq__ # type: ignore[no-any-return]
367+
return type_.isclose # type: ignore[no-any-return]
368+
369+
294370
def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]:
295371
"""Get attr operating methods by annotations.
296372
@@ -315,6 +391,9 @@ def _get_operators(annotation: Any) -> Tuple[_Callable, _Callable]:
315391
sequence = None
316392
type_ = annotation
317393

394+
if annotation == Dict[str, Union[str, int, float, bool, List[Union[str, int, float, bool]]]]:
395+
print("yes", origin, type_, getattr(type_, "__module__", None))
396+
318397
if {getattr(sequence, "__module__", None), getattr(type_, "__module__", None)} < _BUILTINS:
319398
return _builtin_operator, _builtin_operator
320399

0 commit comments

Comments
 (0)