diff --git a/CHANGELOG.md b/CHANGELOG.md index c745c0ef..2f79648c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### Unreleased #### Bug fixes and improvements +- Shut down the per-tile compression `ThreadPoolExecutor` on every exit path of the streaming tiled-write code in `to_geotiff`. The old code only called `shutdown(wait=True)` after the tile-row loop completed, so any mid-stream raise (compression failure, dask compute failure, file write failure) bypassed shutdown and leaked worker threads. The loop now runs inside `try/finally` and the finally calls `shutdown(wait=True, cancel_futures=True)` so queued tiles get dropped on the error path instead of blocking the unwind. The pool's workers carry an `xrspatial-geotiff-tile-compress` `thread_name_prefix` so leak-detection tests can tell them apart from dask's own offload/scheduler pools. (#2276) - Remove read-side emission of the 13 deprecated GeoTIFF attrs (`crs_name`, `geog_citation`, `datum_code`, `angular_units`, `semi_major_axis`, `inv_flattening`, `linear_units`, `projection_code`, `vertical_crs`, `vertical_citation`, `vertical_units`, `colormap_rgba`, `cmap`) and bump `attrs['_xrspatial_geotiff_contract']` from 1 to 2. Downstream code that read these via `attrs[key]` now sees `KeyError`; migrate to `attrs.get(key)` or derive the value from `attrs['crs']` / `attrs['crs_wkt']` with pyproj. The `.xrs.plot()` accessor still surfaces palette colormaps by building a `ListedColormap` from the canonical `attrs['colormap']`. (#2016) - Accept numpy integer scalars as the `crs=` argument to `to_geotiff` / `write_geotiff_gpu`. The validator already allowed `numbers.Integral`, but the writers gated EPSG assignment on `isinstance(crs, int)`, so `np.int32` / `np.int64` / `np.uint16` values passed validation then silently fell through with no EPSG written. (#2082) - Tighten the writer's no-georef sentinel for integer x/y coords. The pre-fix check treated any integer dtype on either axis as the read-side no-georef placeholder and skipped transform inference, which also caught user-authored projected grids with integer-spaced coords (e.g. `x=[100,101,102], y=[200,199]`) and silently stripped their georef on write. The sentinel now matches only the exact reader pattern: `int64` ascending contiguous-step-1 arange on both axes. User-authored integer-coord grids that don't match (descending, non-unit step, non-uniform, or non-`int64`) now produce a real transform or raise `NonUniformCoordsError`. Coord values round-trip exactly through the new path; dtype flips int->float on subsequent reads. (#2087) diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index 508c3799..e0ca6d84 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -144,6 +144,13 @@ # carrying computed offsets, dimensions, or layout. See issue #1769. _OVERRIDABLE_AUTO_TAG_IDS = frozenset({TAG_PHOTOMETRIC, TAG_EXTRA_SAMPLES}) +# Thread-name prefix for the per-tile compression ``ThreadPoolExecutor`` +# in the streaming write path. Tagging the workers lets leak-detection +# tests (issue #2276) tell our pool's threads apart from dask's +# offload/scheduler pools, which also use ``ThreadPoolExecutor`` and +# are kept alive deliberately by dask as singletons. +_TILE_POOL_THREAD_PREFIX = 'xrspatial-geotiff-tile-compress' + # TIFF Photometric Interpretation values (``PHOTOMETRIC_MINISBLACK``, # ``PHOTOMETRIC_RGB``) and the ``_PHOTOMETRIC_NAME_MAP`` friendly-name # table live in ``_encode.py`` and are re-exported above for @@ -1023,113 +1030,133 @@ def _write_streaming(dask_data, path: str, *, _pool_workers = min(tiles_per_segment, os.cpu_count() or 4) _use_pool = (comp_tag != COMPRESSION_NONE and _pool_workers > 1) - tile_pool = (ThreadPoolExecutor(max_workers=_pool_workers) - if _use_pool else None) - - for tr in range(tiles_down): - r0 = tr * th - r1 = min(r0 + th, height) - actual_h = r1 - r0 - - for seg_start in range(0, tiles_across, tiles_per_segment): - seg_end = min(seg_start + tiles_per_segment, - tiles_across) - seg_c0 = seg_start * tw - seg_c1 = min(seg_end * tw, width) - - # Compute just this horizontal segment - if dask_data.ndim == 3: - seg_np = np.asarray( - dask_data[r0:r1, seg_c0:seg_c1, :].compute()) - else: - seg_np = np.asarray( - dask_data[r0:r1, seg_c0:seg_c1].compute()) - if hasattr(seg_np, 'get'): - seg_np = seg_np.get() - - if seg_np.dtype != out_dtype: - seg_np = seg_np.astype(out_dtype) - - # NaN -> nodata sentinel - if (nodata is not None and seg_np.dtype.kind == 'f' - and not np.isnan(nodata) - and restore_sentinel): - nan_mask = np.isnan(seg_np) - if nan_mask.any(): - seg_np = seg_np.copy() - seg_np[nan_mask] = seg_np.dtype.type(nodata) - - # Build tile arrays for this segment - seg_tile_arrs = [] - for tc in range(seg_start, seg_end): - c0 = tc * tw - c1 = min(c0 + tw, width) - actual_w = c1 - c0 - - local_c0 = c0 - seg_c0 - local_c1 = c1 - seg_c0 - tile_slice = seg_np[:, local_c0:local_c1] - - if actual_h < th or actual_w < tw: - if seg_np.ndim == 3: - padded = np.zeros((th, tw, samples), - dtype=out_dtype) + # ``thread_name_prefix`` tags the worker threads so leak + # detection in tests (issue #2276) can tell our pool's + # workers apart from dask's offload/scheduler pools. + tile_pool = ( + ThreadPoolExecutor( + max_workers=_pool_workers, + thread_name_prefix=_TILE_POOL_THREAD_PREFIX) + if _use_pool else None) + + # Wrap the tile loop in ``try/finally`` so the pool is + # always shut down before any exception (compression + # failure, dask compute failure, file write failure) + # propagates. The previous code only called + # ``shutdown`` after the loop completed and leaked + # worker threads on any mid-stream raise. See #2276. + try: + for tr in range(tiles_down): + r0 = tr * th + r1 = min(r0 + th, height) + actual_h = r1 - r0 + + for seg_start in range(0, tiles_across, tiles_per_segment): + seg_end = min(seg_start + tiles_per_segment, + tiles_across) + seg_c0 = seg_start * tw + seg_c1 = min(seg_end * tw, width) + + # Compute just this horizontal segment + if dask_data.ndim == 3: + seg_np = np.asarray( + dask_data[r0:r1, seg_c0:seg_c1, :].compute()) + else: + seg_np = np.asarray( + dask_data[r0:r1, seg_c0:seg_c1].compute()) + if hasattr(seg_np, 'get'): + seg_np = seg_np.get() + + if seg_np.dtype != out_dtype: + seg_np = seg_np.astype(out_dtype) + + # NaN -> nodata sentinel + if (nodata is not None and seg_np.dtype.kind == 'f' + and not np.isnan(nodata) + and restore_sentinel): + nan_mask = np.isnan(seg_np) + if nan_mask.any(): + seg_np = seg_np.copy() + seg_np[nan_mask] = seg_np.dtype.type(nodata) + + # Build tile arrays for this segment + seg_tile_arrs = [] + for tc in range(seg_start, seg_end): + c0 = tc * tw + c1 = min(c0 + tw, width) + actual_w = c1 - c0 + + local_c0 = c0 - seg_c0 + local_c1 = c1 - seg_c0 + tile_slice = seg_np[:, local_c0:local_c1] + + if actual_h < th or actual_w < tw: + if seg_np.ndim == 3: + padded = np.zeros((th, tw, samples), + dtype=out_dtype) + else: + padded = np.zeros((th, tw), dtype=out_dtype) + padded[:actual_h, :actual_w] = tile_slice + tile_arr = padded else: - padded = np.zeros((th, tw), dtype=out_dtype) - padded[:actual_h, :actual_w] = tile_slice - tile_arr = padded + tile_arr = np.ascontiguousarray(tile_slice) + + seg_tile_arrs.append(tile_arr) + + # Parallel compress on the hoisted ``tile_pool`` + # when it exists. zlib/zstd/LZW release the GIL, + # so threading actually parallelises the C-level + # work. Peak memory while the segment is in + # flight covers BOTH the uncompressed + # ``seg_tile_arrs`` (one full tile per column, + # released after the futures resolve) AND the + # compressed buffers ``seg_compressed`` (held + # until the sequential write loop drains them). + # Both lists are bounded by ``tiles_per_segment`` + # which the streaming buffer cap sets; fall + # through to a serial path when the pool is None + # (no compression / single core) or when only + # one tile sits in this segment. + n_seg_tiles = len(seg_tile_arrs) + if tile_pool is None or n_seg_tiles <= 1: + seg_compressed = [ + _compress_block( + ta, tw, th, samples, out_dtype, + bytes_per_sample, pred_int, comp_tag, + compression_level, max_z_error) + for ta in seg_tile_arrs + ] else: - tile_arr = np.ascontiguousarray(tile_slice) - - seg_tile_arrs.append(tile_arr) - - # Parallel compress on the hoisted ``tile_pool`` - # when it exists. zlib/zstd/LZW release the GIL, - # so threading actually parallelises the C-level - # work. Peak memory while the segment is in - # flight covers BOTH the uncompressed - # ``seg_tile_arrs`` (one full tile per column, - # released after the futures resolve) AND the - # compressed buffers ``seg_compressed`` (held - # until the sequential write loop drains them). - # Both lists are bounded by ``tiles_per_segment`` - # which the streaming buffer cap sets; fall - # through to a serial path when the pool is None - # (no compression / single core) or when only - # one tile sits in this segment. - n_seg_tiles = len(seg_tile_arrs) - if tile_pool is None or n_seg_tiles <= 1: - seg_compressed = [ - _compress_block( - ta, tw, th, samples, out_dtype, - bytes_per_sample, pred_int, comp_tag, - compression_level, max_z_error) - for ta in seg_tile_arrs - ] - else: - futures = [ - tile_pool.submit( - _compress_block, - ta, tw, th, samples, out_dtype, - bytes_per_sample, pred_int, comp_tag, - compression_level, max_z_error, - True) - for ta in seg_tile_arrs - ] - seg_compressed = [ - fut.result() for fut in futures] - - # Sequential file write to preserve on-disk tile order - for compressed in seg_compressed: - actual_offsets.append(current_offset) - actual_counts.append(len(compressed)) - f.write(compressed) - current_offset += len(compressed) - - del seg_np, seg_tile_arrs, seg_compressed - - if tile_pool is not None: - tile_pool.shutdown(wait=True) + futures = [ + tile_pool.submit( + _compress_block, + ta, tw, th, samples, out_dtype, + bytes_per_sample, pred_int, comp_tag, + compression_level, max_z_error, + True) + for ta in seg_tile_arrs + ] + seg_compressed = [ + fut.result() for fut in futures] + + # Sequential file write to preserve on-disk tile order + for compressed in seg_compressed: + actual_offsets.append(current_offset) + actual_counts.append(len(compressed)) + f.write(compressed) + current_offset += len(compressed) + + del seg_np, seg_tile_arrs, seg_compressed + finally: + # ``cancel_futures=True`` (Python 3.9+) drops any + # queued-but-not-started compress jobs on the + # error path so ``wait=True`` only blocks on work + # already in flight. The previous shutdown call + # lived past the for-loop and never ran when an + # exception escaped, leaking worker threads. See + # issue #2276. + if tile_pool is not None: + tile_pool.shutdown(wait=True, cancel_futures=True) else: # Strip layout for i in range(n_entries): diff --git a/xrspatial/geotiff/tests/test_streaming_write_pool_leak_2276.py b/xrspatial/geotiff/tests/test_streaming_write_pool_leak_2276.py new file mode 100644 index 00000000..000b62f5 --- /dev/null +++ b/xrspatial/geotiff/tests/test_streaming_write_pool_leak_2276.py @@ -0,0 +1,261 @@ +"""ThreadPoolExecutor leak on mid-stream failure in tiled writes (#2276). + +The streaming tiled-write path in ``_write_streaming`` builds a +``ThreadPoolExecutor`` for parallel per-tile compression. Before this +fix, ``shutdown`` only ran after the tile-row loop completed -- if any +mid-stream step raised (compression error, dask compute error, file +write error), the shutdown was skipped and worker threads survived the +failure path. + +These tests inject a failure mid-stream and assert that: + +1. No worker threads owned by the writer's pool remain alive after the + call returns, and +2. The injected exception propagates cleanly out of ``to_geotiff`` (no + "swallowed" failures), and +3. The pool the writer constructed is shut down (``_shutdown`` flag + set). + +We monkey-patch ``ThreadPoolExecutor`` inside ``_writer`` to capture +the pool ``_write_streaming`` constructs and inspect its state after +the failure path. The test also walks ``threading.enumerate()`` to +check that no threads with the writer's distinctive +``thread_name_prefix`` (``_TILE_POOL_THREAD_PREFIX`` from the writer +module) remain. Dask spins up its own ``ThreadPoolExecutor`` instances +during ``.compute()`` -- those use a different prefix and are +deliberately kept alive as singletons, so filtering on the writer's +prefix avoids false positives. +""" +from __future__ import annotations + +import os +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import to_geotiff +from xrspatial.geotiff import _writer as writer_mod + + +# Re-use the writer's own constant so the test does not silently drift +# if the prefix ever changes on the writer side. ``_writer`` exposes +# ``_TILE_POOL_THREAD_PREFIX`` for exactly this purpose (#2276). +_WRITER_POOL_PREFIX = writer_mod._TILE_POOL_THREAD_PREFIX + + +def _make_dataarray(shape, dtype=np.float32, seed=20260521): + rng = np.random.default_rng(seed) + arr = rng.random(shape, dtype=dtype) + h, w = shape + y = np.linspace(41.0, 40.0, h) + x = np.linspace(-106.0, -105.0, w) + return xr.DataArray( + arr, dims=['y', 'x'], coords={'y': y, 'x': x}, + attrs={'crs': 4326, 'nodata': -9999.0}) + + +def _list_writer_pool_worker_threads(): + """Return live threads owned by the writer's tile-compress pool. + + Dask spins up its own ``ThreadPoolExecutor`` instances (the + ``Dask-Offload`` singleton and the threaded scheduler) that survive + deliberately, so filtering on the writer's distinctive prefix + avoids false positives. + """ + return [t for t in threading.enumerate() + if t.name.startswith(_WRITER_POOL_PREFIX) and t.is_alive()] + + +@pytest.fixture +def captured_pools(monkeypatch): + """Capture only the ``ThreadPoolExecutor`` instances ``_writer`` + constructs (filtered by ``thread_name_prefix``). + + ``_write_streaming`` does ``from concurrent.futures import + ThreadPoolExecutor`` inside the function body, so we patch the + symbol on the ``concurrent.futures`` module the import resolves + through. Dask also constructs its own ``ThreadPoolExecutor`` + instances during ``.compute()``; those use a different + ``thread_name_prefix`` and are filtered out here. + """ + pools = [] + real_cls = ThreadPoolExecutor + + class _RecordingPool(real_cls): # type: ignore[misc, valid-type] + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + prefix = kwargs.get('thread_name_prefix', '') + if prefix.startswith(_WRITER_POOL_PREFIX): + pools.append(self) + + import concurrent.futures as _cf + monkeypatch.setattr(_cf, 'ThreadPoolExecutor', _RecordingPool) + return pools + + +def test_pool_shutdown_on_compress_failure( + captured_pools, monkeypatch, tmp_path): + """A raise inside ``_compress_block`` must shut the pool down.""" + # Force cpu_count high enough that ``_use_pool`` triggers. + monkeypatch.setattr(os, 'cpu_count', lambda: 4) + + pre_existing_pool_threads = _list_writer_pool_worker_threads() + pre_existing_names = {t.name for t in pre_existing_pool_threads} + + real_compress = writer_mod._compress_block + call_count = {'n': 0} + lock = threading.Lock() + + class _InjectedError(RuntimeError): + pass + + def failing_compress(*args, **kwargs): + # Let the first few calls succeed so the executor genuinely + # spins up worker threads, then raise from inside a worker. + with lock: + call_count['n'] += 1 + n = call_count['n'] + if n >= 3: + raise _InjectedError( + f"injected failure on _compress_block call #{n}") + return real_compress(*args, **kwargs) + + monkeypatch.setattr(writer_mod, '_compress_block', failing_compress) + + # Sized to produce many tiles per segment so the parallel branch + # fires (pool only used when ``n_seg_tiles > 1``). + shape = (8 * 256, 8 * 256) + da = _make_dataarray(shape) + dask_da = da.chunk({'y': 256 * 2, 'x': 256 * 2}) + + out_path = str(tmp_path / 'tmp_2276_compress_fail.tif') + + with pytest.raises(_InjectedError): + to_geotiff(dask_da, out_path, + compression='deflate', tile_size=256) + + # At least one pool should have been constructed by _write_streaming. + assert len(captured_pools) >= 1, ( + "Expected _write_streaming to construct a ThreadPoolExecutor; " + "none were captured.") + + # Every captured pool must be shut down by the time we get here. + for idx, pool in enumerate(captured_pools): + assert pool._shutdown, ( + f"Captured pool #{idx} was NOT shut down after the " + f"mid-stream failure -- ThreadPoolExecutor leak.") + + # Give workers a moment to actually exit after shutdown(wait=True); + # _shutdown=True plus shutdown(wait=True) should already mean + # threads have joined, but defensive sleep avoids a race on slow + # CI runners. + deadline = time.monotonic() + 2.0 + leaked = [] + while time.monotonic() < deadline: + current = _list_writer_pool_worker_threads() + leaked = [t for t in current if t.name not in pre_existing_names] + if not leaked: + break + time.sleep(0.05) + + assert not leaked, ( + f"ThreadPoolExecutor worker threads still alive after failed " + f"streaming write: {[t.name for t in leaked]}") + + +def test_pool_shutdown_on_file_write_failure( + captured_pools, monkeypatch, tmp_path): + """A raise from the sequential file-write step (after the parallel + compress has already run for a segment) must still shut the pool + down. This covers the second class of mid-stream failure: the + pool's work finished cleanly but the consumer of those compressed + buffers failed before the loop reached the bottom of the function. + """ + monkeypatch.setattr(os, 'cpu_count', lambda: 4) + + pre_existing = {t.name for t in _list_writer_pool_worker_threads()} + + class _InjectedWriteError(IOError): + pass + + # Wrap ``os.fdopen`` so the file object's ``write`` raises after a + # configurable number of calls. The streaming writer opens the + # output file via ``os.fdopen(fd, 'wb')`` once and then calls + # ``f.write(...)`` for each header, IFD chunk, and compressed + # tile. Letting a generous number of early writes through gets us + # past the header/IFD and into the per-tile write loop where the + # pool is actively in use. + real_fdopen = os.fdopen + write_count = {'n': 0} + + def wrapping_fdopen(fd, mode='r', *args, **kwargs): + f = real_fdopen(fd, mode, *args, **kwargs) + real_write = f.write + + def counting_write(data): + write_count['n'] += 1 + if write_count['n'] >= 12: + raise _InjectedWriteError( + f"injected file-write failure on call " + f"#{write_count['n']}") + return real_write(data) + + f.write = counting_write + return f + + monkeypatch.setattr(os, 'fdopen', wrapping_fdopen) + + shape = (8 * 256, 8 * 256) + da = _make_dataarray(shape) + dask_da = da.chunk({'y': 512, 'x': 512}) + + out_path = str(tmp_path / 'tmp_2276_write_fail.tif') + + with pytest.raises(_InjectedWriteError): + to_geotiff(dask_da, out_path, + compression='deflate', tile_size=256) + + assert len(captured_pools) >= 1, ( + "Expected _write_streaming to construct a writer pool before " + "the file-write failure.") + for idx, pool in enumerate(captured_pools): + assert pool._shutdown, ( + f"Captured pool #{idx} not shut down after file-write " + f"failure -- ThreadPoolExecutor leak.") + + deadline = time.monotonic() + 2.0 + leaked = [] + while time.monotonic() < deadline: + leaked = [t for t in _list_writer_pool_worker_threads() + if t.name not in pre_existing] + if not leaked: + break + time.sleep(0.05) + assert not leaked, ( + f"Pool worker threads leaked after file-write failure: " + f"{[t.name for t in leaked]}") + + +def test_pool_shutdown_on_happy_path( + captured_pools, monkeypatch, tmp_path): + """Regression guard: the success path must still shut the pool + down -- the ``try/finally`` rewrite must not regress the original + behaviour.""" + monkeypatch.setattr(os, 'cpu_count', lambda: 4) + + shape = (4 * 256, 4 * 256) + da = _make_dataarray(shape) + dask_da = da.chunk({'y': 512, 'x': 512}) + + out_path = str(tmp_path / 'tmp_2276_happy.tif') + to_geotiff(dask_da, out_path, + compression='deflate', tile_size=256) + + assert len(captured_pools) >= 1 + for idx, pool in enumerate(captured_pools): + assert pool._shutdown, ( + f"Captured pool #{idx} not shut down on the success path.")