Skip to content

Commit bcbccba

Browse files
authored
geotiff: add read-finalization helpers (PR B of #2162) (#2200)
* geotiff: add read-finalization helpers _finalize_eager_read / _finalize_lazy_read_attrs (#2177) Wave 1 of #2162. Add two private helpers in xrspatial/geotiff/_attrs.py that capture the read-finalization pipelines duplicated across backends. The helpers are dead code until waves 2 (#2178 dask, #2179 eager) and 3 (#2180 VRT, GPU) consume them. _finalize_eager_read: validates geo_info, populates attrs, applies the sentinel mask, casts dtype, sets nodata attrs (with pixels_present as a bool), returns an xarray.DataArray. mask_sentinel is a parameter because the three GPU eager sites derive it three different ways (MinIsWhite inversion, CPU fallback, raw nodata). _finalize_lazy_read_attrs: validates geo_info, populates attrs, sets nodata attrs with pixels_present=None per the documented dask contract from #2135 (a strict per-chunk reduction would force eager .compute()). Returns the attrs dict only; the caller assembles the dask graph and builds the DataArray itself. _validate_read_geo_info runs first in both helpers so partial attrs do not leak on validation failure. No public API change. No call sites migrated. Helper signatures are frozen so wave 2 and 3 can depend on them. * Address review feedback (#2177) - Move ``import xarray as xr`` to module scope. xarray is already a hard dependency, so the per-call local import was unnecessary indirection. - Document the shallow-copy semantics of ``attrs_in`` on both helpers so wave 2 / wave 3 migrators know nested values are shared with the caller's seed dict. - Pin the int-dtype ``nodata_dtype_cast`` value in the lazy helper test so wave 2 catches any drift from the conflated-dtype semantics. - Strengthen the seed-dict-untouched assertions with a ``len(seed) == 1`` check so a future partial-leak that adds new keys is caught even if the original key still matches.
1 parent aa22f26 commit bcbccba

2 files changed

Lines changed: 900 additions & 0 deletions

File tree

xrspatial/geotiff/_attrs.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@
160160

161161
import numpy as np
162162

163+
import xarray as xr
164+
163165
from ._coords import (
166+
coords_from_geo_info as _coords_from_geo_info,
164167
transform_tuple_from_pixel_geometry as _transform_tuple_from_pixel_geometry,
165168
)
166169
from ._geotags import (
@@ -1128,3 +1131,306 @@ def _extract_rich_tags(attrs: dict) -> dict:
11281131
'y_resolution': attrs.get('y_resolution'),
11291132
'resolution_unit': res_unit,
11301133
}
1134+
1135+
1136+
def _apply_eager_nodata_mask(arr, *, mask_sentinel, mask_nodata):
1137+
"""Apply the nodata-to-NaN mask on an eager (host-side) numpy buffer.
1138+
1139+
Mirrors the inline block in ``open_geotiff`` so the eager helper can
1140+
share one implementation. Returns ``(arr, nodata_pixels_present)``
1141+
where ``arr`` may have been promoted from an integer dtype to float64
1142+
when the sentinel matched at least one pixel, and
1143+
``nodata_pixels_present`` is the bool used to populate
1144+
``attrs['nodata_pixels_present']``. ``None`` means "no scan was
1145+
appropriate for this dtype / sentinel combination."
1146+
1147+
The sentinel is taken as the ``mask_sentinel`` parameter rather than
1148+
being read from ``geo_info``. Three GPU eager sites derive it three
1149+
different ways (``_mw_mask_nodata`` local, the CPU-fallback
1150+
``_cpu_fallback_geo._mask_nodata``, raw ``nodata``), so the helper
1151+
accepts the sentinel value directly.
1152+
"""
1153+
nodata_pixels_present: bool | None = None
1154+
if mask_sentinel is None:
1155+
return arr, nodata_pixels_present
1156+
if mask_nodata:
1157+
if arr.dtype.kind == 'f':
1158+
if not np.isnan(mask_sentinel):
1159+
mask_f = arr == arr.dtype.type(mask_sentinel)
1160+
nodata_pixels_present = bool(mask_f.any())
1161+
if nodata_pixels_present:
1162+
arr[mask_f] = np.nan
1163+
else:
1164+
# NaN-only sentinel on a float buffer: ``mask_nodata`` is
1165+
# a no-op, but downstream may want to know if any NaN
1166+
# pixels already exist in the source so the attr stays
1167+
# informative.
1168+
nodata_pixels_present = bool(np.isnan(arr).any())
1169+
elif arr.dtype.kind in ('u', 'i'):
1170+
# Integer arrays: convert to float to represent NaN. Gate on
1171+
# finite + integer + in-range so a sentinel that cannot match
1172+
# an integer pixel resolves to ``False`` rather than crashing
1173+
# in the equality cast (mirrors the eager block in
1174+
# ``open_geotiff`` for #1774 / #1564 / #1616).
1175+
if (np.isfinite(mask_sentinel)
1176+
and float(mask_sentinel).is_integer()):
1177+
nodata_int = int(mask_sentinel)
1178+
info = np.iinfo(arr.dtype)
1179+
if info.min <= nodata_int <= info.max:
1180+
mask = arr == arr.dtype.type(nodata_int)
1181+
nodata_pixels_present = bool(mask.any())
1182+
if nodata_pixels_present:
1183+
arr = arr.astype(np.float64)
1184+
arr[mask] = np.nan
1185+
else:
1186+
nodata_pixels_present = False
1187+
else:
1188+
nodata_pixels_present = False
1189+
else:
1190+
# ``mask_nodata=False``: do not rewrite pixels, but still surface
1191+
# ``attrs['nodata_pixels_present']`` so callers know whether
1192+
# literal sentinel pixels survive in the buffer (issue #2135).
1193+
if arr.dtype.kind == 'f':
1194+
if np.isnan(mask_sentinel):
1195+
nodata_pixels_present = bool(np.isnan(arr).any())
1196+
else:
1197+
nodata_pixels_present = bool(
1198+
(arr == arr.dtype.type(mask_sentinel)).any()
1199+
)
1200+
elif arr.dtype.kind in ('u', 'i'):
1201+
if (np.isfinite(mask_sentinel)
1202+
and float(mask_sentinel).is_integer()):
1203+
nodata_int = int(mask_sentinel)
1204+
info = np.iinfo(arr.dtype)
1205+
if info.min <= nodata_int <= info.max:
1206+
nodata_pixels_present = bool(
1207+
(arr == arr.dtype.type(nodata_int)).any()
1208+
)
1209+
else:
1210+
nodata_pixels_present = False
1211+
else:
1212+
nodata_pixels_present = False
1213+
return arr, nodata_pixels_present
1214+
1215+
1216+
def _finalize_eager_read(
1217+
arr,
1218+
*,
1219+
geo_info,
1220+
nodata,
1221+
mask_sentinel,
1222+
mask_nodata,
1223+
dtype,
1224+
window,
1225+
name,
1226+
allow_rotated: bool = False,
1227+
allow_unparseable_crs: bool = False,
1228+
attrs_in: dict | None = None,
1229+
):
1230+
"""Validate, populate attrs, mask, cast, and build an eager DataArray.
1231+
1232+
Wave 1 of #2162 -- ties together the four steps every eager read path
1233+
runs after the bytes land in a host (or cupy) buffer:
1234+
1235+
1. :func:`_validate_read_geo_info` -- runs first so a rejected file
1236+
does not leak a partially-populated attrs dict.
1237+
2. :func:`_populate_attrs_from_geo_info` -- writes the canonical attrs
1238+
(transform / crs / georef_status / etc.) onto a fresh dict.
1239+
3. Mask nodata pixels to NaN using ``mask_sentinel`` when
1240+
``mask_nodata=True`` and the source declared one. Records the
1241+
``nodata_pixels_present`` bool either way.
1242+
4. Cast to ``dtype`` when explicit; record ``nodata_dtype_cast``.
1243+
5. :func:`_set_nodata_attrs` -- stamps the nodata lifecycle attrs.
1244+
6. Build an :class:`xarray.DataArray` with coords from
1245+
:func:`_coords_from_geo_info`.
1246+
1247+
The ``mask_sentinel`` parameter is intentionally separate from
1248+
``geo_info.nodata`` because the three GPU eager sites derive it three
1249+
different ways (``_mw_mask_nodata`` local on the stripped path, the
1250+
CPU-fallback ``_cpu_fallback_geo._mask_nodata`` on the tiled path,
1251+
raw ``nodata`` on the CPU-decode-then-upload path for URL / fsspec
1252+
sources). Read paths that don't need MinIsWhite inversion can pass
1253+
``mask_sentinel=nodata``.
1254+
1255+
Wave migration plan:
1256+
1257+
* Wave 2 (#2178 dask, #2179 eager numpy) migrates the eager numpy
1258+
paths. The mask block inside this helper matches the inline block
1259+
in ``open_geotiff`` field-for-field; the migration is a straight
1260+
swap.
1261+
* Wave 3 (#2180 VRT, GPU) migrates the VRT eager + three GPU eager
1262+
sites. The VRT eager path is host-side and works with the helper
1263+
as-is. The GPU sites apply masking through a CUDA kernel
1264+
(``_apply_nodata_mask_gpu_with_presence``); they can either
1265+
pre-mask and call the helper with ``nodata=None`` to skip the
1266+
helper's host-side mask block, or wave 3 can extend this
1267+
helper's signature with a ``mask_fn`` injection. Either choice
1268+
lives in #2180; the wave 1 contract here is the host-side path.
1269+
1270+
Returns a :class:`xarray.DataArray` ready for the caller to return
1271+
from the read function. The caller assembles the dask graph
1272+
separately when a lazy backend is in play; this helper is eager-only.
1273+
1274+
``attrs_in`` is shallow-copied via ``dict(attrs_in)``. Nested values
1275+
(e.g. ``extra_tags`` list, ``gdal_metadata`` dict) are shared between
1276+
the caller's seed dict and the returned DataArray's attrs; mutating
1277+
a nested value after the call propagates both ways. Callers that
1278+
care about isolation can ``copy.deepcopy(attrs_in)`` first.
1279+
"""
1280+
# Step 1: validate first so partial attrs never leak.
1281+
_validate_read_geo_info(
1282+
geo_info, window=window,
1283+
allow_rotated=allow_rotated,
1284+
allow_unparseable_crs=allow_unparseable_crs,
1285+
)
1286+
1287+
# Step 2: populate attrs from geo_info onto a fresh dict (or onto a
1288+
# caller-supplied seed dict, which lets the GPU/VRT migration carry
1289+
# backend-specific keys through without bypassing the helper).
1290+
attrs: dict = dict(attrs_in) if attrs_in else {}
1291+
_populate_attrs_from_geo_info(attrs, geo_info, window=window)
1292+
1293+
# Step 3: apply the nodata-to-NaN mask (or compute pixels_present
1294+
# without rewriting if ``mask_nodata=False``). Skipped entirely when
1295+
# the source declared no sentinel.
1296+
nodata_pixels_present: bool | None = None
1297+
if nodata is not None:
1298+
arr, nodata_pixels_present = _apply_eager_nodata_mask(
1299+
arr, mask_sentinel=mask_sentinel, mask_nodata=mask_nodata,
1300+
)
1301+
1302+
# Step 4: caller-requested dtype cast (post-mask so the integer
1303+
# promotion above runs first). ``_validate_dtype_cast`` lives in
1304+
# ``_validation``; local import keeps ``_attrs`` free of a top-level
1305+
# validation dependency for parity with ``_validate_read_geo_info``.
1306+
dtype_cast_attr: str | None = None
1307+
if dtype is not None:
1308+
from ._validation import _validate_dtype_cast
1309+
target = np.dtype(dtype)
1310+
_validate_dtype_cast(np.dtype(str(arr.dtype)), target)
1311+
arr = arr.astype(target)
1312+
dtype_cast_attr = target.name
1313+
1314+
# Step 5: stamp the nodata lifecycle attrs. ``masked`` is True iff
1315+
# the caller opted into masking AND the final buffer dtype is float,
1316+
# mirroring the existing call sites (the integer promotion above
1317+
# only runs when the sentinel matched at least one pixel, so an
1318+
# ``int`` buffer + ``mask_nodata=True`` here means "no pixels were
1319+
# masked" rather than "masking was disabled").
1320+
_set_nodata_attrs(
1321+
attrs, nodata,
1322+
masked=(mask_nodata and np.dtype(str(arr.dtype)).kind == 'f'),
1323+
pixels_present=nodata_pixels_present,
1324+
dtype_cast=dtype_cast_attr,
1325+
)
1326+
1327+
# Step 6: build the DataArray. ``_coords_from_geo_info`` honours the
1328+
# windowed-read contract (origin shifted to the window's top-left).
1329+
height, width = arr.shape[:2]
1330+
coords = _coords_from_geo_info(
1331+
geo_info, height, width, window=window,
1332+
)
1333+
if arr.ndim == 3:
1334+
dims = ['y', 'x', 'band']
1335+
coords['band'] = np.arange(arr.shape[2])
1336+
else:
1337+
dims = ['y', 'x']
1338+
1339+
return xr.DataArray(arr, dims=dims, coords=coords, name=name, attrs=attrs)
1340+
1341+
1342+
def _finalize_lazy_read_attrs(
1343+
*,
1344+
geo_info,
1345+
nodata,
1346+
mask_nodata,
1347+
dtype,
1348+
window,
1349+
allow_rotated: bool = False,
1350+
allow_unparseable_crs: bool = False,
1351+
attrs_in: dict | None = None,
1352+
):
1353+
"""Validate and populate attrs for dask-style lazy reads.
1354+
1355+
Wave 1 of #2162 -- the lazy counterpart of
1356+
:func:`_finalize_eager_read`. The dask + dask-GPU backends cannot
1357+
fold the nodata mask into a single eager step because masking runs
1358+
per-chunk inside the graph; they only need the attrs side of the
1359+
pipeline. This helper:
1360+
1361+
1. :func:`_validate_read_geo_info` -- runs first so partial attrs
1362+
never leak on validation failure.
1363+
2. :func:`_populate_attrs_from_geo_info` -- writes the canonical
1364+
attrs onto a fresh dict.
1365+
3. :func:`_set_nodata_attrs` -- ``masked`` is True iff the caller
1366+
opted into masking AND the graph dtype is float. ``dtype_cast``
1367+
is recorded when the caller passed an explicit ``dtype=`` kwarg.
1368+
``pixels_present=None`` is the documented dask contract from
1369+
issue #2135: a strict per-chunk reduction would force an eager
1370+
``.compute()`` and break the lazy contract, so the attr is left
1371+
absent on lazy outputs.
1372+
1373+
Returns the attrs ``dict`` only; the caller assembles the dask graph
1374+
and builds the :class:`xarray.DataArray` itself, so this helper
1375+
deliberately does not touch arrays or coords.
1376+
1377+
The ``dtype`` parameter accepts a numpy dtype, a string ('float64'),
1378+
or ``None``. It is the **resolved graph dtype** the dask backend
1379+
settled on (e.g. ``target_dtype`` after the int->float64 promotion
1380+
that ``mask_nodata=True`` triggers on int files): the helper uses
1381+
it to derive ``masked`` and writes it as ``nodata_dtype_cast`` when
1382+
non-None.
1383+
1384+
Wave 2 migration note: the current pre-helper dask backend
1385+
distinguishes "caller explicitly passed ``dtype=``" from
1386+
"graph dtype was auto-promoted by masking" so that
1387+
``nodata_dtype_cast`` surfaces only on the explicit-cast case.
1388+
This helper conflates the two -- whatever ``dtype`` value the
1389+
caller passes here becomes the ``nodata_dtype_cast`` attr. The
1390+
migration PR (#2178) can either accept that change, or split the
1391+
helper's ``dtype`` into two parameters at that point. Frozen
1392+
signature here per #2177 means we ship the one-``dtype`` shape
1393+
and leave the split for wave 2 if it turns out to matter.
1394+
1395+
``attrs_in`` is shallow-copied via ``dict(attrs_in)``. Nested values
1396+
are shared between the caller's seed dict and the returned attrs;
1397+
callers that care about isolation can ``copy.deepcopy(attrs_in)``
1398+
first.
1399+
"""
1400+
_validate_read_geo_info(
1401+
geo_info, window=window,
1402+
allow_rotated=allow_rotated,
1403+
allow_unparseable_crs=allow_unparseable_crs,
1404+
)
1405+
1406+
attrs: dict = dict(attrs_in) if attrs_in else {}
1407+
_populate_attrs_from_geo_info(attrs, geo_info, window=window)
1408+
1409+
# ``masked`` mirrors the eager helper's rule and the existing dask
1410+
# call site contract: the graph applies masking per-chunk only when
1411+
# ``mask_nodata=True`` AND the graph dtype is float, so an int graph
1412+
# with ``mask_nodata=True`` still carries literal sentinel values.
1413+
# ``dtype`` here is the resolved graph dtype; the dask backend
1414+
# promotes int -> float64 before calling this helper when the
1415+
# caller wants masking on an int source.
1416+
if dtype is None:
1417+
masked = False
1418+
else:
1419+
masked = bool(mask_nodata and np.dtype(dtype).kind == 'f')
1420+
1421+
# ``dtype_cast`` records the caller-supplied ``dtype=`` kwarg so
1422+
# consumers can tell float-because-masked from float-because-cast.
1423+
# The dask backend resolves ``dtype`` for the graph internally; the
1424+
# helper exposes it via ``attrs['nodata_dtype_cast']`` when set.
1425+
dtype_cast_attr = (
1426+
np.dtype(dtype).name if dtype is not None else None
1427+
)
1428+
1429+
_set_nodata_attrs(
1430+
attrs, nodata,
1431+
masked=masked,
1432+
pixels_present=None,
1433+
dtype_cast=dtype_cast_attr,
1434+
)
1435+
1436+
return attrs

0 commit comments

Comments
 (0)