From 3ddc2f8899ec87bcd29a32cdc04fb7a93bdb7107 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 20 May 2026 07:56:31 -0700 Subject: [PATCH 1/2] reproject: handle (band, y, x) 3-D inputs (#2182) The per-chunk worker assumed the band axis was trailing -- it sliced ``source_data[r:, c:]`` and read ``window.shape[2]`` for the band count. A rasterio/rioxarray-style ``(band, y, x)`` input sliced the band/y axes instead of y/x and either crashed with a band coord-length mismatch or returned wrong-shape data. The entry point now transposes 3-D inputs to the canonical ``(y, x, band)`` layout before dispatch, runs the existing pipeline, then transposes the output back to the input's dim order so downstream rioxarray/rasterio code keeps working. Adds a ``TestReproject3DBandFirst`` class covering ``(band, y, x)`` inputs on the numpy, dask+numpy, cupy, and dask+cupy backends, plus a pixel-equality check that the two layouts agree. --- xrspatial/reproject/__init__.py | 25 +++++ xrspatial/tests/test_reproject.py | 153 ++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+) diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index cfab24881..3cc58de32 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -652,6 +652,23 @@ def reproject( f"{sorted(_VERTICAL_DATUM_EPSG)} or None." ) + # Normalize 3-D inputs to canonical (y, x, band) layout. + # The per-chunk workers slice the source as ``source_data[r:, c:]`` and + # assume the band axis is trailing. A rasterio/rioxarray-style + # ``(band, y, x)`` input would otherwise slice the band/y axes instead + # of the y/x axes and either crash or return wrong-shape data (#2182). + # We record the input's original dim order so the output can be + # transposed back at the end, preserving downstream expectations. + _input_dims = tuple(raster.dims) + if raster.ndim == 3: + _ydim_in, _xdim_in = _find_spatial_dims(raster) + _band_dims_in = [d for d in _input_dims + if d not in (_ydim_in, _xdim_in)] + _band_dim_in = _band_dims_in[0] if _band_dims_in else None + _canonical = (_ydim_in, _xdim_in, _band_dim_in) + if _band_dim_in is not None and _input_dims != _canonical: + raster = raster.transpose(*_canonical) + # Resolve CRS src_crs = _resolve_crs(source_crs) if src_crs is None: @@ -858,6 +875,14 @@ def reproject( name=name or raster.name, attrs=out_attrs, ) + + # Preserve the input's dim order so a ``(band, y, x)`` source produces a + # ``(band, y, x)`` output (#2182). The internal pipeline always builds the + # array as ``(y, x, band)`` for 3-D rasters; transpose back here. + if result.ndim == 3 and set(_input_dims) == set(result.dims): + if tuple(result.dims) != _input_dims: + result = result.transpose(*_input_dims) + return result diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 6edd7012f..2d58f2cc0 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -3888,3 +3888,156 @@ def test_merge_rejects_3d_dataarray(self): ) with pytest.raises(ValueError, match=r"must be 2D"): merge([a, b], resolution=1.0) + + +# ===================================================================== +# Issue #2182: 3-D (band, y, x) inputs across all backends +# ===================================================================== + +@pytest.mark.skipif(not HAS_PYPROJ, reason="pyproj not installed") +class TestReproject3DBandFirst: + """reproject() must accept (band, y, x) inputs (rasterio convention). + + Before the fix, the worker sliced the source as ``source_data[r:, c:]`` + and read ``window.shape[2]`` for the band count, both of which assume + a trailing band axis. A ``(band, y, x)`` source therefore sliced the + band/y axes instead of y/x and either crashed with a coord-length + mismatch or returned wrong-shape data (#2182). + """ + + @staticmethod + def _make_band_first_raster(rng_seed=2182, h=32, w=32, n_bands=3, + dtype=np.float32): + rng = np.random.default_rng(rng_seed) + data = rng.random((h, w, n_bands), dtype=np.float32).astype(dtype) + # Build (y, x, band) first so we can transpose to (band, y, x) and + # keep coords aligned to the same underlying values. + yxb = xr.DataArray( + data, + dims=['y', 'x', 'band'], + coords={ + 'y': np.linspace(55, 45, h), + 'x': np.linspace(-5, 5, w), + 'band': list(range(n_bands)), + }, + attrs={'crs': 'EPSG:4326', 'nodata': np.nan}, + ) + return yxb.transpose('band', 'y', 'x') + + def test_band_first_numpy_dims_preserved(self): + """``(band, y, x)`` input must produce ``(band, y, x)`` output.""" + from xrspatial.reproject import reproject + raster = self._make_band_first_raster() + result = reproject(raster, 'EPSG:32633') + assert result.dims == ('band', 'y', 'x') + assert result.shape[0] == 3 + assert np.any(np.isfinite(result.values)) + + def test_band_first_numpy_band_coord_preserved(self): + """Band coord values must round-trip through reproject.""" + from xrspatial.reproject import reproject + raster = self._make_band_first_raster(n_bands=3) + result = reproject(raster, 'EPSG:32633') + assert 'band' in result.coords + assert list(result.coords['band'].values) == [0, 1, 2] + + def test_band_first_matches_band_last(self): + """The two layouts must produce identical pixel values.""" + from xrspatial.reproject import reproject + bxy = self._make_band_first_raster() + yxb = bxy.transpose('y', 'x', 'band') + out_bxy = reproject(bxy, 'EPSG:32633').transpose('y', 'x', 'band') + out_yxb = reproject(yxb, 'EPSG:32633') + np.testing.assert_array_equal( + np.asarray(out_bxy.values), np.asarray(out_yxb.values), + ) + + def test_band_first_uint8_dtype_roundtrip(self): + """Integer (band, y, x) inputs round-trip to source dtype.""" + from xrspatial.reproject import reproject + rng = np.random.default_rng(11) + data = rng.integers(0, 255, (3, 32, 32), dtype=np.uint8) + raster = xr.DataArray( + data, + dims=['band', 'y', 'x'], + coords={ + 'band': [1, 2, 3], + 'y': np.linspace(55, 45, 32), + 'x': np.linspace(-5, 5, 32), + }, + attrs={'crs': 'EPSG:4326', 'nodata': 0}, + ) + result = reproject(raster, 'EPSG:32633') + assert result.dtype == np.uint8 + assert result.dims == ('band', 'y', 'x') + assert result.shape[0] == 3 + + @pytest.mark.skipif(not HAS_DASK, reason="dask required") + def test_band_first_dask_lazy_shape(self): + """Lazy dask (band, y, x) DataArray must advertise 3-D shape.""" + from xrspatial.reproject import reproject + raster = self._make_band_first_raster() + raster = raster.copy( + data=da.from_array(raster.values, chunks=(3, 16, 16)) + ) + result = reproject(raster, 'EPSG:32633') + assert result.ndim == 3 + assert result.dims == ('band', 'y', 'x') + assert result.shape[0] == 3 + + @pytest.mark.skipif(not HAS_DASK, reason="dask required") + def test_band_first_dask_compute(self): + """Computed dask result keeps band axis without ValueError.""" + from xrspatial.reproject import reproject + raster = self._make_band_first_raster() + raster = raster.copy( + data=da.from_array(raster.values, chunks=(3, 16, 16)) + ) + result = reproject(raster, 'EPSG:32633').compute() + assert result.dims == ('band', 'y', 'x') + assert result.shape[0] == 3 + assert np.any(np.isfinite(result.values)) + + @pytest.mark.skipif(not HAS_DASK, reason="dask required") + def test_band_first_dask_matches_numpy(self): + """Dask (band, y, x) output must match eager numpy output.""" + from xrspatial.reproject import reproject + eager = reproject(self._make_band_first_raster(), 'EPSG:32633') + lazy_src = self._make_band_first_raster().copy( + data=da.from_array( + self._make_band_first_raster().values, chunks=(3, 16, 16) + ) + ) + lazy = reproject(lazy_src, 'EPSG:32633').compute() + np.testing.assert_allclose( + np.asarray(eager.values), np.asarray(lazy.values), + rtol=1e-6, atol=1e-6, equal_nan=True, + ) + + @pytest.mark.skipif(not HAS_CUPY, reason="CuPy not installed") + def test_band_first_cupy(self): + """CuPy (band, y, x) reproject keeps band dim and dim order.""" + from xrspatial.reproject import reproject + host = self._make_band_first_raster() + gpu_data = cp.asarray(host.values) + raster = host.copy(data=gpu_data) + result = reproject(raster, 'EPSG:32633') + assert result.dims == ('band', 'y', 'x') + assert result.shape[0] == 3 + out = (cp.asnumpy(result.data) if isinstance(result.data, cp.ndarray) + else np.asarray(result.values)) + assert np.any(np.isfinite(out)) + + @pytest.mark.skipif( + not (HAS_CUPY and HAS_DASK), reason="CuPy + dask required", + ) + def test_band_first_dask_cupy(self): + """dask+cupy (band, y, x) reproject keeps band dim and dim order.""" + from xrspatial.reproject import reproject + host = self._make_band_first_raster() + gpu_data = da.from_array(cp.asarray(host.values), chunks=(3, 16, 16)) + raster = host.copy(data=gpu_data) + result = reproject(raster, 'EPSG:32633') + assert result.dims == ('band', 'y', 'x') + computed = result.compute() + assert computed.shape[0] == 3 From 63987421323e1d2b13100dbb969e294d6baae18b Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Wed, 20 May 2026 08:06:02 -0700 Subject: [PATCH 2/2] Address review nits: clean up canonical-dim build and cache test raster (#2182) - Move the canonical (y, x, band) tuple construction inside the `_band_dim_in is not None` guard so the tuple never contains None. - Cache `_make_band_first_raster()` in `test_band_first_dask_matches_numpy` so the raster is built once instead of three times. --- xrspatial/reproject/__init__.py | 7 ++++--- xrspatial/tests/test_reproject.py | 9 ++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/xrspatial/reproject/__init__.py b/xrspatial/reproject/__init__.py index 3cc58de32..67a9f82dd 100644 --- a/xrspatial/reproject/__init__.py +++ b/xrspatial/reproject/__init__.py @@ -665,9 +665,10 @@ def reproject( _band_dims_in = [d for d in _input_dims if d not in (_ydim_in, _xdim_in)] _band_dim_in = _band_dims_in[0] if _band_dims_in else None - _canonical = (_ydim_in, _xdim_in, _band_dim_in) - if _band_dim_in is not None and _input_dims != _canonical: - raster = raster.transpose(*_canonical) + if _band_dim_in is not None: + _canonical = (_ydim_in, _xdim_in, _band_dim_in) + if _input_dims != _canonical: + raster = raster.transpose(*_canonical) # Resolve CRS src_crs = _resolve_crs(source_crs) diff --git a/xrspatial/tests/test_reproject.py b/xrspatial/tests/test_reproject.py index 2d58f2cc0..73190d717 100644 --- a/xrspatial/tests/test_reproject.py +++ b/xrspatial/tests/test_reproject.py @@ -4002,11 +4002,10 @@ def test_band_first_dask_compute(self): def test_band_first_dask_matches_numpy(self): """Dask (band, y, x) output must match eager numpy output.""" from xrspatial.reproject import reproject - eager = reproject(self._make_band_first_raster(), 'EPSG:32633') - lazy_src = self._make_band_first_raster().copy( - data=da.from_array( - self._make_band_first_raster().values, chunks=(3, 16, 16) - ) + host = self._make_band_first_raster() + eager = reproject(host, 'EPSG:32633') + lazy_src = host.copy( + data=da.from_array(host.values, chunks=(3, 16, 16)) ) lazy = reproject(lazy_src, 'EPSG:32633').compute() np.testing.assert_allclose(