Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,24 @@ def reproject(
f"{sorted(_VERTICAL_DATUM_EPSG)} or None."
)

# Normalize 3-D inputs to canonical (y, x, band) layout.
# The per-chunk workers slice the source as ``source_data[r:, c:]`` and
# assume the band axis is trailing. A rasterio/rioxarray-style
# ``(band, y, x)`` input would otherwise slice the band/y axes instead
# of the y/x axes and either crash or return wrong-shape data (#2182).
# We record the input's original dim order so the output can be
# transposed back at the end, preserving downstream expectations.
_input_dims = tuple(raster.dims)
if raster.ndim == 3:
_ydim_in, _xdim_in = _find_spatial_dims(raster)
_band_dims_in = [d for d in _input_dims
if d not in (_ydim_in, _xdim_in)]
_band_dim_in = _band_dims_in[0] if _band_dims_in else None
if _band_dim_in is not None:
_canonical = (_ydim_in, _xdim_in, _band_dim_in)
if _input_dims != _canonical:
raster = raster.transpose(*_canonical)

# Resolve CRS
src_crs = _resolve_crs(source_crs)
if src_crs is None:
Expand Down Expand Up @@ -956,6 +974,14 @@ def reproject(
name=name or raster.name,
attrs=out_attrs,
)

# Preserve the input's dim order so a ``(band, y, x)`` source produces a
# ``(band, y, x)`` output (#2182). The internal pipeline always builds the
# array as ``(y, x, band)`` for 3-D rasters; transpose back here.
if result.ndim == 3 and set(_input_dims) == set(result.dims):
if tuple(result.dims) != _input_dims:
result = result.transpose(*_input_dims)

return result


Expand Down
152 changes: 152 additions & 0 deletions xrspatial/tests/test_reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4328,6 +4328,158 @@ def test_merge_rejects_3d_dataarray(self):
merge([a, b], resolution=1.0)


# =====================================================================
# Issue #2182: 3-D (band, y, x) inputs across all backends
# =====================================================================

@pytest.mark.skipif(not HAS_PYPROJ, reason="pyproj not installed")
class TestReproject3DBandFirst:
"""reproject() must accept (band, y, x) inputs (rasterio convention).

Before the fix, the worker sliced the source as ``source_data[r:, c:]``
and read ``window.shape[2]`` for the band count, both of which assume
a trailing band axis. A ``(band, y, x)`` source therefore sliced the
band/y axes instead of y/x and either crashed with a coord-length
mismatch or returned wrong-shape data (#2182).
"""

@staticmethod
def _make_band_first_raster(rng_seed=2182, h=32, w=32, n_bands=3,
dtype=np.float32):
rng = np.random.default_rng(rng_seed)
data = rng.random((h, w, n_bands), dtype=np.float32).astype(dtype)
# Build (y, x, band) first so we can transpose to (band, y, x) and
# keep coords aligned to the same underlying values.
yxb = xr.DataArray(
data,
dims=['y', 'x', 'band'],
coords={
'y': np.linspace(55, 45, h),
'x': np.linspace(-5, 5, w),
'band': list(range(n_bands)),
},
attrs={'crs': 'EPSG:4326', 'nodata': np.nan},
)
return yxb.transpose('band', 'y', 'x')

def test_band_first_numpy_dims_preserved(self):
"""``(band, y, x)`` input must produce ``(band, y, x)`` output."""
from xrspatial.reproject import reproject
raster = self._make_band_first_raster()
result = reproject(raster, 'EPSG:32633')
assert result.dims == ('band', 'y', 'x')
assert result.shape[0] == 3
assert np.any(np.isfinite(result.values))

def test_band_first_numpy_band_coord_preserved(self):
"""Band coord values must round-trip through reproject."""
from xrspatial.reproject import reproject
raster = self._make_band_first_raster(n_bands=3)
result = reproject(raster, 'EPSG:32633')
assert 'band' in result.coords
assert list(result.coords['band'].values) == [0, 1, 2]

