diff --git a/xarray/compat/array_api_compat.py b/xarray/compat/array_api_compat.py index e1e5d5c5bdc..575f8cdf07d 100644 --- a/xarray/compat/array_api_compat.py +++ b/xarray/compat/array_api_compat.py @@ -1,3 +1,5 @@ +from types import ModuleType + import numpy as np from xarray.namedarray.pycompat import array_type @@ -46,7 +48,7 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype: return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) -def get_array_namespace(*values): +def get_array_namespace(*values) -> ModuleType: def _get_single_namespace(x): if hasattr(x, "__array_namespace__"): return x.__array_namespace__() diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 7884c7bd74a..acefe6243c9 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -9,6 +9,7 @@ from contextlib import suppress from dataclasses import dataclass, field from datetime import timedelta +from types import ModuleType from typing import TYPE_CHECKING, Any, cast, overload import numpy as np @@ -16,6 +17,7 @@ from numpy.typing import DTypeLike from packaging.version import Version +from xarray.compat.array_api_compat import get_array_namespace from xarray.compat.npcompat import HAS_STRING_DTYPE from xarray.core import duck_array_ops from xarray.core.coordinate_transform import CoordinateTransform @@ -693,7 +695,10 @@ def __array__( else: return np.asarray(to_numpy(self.get_duck_array()), dtype=dtype) - def get_duck_array(self): + def __array_namespace__(self: Any) -> ModuleType: + return get_array_namespace(self.array) + + def get_duck_array(self) -> duckarray: return self.array.get_duck_array() def __getitem__(self, key: Any): @@ -932,6 +937,9 @@ def get_duck_array(self): async def async_get_duck_array(self): return await self.array.async_get_duck_array() + def __array_namespace__(self: Any) -> ModuleType: + return get_array_namespace(self.array) + def _oindex_get(self, indexer: OuterIndexer): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) @@ -1763,6 +1771,9 @@ def __init__(self, array): ) self.array = array + def __array_namespace__(self: Any) -> ModuleType: + return get_array_namespace(self.array) + class ArrayApiIndexingAdapter(IndexingAdapter): """Wrap an array API array to use explicit indexing.""" diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index c535fcb0bc4..9481c0ce036 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -854,7 +854,7 @@ def chunk( # Using OuterIndexer is a pragmatic choice: dask does not yet handle # different indexing types in an explicit way: # https://github.com/dask/dask/issues/2883 - ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment] + ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) if is_dict_like(chunks): chunks = tuple(starmap(chunks.get, enumerate(ndata.shape))) diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index eb01a150c18..c03b9a4da13 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -68,8 +68,9 @@ def from_array( import dask.array as da if isinstance(data, ImplicitToExplicitIndexingAdapter): - # lazily loaded backend array classes should use NumPy array operations. - kwargs["meta"] = np.ndarray + # lazily loaded backend array classes should use NumPy or CuPy array operations. + xp = data.__array_namespace__() + kwargs["meta"] = xp.ndarray return da.from_array( data, diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py index a192930cea7..52349b13830 100644 --- a/xarray/namedarray/pycompat.py +++ b/xarray/namedarray/pycompat.py @@ -140,7 +140,7 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, return loaded_data if isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter): - return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return] + return data.get_duck_array() elif is_duck_array(data): return data else: diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index 9bbc3d9b06a..f39dfae53fb 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -3,6 +3,7 @@ import pytest import xarray as xr +from xarray.tests import requires_dask # Don't run cupy in CI because it requires a GPU NAMESPACE_ARRAYS = { @@ -22,6 +23,7 @@ "argsort": "no argsort", "conjugate": "conj but no conjugate", "searchsorted": "dask.array.searchsorted but no Array.searchsorted", + "dask_chunk_compute_roundtrip": "no need to test dask with dask", }, }, "jax.numpy": { @@ -123,7 +125,7 @@ def setup_for_test(self, request, namespace): reason = NAMESPACE_ARRAYS[namespace]["xfails"][xarray_method] pytest.xfail(f"xfail for {self.namespace}: {reason}") - def get_test_dataarray(self): + def get_test_dataarray(self) -> xr.DataArray: data = np.asarray([[1, 2, 3, np.nan, 5]]) x = np.arange(5) data = self.constructor(data) @@ -516,3 +518,25 @@ def test_sortby(self): def test_broadcast_like(self): result = self.x.broadcast_like(self.x) assert isinstance(result.data, self.Array) + + +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestDatasetMethods(_BaseTest): + @pytest.fixture(autouse=True) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.ds = self.get_test_dataarray().to_dataset() + + @requires_dask + def test_dask_chunk_compute_roundtrip(self): + """ + Ensure duck arrays chunked into a dask.Array get returned as duck arrays + (and not numpy array) after calling `.compute()`. + """ + chunked_ds = self.ds.chunk(x=2, chunked_array_type="dask") + assert isinstance(chunked_ds.foo.data._meta, self.Array) + + computed_ds = chunked_ds.compute() + assert isinstance(computed_ds.foo.data, self.Array), ( + f"Expected: {self.Array}, got {computed_ds.foo.data.__class__}" + )