Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions src/modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .. import metrics as mtr
from .. import Quantity
from ..types import GeometryType
from ..types import GeometryType, VariableKind
from ..obs import PointObservation, TrackObservation
from ..model import PointModelResult, TrackModelResult
from ..timeseries._timeseries import _validate_data_var_name
Expand Down Expand Up @@ -79,8 +79,13 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset:
assert (
list(data[v].dims)[0] == "time"
), f"All data arrays must have a time dimension; {v} has dimensions {data[v].dims}"
if "kind" not in data[v].attrs:
data[v].attrs["kind"] = "auxiliary"
# Normalize kind attribute (Postel's Law: be liberal in what you accept)
match data[v].attrs.get("kind"):
case None | "auxiliary":
data[v].attrs["kind"] = VariableKind.AUXILIARY.value
case k if k not in (VariableKind.OBSERVATION.value, VariableKind.MODEL.value, VariableKind.AUXILIARY.value):
valid = [e.value for e in VariableKind]
raise ValueError(f"Invalid kind '{k}' for variable '{v}'. Must be one of {valid}")

n_mod = sum([_is_model(da) for da in data.data_vars.values()])
n_obs = sum([_is_observation(da) for da in data.data_vars.values()])
Expand All @@ -95,14 +100,12 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset:
"dataset must have at least one model array (marked by the kind attribute)"
)

# Validate attrs
# Validate gtype attribute
if "gtype" not in data.attrs:
data.attrs["gtype"] = str(GeometryType.POINT)
# assert "gtype" in data.attrs, "data must have a gtype attribute"
# assert data.attrs["gtype"] in [
# str(GeometryType.POINT),
# str(GeometryType.TRACK),
# ], f"data attribute 'gtype' must be one of {GeometryType.POINT} or {GeometryType.TRACK}"
elif data.attrs["gtype"] not in (str(GeometryType.POINT), str(GeometryType.TRACK)):
valid = [str(GeometryType.POINT), str(GeometryType.TRACK)]
raise ValueError(f"Invalid gtype '{data.attrs['gtype']}'. Must be one of {valid}")

if "color" not in data["Observation"].attrs:
data["Observation"].attrs["color"] = "black"
Expand All @@ -121,11 +124,11 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset:


def _is_observation(da: xr.DataArray) -> bool:
return str(da.attrs["kind"]) == "observation"
return str(da.attrs["kind"]) == VariableKind.OBSERVATION.value


def _is_model(da: xr.DataArray) -> bool:
return str(da.attrs["kind"]) == "model"
return str(da.attrs["kind"]) == VariableKind.MODEL.value


def _validate_metrics(metrics: Iterable[Any]) -> None:
Expand Down Expand Up @@ -350,11 +353,11 @@ def _matched_data_to_xarray(
assert isinstance(ds, xr.Dataset)

ds.attrs["name"] = name if name is not None else items.obs
ds["Observation"].attrs["kind"] = "observation"
ds["Observation"].attrs["kind"] = VariableKind.OBSERVATION.value
for m in items.model:
ds[m].attrs["kind"] = "model"
ds[m].attrs["kind"] = VariableKind.MODEL.value
for a in items.aux:
ds[a].attrs["kind"] = "auxiliary"
ds[a].attrs["kind"] = VariableKind.AUXILIARY.value

if x_item is not None:
ds = ds.rename({items.x: "x"}).set_coords("x")
Expand Down Expand Up @@ -454,7 +457,7 @@ def __init__(
# key: ModelResult(value, gtype=self.data.gtype, name=key, x=self.x, y=self.y)
str(key): PointModelResult(self.data[[str(key)]], name=str(key))
for key, value in matched_data.data_vars.items()
if value.attrs["kind"] == "model"
if value.attrs["kind"] == VariableKind.MODEL.value
}
)

Expand Down Expand Up @@ -602,11 +605,11 @@ def __contains__(self, key: str) -> bool:
@property
def aux_names(self) -> List[str]:
"""List of auxiliary data names"""
# we don't require the kind attribute to be "auxiliary"
# Postel's Law: normalize at the boundary, be strict internally
return [
str(k)
for k, v in self.data.data_vars.items()
if v.attrs["kind"] not in ["observation", "model"]
if v.attrs["kind"] == VariableKind.AUXILIARY.value
]

