Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/parcels/_core/utils/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def from_attrs(cls, d: dict[str, Hashable]) -> Self: ...

# Note that - for some optional attributes in the SGRID spec - these IDs are not available
# hence this isn't full coverage
_ID_FETCHERS_GRID2DMETADATA: dict[str, Callable[[Grid2DMetadata], Dim | Padding]] = {
_ID_FETCHERS_GRID2DMETADATA: dict[str, Callable[[SGrid2DMetadata], Dim | Padding]] = {
"node_dimension1": lambda meta: meta.node_dimensions[0],
"node_dimension2": lambda meta: meta.node_dimensions[1],
"face_dimension1": lambda meta: meta.face_dimensions[0].face,
Expand All @@ -65,7 +65,7 @@ def from_attrs(cls, d: dict[str, Hashable]) -> Self: ...
"type2": lambda meta: meta.face_dimensions[1].padding,
}

_ID_FETCHERS_GRID3DMETADATA: dict[str, Callable[[Grid3DMetadata], Dim | Padding]] = {
_ID_FETCHERS_GRID3DMETADATA: dict[str, Callable[[SGrid3DMetadata], Dim | Padding]] = {
"node_dimension1": lambda meta: meta.node_dimensions[0],
"node_dimension2": lambda meta: meta.node_dimensions[1],
"node_dimension3": lambda meta: meta.node_dimensions[2],
Expand All @@ -78,7 +78,7 @@ def from_attrs(cls, d: dict[str, Hashable]) -> Self: ...
}


class Grid2DMetadata(AttrsSerializable):
class SGrid2DMetadata(AttrsSerializable):
def __init__(
self,
cf_role: Literal["grid_topology"],
Expand Down Expand Up @@ -152,7 +152,7 @@ def __str__(self) -> str:
return _grid2d_to_ascii(self)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Grid2DMetadata):
if not isinstance(other, SGrid2DMetadata):
return NotImplemented
return self.to_attrs() == other.to_attrs()

Expand Down Expand Up @@ -200,7 +200,7 @@ def get_value_by_id(self, id: str) -> Dim | Padding:
return _ID_FETCHERS_GRID2DMETADATA[id](self)


class Grid3DMetadata(AttrsSerializable):
class SGrid3DMetadata(AttrsSerializable):
def __init__(
self,
cf_role: Literal["grid_topology"],
Expand Down Expand Up @@ -268,7 +268,7 @@ def __str__(self) -> str:
return _grid3d_to_ascii(self)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Grid3DMetadata):
if not isinstance(other, SGrid3DMetadata):
return NotImplemented
return self.to_attrs() == other.to_attrs()

Expand Down Expand Up @@ -431,14 +431,14 @@ class SGridParsingException(Exception):
pass


def parse_grid_attrs(attrs: dict[str, Hashable]) -> Grid2DMetadata | Grid3DMetadata:
grid: Grid2DMetadata | Grid3DMetadata
def parse_grid_attrs(attrs: dict[str, Hashable]) -> SGrid2DMetadata | SGrid3DMetadata:
grid: SGrid2DMetadata | SGrid3DMetadata
try:
grid = Grid2DMetadata.from_attrs(attrs)
grid = SGrid2DMetadata.from_attrs(attrs)
except Exception as e:
e.add_note("Failed to parse as 2D SGrid, trying 3D SGrid")
try:
grid = Grid3DMetadata.from_attrs(attrs)
grid = SGrid3DMetadata.from_attrs(attrs)
except Exception as e2:
e2.add_note("Failed to parse as 3D SGrid")
raise SGridParsingException("Failed to parse SGrid metadata as either 2D or 3D grid") from e2
Expand All @@ -464,10 +464,10 @@ def parse_sgrid(ds: xr.Dataset):
except Exception as e:
raise SGridParsingException(f"Error parsing {grid_topology=!r}") from e

if isinstance(grid, Grid2DMetadata):
if isinstance(grid, SGrid2DMetadata):
dimensions = grid.face_dimensions + (grid.vertical_dimensions or ())
else:
assert isinstance(grid, Grid3DMetadata)
assert isinstance(grid, SGrid3DMetadata)
dimensions = grid.volume_dimensions

xgcm_coords = {}
Expand Down Expand Up @@ -499,7 +499,7 @@ def rename(ds: xr.Dataset, name_dict: dict[str, str]) -> xr.Dataset:
return ds


def get_unique_names(grid: Grid2DMetadata | Grid3DMetadata) -> set[str]:
def get_unique_names(grid: SGrid2DMetadata | SGrid3DMetadata) -> set[str]:
dims = set()
dims.update(set(grid.node_dimensions))

Expand Down Expand Up @@ -635,7 +635,7 @@ def _face_node_padding_to_text(obj: FaceNodePadding) -> list[str]:
· = cell centre"""


def _grid2d_to_ascii(grid: Grid2DMetadata) -> str:
def _grid2d_to_ascii(grid: SGrid2DMetadata) -> str:
fd = grid.face_dimensions
nd = grid.node_dimensions
lines = [
Expand Down Expand Up @@ -667,7 +667,7 @@ def _grid2d_to_ascii(grid: Grid2DMetadata) -> str:
return "\n".join(lines)


def _grid3d_to_ascii(grid: Grid3DMetadata) -> str:
def _grid3d_to_ascii(grid: SGrid3DMetadata) -> str:
vd = grid.volume_dimensions
nd = grid.node_dimensions
lines = [
Expand All @@ -694,7 +694,7 @@ def _grid3d_to_ascii(grid: Grid3DMetadata) -> str:
return "\n".join(lines)


def _attach_sgrid_metadata(ds: xr.Dataset, grid: Grid2DMetadata | Grid3DMetadata):
def _attach_sgrid_metadata(ds: xr.Dataset, grid: SGrid2DMetadata | SGrid3DMetadata) -> xr.Dataset:
"""Copies the dataset and attaches the SGRID metadata in 'grid' variable. Modifies 'conventions' attribute."""
ds = ds.copy()
ds["grid"] = (
Expand All @@ -707,11 +707,11 @@ def _attach_sgrid_metadata(ds: xr.Dataset, grid: Grid2DMetadata | Grid3DMetadata


@overload
def _metadata_rename(grid: Grid2DMetadata, names_dict: dict[str, str]) -> Grid2DMetadata: ...
def _metadata_rename(grid: SGrid2DMetadata, names_dict: dict[str, str]) -> SGrid2DMetadata: ...


@overload
def _metadata_rename(grid: Grid3DMetadata, names_dict: dict[str, str]) -> Grid3DMetadata: ...
def _metadata_rename(grid: SGrid3DMetadata, names_dict: dict[str, str]) -> SGrid3DMetadata: ...


def _metadata_rename(grid, names_dict):
Expand Down
4 changes: 2 additions & 2 deletions src/parcels/_datasets/structured/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from parcels._core.utils.sgrid import (
FaceNodePadding,
Grid2DMetadata,
Padding,
SGrid2DMetadata,
_attach_sgrid_metadata,
)
from parcels._core.utils.time import timedelta_to_float
Expand All @@ -30,7 +30,7 @@ def simple_UV_dataset(dims=(360, 2, 30, 4), maxdepth=1, mesh="spherical"):
},
).pipe(
_attach_sgrid_metadata,
Grid2DMetadata(
SGrid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("XG", "YG"),
Expand Down
6 changes: 3 additions & 3 deletions src/parcels/_datasets/structured/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from parcels._core.utils.sgrid import (
FaceNodePadding,
Grid2DMetadata,
Padding,
SGrid2DMetadata,
_attach_sgrid_metadata,
)
from parcels._core.utils.sgrid import (
Expand Down Expand Up @@ -249,7 +249,7 @@ def _unrolled_cone_curvilinear_grid():
datasets["ds_2d_left"]
.pipe(
_attach_sgrid_metadata,
Grid2DMetadata(
SGrid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("XG", "YG"),
Expand All @@ -270,7 +270,7 @@ def _unrolled_cone_curvilinear_grid():
datasets["ds_2d_right"]
.pipe(
_attach_sgrid_metadata,
Grid2DMetadata(
SGrid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=("XG", "YG"),
Expand Down
93 changes: 93 additions & 0 deletions src/parcels/_datasets/structured/strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numpy as np
import xarray as xr
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays as np_arrays

import parcels._strategies as pst
from parcels._core.utils import sgrid
from parcels._core.utils.sgrid import _attach_sgrid_metadata


def _face_size(node_size: int, padding: sgrid.Padding) -> int:
if padding == sgrid.Padding.NONE:
return node_size - 1
elif padding in (sgrid.Padding.LOW, sgrid.Padding.HIGH):
return node_size
else: # Padding.BOTH
return node_size + 1


@st.composite
def sgrid_dataset(draw, grid: sgrid.SGrid2DMetadata | None = None) -> xr.Dataset:
"""Strategy to create Xarray Sgrid datasets for testing"""
if grid is None:
grid = draw(pst.sgrid.grid2Dmetadata(use_standard_names=True).filter(lambda g: g.node_coordinates is not None))
elif grid.node_coordinates is None:
raise ValueError("grid in Parcels must have node_coordinates set")
assert grid is not None
assert grid.node_coordinates is not None

N = draw(st.integers(min_value=5, max_value=100))
M = draw(st.integers(min_value=5, max_value=100))

node_dim1, node_dim2 = grid.node_dimensions
face_dim1 = grid.face_dimensions[0].face
face_dim2 = grid.face_dimensions[1].face
N_face = _face_size(N, grid.face_dimensions[0].padding)
M_face = _face_size(M, grid.face_dimensions[1].padding)

if has_vertical := grid.vertical_dimensions is not None:
P = draw(st.integers(min_value=5, max_value=20))
vert_node_dim = grid.vertical_dimensions[0].node
vert_face_dim = grid.vertical_dimensions[0].face
P_face = _face_size(P, grid.vertical_dimensions[0].padding)

has_curvilinear_grid = draw(st.booleans())
coord_name1, coord_name2 = grid.node_coordinates

if has_curvilinear_grid:
c1, c2 = np.meshgrid(np.linspace(0, 100, N), np.linspace(0, 100, M), indexing="ij")
coord1_dims = [node_dim1, node_dim2]
coord2_dims = [node_dim1, node_dim2]
else:
c1 = np.linspace(0, 100, N)
c2 = np.linspace(0, 100, M)
coord1_dims = [node_dim1]
coord2_dims = [node_dim2]

num_fields = draw(st.integers(min_value=1, max_value=4))
data_vars = {}

for i in range(num_fields):
dim1 = draw(st.sampled_from([node_dim1, face_dim1]))
size1 = N if dim1 == node_dim1 else N_face

dim2 = draw(st.sampled_from([node_dim2, face_dim2]))
size2 = M if dim2 == node_dim2 else M_face

shape: tuple[int, ...]
if has_vertical and draw(st.booleans()):
vert_dim = draw(st.sampled_from([vert_node_dim, vert_face_dim]))
vert_size = P if vert_dim == vert_node_dim else P_face
dims = [vert_dim, dim1, dim2]
shape = (vert_size, size1, size2)
else:
dims = [dim1, dim2]
shape = (size1, size2)

data = draw(
np_arrays(
dtype=np.float64,
shape=shape,
elements=st.floats(min_value=1e-3, max_value=100.0, allow_nan=False, allow_infinity=False),
)
)
data_vars[f"field_{i}"] = (dims, data)

coords = {
coord_name1: (coord1_dims, c1),
coord_name2: (coord2_dims, c2),
}

ds = xr.Dataset(data_vars=data_vars, coords=coords)
return _attach_sgrid_metadata(ds, grid)
13 changes: 13 additions & 0 deletions src/parcels/_strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# isort: skip_file

try:
import hypothesis # noqa: F401
except ImportError as err:
err.add_note(
"To use strategies you must have hypothesis installed. Install it from PyPI, Conda, or using your preffered package manager."
)
raise err
Comment on lines +3 to +9
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Wouldn't hypothesis simply be part of our Pixi install?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its part of our Pixi install, but its not a Parcels run dependency (i.e., its not in pixi.toml::run-dependencies or the recipe.yaml/pyproject.toml.

This is our first "optional dependency" for Parcels (i.e., a part of the codebase that needs a specific package in order to fulfill a function, but where it doesn't make sense to include it for everyone since most people wont use the specific function).

People doing conda install parcels then import parcels._strategies will encounter this more informative error message.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see. But then (in a next PR?) fix the type-o preffered to preferred?


from . import sgrid, time

__all__ = ["sgrid", "time"]
58 changes: 40 additions & 18 deletions tests/strategies/sgrid.py → src/parcels/_strategies/sgrid.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this is just good practice, but I'm surprised that the strategies are moved out of the tests directory. Would they not more logically belong there?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - good question.

Most of our dataset generation code we have shipping with Parcels via parcels._datasets. This means that users can have easy access to Parcels-compatible example datasets and dataset generation for posting issues etc. . When adding strategies for creating example datasets, I needed to use existing strategies that were in the tests module. Since I cant put code in parcels that depends on the tests module - the options I saw were:

  1. Leave the strategies where they are in tests and carve out a part in tests/... for housing these dataset strategies. This would mean that some datasets are used via parcels/_datasets/... and some via tests/... which I felt was confusing for devs
  2. Move the dataset generation to tests - foregoing the benefit of users having easy access to datasets
  3. Move the strategies to parcels

I thought (3) was the best. It might seem weird to be putting "test" code in the parcels release, but its actually not that weird - some projects even put their whole test suites in the release itself.

Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,34 @@
dimension_name = xr_st.names().filter(
lambda s: " " not in s
) # assuming for now spaces are allowed in dimension names in SGrid convention
dim_dim_padding = (
face_node_padding = (
st.tuples(dimension_name, dimension_name, padding)
.filter(lambda t: t[0] != t[1])
.map(lambda t: sgrid.FaceNodePadding(*t))
)

mappings = st.lists(dim_dim_padding | dimension_name).map(tuple)
mappings = st.lists(face_node_padding | dimension_name).map(tuple)


@st.composite
def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata:
N = 8
names = draw(
st.lists(dimension_name, min_size=N, max_size=N, unique=True)
# Reserved, as 'grid' name is used in Parcels testing to store grid information
.filter(lambda names: "grid" not in names)
)
def grid2Dmetadata(draw, use_standard_names=False) -> sgrid.SGrid2DMetadata:
names = [
"node_dimension1",
"node_dimension2",
"face_dimension1",
"face_dimension2",
"node_coordinates_var1",
"node_coordinates_var2",
"vertical_dimensions_face",
"vertical_dimensions_node",
]
if not use_standard_names:
names = draw(
st.lists(dimension_name, min_size=len(names), max_size=len(names), unique=True)
# Reserved, as 'grid' name is used in Parcels testing to store grid information
.filter(lambda names: "grid" not in names)
)

node_dimension1 = names[0]
node_dimension2 = names[1]
face_dimension1 = names[2]
Expand Down Expand Up @@ -62,7 +73,7 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata:
else:
vertical_dimensions = None

return sgrid.Grid2DMetadata(
return sgrid.SGrid2DMetadata(
cf_role="grid_topology",
topology_dimension=2,
node_dimensions=(node_dimension1, node_dimension2),
Expand All @@ -76,13 +87,24 @@ def grid2Dmetadata(draw) -> sgrid.Grid2DMetadata:


@st.composite
def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata:
N = 9
names = draw(
st.lists(dimension_name, min_size=N, max_size=N, unique=True)
# Reserved, as 'grid' name is used in Parcels testing to store grid information
.filter(lambda names: "grid" not in names)
)
def grid3Dmetadata(draw, use_standard_names=False) -> sgrid.SGrid3DMetadata:
names = [
"node_dimension1",
"node_dimension2",
"node_dimension3",
"face_dimension1",
"face_dimension2",
"face_dimension3",
"node_coordinates_var1",
"node_coordinates_var2",
"node_coordinates_dim3",
]
if not use_standard_names:
names = draw(
st.lists(dimension_name, min_size=len(names), max_size=len(names), unique=True)
# Reserved, as 'grid' name is used in Parcels testing to store grid information
.filter(lambda names: "grid" not in names)
)
node_dimension1 = names[0]
node_dimension2 = names[1]
node_dimension3 = names[2]
Expand All @@ -103,7 +125,7 @@ def grid3Dmetadata(draw) -> sgrid.Grid3DMetadata:
else:
node_coordinates = None

return sgrid.Grid3DMetadata(
return sgrid.SGrid3DMetadata(
cf_role="grid_topology",
topology_dimension=3,
node_dimensions=(node_dimension1, node_dimension2, node_dimension3),
Expand Down
File renamed without changes.
Loading
Loading