Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,14 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
isinstance(key, tuple) and len(key) == 1
): # pyarrow type arrays can't handle single-length tuples
(key,) = key
# Pandas extension arrays are 1-D and don't support Ellipsis-based
# 0-d indexing. NumpyIndexingAdapter appends Ellipsis to force 0-d
# slices, but this causes issues with IntervalArray and similar types.
# Strip trailing Ellipsis for 1-D extension arrays (GH#11300).
if isinstance(key, tuple) and key[-1] is Ellipsis:
key = tuple(k for k in key if k is not Ellipsis)
if len(key) == 1:
(key,) = key
item = self.array[key]
if is_allowed_extension_array(item):
return PandasExtensionArray(item)
Expand Down
8 changes: 8 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,14 @@ def __getitem__(self, key) -> Self:

def _finalize_indexing_result(self, dims, data) -> Self:
"""Used by IndexVariable to return IndexVariable objects when possible."""
# PandasExtensionArray is always 1-D and cannot represent scalar (0-d)
# data. When indexing produces a scalar result (dims=()), convert to
# numpy array (GH#11300).
if not dims:
from xarray.core.extension_array import PandasExtensionArray

if isinstance(data, PandasExtensionArray):
data = data.array[0]
return self._replace(dims=dims, data=data)

def _getitem_with_mask(self, key, fill_value=dtypes.NA):
Expand Down
11 changes: 6 additions & 5 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,20 +538,21 @@ def test_infer_datetime_units(freq, units) -> None:


@pytest.mark.parametrize(
["dates", "expected"],
["date_strings", "expected"],
[
(
pd.to_datetime(["1900-01-01", "1900-01-02", "NaT"], unit="ns"),
["1900-01-01", "1900-01-02", "NaT"],
"days since 1900-01-01 00:00:00",
),
(
pd.to_datetime(["NaT", "1900-01-01"], unit="ns"),
["NaT", "1900-01-01"],
"days since 1900-01-01 00:00:00",
),
(pd.to_datetime(["NaT"], unit="ns"), "days since 1970-01-01 00:00:00"),
(["NaT"], "days since 1970-01-01 00:00:00"),
],
)
def test_infer_datetime_units_with_NaT(dates, expected) -> None:
def test_infer_datetime_units_with_NaT(date_strings, expected) -> None:
dates = pd.to_datetime(date_strings, unit="ns")
assert expected == infer_datetime_units(dates)


Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6997,6 +6997,28 @@ def test_idxminmax_dask(self, op: str, ndim: int) -> None:
assert_equal(getattr(ar0_dsk, op)(dim="x"), getattr(ar0_raw, op)(dim="x"))


def test_idxminmax_interval_coords() -> None:
# GH#11300 - idxmax/idxmin should work with IntervalIndex coordinates
import pandas as pd

idx = pd.IntervalIndex.from_breaks([0, 1, 2, 3])
da = xr.DataArray([False, True, True], dims=["z"], coords={"z": idx})

result = da.idxmax()
assert result.values.item() == pd.Interval(1, 2, closed="right")

result = da.idxmin()
assert result.values.item() == pd.Interval(0, 1, closed="right")

# Test with skipna=False
da2 = xr.DataArray([np.nan, 1.0, 2.0], dims=["z"], coords={"z": idx})
result2 = da2.idxmax()
assert result2.values.item() == pd.Interval(2, 3, closed="right")

result3 = da2.idxmin()
assert result3.values.item() == pd.Interval(1, 2, closed="right")


@pytest.mark.parametrize("da", ("repeating_ints",), indirect=True)
def test_isin(da) -> None:
expected = DataArray(
Expand Down
Loading