diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 4e3b8753..7bafb696 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -25,17 +25,25 @@ __all__ = [ "angle", "apply_where", + "argpartition", "atleast_nd", "broadcast_shapes", "cov", "create_diagonal", + "default_dtype", "expand_dims", + "isclose", + "isin", "kron", + "nan_to_num", "nunique", + "one_hot", "pad", + "partition", "searchsorted", "setdiff1d", "sinc", + "union1d", ] diff --git a/tests/test_funcs.py b/tests/test_funcs.py index c212129d..ff14c0d4 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1,3 +1,4 @@ +import inspect import math import warnings from collections.abc import Callable @@ -38,6 +39,7 @@ from array_api_extra import ( searchsorted as xpx_searchsorted, ) +from array_api_extra._lib import _funcs as functions from array_api_extra._lib._backends import NUMPY_VERSION, Backend from array_api_extra._lib._funcs import searchsorted as _funcs_searchsorted from array_api_extra._lib._utils._compat import ( @@ -73,6 +75,18 @@ lazy_xp_function(_funcs_searchsorted) +def test_all_contains_all_public_functions(): + public_functions = { + name + for name, obj in inspect.getmembers(functions, inspect.isfunction) + if not name.startswith("_") and obj.__module__ == functions.__name__ + } + assert public_functions == set(functions.__all__), ( + f"Missing from __all__: {sorted(public_functions - set(functions.__all__))}\t" + f"Extra in __all__: {sorted(set(functions.__all__) - public_functions)}" + ) + + class TestApplyWhere: @staticmethod def f1(x: Array, y: Array | int = 10) -> Array: