Skip to content

Commit 8117303

Browse files
authored
Lightweight CRS parser to remove hard pyproj dependency (#1072)
* Add design spec for fused_overlap and multi_overlap utilities * Add implementation plan for fused_overlap and multi_overlap * Add lightweight CRS class with embedded EPSG table (#1057) * Fix ValueError message to include pyproj install hint (#1057) * Two-tier CRS resolution: lite CRS first, pyproj fallback (#1057) * Add scatter-point transform_points() for boundary estimation (#1057) * Use Numba scatter-point transform for grid boundary estimation (#1057) * Wire chunk functions and merge helper to use lite CRS (#1057) * Add integration tests for pyproj-free CRS resolution (#1057)
1 parent 142fa1c commit 8117303

File tree

7 files changed

+1165
-70
lines changed

7 files changed

+1165
-70
lines changed

xrspatial/reproject/__init__.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,10 @@ def _reproject_chunk_numpy(
192192
Called inside ``dask.delayed`` for the dask path, or directly for numpy.
193193
CRS objects are passed as WKT strings for pickle safety.
194194
"""
195-
from ._crs_utils import _require_pyproj
195+
from ._crs_utils import _crs_from_wkt
196196

197-
pyproj = _require_pyproj()
198-
src_crs = pyproj.CRS.from_wkt(src_wkt)
199-
tgt_crs = pyproj.CRS.from_wkt(tgt_wkt)
197+
src_crs = _crs_from_wkt(src_wkt)
198+
tgt_crs = _crs_from_wkt(tgt_wkt)
200199

201200
# Try Numba fast path first (avoids creating pyproj Transformer)
202201
numba_result = None
@@ -212,6 +211,8 @@ def _reproject_chunk_numpy(
212211
src_y, src_x = numba_result
213212
else:
214213
# Fallback: create pyproj Transformer (expensive)
214+
from ._crs_utils import _require_pyproj
215+
pyproj = _require_pyproj()
215216
transformer = pyproj.Transformer.from_crs(
216217
tgt_crs, src_crs, always_xy=True
217218
)
@@ -321,15 +322,10 @@ def _reproject_chunk_cupy(
321322
"""CuPy variant of ``_reproject_chunk_numpy``."""
322323
import cupy as cp
323324

324-
from ._crs_utils import _require_pyproj
325+
from ._crs_utils import _crs_from_wkt
325326

326-
pyproj = _require_pyproj()
327-
src_crs = pyproj.CRS.from_wkt(src_wkt)
328-
tgt_crs = pyproj.CRS.from_wkt(tgt_wkt)
329-
330-
transformer = pyproj.Transformer.from_crs(
331-
tgt_crs, src_crs, always_xy=True
332-
)
327+
src_crs = _crs_from_wkt(src_wkt)
328+
tgt_crs = _crs_from_wkt(tgt_wkt)
333329

334330
# Try CUDA transform first (keeps coordinates on-device)
335331
cuda_result = None
@@ -371,6 +367,11 @@ def _reproject_chunk_cupy(
371367
_use_native_cuda = True
372368
else:
373369
# CPU fallback (Numba JIT or pyproj)
370+
from ._crs_utils import _require_pyproj
371+
pyproj = _require_pyproj()
372+
transformer = pyproj.Transformer.from_crs(
373+
tgt_crs, src_crs, always_xy=True
374+
)
374375
src_y, src_x = _transform_coords(
375376
transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
376377
src_crs=src_crs, tgt_crs=tgt_crs,
@@ -513,16 +514,13 @@ def reproject(
513514
If vertical transformation was applied, ``attrs['vertical_crs']``
514515
records the target vertical datum.
515516
"""
516-
from ._crs_utils import _require_pyproj
517-
518517
if not isinstance(raster, xr.DataArray):
519518
raise TypeError(
520519
f"reproject(): raster must be an xr.DataArray, "
521520
f"got {type(raster).__name__}"
522521
)
523522

524523
_validate_resampling(resampling)
525-
_require_pyproj()
526524

527525
# Resolve CRS
528526
src_crs = _resolve_crs(source_crs)
@@ -984,11 +982,10 @@ def _reproject_dask_cupy(
984982
"""
985983
import cupy as cp
986984

987-
from ._crs_utils import _require_pyproj
985+
from ._crs_utils import _crs_from_wkt
988986

989-
pyproj = _require_pyproj()
990-
src_crs = pyproj.CRS.from_wkt(src_wkt)
991-
tgt_crs = pyproj.CRS.from_wkt(tgt_wkt)
987+
src_crs = _crs_from_wkt(src_wkt)
988+
tgt_crs = _crs_from_wkt(tgt_wkt)
992989

993990
# Use larger chunks for GPU to amortize kernel launch overhead
994991
gpu_chunk = chunk_size or 2048
@@ -1048,6 +1045,8 @@ def _reproject_dask_cupy(
10481045
c_max = int(np.ceil(c_max_val)) + 3
10491046
else:
10501047
# CPU fallback for this chunk
1048+
from ._crs_utils import _require_pyproj
1049+
pyproj = _require_pyproj()
10511050
transformer = pyproj.Transformer.from_crs(
10521051
tgt_crs, src_crs, always_xy=True
10531052
)
@@ -1120,30 +1119,44 @@ def _reproject_dask_cupy(
11201119

11211120

11221121
def _source_footprint_in_target(src_bounds, src_wkt, tgt_wkt):
1123-
"""Compute an approximate bounding box of the source raster in target CRS.
1124-
1125-
Transforms corners and edge midpoints (12 points) to handle non-linear
1126-
projections. Returns ``(left, bottom, right, top)`` in target CRS, or
1127-
*None* if the transform fails (e.g. out-of-domain).
1128-
"""
1122+
"""Compute approximate bounding box of source raster in target CRS."""
11291123
try:
1130-
from ._crs_utils import _require_pyproj
1131-
pyproj = _require_pyproj()
1132-
src_crs = pyproj.CRS(src_wkt)
1133-
tgt_crs = pyproj.CRS(tgt_wkt)
1134-
transformer = pyproj.Transformer.from_crs(
1135-
src_crs, tgt_crs, always_xy=True
1136-
)
1124+
from ._crs_utils import _crs_from_wkt, _resolve_crs
1125+
try:
1126+
src_crs = _crs_from_wkt(src_wkt)
1127+
except Exception:
1128+
src_crs = _resolve_crs(src_wkt)
1129+
try:
1130+
tgt_crs = _crs_from_wkt(tgt_wkt)
1131+
except Exception:
1132+
tgt_crs = _resolve_crs(tgt_wkt)
11371133
except Exception:
11381134
return None
11391135

11401136
sl, sb, sr, st = src_bounds
11411137
mx = (sl + sr) / 2
11421138
my = (sb + st) / 2
1143-
xs = [sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx]
1144-
ys = [sb, sb, sb, my, my, my, st, st, st, mx, mx, sb]
1139+
xs = np.array([sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx])
1140+
ys = np.array([sb, sb, sb, my, my, my, st, st, st, mx, mx, sb])
1141+
11451142
try:
1146-
tx, ty = transformer.transform(xs, ys)
1143+
from ._projections import transform_points
1144+
result = transform_points(src_crs, tgt_crs, xs, ys)
1145+
if result is not None:
1146+
tx, ty = result
1147+
tx = [v for v in tx if np.isfinite(v)]
1148+
ty = [v for v in ty if np.isfinite(v)]
1149+
if not tx or not ty:
1150+
return None
1151+
return (min(tx), min(ty), max(tx), max(ty))
1152+
except (ImportError, ModuleNotFoundError):
1153+
pass
1154+
1155+
try:
1156+
from ._crs_utils import _require_pyproj
1157+
pyproj = _require_pyproj()
1158+
transformer = pyproj.Transformer.from_crs(src_crs, tgt_crs, always_xy=True)
1159+
tx, ty = transformer.transform(xs.tolist(), ys.tolist())
11471160
tx = [v for v in tx if np.isfinite(v)]
11481161
ty = [v for v in ty if np.isfinite(v)]
11491162
if not tx or not ty:
@@ -1298,14 +1311,11 @@ def merge(
12981311
-------
12991312
xr.DataArray
13001313
"""
1301-
from ._crs_utils import _require_pyproj
1302-
13031314
if not rasters:
13041315
raise ValueError("merge(): rasters list must not be empty")
13051316

13061317
_validate_resampling(resampling)
13071318
_validate_strategy(strategy)
1308-
pyproj = _require_pyproj()
13091319

13101320
# Resolve target CRS
13111321
tgt_crs = _resolve_crs(target_crs)
@@ -1485,9 +1495,8 @@ def _merge_inmemory(
14851495
Detects same-CRS tiles and uses fast direct placement instead
14861496
of reprojection.
14871497
"""
1488-
from ._crs_utils import _require_pyproj
1489-
pyproj = _require_pyproj()
1490-
tgt_crs = pyproj.CRS.from_wkt(tgt_wkt)
1498+
from ._crs_utils import _crs_from_wkt
1499+
tgt_crs = _crs_from_wkt(tgt_wkt)
14911500

14921501
arrays = []
14931502
for info in raster_infos:

xrspatial/reproject/_crs_utils.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,86 @@
1-
"""CRS detection utilities and optional pyproj import guard."""
1+
"""CRS detection utilities and optional pyproj import guard.
2+
3+
Uses a two-tier strategy: try the lightweight built-in CRS first,
4+
then fall back to pyproj for codes/formats not in the built-in table.
5+
"""
26
from __future__ import annotations
37

8+
from xrspatial.reproject._lite_crs import CRS as LiteCRS
49

5-
def _require_pyproj():
6-
"""Import and return the pyproj module, raising a clear error if missing."""
10+
11+
def _try_import_pyproj():
12+
"""Try to import pyproj, returning the module or None."""
713
try:
814
import pyproj
915
return pyproj
1016
except ImportError:
17+
return None
18+
19+
20+
def _require_pyproj():
21+
"""Import and return the pyproj module, raising a clear error if missing."""
22+
pyproj = _try_import_pyproj()
23+
if pyproj is None:
1124
raise ImportError(
1225
"pyproj is required for CRS reprojection. "
1326
"Install it with: pip install pyproj "
1427
"or: pip install xarray-spatial[reproject]"
1528
)
29+
return pyproj
1630

1731

1832
def _resolve_crs(crs_input):
19-
"""Convert *crs_input* to a ``pyproj.CRS`` object.
20-
21-
Accepts anything ``pyproj.CRS()`` accepts: EPSG int, authority string,
22-
WKT, proj4 dict, or an existing ``pyproj.CRS`` instance.
23-
24-
Returns None if *crs_input* is None.
33+
"""Convert *crs_input* to a CRS object.
34+
35+
Resolution order:
36+
37+
1. ``None`` passes through as ``None``.
38+
2. An existing ``LiteCRS`` instance passes through unchanged.
39+
3. An existing ``pyproj.CRS`` instance passes through unchanged
40+
(only checked when pyproj is importable).
41+
4. Try ``LiteCRS(crs_input)`` -- covers EPSG ints and ``"EPSG:XXXX"``
42+
strings for codes in the built-in table.
43+
5. Fall back to ``pyproj.CRS(crs_input)`` -- raises ``ImportError``
44+
if pyproj is not installed.
2545
"""
2646
if crs_input is None:
2747
return None
28-
pyproj = _require_pyproj()
29-
if isinstance(crs_input, pyproj.CRS):
48+
49+
# Pass through existing LiteCRS
50+
if isinstance(crs_input, LiteCRS):
51+
return crs_input
52+
53+
# Pass through existing pyproj.CRS (if pyproj available)
54+
pyproj = _try_import_pyproj()
55+
if pyproj is not None and isinstance(crs_input, pyproj.CRS):
3056
return crs_input
57+
58+
# Try lite CRS first
59+
try:
60+
return LiteCRS(crs_input)
61+
except (ValueError, TypeError):
62+
pass
63+
64+
# Fall back to pyproj
65+
pyproj = _require_pyproj()
3166
return pyproj.CRS(crs_input)
3267

3368

69+
def _crs_from_wkt(wkt):
70+
"""Build a CRS from an OGC WKT string.
71+
72+
Tries ``LiteCRS.from_wkt`` first (extracts the AUTHORITY tag),
73+
then falls back to ``pyproj.CRS.from_wkt``.
74+
"""
75+
try:
76+
return LiteCRS.from_wkt(wkt)
77+
except (ValueError, TypeError):
78+
pass
79+
80+
pyproj = _require_pyproj()
81+
return pyproj.CRS.from_wkt(wkt)
82+
83+
3484
def _detect_source_crs(raster):
3585
"""Auto-detect the CRS of a DataArray.
3686
@@ -47,7 +97,7 @@ def _detect_source_crs(raster):
4797

4898
crs_wkt = raster.attrs.get('crs_wkt')
4999
if crs_wkt is not None:
50-
return _resolve_crs(crs_wkt)
100+
return _crs_from_wkt(crs_wkt)
51101

52102
# rioxarray fallback
53103
try:

xrspatial/reproject/_grid.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,38 @@
44
import numpy as np
55

66

7+
def _transform_boundary(source_crs, target_crs, xs, ys):
8+
"""Transform coordinate arrays, preferring Numba fast path over pyproj.
9+
10+
Parameters
11+
----------
12+
source_crs, target_crs : CRS-like
13+
Source and target coordinate reference systems.
14+
xs, ys : ndarray
15+
1-D arrays of x and y coordinates in *source_crs*.
16+
17+
Returns
18+
-------
19+
tx, ty : ndarray
20+
Transformed coordinates as numpy arrays.
21+
"""
22+
from ._projections import transform_points
23+
24+
result = transform_points(source_crs, target_crs, xs, ys)
25+
if result is not None:
26+
return result
27+
28+
# Fall back to pyproj
29+
from ._crs_utils import _require_pyproj
30+
31+
pyproj = _require_pyproj()
32+
transformer = pyproj.Transformer.from_crs(
33+
source_crs, target_crs, always_xy=True
34+
)
35+
tx, ty = transformer.transform(xs, ys)
36+
return np.asarray(tx), np.asarray(ty)
37+
38+
739
def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
840
resolution=None, bounds=None, width=None, height=None):
941
"""Compute the output raster grid parameters.
@@ -14,7 +46,7 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
1446
(left, bottom, right, top) in source CRS.
1547
source_shape : tuple
1648
(height, width) of source raster.
17-
source_crs, target_crs : pyproj.CRS
49+
source_crs, target_crs : CRS-like
1850
Source and target coordinate reference systems.
1951
resolution : float or tuple or None
2052
Target resolution. If tuple, (x_res, y_res).
@@ -27,13 +59,6 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
2759
-------
2860
dict with keys: bounds, shape, res_x, res_y
2961
"""
30-
from ._crs_utils import _require_pyproj
31-
32-
pyproj = _require_pyproj()
33-
transformer = pyproj.Transformer.from_crs(
34-
source_crs, target_crs, always_xy=True
35-
)
36-
3762
if bounds is None:
3863
# Transform source corners and edges to target CRS
3964
src_left, src_bottom, src_right, src_top = source_bounds
@@ -76,7 +101,7 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
76101
ixx, iyy = np.meshgrid(ix, iy)
77102
xs = np.concatenate([edge_xs, ixx.ravel()])
78103
ys = np.concatenate([edge_ys, iyy.ravel()])
79-
tx, ty = transformer.transform(xs, ys)
104+
tx, ty = _transform_boundary(source_crs, target_crs, xs, ys)
80105
tx = np.asarray(tx)
81106
ty = np.asarray(ty)
82107
# Filter out inf/nan from failed transforms
@@ -110,7 +135,9 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
110135
ix = np.linspace(src_left, src_right, n_dense)
111136
iy = np.linspace(src_bottom, src_top, n_dense)
112137
ixx, iyy = np.meshgrid(ix, iy)
113-
itx, ity = transformer.transform(ixx.ravel(), iyy.ravel())
138+
itx, ity = _transform_boundary(
139+
source_crs, target_crs, ixx.ravel(), iyy.ravel()
140+
)
114141
itx = np.asarray(itx)
115142
ity = np.asarray(ity)
116143
ivalid = np.isfinite(itx) & np.isfinite(ity)
@@ -150,13 +177,15 @@ def _compute_output_grid(source_bounds, source_shape, source_crs, target_crs,
150177
src_res_y = (src_top - src_bottom) / src_h
151178
center_x = (src_left + src_right) / 2
152179
center_y = (src_bottom + src_top) / 2
153-
tc_x, tc_y = transformer.transform(center_x, center_y)
154-
# Step along x only
155-
tx_x, tx_y = transformer.transform(center_x + src_res_x, center_y)
156-
dx = np.hypot(float(tx_x) - float(tc_x), float(tx_y) - float(tc_y))
157-
# Step along y only
158-
ty_x, ty_y = transformer.transform(center_x, center_y + src_res_y)
159-
dy = np.hypot(float(ty_x) - float(tc_x), float(ty_y) - float(tc_y))
180+
# Batch the three resolution-estimation points into one call
181+
pts_x = np.array([center_x, center_x + src_res_x, center_x])
182+
pts_y = np.array([center_y, center_y, center_y + src_res_y])
183+
tp_x, tp_y = _transform_boundary(source_crs, target_crs, pts_x, pts_y)
184+
tc_x, tc_y = float(tp_x[0]), float(tp_y[0])
185+
tx_x, tx_y = float(tp_x[1]), float(tp_y[1])
186+
ty_x, ty_y = float(tp_x[2]), float(tp_y[2])
187+
dx = np.hypot(tx_x - tc_x, tx_y - tc_y)
188+
dy = np.hypot(ty_x - tc_x, ty_y - tc_y)
160189
if dx == 0 or dy == 0:
161190
res_x = (right - left) / src_w
162191
res_y = (top - bottom) / src_h

0 commit comments

Comments
 (0)