Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,25 @@
__all__ = [
"angle",
"apply_where",
"argpartition",
"atleast_nd",
"broadcast_shapes",
"cov",
"create_diagonal",
"default_dtype",
"expand_dims",
"isclose",
"isin",
"kron",
"nan_to_num",
"nunique",
"one_hot",
"pad",
"partition",
"searchsorted",
"setdiff1d",
"sinc",
"union1d",
]


Expand Down
14 changes: 14 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import math
import warnings
from collections.abc import Callable
Expand Down Expand Up @@ -38,6 +39,7 @@
from array_api_extra import (
searchsorted as xpx_searchsorted,
)
from array_api_extra._lib import _funcs as functions
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
from array_api_extra._lib._funcs import searchsorted as _funcs_searchsorted
from array_api_extra._lib._utils._compat import (
Expand Down Expand Up @@ -73,6 +75,18 @@
lazy_xp_function(_funcs_searchsorted)


def test_all_contains_all_public_functions():
public_functions = {
name
for name, obj in inspect.getmembers(functions, inspect.isfunction)
if not name.startswith("_") and obj.__module__ == functions.__name__
}
assert public_functions == set(functions.__all__), (
f"Missing from __all__: {sorted(public_functions - set(functions.__all__))}\t"
f"Extra in __all__: {sorted(set(functions.__all__) - public_functions)}"
)


class TestApplyWhere:
@staticmethod
def f1(x: Array, y: Array | int = 10) -> Array:
Expand Down
Loading