Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,12 @@ def open_geotiff(source: str, *, window=None,
# Adjust coordinates for windowed read
r0, c0, r1, c1 = window
t = geo_info.transform
full_x = np.arange(c0, c1, dtype=np.float64) * t.pixel_width + t.origin_x + t.pixel_width * 0.5
full_y = np.arange(r0, r1, dtype=np.float64) * t.pixel_height + t.origin_y + t.pixel_height * 0.5
if geo_info.raster_type == RASTER_PIXEL_IS_POINT:
full_x = np.arange(c0, c1, dtype=np.float64) * t.pixel_width + t.origin_x
full_y = np.arange(r0, r1, dtype=np.float64) * t.pixel_height + t.origin_y
else:
full_x = np.arange(c0, c1, dtype=np.float64) * t.pixel_width + t.origin_x + t.pixel_width * 0.5
full_y = np.arange(r0, r1, dtype=np.float64) * t.pixel_height + t.origin_y + t.pixel_height * 0.5
coords = {'y': full_y, 'x': full_x}

if name is None:
Expand Down Expand Up @@ -402,6 +406,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,

geo_transform = None
epsg = None
wkt_fallback = None # WKT string when EPSG is not available
raster_type = RASTER_PIXEL_IS_AREA
x_res = None
y_res = None
Expand All @@ -414,6 +419,8 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
epsg = crs
elif isinstance(crs, str):
epsg = _wkt_to_epsg(crs) # try to extract EPSG from WKT/PROJ
if epsg is None:
wkt_fallback = crs

if isinstance(data, xr.DataArray):
# Handle CuPy-backed DataArrays: convert to numpy for CPU write
Expand All @@ -436,12 +443,16 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
if isinstance(crs_attr, str):
# WKT string from reproject() or other source
epsg = _wkt_to_epsg(crs_attr)
if epsg is None and wkt_fallback is None:
wkt_fallback = crs_attr
elif crs_attr is not None:
epsg = int(crs_attr)
if epsg is None:
wkt = data.attrs.get('crs_wkt')
if isinstance(wkt, str):
epsg = _wkt_to_epsg(wkt)
if epsg is None and wkt_fallback is None:
wkt_fallback = wkt
if nodata is None:
nodata = data.attrs.get('nodata')
if data.attrs.get('raster_type') == 'point':
Expand Down Expand Up @@ -477,10 +488,19 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
elif arr.dtype == np.bool_:
arr = arr.astype(np.uint8)

# Restore NaN pixels to the nodata sentinel value so the written file
# has sentinel values matching the GDAL_NODATA tag.
if nodata is not None and arr.dtype.kind == 'f' and not np.isnan(nodata):
nan_mask = np.isnan(arr)
if nan_mask.any():
arr = arr.copy()
arr[nan_mask] = arr.dtype.type(nodata)

write(
arr, path,
geo_transform=geo_transform,
crs_epsg=epsg,
crs_wkt=wkt_fallback if epsg is None else None,
nodata=nodata,
compression=compression,
tiled=tiled,
Expand Down
44 changes: 43 additions & 1 deletion xrspatial/geotiff/_geotags.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,18 @@ def extract_geo_info(ifd: IFD, data: bytes | memoryview,
)


def _model_type_from_wkt(wkt: str) -> int:
"""Guess ModelType from a WKT string prefix."""
upper = wkt.strip().upper()
if upper.startswith(('GEOGCS', 'GEOGCRS')):
return MODEL_TYPE_GEOGRAPHIC
return MODEL_TYPE_PROJECTED


def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None,
nodata=None,
raster_type: int = RASTER_PIXEL_IS_AREA) -> dict[int, tuple]:
raster_type: int = RASTER_PIXEL_IS_AREA,
crs_wkt: str | None = None) -> dict[int, tuple]:
"""Build GeoTIFF IFD tag entries for writing.

Parameters
Expand All @@ -537,6 +546,11 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None,
NoData value.
raster_type : int
RASTER_PIXEL_IS_AREA (1) or RASTER_PIXEL_IS_POINT (2).
crs_wkt : str or None
WKT or PROJ string for the CRS. Used only when *crs_epsg* is
None so that custom (non-EPSG) coordinate systems survive
round-trips. Stored in the GeoAsciiParamsTag and referenced
from GTCitationGeoKey.

Returns
-------
Expand All @@ -562,6 +576,10 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None,
num_keys = 1 # at least RasterType
key_entries = []

# Collect ASCII params strings (pipe-delimited in GeoAsciiParamsTag)
ascii_parts = []
ascii_offset = 0

# ModelType
if crs_epsg is not None:
# Guess model type from EPSG (simple heuristic)
Expand All @@ -571,6 +589,10 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None,
model_type = MODEL_TYPE_PROJECTED
key_entries.append((GEOKEY_MODEL_TYPE, 0, 1, model_type))
num_keys += 1
elif crs_wkt is not None:
model_type = _model_type_from_wkt(crs_wkt)
key_entries.append((GEOKEY_MODEL_TYPE, 0, 1, model_type))
num_keys += 1

# RasterType
key_entries.append((GEOKEY_RASTER_TYPE, 0, 1, raster_type))
Expand All @@ -582,6 +604,22 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None,
else:
key_entries.append((GEOKEY_PROJECTED_CS_TYPE, 0, 1, crs_epsg))
num_keys += 1
elif crs_wkt is not None:
# User-defined CRS: store 32767 and write WKT to GeoAsciiParams
if model_type == MODEL_TYPE_GEOGRAPHIC:
key_entries.append((GEOKEY_GEOGRAPHIC_TYPE, 0, 1, 32767))
else:
key_entries.append((GEOKEY_PROJECTED_CS_TYPE, 0, 1, 32767))
num_keys += 1
# GTCitationGeoKey -> GeoAsciiParams
wkt_with_pipe = crs_wkt + '|'
key_entries.append((
GEOKEY_CITATION, TAG_GEO_ASCII_PARAMS,
len(wkt_with_pipe), ascii_offset,
))
ascii_parts.append(wkt_with_pipe)
ascii_offset += len(wkt_with_pipe)
num_keys += 1

num_keys = len(key_entries)
header = [1, 1, 0, num_keys]
Expand All @@ -591,6 +629,10 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None,

tags[TAG_GEO_KEY_DIRECTORY] = tuple(flat)

# GeoAsciiParamsTag (34737)
if ascii_parts:
tags[TAG_GEO_ASCII_PARAMS] = ''.join(ascii_parts)

# GDAL_NODATA
if nodata is not None:
tags[TAG_GDAL_NODATA] = str(nodata)
Expand Down
13 changes: 10 additions & 3 deletions xrspatial/geotiff/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ._geotags import (
GeoTransform,
build_geo_tags,
TAG_GEO_ASCII_PARAMS,
TAG_GEO_KEY_DIRECTORY,
TAG_GDAL_NODATA,
TAG_MODEL_PIXEL_SCALE,
Expand Down Expand Up @@ -525,6 +526,7 @@ def _assemble_tiff(width: int, height: int, dtype: np.dtype,
nodata,
is_cog: bool = False,
raster_type: int = 1,
crs_wkt: str | None = None,
gdal_metadata_xml: str | None = None,
extra_tags: list | None = None,
x_resolution: float | None = None,
Expand Down Expand Up @@ -557,12 +559,14 @@ def _assemble_tiff(width: int, height: int, dtype: np.dtype,
geo_tags_dict = {}
if geo_transform is not None:
geo_tags_dict = build_geo_tags(
geo_transform, crs_epsg, nodata, raster_type=raster_type)
geo_transform, crs_epsg, nodata, raster_type=raster_type,
crs_wkt=crs_wkt)
else:
# No spatial reference -- still write CRS and nodata if provided
if crs_epsg is not None or nodata is not None:
if crs_epsg is not None or crs_wkt is not None or nodata is not None:
geo_tags_dict = build_geo_tags(
GeoTransform(), crs_epsg, nodata, raster_type=raster_type,
crs_wkt=crs_wkt,
)
# Remove the default pixel scale / tiepoint tags since we
# have no real transform -- keep only GeoKeys and NODATA.
Expand Down Expand Up @@ -641,6 +645,8 @@ def _assemble_tiff(width: int, height: int, dtype: np.dtype,
tags.append((gtag, DOUBLE, 6, list(gval)))
elif gtag == TAG_GEO_KEY_DIRECTORY:
tags.append((gtag, SHORT, len(gval), list(gval)))
elif gtag == TAG_GEO_ASCII_PARAMS:
tags.append((gtag, ASCII, len(str(gval)) + 1, str(gval)))
elif gtag == TAG_GDAL_NODATA:
tags.append((gtag, ASCII, len(str(gval)) + 1, str(gval)))

Expand Down Expand Up @@ -846,6 +852,7 @@ def _assemble_cog_layout(header_size: int,
def write(data: np.ndarray, path: str, *,
geo_transform: GeoTransform | None = None,
crs_epsg: int | None = None,
crs_wkt: str | None = None,
nodata=None,
compression: str = 'zstd',
tiled: bool = True,
Expand Down Expand Up @@ -939,7 +946,7 @@ def write(data: np.ndarray, path: str, *,
file_bytes = _assemble_tiff(
w, h, data.dtype, comp_tag, predictor, tiled, tile_size,
parts, geo_transform, crs_epsg, nodata, is_cog=cog,
raster_type=raster_type,
raster_type=raster_type, crs_wkt=crs_wkt,
gdal_metadata_xml=gdal_metadata_xml,
extra_tags=extra_tags,
x_resolution=x_resolution, y_resolution=y_resolution,
Expand Down
Loading
Loading