diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index aea44a0d..97904ddb 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -291,7 +291,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) - avg = _helpers.mean(m, axis=-1, keepdims=True, xp=xp) + avg = xp.mean(m, axis=-1, keepdims=True) m_shape = eager_shape(m) fact = m_shape[-1] - 1 diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 4065c25f..097307a2 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -28,7 +28,6 @@ is_dask_namespace, is_jax_namespace, is_numpy_array, - is_pydata_sparse_namespace, is_torch_namespace, ) from ._typing import Array, Device @@ -53,7 +52,6 @@ def override(func): "in1d", "is_python_scalar", "jax_autojit", - "mean", "meta_namespace", "pickle_flatten", "pickle_unflatten", @@ -122,29 +120,6 @@ def in1d( return xp.take(ret, rev_idx, axis=0) -def mean( - x: Array, - /, - *, - axis: int | tuple[int, ...] | None = None, - keepdims: bool = False, - xp: ModuleType | None = None, -) -> Array: # numpydoc ignore=PR01,RT01 - """ - Complex mean, https://github.com/data-apis/array-api/issues/846. - """ - if xp is None: - xp = array_namespace(x) - - if xp.isdtype(x.dtype, "complex floating"): - x_real = xp.real(x) - x_imag = xp.imag(x) - mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims) - mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims) - return mean_real + (mean_imag * xp.asarray(1j)) - return xp.mean(x, axis=axis, keepdims=keepdims) - - def is_python_scalar(x: object) -> TypeIs[complex]: # numpydoc ignore=PR01,RT01 """Return True if `x` is a Python scalar, False otherwise.""" # isinstance(x, float) returns True for np.float64 @@ -332,14 +307,7 @@ def capabilities( Capabilities of the namespace. """ out = xp.__array_namespace_info__().capabilities() - if is_pydata_sparse_namespace(xp): - if out["boolean indexing"]: - # FIXME https://github.com/pydata/sparse/issues/876 - # boolean indexing is supported, but not when the index is a sparse array. - # boolean indexing by list or numpy array is not part of the Array API. - out = out.copy() - out["boolean indexing"] = False - elif is_jax_namespace(xp): + if is_jax_namespace(xp): if out["boolean indexing"]: # pragma: no cover # Backwards compatibility with jax <0.6.0 # https://github.com/jax-ml/jax/issues/27418