# TODO: always "Observation", necessary to have this property?
Expand Down
8 changes: 4 additions & 4 deletions src/modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
observation,
)
from .timeseries import TimeSeries
from .types import Period
from .types import Period, VariableKind

TimeDeltaTypes = Union[float, int, np.timedelta64, pd.Timedelta, timedelta]
IdxOrNameTypes = Optional[Union[int, str]]
Expand Down Expand Up @@ -410,8 +410,8 @@ def _match_space_time(
f"Matching not implemented for model type {type(mr)} and observation type {type(observation)}"
)

if overlapping := set(aligned.filter_by_attrs(kind="aux").data_vars) & set(
observation.data.filter_by_attrs(kind="aux").data_vars
if overlapping := set(aligned.filter_by_attrs(kind=VariableKind.AUXILIARY.value).data_vars) & set(
observation.data.filter_by_attrs(kind=VariableKind.AUXILIARY.value).data_vars
):
raise ValueError(
f"Aux variables are not allowed to have identical names. Choose either aux from obs or model. Overlapping: {overlapping}"
Expand All @@ -422,7 +422,7 @@ def _match_space_time(

# drop NaNs in model and observation columns (but allow NaNs in aux columns)
def mo_kind(k: str) -> bool:
return k in ["model", "observation"]
return k in [VariableKind.MODEL.value, VariableKind.OBSERVATION.value]

# TODO mo_cols vs non_aux_cols?
mo_cols = data.filter_by_attrs(kind=mo_kind).data_vars
Expand Down
4 changes: 2 additions & 2 deletions src/modelskill/model/point.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd

from ..obs import Observation
from ..types import PointType
from ..types import PointType, VariableKind
from ..quantity import Quantity
from ..timeseries import TimeSeries, _parse_xyz_point_input

Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
assert isinstance(data, xr.Dataset)

data_var = str(list(data.data_vars)[0])
data[data_var].attrs["kind"] = "model"
data[data_var].attrs["kind"] = VariableKind.MODEL.value
super().__init__(data=data)

def interp_time(self, observation: Observation, **kwargs: Any) -> PointModelResult:
Expand Down
4 changes: 2 additions & 2 deletions src/modelskill/model/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import xarray as xr

from ..types import TrackType
from ..types import TrackType, VariableKind
from ..obs import TrackObservation
from ..quantity import Quantity
from ..timeseries import TimeSeries, _parse_track_input
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(

assert isinstance(data, xr.Dataset)
data_var = str(list(data.data_vars)[0])
data[data_var].attrs["kind"] = "model"
data[data_var].attrs["kind"] = VariableKind.MODEL.value
super().__init__(data=data)

def subset_to(
Expand Down
4 changes: 2 additions & 2 deletions src/modelskill/obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pandas as pd
import xarray as xr

from .types import PointType, TrackType, GeometryType, DataInputType
from .types import PointType, TrackType, GeometryType, DataInputType, VariableKind
from . import Quantity
from .timeseries import (
TimeSeries,
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
assert isinstance(data, xr.Dataset)

data_var = str(list(data.data_vars)[0])
data[data_var].attrs["kind"] = "observation"
data[data_var].attrs["kind"] = VariableKind.OBSERVATION.value

# check that user-defined attrs don't overwrite existing attrs!
_validate_attrs(data.attrs, attrs)
Expand Down
4 changes: 2 additions & 2 deletions src/modelskill/timeseries/_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import mikeio

from ..types import GeometryType, PointType
from ..types import GeometryType, PointType, VariableKind
from ..quantity import Quantity
from ..utils import _get_name
from ._timeseries import _validate_data_var_name
Expand Down Expand Up @@ -172,7 +172,7 @@ def _include_attributes(
ds[name].attrs["is_directional"] = int(quantity.is_directional)

for aux_item in sel_items.aux:
ds[aux_item].attrs["kind"] = "aux"
ds[aux_item].attrs["kind"] = VariableKind.AUXILIARY.value

return ds

Expand Down
14 changes: 7 additions & 7 deletions src/modelskill/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import xarray as xr

from ..types import GeometryType
from ..types import GeometryType, VariableKind
from ..quantity import Quantity
from ._plotter import TimeSeriesPlotter, MatplotlibTimeSeriesPlotter
from .. import __version__
Expand Down Expand Up @@ -86,8 +86,8 @@ def _validate_dataset(ds: xr.Dataset) -> xr.Dataset:
list(ds[v].dims)[0] == "time"
), f"All data arrays must have a time dimension; {v} has dimensions {ds[v].dims}"
if "kind" not in ds[v].attrs:
ds[v].attrs["kind"] = "auxiliary"
if ds[v].attrs["kind"] in ["model", "observation"]:
ds[v].attrs["kind"] = VariableKind.AUXILIARY.value
if ds[v].attrs["kind"] in [VariableKind.MODEL.value, VariableKind.OBSERVATION.value]:
n_primary += 1
name = v

Expand Down Expand Up @@ -152,8 +152,8 @@ def _val_item(self) -> str:
return [
str(v)
for v in self.data.data_vars
if self.data[v].attrs["kind"] == "model"
or self.data[v].attrs["kind"] == "observation"
if self.data[v].attrs["kind"] == VariableKind.MODEL.value
or self.data[v].attrs["kind"] == VariableKind.OBSERVATION.value
][0]

@property
Expand Down Expand Up @@ -228,7 +228,7 @@ def _coordinate_values(self, coord: str) -> float | np.ndarray:

@property
def _is_modelresult(self) -> bool:
return bool(self.data[self.name].attrs["kind"] == "model")
return bool(self.data[self.name].attrs["kind"] == VariableKind.MODEL.value)

@property
def values(self) -> np.ndarray:
Expand All @@ -242,7 +242,7 @@ def _values_as_series(self) -> pd.Series:

@property
def _aux_vars(self):
return list(self.data.filter_by_attrs(kind="aux").data_vars)
return list(self.data.filter_by_attrs(kind=VariableKind.AUXILIARY.value).data_vars)

def __repr__(self) -> str:
res = []
Expand Down
4 changes: 2 additions & 2 deletions src/modelskill/timeseries/_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import mikeio

from ..types import GeometryType, TrackType
from ..types import GeometryType, TrackType, VariableKind
from ..quantity import Quantity
from ..utils import _get_name, make_unique_index
from ._timeseries import _validate_data_var_name
Expand Down Expand Up @@ -152,7 +152,7 @@ def _parse_track_input(
ds[name].attrs["units"] = model_quantity.unit

for aux_item in sel_items.aux:
ds[aux_item].attrs["kind"] = "aux"
ds[aux_item].attrs["kind"] = VariableKind.AUXILIARY.value

ds.attrs["gtype"] = str(GeometryType.TRACK)
assert isinstance(ds, xr.Dataset)
Expand Down
8 changes: 8 additions & 0 deletions src/modelskill/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def from_string(s: str) -> "GeometryType":
) from e


class VariableKind(Enum):
"""Kind of data variable in a matched dataset"""

OBSERVATION = "observation"
MODEL = "model"
AUXILIARY = "aux"


DataInputType = Union[
str,
Path,
Expand Down
56 changes: 51 additions & 5 deletions tests/test_comparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,14 @@ def test_matched_df_with_aux(pt_df):
)
assert cmp.mod_names == ["m1", "m2"]
assert cmp.n_points == 6
assert cmp.data["wind"].attrs["kind"] == "auxiliary"
assert cmp.data["wind"].attrs["kind"] == "aux"
assert "not_relevant" not in cmp.data.data_vars

# if aux_items is a string, it is automatically converted to a list
cmp = Comparer.from_matched_data(
data=pt_df, mod_items=["m1", "m2"], aux_items="wind"
)
assert cmp.data["wind"].attrs["kind"] == "auxiliary"
assert cmp.data["wind"].attrs["kind"] == "aux"

# if models are specified, it is NOT automatically considered an aux variable
cmp = Comparer.from_matched_data(data=pt_df, mod_items=["m1", "m2"])
Expand All @@ -191,7 +191,7 @@ def test_aux_can_str_(pt_df):
pt_df["area"] = ["a", "b", "c", "d", "e", "f"]

cmp = Comparer.from_matched_data(pt_df, aux_items="area")
assert cmp.data["area"].attrs["kind"] == "auxiliary"
assert cmp.data["area"].attrs["kind"] == "aux"


def test_mod_and_obs_must_be_numeric():
Expand Down Expand Up @@ -349,6 +349,52 @@ def test_minimal_matched_data(pt_df):
assert cmp.n_models == 2


def test_kind_must_be_observation_model_or_aux(pt_df):
"""The kind attribute must be 'observation', 'model', or 'aux'."""
data = xr.Dataset(pt_df)
data["Observation"].attrs["kind"] = "observation"
data["m1"].attrs["kind"] = "model"
data["m2"].attrs["kind"] = "aux"
data.attrs["name"] = "valid"

cmp = Comparer.from_matched_data(data=data)
assert cmp.mod_names == ["m1"]
assert cmp.aux_names == ["m2"]

# 'auxiliary' is normalized to 'aux' for backwards compatibility
data["m2"].attrs["kind"] = "auxiliary"
cmp = Comparer.from_matched_data(data=data)
assert cmp.data["m2"].attrs["kind"] == "aux"

# invalid kind values are rejected
data["m2"].attrs["kind"] = "invalid"
with pytest.raises(ValueError, match="Invalid kind 'invalid'.*Must be one of"):
Comparer.from_matched_data(data=data)


def test_gtype_must_be_point_or_track(pt_df):
"""The gtype attribute must be 'point' or 'track'."""
data = xr.Dataset(pt_df)
data["Observation"].attrs["kind"] = "observation"
data["m1"].attrs["kind"] = "model"
data["m2"].attrs["kind"] = "model"
data.attrs["name"] = "valid"

# valid gtype values are accepted
data.attrs["gtype"] = "point"
cmp = Comparer.from_matched_data(data=data)
assert cmp.gtype == "point"

data.attrs["gtype"] = "track"
cmp = Comparer.from_matched_data(data=data)
assert cmp.gtype == "track"

# invalid gtype values are rejected
data.attrs["gtype"] = "grid"
with pytest.raises(ValueError, match="Invalid gtype 'grid'.*Must be one of"):
Comparer.from_matched_data(data=data)


def test_from_compared_data_doesnt_accept_missing_values_in_obs():
df = pd.DataFrame(
{
Expand Down Expand Up @@ -386,7 +432,7 @@ def test_multiple_forecasts_matched_data():
cmp = Comparer.from_matched_data(data=data) # no additional raw_mod_data
assert len(cmp.raw_mod_data["m1"]) == 5
assert cmp.mod_names == ["m1"]
assert cmp.data["leadtime"].attrs["kind"] == "auxiliary"
assert cmp.data["leadtime"].attrs["kind"] == "aux"
analysis = cmp.where(cmp.data["leadtime"] == 0)
analysis.score()
assert len(analysis.raw_mod_data["m1"]) == 5
Expand All @@ -407,7 +453,7 @@ def test_matched_aux_variables(pt_df):
data["m2"].attrs["kind"] = "model"
cmp = Comparer.from_matched_data(data=data)
assert "wind" not in cmp.mod_names
assert cmp.data["wind"].attrs["kind"] == "auxiliary"
assert cmp.data["wind"].attrs["kind"] == "aux"


def test_pc_properties(pc):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pointcompare.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,5 +326,5 @@ def test_mod_aux_carried_over(klagshamn):
cmp = ms.match(klagshamn, mr, spatial_method="contained")
assert "U velocity" in cmp.data.data_vars
assert cmp.data["U velocity"].values[0] == pytest.approx(-0.0360998)
assert cmp.data["U velocity"].attrs["kind"] == "aux"
assert cmp.data["U velocity"].attrs["kind"] == "aux" # normalized
assert cmp.mod_names == ["Oresund2D_subset"]