From 281fa0cfcb6aec0fa5484d3411934b518a56c01f Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Fri, 15 May 2026 06:29:19 -0700 Subject: [PATCH] geotiff: route write_vrt dtype through central resolver (#1914) write_vrt built the VRT dataType attribute from a local if/elif/else ladder keyed on sample_format and bps. The ladder had no entry for bps=12, so unsigned 12-bit TIFF sources were tagged as Byte in the VRT header. The reader path promotes those samples to uint16, so downstream GDAL/VRT consumers could truncate values above 255. Replace the ladder with _vrt_dtype_name_for, which calls tiff_dtype_to_numpy and then looks up the GDAL name in the existing _NP_TO_VRT_DTYPE table. A numpy dtype with no GDAL mapping now raises ValueError instead of falling back to Byte. The helper is exposed at module scope so the dtype path can be unit-tested without writing a 12-bit TIFF to disk. --- xrspatial/geotiff/_vrt.py | 57 ++++-- .../tests/test_vrt_dtype_12bit_1914.py | 177 ++++++++++++++++++ 2 files changed, 221 insertions(+), 13 deletions(-) create mode 100644 xrspatial/geotiff/tests/test_vrt_dtype_12bit_1914.py diff --git a/xrspatial/geotiff/_vrt.py b/xrspatial/geotiff/_vrt.py index a070016a1..845783983 100644 --- a/xrspatial/geotiff/_vrt.py +++ b/xrspatial/geotiff/_vrt.py @@ -1337,6 +1337,42 @@ def read_vrt(vrt_path: str, *, window=None, _NP_TO_VRT_DTYPE = {v: k for k, v in _DTYPE_MAP.items()} +def _vrt_dtype_name_for(bps, sample_format): + """Map TIFF ``BitsPerSample`` + ``SampleFormat`` to a GDAL VRT dtype. + + Routes through ``_dtypes.tiff_dtype_to_numpy`` so the VRT header + matches what the reader actually decodes (including ``bps=12, sf=1`` + -> ``uint16``). Raises ``ValueError`` when the resolved numpy dtype + has no GDAL VRT name rather than silently falling back to ``Byte``. + + Parameters + ---------- + bps : int + Raw ``BitsPerSample`` (already resolved to a scalar; pass through + ``resolve_bits_per_sample`` first if you have a sequence). + sample_format : int or sequence of int + Raw ``SampleFormat`` tag value. Sequences are normalised through + ``resolve_sample_format``. + + Returns + ------- + str + A GDAL ``dataType`` name (``Byte``, ``UInt16``, ``Float32`` ...). + """ + from ._dtypes import resolve_sample_format, tiff_dtype_to_numpy + + sf = resolve_sample_format(sample_format) + np_dtype = tiff_dtype_to_numpy(bps, sf) + try: + return _NP_TO_VRT_DTYPE[np_dtype.type] + except KeyError: + raise ValueError( + f"Cannot map numpy dtype {np_dtype} (from bps={bps}, " + f"sample_format={sf}) to a GDAL VRT dataType. Supported " + f"VRT dtypes are: {sorted(_NP_TO_VRT_DTYPE.values())}." + ) + + def write_vrt(vrt_path: str, source_files: list[str], *, relative: bool = True, crs_wkt: str | None = None, @@ -1504,19 +1540,14 @@ def _pixel_size_mismatch(a: float, b: float) -> bool: total_w = int(round((mosaic_x1 - mosaic_x0) / abs(res_x))) total_h = int(round((mosaic_y_top - mosaic_y_bottom) / abs(res_y))) - # Determine VRT dtype - sf = first['sample_format'] - bps = first['bps'] - if sf == 3: - vrt_dtype_name = 'Float64' if bps == 64 else 'Float32' - elif sf == 2: - vrt_dtype_name = { - 8: 'Int8', 16: 'Int16', 32: 'Int32', 64: 'Int64', - }.get(bps, 'Int32') - else: - vrt_dtype_name = { - 8: 'Byte', 16: 'UInt16', 32: 'UInt32', 64: 'UInt64', - }.get(bps, 'Byte') + # Determine VRT dtype via the central TIFF-to-numpy resolver so the + # VRT header agrees with what the reader will actually decode. The + # previous local if/elif/else ladder had no entry for sub-byte or + # 12-bit unsigned samples (reader promotes ``bps=12, sf=1`` to + # ``uint16``), so a VRT over a valid 12-bit source got tagged + # ``Byte`` and could be truncated by downstream GDAL readers. Issue + # #1914. + vrt_dtype_name = _vrt_dtype_name_for(first['bps'], first['sample_format']) srs = crs_wkt or first.get('crs_wkt') or '' nd = nodata if nodata is not None else first.get('nodata') diff --git a/xrspatial/geotiff/tests/test_vrt_dtype_12bit_1914.py b/xrspatial/geotiff/tests/test_vrt_dtype_12bit_1914.py new file mode 100644 index 000000000..f796e0407 --- /dev/null +++ b/xrspatial/geotiff/tests/test_vrt_dtype_12bit_1914.py @@ -0,0 +1,177 @@ +"""Regression tests for issue #1914. + +``write_vrt`` used to build its GDAL ``dataType`` attribute from a local +if/elif/else ladder keyed on ``sample_format`` and ``bps`` rather than +going through the central ``tiff_dtype_to_numpy`` resolver. For +``bps=12, sample_format=1`` that ladder fell through to ``'Byte'``, +even though the reader promotes the same sample to ``uint16``. A VRT +over a valid 12-bit unsigned source would then advertise a narrower +type and could be truncated by downstream GDAL/VRT consumers. + +These tests pin the VRT dtype name for every TIFF (bps, sample_format) +the resolver supports, including the 12-bit regression case. The +helper ``_vrt_dtype_name_for`` is called directly because the on-disk +writers in ``to_geotiff`` only emit standard 8/16/32/64-bit samples, +so a real 12-bit TIFF can't be round-tripped through ``write_vrt`` in +the test environment. + +A second block exercises ``write_vrt`` end-to-end with normal 16-bit +unsigned tiles and asserts the generated XML carries ``UInt16`` (not +``Byte``) so the actual writer path is covered too. +""" +from __future__ import annotations + +import os +import uuid + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import to_geotiff +from xrspatial.geotiff._vrt import ( + _NP_TO_VRT_DTYPE, + _vrt_dtype_name_for, + write_vrt, +) + + +# --------------------------------------------------------------------------- +# Direct helper tests: every supported (bps, sample_format) pair +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "bps,sf,expected", + [ + # The regression case. ``_dtypes.tiff_dtype_to_numpy`` promotes + # ``bps=12, sf=1`` to ``uint16``; the VRT dtype must follow. + (12, 1, "UInt16"), + (12, 4, "UInt16"), # SAMPLE_FORMAT_UNDEFINED -> uint + # Sub-byte unsigned bit depths the reader promotes to uint8. + (1, 1, "Byte"), + (2, 1, "Byte"), + (4, 1, "Byte"), + # Standard unsigned bit depths. + (8, 1, "Byte"), + (16, 1, "UInt16"), + (32, 1, "UInt32"), + (64, 1, "UInt64"), + # Signed integers. + (8, 2, "Int8"), + (16, 2, "Int16"), + (32, 2, "Int32"), + (64, 2, "Int64"), + # Floats. The previous ladder defaulted ``sf=3`` to ``Float32`` + # for any non-64 bps, which silently downcast ``bps=16`` halfs. + # The resolver rejects bps=16 floats outright (no IEEE half + # support), so we only pin the supported widths here. + (32, 3, "Float32"), + (64, 3, "Float64"), + ], +) +def test_vrt_dtype_name_for_supported(bps, sf, expected): + assert _vrt_dtype_name_for(bps, sf) == expected + + +def test_vrt_dtype_name_for_sample_format_sequence_resolves(): + # ``ifd.sample_format`` is sometimes a tuple of per-band values; the + # helper must funnel it through ``resolve_sample_format`` rather + # than treating the raw tuple as an int. + assert _vrt_dtype_name_for(8, [1, 1]) == "Byte" + assert _vrt_dtype_name_for(16, (2, 2)) == "Int16" + + +def test_vrt_dtype_name_for_unsupported_raises(): + # bps=24 / sf=2 isn't in ``tiff_dtype_to_numpy``; should surface as + # a ValueError from the resolver rather than silently mapping to + # 'Byte' or 'Int32' the way the old ladder did. + with pytest.raises(ValueError): + _vrt_dtype_name_for(24, 2) + + +def test_np_to_vrt_dtype_table_covers_all_resolver_outputs(): + # Defensive: every numpy dtype that ``tiff_dtype_to_numpy`` can + # produce must have an entry in ``_NP_TO_VRT_DTYPE`` or else + # ``_vrt_dtype_name_for`` will raise the catch-all ValueError in + # normal use. This catches future additions to the resolver that + # forget to wire up a GDAL name. + from xrspatial.geotiff._dtypes import tiff_dtype_to_numpy + + pairs = [ + (8, 1), (8, 2), + (16, 1), (16, 2), + (32, 1), (32, 2), (32, 3), + (64, 1), (64, 2), (64, 3), + (1, 1), (2, 1), (4, 1), (12, 1), + ] + for bps, sf in pairs: + np_dtype = tiff_dtype_to_numpy(bps, sf) + assert np_dtype.type in _NP_TO_VRT_DTYPE, ( + f"resolver yields {np_dtype} for bps={bps}, sf={sf} but " + f"_NP_TO_VRT_DTYPE has no entry for it" + ) + + +# --------------------------------------------------------------------------- +# End-to-end write_vrt sanity: uint16 source produces UInt16 in the VRT +# --------------------------------------------------------------------------- + + +def _unique_dir(tmp_path, label: str) -> str: + d = tmp_path / f"vrt_1914_{label}_{uuid.uuid4().hex[:8]}" + d.mkdir() + return str(d) + + +def _write_uint16_tif(path: str, *, h: int = 4, w: int = 4, + origin_x: float = 0.0) -> None: + arr = np.arange(h * w, dtype=np.uint16).reshape(h, w) + y = 100.0 + (np.arange(h) + 0.5) * -1.0 + x = origin_x + (np.arange(w) + 0.5) * 1.0 + da = xr.DataArray( + arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}, + ) + to_geotiff(da, path, compression='none') + + +def test_uint16_source_writes_uint16_vrt_datatype(tmp_path): + # Pre-fix this would have produced ``dataType="UInt16"`` too, since + # bps=16/sf=1 happened to be in the old ladder. The point of this + # test is to assert the XML still says UInt16 after refactoring the + # dtype path through the central resolver -- i.e. we didn't + # accidentally regress the easy case while fixing the 12-bit one. + d = _unique_dir(tmp_path, "u16") + a = os.path.join(d, "a.tif") + b = os.path.join(d, "b.tif") + _write_uint16_tif(a) + _write_uint16_tif(b, origin_x=4.0) + vrt = os.path.join(d, "out.vrt") + write_vrt(vrt, [a, b]) + with open(vrt) as f: + xml = f.read() + assert 'dataType="UInt16"' in xml + assert 'dataType="Byte"' not in xml + + +def test_int16_source_writes_int16_vrt_datatype(tmp_path): + # The old ladder mapped sf=2, bps=16 to 'Int16' correctly; pin that + # behaviour so the new resolver path doesn't drift. + d = _unique_dir(tmp_path, "i16") + a = os.path.join(d, "a.tif") + arr = np.arange(16, dtype=np.int16).reshape(4, 4) + y = 100.0 + (np.arange(4) + 0.5) * -1.0 + x = (np.arange(4) + 0.5) * 1.0 + da = xr.DataArray( + arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}, + ) + to_geotiff(da, a, compression='none') + vrt = os.path.join(d, "out.vrt") + write_vrt(vrt, [a]) + with open(vrt) as f: + xml = f.read() + assert 'dataType="Int16"' in xml