Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion xarray/compat/array_api_compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from types import ModuleType

import numpy as np

from xarray.namedarray.pycompat import array_type
Expand Down Expand Up @@ -46,7 +48,7 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype:
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)


def get_array_namespace(*values):
def get_array_namespace(*values) -> ModuleType:
def _get_single_namespace(x):
if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
Expand Down
13 changes: 12 additions & 1 deletion xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from contextlib import suppress
from dataclasses import dataclass, field
from datetime import timedelta
from types import ModuleType
from typing import TYPE_CHECKING, Any, cast, overload

import numpy as np
import pandas as pd
from numpy.typing import DTypeLike
from packaging.version import Version

from xarray.compat.array_api_compat import get_array_namespace
from xarray.compat.npcompat import HAS_STRING_DTYPE
from xarray.core import duck_array_ops
from xarray.core.coordinate_transform import CoordinateTransform
Expand Down Expand Up @@ -693,7 +695,10 @@ def __array__(
else:
return np.asarray(to_numpy(self.get_duck_array()), dtype=dtype)

def get_duck_array(self):
def __array_namespace__(self: Any) -> ModuleType:
return get_array_namespace(self.array)

def get_duck_array(self) -> duckarray:
return self.array.get_duck_array()

def __getitem__(self, key: Any):
Expand Down Expand Up @@ -932,6 +937,9 @@ def get_duck_array(self):
async def async_get_duck_array(self):
return await self.array.async_get_duck_array()

def __array_namespace__(self: Any) -> ModuleType:
return get_array_namespace(self.array)

def _oindex_get(self, indexer: OuterIndexer):
return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer]))

Expand Down Expand Up @@ -1763,6 +1771,9 @@ def __init__(self, array):
)
self.array = array

def __array_namespace__(self: Any) -> ModuleType:
return get_array_namespace(self.array)


class ArrayApiIndexingAdapter(IndexingAdapter):
"""Wrap an array API array to use explicit indexing."""
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ def chunk(
# Using OuterIndexer is a pragmatic choice: dask does not yet handle
# different indexing types in an explicit way:
# https://github.com/dask/dask/issues/2883
ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment]
ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer)

if is_dict_like(chunks):
chunks = tuple(starmap(chunks.get, enumerate(ndata.shape)))
Expand Down
5 changes: 3 additions & 2 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def from_array(
import dask.array as da

if isinstance(data, ImplicitToExplicitIndexingAdapter):
# lazily loaded backend array classes should use NumPy array operations.
kwargs["meta"] = np.ndarray
# lazily loaded backend array classes should use NumPy or CuPy array operations.
xp = data.__array_namespace__()
kwargs["meta"] = xp.ndarray

return da.from_array(
data,
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType,
return loaded_data

if isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter):
return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return]
return data.get_duck_array()
elif is_duck_array(data):
return data
else:
Expand Down
26 changes: 25 additions & 1 deletion xarray/tests/test_duck_array_wrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import xarray as xr
from xarray.tests import requires_dask

# Don't run cupy in CI because it requires a GPU
NAMESPACE_ARRAYS = {
Expand All @@ -22,6 +23,7 @@
"argsort": "no argsort",
"conjugate": "conj but no conjugate",
"searchsorted": "dask.array.searchsorted but no Array.searchsorted",
"dask_chunk_compute_roundtrip": "no need to test dask with dask",
},
},
"jax.numpy": {
Expand Down Expand Up @@ -123,7 +125,7 @@ def setup_for_test(self, request, namespace):
reason = NAMESPACE_ARRAYS[namespace]["xfails"][xarray_method]
pytest.xfail(f"xfail for {self.namespace}: {reason}")

def get_test_dataarray(self):
def get_test_dataarray(self) -> xr.DataArray:
data = np.asarray([[1, 2, 3, np.nan, 5]])
x = np.arange(5)
data = self.constructor(data)
Expand Down Expand Up @@ -516,3 +518,25 @@ def test_sortby(self):
def test_broadcast_like(self):
result = self.x.broadcast_like(self.x)
assert isinstance(result.data, self.Array)


@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS)
class TestDatasetMethods(_BaseTest):
@pytest.fixture(autouse=True)
def setUp(self, request, namespace):
self.setup_for_test(request, namespace)
self.ds = self.get_test_dataarray().to_dataset()

@requires_dask
def test_dask_chunk_compute_roundtrip(self):
"""
Ensure duck arrays chunked into a dask.Array get returned as duck arrays
(and not numpy array) after calling `.compute()`.
"""
chunked_ds = self.ds.chunk(x=2, chunked_array_type="dask")
assert isinstance(chunked_ds.foo.data._meta, self.Array)

computed_ds = chunked_ds.compute()
assert isinstance(computed_ds.foo.data, self.Array), (
f"Expected: {self.Array}, got {computed_ds.foo.data.__class__}"
)
Loading