diff --git a/src/modelskill/comparison/_comparison.py b/src/modelskill/comparison/_comparison.py index f8fedef75..82a00b0b2 100644 --- a/src/modelskill/comparison/_comparison.py +++ b/src/modelskill/comparison/_comparison.py @@ -23,7 +23,7 @@ from .. import metrics as mtr from .. import Quantity -from ..types import GeometryType +from ..types import GeometryType, VariableKind from ..obs import PointObservation, TrackObservation from ..model import PointModelResult, TrackModelResult from ..timeseries._timeseries import _validate_data_var_name @@ -79,8 +79,13 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset: assert ( list(data[v].dims)[0] == "time" ), f"All data arrays must have a time dimension; {v} has dimensions {data[v].dims}" - if "kind" not in data[v].attrs: - data[v].attrs["kind"] = "auxiliary" + # Normalize kind attribute (Postel's Law: be liberal in what you accept) + match data[v].attrs.get("kind"): + case None | "auxiliary": + data[v].attrs["kind"] = VariableKind.AUXILIARY.value + case k if k not in (VariableKind.OBSERVATION.value, VariableKind.MODEL.value, VariableKind.AUXILIARY.value): + valid = [e.value for e in VariableKind] + raise ValueError(f"Invalid kind '{k}' for variable '{v}'. Must be one of {valid}") n_mod = sum([_is_model(da) for da in data.data_vars.values()]) n_obs = sum([_is_observation(da) for da in data.data_vars.values()]) @@ -95,14 +100,12 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset: "dataset must have at least one model array (marked by the kind attribute)" ) - # Validate attrs + # Validate gtype attribute if "gtype" not in data.attrs: data.attrs["gtype"] = str(GeometryType.POINT) - # assert "gtype" in data.attrs, "data must have a gtype attribute" - # assert data.attrs["gtype"] in [ - # str(GeometryType.POINT), - # str(GeometryType.TRACK), - # ], f"data attribute 'gtype' must be one of {GeometryType.POINT} or {GeometryType.TRACK}" + elif data.attrs["gtype"] not in (str(GeometryType.POINT), str(GeometryType.TRACK)): + valid = [str(GeometryType.POINT), str(GeometryType.TRACK)] + raise ValueError(f"Invalid gtype '{data.attrs['gtype']}'. Must be one of {valid}") if "color" not in data["Observation"].attrs: data["Observation"].attrs["color"] = "black" @@ -121,11 +124,11 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset: def _is_observation(da: xr.DataArray) -> bool: - return str(da.attrs["kind"]) == "observation" + return str(da.attrs["kind"]) == VariableKind.OBSERVATION.value def _is_model(da: xr.DataArray) -> bool: - return str(da.attrs["kind"]) == "model" + return str(da.attrs["kind"]) == VariableKind.MODEL.value def _validate_metrics(metrics: Iterable[Any]) -> None: @@ -350,11 +353,11 @@ def _matched_data_to_xarray( assert isinstance(ds, xr.Dataset) ds.attrs["name"] = name if name is not None else items.obs - ds["Observation"].attrs["kind"] = "observation" + ds["Observation"].attrs["kind"] = VariableKind.OBSERVATION.value for m in items.model: - ds[m].attrs["kind"] = "model" + ds[m].attrs["kind"] = VariableKind.MODEL.value for a in items.aux: - ds[a].attrs["kind"] = "auxiliary" + ds[a].attrs["kind"] = VariableKind.AUXILIARY.value if x_item is not None: ds = ds.rename({items.x: "x"}).set_coords("x") @@ -454,7 +457,7 @@ def __init__( # key: ModelResult(value, gtype=self.data.gtype, name=key, x=self.x, y=self.y) str(key): PointModelResult(self.data[[str(key)]], name=str(key)) for key, value in matched_data.data_vars.items() - if value.attrs["kind"] == "model" + if value.attrs["kind"] == VariableKind.MODEL.value } ) @@ -602,11 +605,11 @@ def __contains__(self, key: str) -> bool: @property def aux_names(self) -> List[str]: """List of auxiliary data names""" - # we don't require the kind attribute to be "auxiliary" + # Postel's Law: normalize at the boundary, be strict internally return [ str(k) for k, v in self.data.data_vars.items() - if v.attrs["kind"] not in ["observation", "model"] + if v.attrs["kind"] == VariableKind.AUXILIARY.value ] # TODO: always "Observation", necessary to have this property? diff --git a/src/modelskill/matching.py b/src/modelskill/matching.py index 763c9eea1..35f7243d8 100644 --- a/src/modelskill/matching.py +++ b/src/modelskill/matching.py @@ -35,7 +35,7 @@ observation, ) from .timeseries import TimeSeries -from .types import Period +from .types import Period, VariableKind TimeDeltaTypes = Union[float, int, np.timedelta64, pd.Timedelta, timedelta] IdxOrNameTypes = Optional[Union[int, str]] @@ -410,8 +410,8 @@ def _match_space_time( f"Matching not implemented for model type {type(mr)} and observation type {type(observation)}" ) - if overlapping := set(aligned.filter_by_attrs(kind="aux").data_vars) & set( - observation.data.filter_by_attrs(kind="aux").data_vars + if overlapping := set(aligned.filter_by_attrs(kind=VariableKind.AUXILIARY.value).data_vars) & set( + observation.data.filter_by_attrs(kind=VariableKind.AUXILIARY.value).data_vars ): raise ValueError( f"Aux variables are not allowed to have identical names. Choose either aux from obs or model. Overlapping: {overlapping}" @@ -422,7 +422,7 @@ def _match_space_time( # drop NaNs in model and observation columns (but allow NaNs in aux columns) def mo_kind(k: str) -> bool: - return k in ["model", "observation"] + return k in [VariableKind.MODEL.value, VariableKind.OBSERVATION.value] # TODO mo_cols vs non_aux_cols? mo_cols = data.filter_by_attrs(kind=mo_kind).data_vars diff --git a/src/modelskill/model/point.py b/src/modelskill/model/point.py index 5a11b1a30..c375a5f72 100644 --- a/src/modelskill/model/point.py +++ b/src/modelskill/model/point.py @@ -6,7 +6,7 @@ import pandas as pd from ..obs import Observation -from ..types import PointType +from ..types import PointType, VariableKind from ..quantity import Quantity from ..timeseries import TimeSeries, _parse_xyz_point_input @@ -67,7 +67,7 @@ def __init__( assert isinstance(data, xr.Dataset) data_var = str(list(data.data_vars)[0]) - data[data_var].attrs["kind"] = "model" + data[data_var].attrs["kind"] = VariableKind.MODEL.value super().__init__(data=data) def interp_time(self, observation: Observation, **kwargs: Any) -> PointModelResult: diff --git a/src/modelskill/model/track.py b/src/modelskill/model/track.py index cb37ca8e7..e193791b4 100644 --- a/src/modelskill/model/track.py +++ b/src/modelskill/model/track.py @@ -5,7 +5,7 @@ import numpy as np import xarray as xr -from ..types import TrackType +from ..types import TrackType, VariableKind from ..obs import TrackObservation from ..quantity import Quantity from ..timeseries import TimeSeries, _parse_track_input @@ -68,7 +68,7 @@ def __init__( assert isinstance(data, xr.Dataset) data_var = str(list(data.data_vars)[0]) - data[data_var].attrs["kind"] = "model" + data[data_var].attrs["kind"] = VariableKind.MODEL.value super().__init__(data=data) def subset_to( diff --git a/src/modelskill/obs.py b/src/modelskill/obs.py index 5828b2b03..5800f3aa7 100644 --- a/src/modelskill/obs.py +++ b/src/modelskill/obs.py @@ -16,7 +16,7 @@ import pandas as pd import xarray as xr -from .types import PointType, TrackType, GeometryType, DataInputType +from .types import PointType, TrackType, GeometryType, DataInputType, VariableKind from . import Quantity from .timeseries import ( TimeSeries, @@ -114,7 +114,7 @@ def __init__( assert isinstance(data, xr.Dataset) data_var = str(list(data.data_vars)[0]) - data[data_var].attrs["kind"] = "observation" + data[data_var].attrs["kind"] = VariableKind.OBSERVATION.value # check that user-defined attrs don't overwrite existing attrs! _validate_attrs(data.attrs, attrs) diff --git a/src/modelskill/timeseries/_point.py b/src/modelskill/timeseries/_point.py index 5e2a0c86a..a196530bb 100644 --- a/src/modelskill/timeseries/_point.py +++ b/src/modelskill/timeseries/_point.py @@ -9,7 +9,7 @@ import mikeio -from ..types import GeometryType, PointType +from ..types import GeometryType, PointType, VariableKind from ..quantity import Quantity from ..utils import _get_name from ._timeseries import _validate_data_var_name @@ -172,7 +172,7 @@ def _include_attributes( ds[name].attrs["is_directional"] = int(quantity.is_directional) for aux_item in sel_items.aux: - ds[aux_item].attrs["kind"] = "aux" + ds[aux_item].attrs["kind"] = VariableKind.AUXILIARY.value return ds diff --git a/src/modelskill/timeseries/_timeseries.py b/src/modelskill/timeseries/_timeseries.py index 52ff26ca4..e449630d2 100644 --- a/src/modelskill/timeseries/_timeseries.py +++ b/src/modelskill/timeseries/_timeseries.py @@ -7,7 +7,7 @@ import pandas as pd import xarray as xr -from ..types import GeometryType +from ..types import GeometryType, VariableKind from ..quantity import Quantity from ._plotter import TimeSeriesPlotter, MatplotlibTimeSeriesPlotter from .. import __version__ @@ -86,8 +86,8 @@ def _validate_dataset(ds: xr.Dataset) -> xr.Dataset: list(ds[v].dims)[0] == "time" ), f"All data arrays must have a time dimension; {v} has dimensions {ds[v].dims}" if "kind" not in ds[v].attrs: - ds[v].attrs["kind"] = "auxiliary" - if ds[v].attrs["kind"] in ["model", "observation"]: + ds[v].attrs["kind"] = VariableKind.AUXILIARY.value + if ds[v].attrs["kind"] in [VariableKind.MODEL.value, VariableKind.OBSERVATION.value]: n_primary += 1 name = v @@ -152,8 +152,8 @@ def _val_item(self) -> str: return [ str(v) for v in self.data.data_vars - if self.data[v].attrs["kind"] == "model" - or self.data[v].attrs["kind"] == "observation" + if self.data[v].attrs["kind"] == VariableKind.MODEL.value + or self.data[v].attrs["kind"] == VariableKind.OBSERVATION.value ][0] @property @@ -228,7 +228,7 @@ def _coordinate_values(self, coord: str) -> float | np.ndarray: @property def _is_modelresult(self) -> bool: - return bool(self.data[self.name].attrs["kind"] == "model") + return bool(self.data[self.name].attrs["kind"] == VariableKind.MODEL.value) @property def values(self) -> np.ndarray: @@ -242,7 +242,7 @@ def _values_as_series(self) -> pd.Series: @property def _aux_vars(self): - return list(self.data.filter_by_attrs(kind="aux").data_vars) + return list(self.data.filter_by_attrs(kind=VariableKind.AUXILIARY.value).data_vars) def __repr__(self) -> str: res = [] diff --git a/src/modelskill/timeseries/_track.py b/src/modelskill/timeseries/_track.py index c0779a10e..0ac4879a2 100644 --- a/src/modelskill/timeseries/_track.py +++ b/src/modelskill/timeseries/_track.py @@ -9,7 +9,7 @@ import mikeio -from ..types import GeometryType, TrackType +from ..types import GeometryType, TrackType, VariableKind from ..quantity import Quantity from ..utils import _get_name, make_unique_index from ._timeseries import _validate_data_var_name @@ -152,7 +152,7 @@ def _parse_track_input( ds[name].attrs["units"] = model_quantity.unit for aux_item in sel_items.aux: - ds[aux_item].attrs["kind"] = "aux" + ds[aux_item].attrs["kind"] = VariableKind.AUXILIARY.value ds.attrs["gtype"] = str(GeometryType.TRACK) assert isinstance(ds, xr.Dataset) diff --git a/src/modelskill/types.py b/src/modelskill/types.py index 0ba3dcd63..982a0cfc7 100644 --- a/src/modelskill/types.py +++ b/src/modelskill/types.py @@ -50,6 +50,14 @@ def from_string(s: str) -> "GeometryType": ) from e +class VariableKind(Enum): + """Kind of data variable in a matched dataset""" + + OBSERVATION = "observation" + MODEL = "model" + AUXILIARY = "aux" + + DataInputType = Union[ str, Path, diff --git a/tests/test_comparer.py b/tests/test_comparer.py index 97d4b1a93..15521941e 100644 --- a/tests/test_comparer.py +++ b/tests/test_comparer.py @@ -171,14 +171,14 @@ def test_matched_df_with_aux(pt_df): ) assert cmp.mod_names == ["m1", "m2"] assert cmp.n_points == 6 - assert cmp.data["wind"].attrs["kind"] == "auxiliary" + assert cmp.data["wind"].attrs["kind"] == "aux" assert "not_relevant" not in cmp.data.data_vars # if aux_items is a string, it is automatically converted to a list cmp = Comparer.from_matched_data( data=pt_df, mod_items=["m1", "m2"], aux_items="wind" ) - assert cmp.data["wind"].attrs["kind"] == "auxiliary" + assert cmp.data["wind"].attrs["kind"] == "aux" # if models are specified, it is NOT automatically considered an aux variable cmp = Comparer.from_matched_data(data=pt_df, mod_items=["m1", "m2"]) @@ -191,7 +191,7 @@ def test_aux_can_str_(pt_df): pt_df["area"] = ["a", "b", "c", "d", "e", "f"] cmp = Comparer.from_matched_data(pt_df, aux_items="area") - assert cmp.data["area"].attrs["kind"] == "auxiliary" + assert cmp.data["area"].attrs["kind"] == "aux" def test_mod_and_obs_must_be_numeric(): @@ -349,6 +349,52 @@ def test_minimal_matched_data(pt_df): assert cmp.n_models == 2 +def test_kind_must_be_observation_model_or_aux(pt_df): + """The kind attribute must be 'observation', 'model', or 'aux'.""" + data = xr.Dataset(pt_df) + data["Observation"].attrs["kind"] = "observation" + data["m1"].attrs["kind"] = "model" + data["m2"].attrs["kind"] = "aux" + data.attrs["name"] = "valid" + + cmp = Comparer.from_matched_data(data=data) + assert cmp.mod_names == ["m1"] + assert cmp.aux_names == ["m2"] + + # 'auxiliary' is normalized to 'aux' for backwards compatibility + data["m2"].attrs["kind"] = "auxiliary" + cmp = Comparer.from_matched_data(data=data) + assert cmp.data["m2"].attrs["kind"] == "aux" + + # invalid kind values are rejected + data["m2"].attrs["kind"] = "invalid" + with pytest.raises(ValueError, match="Invalid kind 'invalid'.*Must be one of"): + Comparer.from_matched_data(data=data) + + +def test_gtype_must_be_point_or_track(pt_df): + """The gtype attribute must be 'point' or 'track'.""" + data = xr.Dataset(pt_df) + data["Observation"].attrs["kind"] = "observation" + data["m1"].attrs["kind"] = "model" + data["m2"].attrs["kind"] = "model" + data.attrs["name"] = "valid" + + # valid gtype values are accepted + data.attrs["gtype"] = "point" + cmp = Comparer.from_matched_data(data=data) + assert cmp.gtype == "point" + + data.attrs["gtype"] = "track" + cmp = Comparer.from_matched_data(data=data) + assert cmp.gtype == "track" + + # invalid gtype values are rejected + data.attrs["gtype"] = "grid" + with pytest.raises(ValueError, match="Invalid gtype 'grid'.*Must be one of"): + Comparer.from_matched_data(data=data) + + def test_from_compared_data_doesnt_accept_missing_values_in_obs(): df = pd.DataFrame( { @@ -386,7 +432,7 @@ def test_multiple_forecasts_matched_data(): cmp = Comparer.from_matched_data(data=data) # no additional raw_mod_data assert len(cmp.raw_mod_data["m1"]) == 5 assert cmp.mod_names == ["m1"] - assert cmp.data["leadtime"].attrs["kind"] == "auxiliary" + assert cmp.data["leadtime"].attrs["kind"] == "aux" analysis = cmp.where(cmp.data["leadtime"] == 0) analysis.score() assert len(analysis.raw_mod_data["m1"]) == 5 @@ -407,7 +453,7 @@ def test_matched_aux_variables(pt_df): data["m2"].attrs["kind"] = "model" cmp = Comparer.from_matched_data(data=data) assert "wind" not in cmp.mod_names - assert cmp.data["wind"].attrs["kind"] == "auxiliary" + assert cmp.data["wind"].attrs["kind"] == "aux" def test_pc_properties(pc): diff --git a/tests/test_pointcompare.py b/tests/test_pointcompare.py index a75206574..f28b0eae0 100644 --- a/tests/test_pointcompare.py +++ b/tests/test_pointcompare.py @@ -326,5 +326,5 @@ def test_mod_aux_carried_over(klagshamn): cmp = ms.match(klagshamn, mr, spatial_method="contained") assert "U velocity" in cmp.data.data_vars assert cmp.data["U velocity"].values[0] == pytest.approx(-0.0360998) - assert cmp.data["U velocity"].attrs["kind"] == "aux" + assert cmp.data["U velocity"].attrs["kind"] == "aux" # normalized assert cmp.mod_names == ["Oresund2D_subset"]