diff --git a/docs/api-reference.md b/docs/api-reference.md index 75724d0..39a184a 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -9,3 +9,5 @@ ::: compyre.builtin.equal_fns ::: compyre.alias + +::: compyre.utils diff --git a/src/compyre/builtin/_numpy.py b/src/compyre/builtin/_numpy.py index fbd79bc..12172b1 100644 --- a/src/compyre/builtin/_numpy.py +++ b/src/compyre/builtin/_numpy.py @@ -1,10 +1,8 @@ from typing import Annotated -from compyre import alias, api +from compyre import alias, api, utils from compyre._availability import available_if -from ._utils import both_isinstance - @available_if("numpy") def numpy_ndarray( @@ -39,7 +37,7 @@ def numpy_ndarray( """ import numpy as np - if not both_isinstance(p, np.ndarray): + if not utils.both_isinstance(p, np.ndarray): return None try: diff --git a/src/compyre/builtin/_pandas.py b/src/compyre/builtin/_pandas.py index bd04837..50317e5 100644 --- a/src/compyre/builtin/_pandas.py +++ b/src/compyre/builtin/_pandas.py @@ -1,10 +1,8 @@ from typing import Annotated -from compyre import alias, api +from compyre import alias, api, utils from compyre._availability import available_if -from ._utils import both_isinstance - @available_if("pandas") def pandas_dataframe( @@ -34,7 +32,7 @@ def pandas_dataframe( """ import pandas as pd - if not both_isinstance(p, pd.DataFrame): + if not utils.both_isinstance(p, pd.DataFrame): return None try: @@ -77,7 +75,7 @@ def pandas_series( """ import pandas as pd - if not both_isinstance(p, pd.Series): + if not utils.both_isinstance(p, pd.Series): return None try: diff --git a/src/compyre/builtin/_pydantic.py b/src/compyre/builtin/_pydantic.py index 16c71ef..9c6b8bd 100644 --- a/src/compyre/builtin/_pydantic.py +++ b/src/compyre/builtin/_pydantic.py @@ -1,14 +1,13 @@ +from compyre import api, utils from compyre._availability import available_if -from compyre.api import Pair, UnpackFnResult from ._stdlib import collections_mapping -from ._utils import both_isinstance __all__ = ["pydantic_model"] @available_if("pydantic>=2,<3") -def pydantic_model(p: Pair, /) -> UnpackFnResult: +def pydantic_model(p: api.Pair, /) -> api.UnpackFnResult: """Unpack [pydantic.BaseModel][]s using [pydantic.BaseModel.model_dump][]. Args: @@ -29,7 +28,7 @@ def pydantic_model(p: Pair, /) -> UnpackFnResult: """ import pydantic - if not both_isinstance(p, pydantic.BaseModel): + if not utils.both_isinstance(p, pydantic.BaseModel): return None try: @@ -38,4 +37,6 @@ def pydantic_model(p: Pair, /) -> UnpackFnResult: except Exception as result: return result - return collections_mapping(Pair(index=p.index, actual=actual, expected=expected)) + return collections_mapping( + api.Pair(index=p.index, actual=actual, expected=expected) + ) diff --git a/src/compyre/builtin/_stdlib.py b/src/compyre/builtin/_stdlib.py index aa3882a..87b8fee 100644 --- a/src/compyre/builtin/_stdlib.py +++ b/src/compyre/builtin/_stdlib.py @@ -7,9 +7,7 @@ from collections.abc import Mapping, Sequence from typing import Annotated -from compyre import alias, api - -from ._utils import both_isinstance, either_isinstance +from compyre import alias, api, utils __all__ = [ "builtins_number", @@ -36,7 +34,7 @@ def collections_mapping(p: api.Pair, /) -> api.UnpackFnResult: (ValueError): If the keys of [`p.actual`][compyre.api.Pair] and [`p.expected`][compyre.api.Pair] mismatch. """ - if not both_isinstance(p, Mapping): + if not utils.both_isinstance(p, Mapping): return None extra = p.actual.keys() - p.expected.keys() @@ -73,7 +71,7 @@ def collections_sequence(p: api.Pair, /) -> api.UnpackFnResult: (ValueError): If the length of [`p.actual`][compyre.api.Pair] and [`p.expected`][compyre.api.Pair] mismatch. """ - if not both_isinstance(p, Sequence) or either_isinstance(p, str): + if not utils.both_isinstance(p, Sequence) or utils.either_isinstance(p, str): return None if (la := len(p.actual)) != (le := len(p.expected)): @@ -106,7 +104,7 @@ def collections_ordered_dict(p: api.Pair, /) -> api.UnpackFnResult: mismatch. """ - if not both_isinstance(p, OrderedDict): + if not utils.both_isinstance(p, OrderedDict): return None if (aks := list(p.actual.keys())) != (eks := list(p.expected.keys())): @@ -145,10 +143,12 @@ def builtins_number( (AssertionError): If [math.isclose][] or [cmath.isclose][] returns [False][] for the input pair. """ - if not both_isinstance(p, (int, float, complex)) or either_isinstance(p, bool): + if not utils.both_isinstance(p, (int, float, complex)) or utils.either_isinstance( + p, bool + ): return None - isclose = cmath.isclose if either_isinstance(p, complex) else math.isclose + isclose = cmath.isclose if utils.either_isinstance(p, complex) else math.isclose if isclose(p.actual, p.expected, abs_tol=abs_tol, rel_tol=rel_tol): return True @@ -232,7 +232,7 @@ def dataclasses_dataclass(p: api.Pair, /) -> api.UnpackFnResult: # dataclasses.is_dataclass returns True for dataclass instances and types, but we only handle the former if not ( dataclasses.is_dataclass(p.actual) and dataclasses.is_dataclass(p.expected) - ) or either_isinstance(p, type): + ) or utils.either_isinstance(p, type): return None return collections_mapping( diff --git a/src/compyre/builtin/_torch.py b/src/compyre/builtin/_torch.py index 37689c6..30a6cee 100644 --- a/src/compyre/builtin/_torch.py +++ b/src/compyre/builtin/_torch.py @@ -1,10 +1,8 @@ from typing import Annotated -from compyre import alias, api +from compyre import alias, api, utils from compyre._availability import available_if -from ._utils import both_isinstance - @available_if("torch") def torch_tensor( @@ -39,7 +37,7 @@ def torch_tensor( """ import torch - if not both_isinstance(p, torch.Tensor): + if not utils.both_isinstance(p, torch.Tensor): return None try: diff --git a/src/compyre/builtin/_utils.py b/src/compyre/builtin/_utils.py deleted file mode 100644 index 318abad..0000000 --- a/src/compyre/builtin/_utils.py +++ /dev/null @@ -1,11 +0,0 @@ -from compyre.api import Pair - -__all__ = ["both_isinstance", "either_isinstance"] - - -def both_isinstance(pair: Pair, t: type | tuple[type, ...]) -> bool: - return isinstance(pair.actual, t) and isinstance(pair.expected, t) - - -def either_isinstance(pair: Pair, t: type | tuple[type, ...]) -> bool: - return isinstance(pair.actual, t) or isinstance(pair.expected, t) diff --git a/src/compyre/utils.py b/src/compyre/utils.py new file mode 100644 index 0000000..0f8aeb0 --- /dev/null +++ b/src/compyre/utils.py @@ -0,0 +1,31 @@ +from compyre import api + +__all__ = ["both_isinstance", "either_isinstance"] + + +def both_isinstance(pair: api.Pair, t: type | tuple[type, ...], /) -> bool: + """Check whether both values in a pair are instances of a given type. + + Args: + pair: Pair to be checked + t: The type or tuple of types to check the `pair`'s values against. + + Returns: + Whether both [`p.actual`][compyre.api.Pair] and [`p.expected`][compyre.api.Pair] are instances of `t`. + + """ + return isinstance(pair.actual, t) and isinstance(pair.expected, t) + + +def either_isinstance(pair: api.Pair, t: type | tuple[type, ...], /) -> bool: + """Check whether either value in a pair is an instance of a given type. + + Args: + pair: Pair to be checked + t: The type or tuple of types to check the `pair`'s values against. + + Returns: + Whether either [`p.actual`][compyre.api.Pair] or [`p.expected`][compyre.api.Pair] is an instances of `t`. + + """ + return isinstance(pair.actual, t) or isinstance(pair.expected, t)