def test_band_first_matches_band_last(self):
"""The two layouts must produce identical pixel values."""
from xrspatial.reproject import reproject
bxy = self._make_band_first_raster()
yxb = bxy.transpose('y', 'x', 'band')
out_bxy = reproject(bxy, 'EPSG:32633').transpose('y', 'x', 'band')
out_yxb = reproject(yxb, 'EPSG:32633')
np.testing.assert_array_equal(
np.asarray(out_bxy.values), np.asarray(out_yxb.values),
)

def test_band_first_uint8_dtype_roundtrip(self):
"""Integer (band, y, x) inputs round-trip to source dtype."""
from xrspatial.reproject import reproject
rng = np.random.default_rng(11)
data = rng.integers(0, 255, (3, 32, 32), dtype=np.uint8)
raster = xr.DataArray(
data,
dims=['band', 'y', 'x'],
coords={
'band': [1, 2, 3],
'y': np.linspace(55, 45, 32),
'x': np.linspace(-5, 5, 32),
},
attrs={'crs': 'EPSG:4326', 'nodata': 0},
)
result = reproject(raster, 'EPSG:32633')
assert result.dtype == np.uint8
assert result.dims == ('band', 'y', 'x')
assert result.shape[0] == 3

@pytest.mark.skipif(not HAS_DASK, reason="dask required")
def test_band_first_dask_lazy_shape(self):
"""Lazy dask (band, y, x) DataArray must advertise 3-D shape."""
from xrspatial.reproject import reproject
raster = self._make_band_first_raster()
raster = raster.copy(
data=da.from_array(raster.values, chunks=(3, 16, 16))
)
result = reproject(raster, 'EPSG:32633')
assert result.ndim == 3
assert result.dims == ('band', 'y', 'x')
assert result.shape[0] == 3

@pytest.mark.skipif(not HAS_DASK, reason="dask required")
def test_band_first_dask_compute(self):
"""Computed dask result keeps band axis without ValueError."""
from xrspatial.reproject import reproject
raster = self._make_band_first_raster()
raster = raster.copy(
data=da.from_array(raster.values, chunks=(3, 16, 16))
)
result = reproject(raster, 'EPSG:32633').compute()
assert result.dims == ('band', 'y', 'x')
assert result.shape[0] == 3
assert np.any(np.isfinite(result.values))

@pytest.mark.skipif(not HAS_DASK, reason="dask required")
def test_band_first_dask_matches_numpy(self):
"""Dask (band, y, x) output must match eager numpy output."""
from xrspatial.reproject import reproject
host = self._make_band_first_raster()
eager = reproject(host, 'EPSG:32633')
lazy_src = host.copy(
data=da.from_array(host.values, chunks=(3, 16, 16))
)
lazy = reproject(lazy_src, 'EPSG:32633').compute()
np.testing.assert_allclose(
np.asarray(eager.values), np.asarray(lazy.values),
rtol=1e-6, atol=1e-6, equal_nan=True,
)

@pytest.mark.skipif(not HAS_CUPY, reason="CuPy not installed")
def test_band_first_cupy(self):
"""CuPy (band, y, x) reproject keeps band dim and dim order."""
from xrspatial.reproject import reproject
host = self._make_band_first_raster()
gpu_data = cp.asarray(host.values)
raster = host.copy(data=gpu_data)
result = reproject(raster, 'EPSG:32633')
assert result.dims == ('band', 'y', 'x')
assert result.shape[0] == 3
out = (cp.asnumpy(result.data) if isinstance(result.data, cp.ndarray)
else np.asarray(result.values))
assert np.any(np.isfinite(out))

@pytest.mark.skipif(
not (HAS_CUPY and HAS_DASK), reason="CuPy + dask required",
)
def test_band_first_dask_cupy(self):
"""dask+cupy (band, y, x) reproject keeps band dim and dim order."""
from xrspatial.reproject import reproject
host = self._make_band_first_raster()
gpu_data = da.from_array(cp.asarray(host.values), chunks=(3, 16, 16))
raster = host.copy(data=gpu_data)
result = reproject(raster, 'EPSG:32633')
assert result.dims == ('band', 'y', 'x')
computed = result.compute()
assert computed.shape[0] == 3


# ---------------------------------------------------------------------------
# Integer dtype nodata handling (#2185)
# ---------------------------------------------------------------------------
Expand Down
Loading