Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions test/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,25 @@ def test_to_raster(gridpath):
mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc")
uxds = ux.open_dataset(mesh_path, mesh_path)

raster = uxds['bottomDepth'].to_raster(ax=ax)
with pytest.warns(UserWarning, match=r"Axes extent was default"):
raster = uxds['bottomDepth'].to_raster(ax=ax)

assert isinstance(raster, np.ndarray)


def test_to_raster_with_extra_dims(gridpath):
fig, ax = plt.subplots(
subplot_kw={'projection': ccrs.Robinson()},
constrained_layout=True,
figsize=(10, 5),
)

mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc")
uxds = ux.open_dataset(mesh_path, mesh_path)

da = uxds['bottomDepth'].expand_dims(time=[0])
with pytest.warns(UserWarning, match=r"Axes extent was default"):
raster = da.to_raster(ax=ax)

assert isinstance(raster, np.ndarray)

Expand All @@ -121,9 +139,10 @@ def test_to_raster_reuse_mapping(gridpath, tmpdir):
uxds = ux.open_dataset(mesh_path, mesh_path)

# Returning
raster1, pixel_mapping = uxds['bottomDepth'].to_raster(
ax=ax, pixel_ratio=0.5, return_pixel_mapping=True
)
with pytest.warns(UserWarning, match=r"Axes extent was default"):
raster1, pixel_mapping = uxds['bottomDepth'].to_raster(
ax=ax, pixel_ratio=0.5, return_pixel_mapping=True
)
assert isinstance(raster1, np.ndarray)
assert isinstance(pixel_mapping, xr.DataArray)

Expand Down
40 changes: 37 additions & 3 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,16 +375,16 @@ def to_raster(
_RasterAxAttrs,
)

_ensure_dimensions(self)
data = _ensure_dimensions(self)

if not isinstance(ax, GeoAxes):
raise TypeError("`ax` must be an instance of cartopy.mpl.geoaxes.GeoAxes")

pixel_ratio_set = pixel_ratio is not None
if not pixel_ratio_set:
pixel_ratio = 1.0
input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio)
if pixel_mapping is not None:
input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio)
if isinstance(pixel_mapping, xr.DataArray):
pixel_ratio_input = pixel_ratio
pixel_ratio = pixel_mapping.attrs["pixel_ratio"]
Expand All @@ -403,9 +403,43 @@ def to_raster(
+ input_ax_attrs._value_comparison_message(pm_ax_attrs)
)
pixel_mapping = np.asarray(pixel_mapping, dtype=INT_DTYPE)
else:

def _is_default_extent() -> bool:
# Default extents are indicated by xlim/ylim being (0, 1)
# when autoscale is still on (no extent has been explicitly set)
if not ax.get_autoscale_on():
return False
xlim, ylim = ax.get_xlim(), ax.get_ylim()
return np.allclose(xlim, (0.0, 1.0)) and np.allclose(ylim, (0.0, 1.0))

if _is_default_extent():
try:
import cartopy.crs as ccrs

lon_min = float(self.uxgrid.node_lon.min(skipna=True).values)
lon_max = float(self.uxgrid.node_lon.max(skipna=True).values)
lat_min = float(self.uxgrid.node_lat.min(skipna=True).values)
lat_max = float(self.uxgrid.node_lat.max(skipna=True).values)
ax.set_extent(
(lon_min, lon_max, lat_min, lat_max),
crs=ccrs.PlateCarree(),
)
warn(
"Axes extent was default; auto-setting from grid lon/lat bounds for rasterization. "
"Set the extent explicitly to control this, e.g. via ax.set_global(), "
"ax.set_extent(...), or ax.set_xlim(...) + ax.set_ylim(...).",
stacklevel=2,
)
except Exception as e:
warn(
f"Failed to auto-set extent from grid bounds: {e}",
stacklevel=2,
)
input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio)

raster, pixel_mapping_np = _nearest_neighbor_resample(
self,
data,
ax,
pixel_ratio=pixel_ratio,
pixel_mapping=pixel_mapping,
Expand Down
18 changes: 11 additions & 7 deletions uxarray/plot/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@ def _ensure_dimensions(data: UxDataArray) -> UxDataArray:
ValueError
If the sole dimension is not named "n_face".
"""
# Check dimensionality
if data.ndim != 1:
# Allow extra singleton dimensions as long as there's exactly one non-singleton dim
non_trivial_dims = [dim for dim, size in zip(data.dims, data.shape) if size != 1]

if len(non_trivial_dims) != 1:
raise ValueError(
f"Expected a 1D DataArray over 'n_face', but got {data.ndim} dimensions: {data.dims}"
"Expected data with a single dimension (other axes may be length 1), "
f"but got dims {data.dims} with shape {data.shape}"
)

# Check dimension name
if data.dims[0] != "n_face":
raise ValueError(f"Expected dimension 'n_face', but got '{data.dims[0]}'")
sole_dim = non_trivial_dims[0]
if sole_dim != "n_face":
raise ValueError(f"Expected dimension 'n_face', but got '{sole_dim}'")

return data
# Squeeze any singleton axes to ensure we return a true 1D array over n_face
return data.squeeze()


class _RasterAxAttrs(NamedTuple):
Expand Down
1 change: 0 additions & 1 deletion uxarray/utils/computing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def dot_fma(v1, v2):
----------
S. Graillat, Ph. Langlois, and N. Louvet. "Accurate dot products with FMA." Presented at RNC 7, 2007, Nancy, France.
DALI-LP2A Laboratory, University of Perpignan, France.
[Poster](https://www-pequan.lip6.fr/~graillat/papers/posterRNC7.pdf)
"""
if len(v1) != len(v2):
raise ValueError("Input vectors must be of the same length")
Expand Down
Loading