Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
::: compyre.builtin.equal_fns

::: compyre.alias

::: compyre.utils
6 changes: 2 additions & 4 deletions src/compyre/builtin/_numpy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Annotated

from compyre import alias, api
from compyre import alias, api, utils
from compyre._availability import available_if

from ._utils import both_isinstance


@available_if("numpy")
def numpy_ndarray(
Expand Down Expand Up @@ -39,7 +37,7 @@ def numpy_ndarray(
"""
import numpy as np

if not both_isinstance(p, np.ndarray):
if not utils.both_isinstance(p, np.ndarray):
return None

try:
Expand Down
8 changes: 3 additions & 5 deletions src/compyre/builtin/_pandas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Annotated

from compyre import alias, api
from compyre import alias, api, utils
from compyre._availability import available_if

from ._utils import both_isinstance


@available_if("pandas")
def pandas_dataframe(
Expand Down Expand Up @@ -34,7 +32,7 @@ def pandas_dataframe(
"""
import pandas as pd

if not both_isinstance(p, pd.DataFrame):
if not utils.both_isinstance(p, pd.DataFrame):
return None

try:
Expand Down Expand Up @@ -77,7 +75,7 @@ def pandas_series(
"""
import pandas as pd

if not both_isinstance(p, pd.Series):
if not utils.both_isinstance(p, pd.Series):
return None

try:
Expand Down
11 changes: 6 additions & 5 deletions src/compyre/builtin/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from compyre import api, utils
from compyre._availability import available_if
from compyre.api import Pair, UnpackFnResult

from ._stdlib import collections_mapping
from ._utils import both_isinstance

__all__ = ["pydantic_model"]


@available_if("pydantic>=2,<3")
def pydantic_model(p: Pair, /) -> UnpackFnResult:
def pydantic_model(p: api.Pair, /) -> api.UnpackFnResult:
"""Unpack [pydantic.BaseModel][]s using [pydantic.BaseModel.model_dump][].

Args:
Expand All @@ -29,7 +28,7 @@ def pydantic_model(p: Pair, /) -> UnpackFnResult:
"""
import pydantic

if not both_isinstance(p, pydantic.BaseModel):
if not utils.both_isinstance(p, pydantic.BaseModel):
return None

try:
Expand All @@ -38,4 +37,6 @@ def pydantic_model(p: Pair, /) -> UnpackFnResult:
except Exception as result:
return result

return collections_mapping(Pair(index=p.index, actual=actual, expected=expected))
return collections_mapping(
api.Pair(index=p.index, actual=actual, expected=expected)
)
18 changes: 9 additions & 9 deletions src/compyre/builtin/_stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from collections.abc import Mapping, Sequence
from typing import Annotated

from compyre import alias, api

from ._utils import both_isinstance, either_isinstance
from compyre import alias, api, utils

__all__ = [
"builtins_number",
Expand All @@ -36,7 +34,7 @@ def collections_mapping(p: api.Pair, /) -> api.UnpackFnResult:
(ValueError): If the keys of [`p.actual`][compyre.api.Pair] and [`p.expected`][compyre.api.Pair] mismatch.

"""
if not both_isinstance(p, Mapping):
if not utils.both_isinstance(p, Mapping):
return None

extra = p.actual.keys() - p.expected.keys()
Expand Down Expand Up @@ -73,7 +71,7 @@ def collections_sequence(p: api.Pair, /) -> api.UnpackFnResult:
(ValueError): If the length of [`p.actual`][compyre.api.Pair] and [`p.expected`][compyre.api.Pair] mismatch.

"""
if not both_isinstance(p, Sequence) or either_isinstance(p, str):
if not utils.both_isinstance(p, Sequence) or utils.either_isinstance(p, str):
return None

if (la := len(p.actual)) != (le := len(p.expected)):
Expand Down Expand Up @@ -106,7 +104,7 @@ def collections_ordered_dict(p: api.Pair, /) -> api.UnpackFnResult:
mismatch.

"""
if not both_isinstance(p, OrderedDict):
if not utils.both_isinstance(p, OrderedDict):
return None

if (aks := list(p.actual.keys())) != (eks := list(p.expected.keys())):
Expand Down Expand Up @@ -145,10 +143,12 @@ def builtins_number(
(AssertionError): If [math.isclose][] or [cmath.isclose][] returns [False][] for the input pair.

"""
if not both_isinstance(p, (int, float, complex)) or either_isinstance(p, bool):
if not utils.both_isinstance(p, (int, float, complex)) or utils.either_isinstance(
p, bool
):
return None

isclose = cmath.isclose if either_isinstance(p, complex) else math.isclose
isclose = cmath.isclose if utils.either_isinstance(p, complex) else math.isclose
if isclose(p.actual, p.expected, abs_tol=abs_tol, rel_tol=rel_tol):
return True

Expand Down Expand Up @@ -232,7 +232,7 @@ def dataclasses_dataclass(p: api.Pair, /) -> api.UnpackFnResult:
# dataclasses.is_dataclass returns True for dataclass instances and types, but we only handle the former
if not (
dataclasses.is_dataclass(p.actual) and dataclasses.is_dataclass(p.expected)
) or either_isinstance(p, type):
) or utils.either_isinstance(p, type):
return None

return collections_mapping(
Expand Down
6 changes: 2 additions & 4 deletions src/compyre/builtin/_torch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Annotated

from compyre import alias, api
from compyre import alias, api, utils
from compyre._availability import available_if

from ._utils import both_isinstance


@available_if("torch")
def torch_tensor(
Expand Down Expand Up @@ -39,7 +37,7 @@ def torch_tensor(
"""
import torch

if not both_isinstance(p, torch.Tensor):
if not utils.both_isinstance(p, torch.Tensor):
return None

try:
Expand Down
11 changes: 0 additions & 11 deletions src/compyre/builtin/_utils.py

This file was deleted.

31 changes: 31 additions & 0 deletions src/compyre/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from compyre import api

__all__ = ["both_isinstance", "either_isinstance"]


def both_isinstance(pair: api.Pair, t: type | tuple[type, ...], /) -> bool:
"""Check whether both values in a pair are instances of a given type.

Args:
pair: Pair to be checked
t: The type or tuple of types to check the `pair`'s values against.

Returns:
Whether both [`p.actual`][compyre.api.Pair] and [`p.expected`][compyre.api.Pair] are instances of `t`.

"""
return isinstance(pair.actual, t) and isinstance(pair.expected, t)


def either_isinstance(pair: api.Pair, t: type | tuple[type, ...], /) -> bool:
"""Check whether either value in a pair is an instance of a given type.

Args:
pair: Pair to be checked
t: The type or tuple of types to check the `pair`'s values against.

Returns:
Whether either [`p.actual`][compyre.api.Pair] or [`p.expected`][compyre.api.Pair] is an instances of `t`.

"""
return isinstance(pair.actual, t) or isinstance(pair.expected, t)
Loading