Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 126 additions & 16 deletions xrspatial/reproject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,18 @@ def _reproject_chunk_numpy(
c_min = np.nanmin(src_col_px)
c_max = np.nanmax(src_col_px)

# 3-D source: empty-chunk returns must carry the band axis or the
# dask map_blocks template (which is 3-D for 3-D sources) sees a
# shape mismatch (#2027).
if source_data.ndim == 3:
_empty_shape = (*chunk_shape, source_data.shape[2])
else:
_empty_shape = chunk_shape

if not np.isfinite(r_min) or not np.isfinite(r_max):
return np.full(chunk_shape, nodata, dtype=np.float64)
return np.full(_empty_shape, nodata, dtype=np.float64)
if not np.isfinite(c_min) or not np.isfinite(c_max):
return np.full(chunk_shape, nodata, dtype=np.float64)
return np.full(_empty_shape, nodata, dtype=np.float64)

r_min = int(np.floor(r_min)) - 2
r_max = int(np.ceil(r_max)) + 3
Expand All @@ -263,7 +271,7 @@ def _reproject_chunk_numpy(

# Check overlap
if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
return np.full(chunk_shape, nodata, dtype=np.float64)
return np.full(_empty_shape, nodata, dtype=np.float64)

# Clip to source bounds
r_min_clip = max(0, r_min)
Expand All @@ -276,7 +284,7 @@ def _reproject_chunk_numpy(
_MAX_WINDOW_PIXELS = 64 * 1024 * 1024 # 64 Mpix (~512 MB for float64)
win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip)
if win_pixels > _MAX_WINDOW_PIXELS:
return np.full(chunk_shape, nodata, dtype=np.float64)
return np.full(_empty_shape, nodata, dtype=np.float64)

# Extract source window
window = source_data[r_min_clip:r_max_clip, c_min_clip:c_max_clip]
Expand Down Expand Up @@ -337,6 +345,13 @@ def _reproject_chunk_cupy(
src_crs = _crs_from_wkt(src_wkt)
tgt_crs = _crs_from_wkt(tgt_wkt)

# 3-D source: empty-chunk returns must carry the band axis to match
# the dask+cupy map_blocks template (#2027).
if source_data.ndim == 3:
_empty_shape = (*chunk_shape, source_data.shape[2])
else:
_empty_shape = chunk_shape

# Try CUDA transform first (keeps coordinates on-device)
cuda_result = None
if src_crs is not None and tgt_crs is not None:
Expand Down Expand Up @@ -372,7 +387,7 @@ def _reproject_chunk_cupy(
)
if not (np.isfinite(r_min_val) and np.isfinite(r_max_val)
and np.isfinite(c_min_val) and np.isfinite(c_max_val)):
return cp.full(chunk_shape, nodata, dtype=cp.float64)
return cp.full(_empty_shape, nodata, dtype=cp.float64)
r_min = int(np.floor(r_min_val)) - 2
r_max = int(np.ceil(r_max_val)) + 3
c_min = int(np.floor(c_min_val)) - 2
Expand Down Expand Up @@ -407,17 +422,17 @@ def _reproject_chunk_cupy(
c_min = np.nanmin(src_col_px)
c_max = np.nanmax(src_col_px)
if not np.isfinite(r_min) or not np.isfinite(r_max):
return cp.full(chunk_shape, nodata, dtype=cp.float64)
return cp.full(_empty_shape, nodata, dtype=cp.float64)
if not np.isfinite(c_min) or not np.isfinite(c_max):
return cp.full(chunk_shape, nodata, dtype=cp.float64)
return cp.full(_empty_shape, nodata, dtype=cp.float64)
r_min = int(np.floor(r_min)) - 2
r_max = int(np.ceil(r_max)) + 3
c_min = int(np.floor(c_min)) - 2
c_max = int(np.ceil(c_max)) + 3
_use_native_cuda = False

if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
return cp.full(chunk_shape, nodata, dtype=cp.float64)
return cp.full(_empty_shape, nodata, dtype=cp.float64)

r_min_clip = max(0, r_min)
r_max_clip = min(src_h, r_max)
Expand All @@ -429,20 +444,56 @@ def _reproject_chunk_cupy(
_MAX_WINDOW_PIXELS = 64 * 1024 * 1024 # 64 Mpix (~512 MB for float64)
win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip)
if win_pixels > _MAX_WINDOW_PIXELS:
return cp.full(chunk_shape, nodata, dtype=cp.float64)
return cp.full(_empty_shape, nodata, dtype=cp.float64)

window = source_data[r_min_clip:r_max_clip, c_min_clip:c_max_clip]
if hasattr(window, 'compute'):
window = window.compute()
if not isinstance(window, cp.ndarray):
window = cp.asarray(window)
orig_dtype = window.dtype
window = window.astype(cp.float64)

# Adjust coordinates relative to window (stays on GPU if CuPy)
local_row = src_row_px - r_min_clip
local_col = src_col_px - c_min_clip

# Multi-band: reproject each band separately, share coordinates.
# Matches the 3-D branch in _reproject_chunk_numpy so 3-D inputs work
# on cupy and dask+cupy backends instead of crashing with a CUDA
# signature mismatch (#2027).
if window.ndim == 3:
n_bands = window.shape[2]
bands = []
for b in range(n_bands):
band_data = window[:, :, b].astype(cp.float64)
if _use_native_cuda:
# Native CUDA kernels do the nodata->NaN conversion
# internally; matching the 2-D path above.
band_result = _resample_cupy_native(
band_data, local_row, local_col,
resampling=resampling, nodata=nodata,
)
else:
# CPU coords path needs explicit conversion before
# cupyx.scipy.ndimage.map_coordinates.
if not np.isnan(nodata):
band_data = cp.where(
band_data == nodata, cp.nan, band_data,
)
band_result = _resample_cupy(
band_data, local_row, local_col,
resampling=resampling, nodata=nodata,
)
if np.issubdtype(orig_dtype, np.integer):
info = np.iinfo(orig_dtype)
band_result = cp.clip(
cp.round(band_result), info.min, info.max,
).astype(orig_dtype)
bands.append(band_result)
return cp.stack(bands, axis=-1)

window = window.astype(cp.float64)

if _use_native_cuda:
# Coordinates are already CuPy arrays -- use native CUDA kernels
# (nodata->NaN conversion is handled inside _resample_cupy_native)
Expand Down Expand Up @@ -1233,6 +1284,18 @@ def _reproject_dask_cupy(
src_res_x = (src_right - src_left) / src_w
src_res_y = (src_top - src_bottom) / src_h

# 3-D source: the fast path's inline loop assumes 2-D windows.
# Delegate to the map_blocks path which handles 3-D via
# _reproject_chunk_cupy's per-band loop (#2027).
if raster.data.ndim == 3:
return _reproject_dask(
raster, src_bounds, src_shape, y_desc,
src_wkt, tgt_wkt,
out_bounds, out_shape,
resampling, nodata, precision,
chunk_size or 2048, True, # is_cupy=True
)

# Memory check: if the full output doesn't fit in GPU memory,
# fall back to the map_blocks path which is O(chunk_size) memory.
estimated = out_shape[0] * out_shape[1] * 8 # float64
Expand Down Expand Up @@ -1444,21 +1507,37 @@ def _reproject_block_adapter(
src_wkt, tgt_wkt,
out_bounds, out_shape,
resampling, nodata, precision,
is_cupy, src_footprint_tgt,
is_cupy, src_footprint_tgt, n_bands=None,
):
"""``map_blocks`` adapter for reprojection.

Derives chunk bounds from *block_info* and delegates to the
per-chunk worker.

For 3-D sources the template carries the band axis, so each block
is ``(rh, rw, n_bands)``. The adapter strips the trailing band axis
when computing 2-D chunk bounds and the per-chunk worker returns a
3-D result that fits the template (#2027).
"""
info = block_info[0]
(row_start, row_end), (col_start, col_end) = info['array-location']
# 3-D template -> array-location is 3 entries; spatial dims are the
# first two. Band dim spans the full axis (single chunk).
spatial_loc = info['array-location'][:2]
(row_start, row_end), (col_start, col_end) = spatial_loc
chunk_shape = (row_end - row_start, col_end - col_start)
cb = _chunk_bounds(out_bounds, out_shape,
row_start, row_end, col_start, col_end)

is_3d = n_bands is not None

# Skip chunks that don't overlap the source footprint
if src_footprint_tgt is not None and not _bounds_overlap(cb, src_footprint_tgt):
if is_3d:
empty_shape = (*chunk_shape, n_bands)
if is_cupy:
import cupy as cp
return cp.full(empty_shape, nodata, dtype=cp.float64)
return np.full(empty_shape, nodata, dtype=np.float64)
return np.full(chunk_shape, nodata, dtype=np.float64)

chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy
Expand Down Expand Up @@ -1487,6 +1566,13 @@ def _reproject_dask(
adding the full source as a dependency of every output block (which
would cause a MemoryError on distributed schedulers when the source
exceeds the worker memory limit).

For 3-D sources with shape ``(H, W, n_bands)`` the template is built
as ``(out_H, out_W, n_bands)`` so the lazy metadata matches the
actual chunk output shape (#2027). Without this, the lazy DataArray
advertised 2-D shape while the underlying chunks produced 3-D
arrays, causing a ``ValueError: replacement data has shape ...``
crash on ``.compute()``.
"""
import functools

Expand All @@ -1499,6 +1585,11 @@ def _reproject_dask(
src_bounds, src_wkt, tgt_wkt
)

# Detect 3-D source: chunks will return (rh, rw, n_bands) so the
# template must carry the band axis through to the lazy DataArray.
is_3d = raster.data.ndim == 3
n_bands = raster.data.shape[2] if is_3d else None

# Bind the source dask array and all scalar params via partial so
# map_blocks doesn't detect them as dask Array kwargs (which would
# add the full source as a dependency of every output block).
Expand All @@ -1517,6 +1608,7 @@ def _reproject_dask(
precision=precision,
is_cupy=is_cupy,
src_footprint_tgt=src_footprint_tgt,
n_bands=n_bands,
)

# Pick the template dtype to match the eager path: integer sources
Expand All @@ -1529,9 +1621,22 @@ def _reproject_dask(
else:
out_dtype = np.dtype(np.float64)

template = da.empty(
out_shape, dtype=out_dtype, chunks=(row_chunks, col_chunks)
)
if is_3d:
# Band axis is one chunk in the template regardless of how the
# source dask array is chunked along its band axis. The per-block
# worker reads all bands together (via a 2-D y/x slice that
# rejoins band chunks on compute) and emits the full band stack
# for its (rh, rw) tile, so multi-chunk output along bands would
# never get filled.
template = da.empty(
(*out_shape, n_bands),
dtype=out_dtype,
chunks=(row_chunks, col_chunks, (n_bands,)),
)
else:
template = da.empty(
out_shape, dtype=out_dtype, chunks=(row_chunks, col_chunks)
)

return da.map_blocks(
bound_adapter,
Expand Down Expand Up @@ -1635,8 +1740,13 @@ def merge(
raise ValueError("merge(): rasters list must not be empty")

for i, r in enumerate(rasters):
# merge() only supports 2-D rasters: the merge strategies, same-CRS
# placement, and output DataArray construction all assume (y, x).
# The 3-D (y, x, band) path was never implemented end-to-end and
# crashed at DataArray construction with a dims-vs-shape mismatch
# (#2027). Reject 3-D up front so callers get a clear error.
_validate_raster(r, func_name='merge', name=f'rasters[{i}]',
ndim=(2, 3))
ndim=(2,))

_validate_grid_params(
resolution=resolution,
Expand Down
Loading
Loading