Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a76005c
Add NetworkModelResult and supporting classes for network data handling
jpalm3r Feb 18, 2026
e4d9362
Simplify parsing logic
jpalm3r Feb 18, 2026
d0c0f92
feat: Add NetworkModelResult for handling network data in model skill
jpalm3r Feb 19, 2026
91d4513
small fixes
jpalm3r Feb 19, 2026
3d7df6c
feat: Add NodeModelResult and NodeObservation for network node handling
jpalm3r Feb 19, 2026
1fc1ecf
feat: Add extract_multiple method for batch extraction of NodeModelRe…
jpalm3r Feb 19, 2026
35882ab
feat: Enhance matching functionality to support NodeModelResult and N…
jpalm3r Feb 19, 2026
5a692ef
feat: Update matching logic to extract observations from NetworkModel…
jpalm3r Feb 19, 2026
e15ae63
feat: Add NetworkObservation class for handling collections of node o…
jpalm3r Feb 19, 2026
1219a51
Fix test
jpalm3r Feb 19, 2026
98f8a80
Importing network test data
jpalm3r Feb 19, 2026
91daa72
feat: Update observation handling to include NodeObservation and adju…
jpalm3r Feb 20, 2026
26a6d45
feat: Simplify node retrieval in NodeModelResult and enhance coordina…
jpalm3r Feb 20, 2026
89d4925
feat: Enhance time coordinate handling in NetworkModelResult and impr…
jpalm3r Feb 20, 2026
caeee13
feat: Refactor observation handling to support network geometry and s…
jpalm3r Feb 20, 2026
c8c3ebd
refactor: Simplify node retrieval in NodeObservation by removing unne…
jpalm3r Feb 20, 2026
f826ea5
Update src/modelskill/matching.py
jpalm3r Feb 20, 2026
87e1335
commit with errors
jpalm3r Feb 20, 2026
a4476ad
Fix mypy issues. Introduce Nework1D protocol
jpalm3r Feb 20, 2026
0966458
Removing NetworkObservation
jpalm3r Feb 20, 2026
4e5cd78
refactor: streamline NetworkModelResult extraction method by removing…
jpalm3r Feb 20, 2026
0f5af4d
refactor: simplify observation matching logic in _match_single_obs fu…
jpalm3r Feb 20, 2026
d32e337
fix: add assertion message for non-empty datetime indices in _get_glo…
jpalm3r Feb 20, 2026
636bc0f
test: add unit tests for NetworkModelResult and NodeObservation classes
jpalm3r Feb 20, 2026
6bfa6bc
refactor: update notebook to use new network dataset and remove obsol…
jpalm3r Feb 20, 2026
f91a4df
Rename notebook
jpalm3r Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
991 changes: 991 additions & 0 deletions notebooks/Collection_systems_network.ipynb

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions src/modelskill/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,10 @@
TrackModelResult,
GridModelResult,
DfsuModelResult,
NetworkModelResult,
DummyModelResult,
)
from .obs import (
observation,
PointObservation,
TrackObservation,
)
from .obs import observation, PointObservation, TrackObservation, NodeObservation
from .matching import from_matched, match
from .configuration import from_config
from .settings import options, get_option, set_option, reset_option, load_style
Expand Down Expand Up @@ -94,9 +91,11 @@ def load(filename: Union[str, Path]) -> Comparer | ComparerCollection:
"GridModelResult",
"DfsuModelResult",
"DummyModelResult",
"NetworkModelResult",
"observation",
"PointObservation",
"TrackObservation",
"NodeObservation",
"TimeSeries",
"match",
"from_matched",
Expand Down
103 changes: 78 additions & 25 deletions src/modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import xarray as xr
from copy import deepcopy

from ..model.network import NodeModelResult


from .. import metrics as mtr
from .. import Quantity
from ..types import GeometryType
from ..obs import PointObservation, TrackObservation
from ..obs import PointObservation, TrackObservation, NodeObservation
from ..model import PointModelResult, TrackModelResult
from ..timeseries._timeseries import _validate_data_var_name
from ._comparer_plotter import ComparerPlotter
Expand All @@ -49,6 +51,12 @@
Serializable = Union[str, int, float]


def _drop_scalar_coords(data: xr.Dataset) -> xr.Dataset:
"""Drop scalar coordinate variables that shouldn't appear as columns in dataframes"""
coords_to_drop = ["x", "y", "z", "node"]
return data.drop_vars(coords_to_drop, errors="ignore")


def _parse_dataset(data: xr.Dataset) -> xr.Dataset:
if not isinstance(data, xr.Dataset):
raise ValueError("matched_data must be an xarray.Dataset")
Expand All @@ -60,12 +68,15 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset:
raise ValueError("Observation data must not contain missing values.")

# coordinates
if "x" not in data.coords:
data.coords["x"] = np.nan
if "y" not in data.coords:
data.coords["y"] = np.nan
if "z" not in data.coords:
data.coords["z"] = np.nan
# Only add x, y, z coordinates if they don't exist and we don't have node coordinates
has_node_coords = "node" in data.coords
if not has_node_coords:
if "x" not in data.coords:
data.coords["x"] = np.nan
if "y" not in data.coords:
data.coords["y"] = np.nan
if "z" not in data.coords:
data.coords["z"] = np.nan

# Validate data
vars = [v for v in data.data_vars]
Expand Down Expand Up @@ -97,7 +108,11 @@ def _parse_dataset(data: xr.Dataset) -> xr.Dataset:

# Validate attrs
if "gtype" not in data.attrs:
data.attrs["gtype"] = str(GeometryType.POINT)
# Determine gtype based on available coordinates
if "node" in data.coords:
data.attrs["gtype"] = str(GeometryType.NETWORK)
else:
data.attrs["gtype"] = str(GeometryType.POINT)
# assert "gtype" in data.attrs, "data must have a gtype attribute"
# assert data.attrs["gtype"] in [
# str(GeometryType.POINT),
Expand Down Expand Up @@ -444,7 +459,11 @@ class Comparer:
def __init__(
self,
matched_data: xr.Dataset,
raw_mod_data: dict[str, PointModelResult | TrackModelResult] | None = None,
raw_mod_data: dict[
str,
PointModelResult | TrackModelResult | NodeModelResult,
]
| None = None,
) -> None:
self.data = _parse_dataset(matched_data)
self.raw_mod_data = (
Expand All @@ -464,7 +483,12 @@ def __init__(
@staticmethod
def from_matched_data(
data: xr.Dataset | pd.DataFrame,
raw_mod_data: Optional[Dict[str, PointModelResult | TrackModelResult]] = None,
raw_mod_data: Optional[
Dict[
str,
PointModelResult | TrackModelResult | NodeModelResult,
]
] = None,
obs_item: str | int | None = None,
mod_items: Optional[Iterable[str | int]] = None,
aux_items: Optional[Iterable[str | int]] = None,
Expand Down Expand Up @@ -582,7 +606,15 @@ def z(self) -> Any:
"""z-coordinate"""
return self._coordinate_values("z")

def _coordinate_values(self, coord: str) -> Any:
@property
def node(self) -> Any:
"""node-coordinate"""
return self._coordinate_values("node")

def _coordinate_values(self, coord: str) -> None | Any:
"""Get coordinate values if they exist, otherwise return None"""
if coord not in self.data.coords:
return None
vals = self.data[coord].values
return np.atleast_1d(vals)[0] if vals.ndim == 0 else vals

Expand Down Expand Up @@ -707,10 +739,10 @@ def rename(

return Comparer(matched_data=data, raw_mod_data=raw_mod_data)

def _to_observation(self) -> PointObservation | TrackObservation:
def _to_observation(self) -> PointObservation | TrackObservation | NodeObservation:
"""Convert to Observation"""
if self.gtype == "point":
df = self.data.drop_vars(["x", "y", "z"])[self._obs_str].to_dataframe()
df = _drop_scalar_coords(self.data)[self._obs_str].to_dataframe()
return PointObservation(
data=df,
name=self.name,
Expand All @@ -721,7 +753,9 @@ def _to_observation(self) -> PointObservation | TrackObservation:
# TODO: add attrs
)
elif self.gtype == "track":
df = self.data.drop_vars(["z"])[[self._obs_str]].to_dataframe()
df = self.data.drop_vars(["z"], errors="ignore")[
[self._obs_str]
].to_dataframe()
return TrackObservation(
data=df,
item=0,
Expand All @@ -731,10 +765,21 @@ def _to_observation(self) -> PointObservation | TrackObservation:
quantity=self.quantity,
# TODO: add attrs
)
elif self.gtype == "network":
df = _drop_scalar_coords(self.data)[self._obs_str].to_dataframe()
return NodeObservation(
data=df,
name=self.name,
node=self.node,
quantity=self.quantity,
# TODO: add attrs
)
else:
raise NotImplementedError(f"Unknown gtype: {self.gtype}")

def _to_model(self) -> list[PointModelResult | TrackModelResult]:
def _to_model(
self,
) -> list[PointModelResult | TrackModelResult | NodeModelResult]:
mods = list(self.raw_mod_data.values())
return mods

Expand Down Expand Up @@ -896,9 +941,9 @@ def _to_long_dataframe(
) -> pd.DataFrame:
"""Return a copy of the data as a long-format pandas DataFrame (for groupby operations)"""

data = self.data.drop_vars("z", errors="ignore")
data = self.data.drop_vars(["z", "node"], errors="ignore")

# this step is necessary since we keep arbitrary derived data in the dataset, but not z
# this step is necessary since we keep arbitrary derived data in the dataset, but not z/node
# i.e. using a hardcoded whitelist of variables to keep is less flexible
id_vars = [v for v in data.variables if v not in self.mod_names]

Expand Down Expand Up @@ -981,8 +1026,8 @@ def skill(

df = cmp._to_long_dataframe()
res = _groupby_df(df, by=by, metrics=metrics)
res["x"] = np.nan if self.gtype == "track" else cmp.x
res["y"] = np.nan if self.gtype == "track" else cmp.y
res["x"] = np.nan if self.gtype == "track" or cmp.x is None else cmp.x
res["y"] = np.nan if self.gtype == "track" or cmp.y is None else cmp.y
res = self._add_as_col_if_not_in_index(df, skilldf=res)
return SkillTable(res)

Expand Down Expand Up @@ -1138,7 +1183,7 @@ def gridded_skill(

@property
def _residual(self) -> np.ndarray:
df = self.data.drop_vars(["x", "y", "z"]).to_dataframe()
df = _drop_scalar_coords(self.data).to_dataframe()
obs = df[self._obs_str].values
mod = df[self.mod_names].values
return mod - np.vstack(obs)
Expand Down Expand Up @@ -1202,12 +1247,17 @@ def to_dataframe(self) -> pd.DataFrame:
if self.gtype == str(GeometryType.POINT):
# we remove the scalar coordinate variables as they
# will otherwise be columns in the dataframe
return self.data.drop_vars(["x", "y", "z"]).to_dataframe()
return _drop_scalar_coords(self.data).to_dataframe()
elif self.gtype == str(GeometryType.TRACK):
df = self.data.drop_vars(["z"]).to_dataframe()
# make sure that x, y cols are first
cols = ["x", "y"] + [c for c in df.columns if c not in ["x", "y"]]
df = self.data.drop_vars(["z"], errors="ignore").to_dataframe()
# make sure that x, y cols are first if they exist
coord_cols = [c for c in ["x", "y"] if c in df.columns]
other_cols = [c for c in df.columns if c not in ["x", "y"]]
cols = coord_cols + other_cols
return df[cols]
elif self.gtype == str(GeometryType.NETWORK):
# For network data, drop node coordinate like other geometries drop their coordinates
return _drop_scalar_coords(self.data).to_dataframe()
else:
raise NotImplementedError(f"Unknown gtype: {self.gtype}")

Expand Down Expand Up @@ -1258,7 +1308,10 @@ def load(filename: Union[str, Path]) -> "Comparer":
return Comparer(matched_data=data)

if data.gtype == "point":
raw_mod_data: Dict[str, PointModelResult | TrackModelResult] = {}
raw_mod_data: Dict[
str,
PointModelResult | TrackModelResult | NodeModelResult,
] = {}

for var in data.data_vars:
var_name = str(var)
Expand Down
60 changes: 45 additions & 15 deletions src/modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,15 @@
from .model.dfsu import DfsuModelResult
from .model.dummy import DummyModelResult
from .model.grid import GridModelResult
from .model.network import NetworkModelResult, NodeModelResult
from .model.track import TrackModelResult
from .obs import Observation, PointObservation, TrackObservation
from .model.point import align_data
from .obs import (
Observation,
PointObservation,
TrackObservation,
NodeObservation,
)
from .timeseries import TimeSeries
from .types import Period

Expand All @@ -40,8 +47,13 @@
GridModelResult,
DfsuModelResult,
TrackModelResult,
NetworkModelResult,
DummyModelResult,
]
Fieldypes = Union[
GridModelResult,
DfsuModelResult,
]
MRInputType = Union[
str,
Path,
Expand All @@ -56,7 +68,7 @@
TimeSeries,
MRTypes,
]
ObsTypes = Union[PointObservation, TrackObservation]
ObsTypes = Union[PointObservation, TrackObservation, NodeObservation]
ObsInputType = Union[
str,
Path,
Expand All @@ -67,7 +79,6 @@
pd.Series,
ObsTypes,
]

T = TypeVar("T", bound="TimeSeries")


Expand Down Expand Up @@ -248,6 +259,7 @@ def match(
--------
from_matched - Create a Comparer from observation and model results that are already matched
"""

if isinstance(obs, get_args(ObsInputType)):
return _match_single_obs(
obs,
Expand All @@ -267,7 +279,15 @@ def match(

if len(obs) > 1 and isinstance(mod, Collection) and len(mod) > 1:
if not all(
isinstance(m, (DfsuModelResult, GridModelResult, DummyModelResult))
isinstance(
m,
(
DfsuModelResult,
GridModelResult,
NetworkModelResult,
DummyModelResult,
),
)
for m in mod
):
raise ValueError(
Expand Down Expand Up @@ -317,14 +337,19 @@ def _match_single_obs(
if len(names) != len(set(names)):
raise ValueError(f"Duplicate model names found: {names}")

raw_mod_data = {
m.name: (
m.extract(obs, spatial_method=spatial_method)
if isinstance(m, (DfsuModelResult, GridModelResult, DummyModelResult))
else m
)
for m in models
}
raw_mod_data: dict[str, PointModelResult | TrackModelResult | NodeModelResult] = {}
for m in models:
is_field = isinstance(m, (GridModelResult, DfsuModelResult))
is_dummy = isinstance(m, DummyModelResult)
is_network = isinstance(m, NetworkModelResult)
if is_field or is_dummy:
matching_obs = m.extract(obs, spatial_method=spatial_method)
elif is_network:
matching_obs = m.extract(obs)
else:
matching_obs = m

raw_mod_data[m.name] = matching_obs

matched_data = _match_space_time(
observation=obs,
Expand All @@ -341,7 +366,7 @@ def _match_single_obs(


def _get_global_start_end(idxs: Iterable[pd.DatetimeIndex]) -> Period:
assert all([len(x) > 0 for x in idxs])
assert all([len(x) > 0 for x in idxs]), "All datetime indices must be non-empty"

starts = [x[0] for x in idxs]
ends = [x[-1] for x in idxs]
Expand All @@ -351,7 +376,9 @@ def _get_global_start_end(idxs: Iterable[pd.DatetimeIndex]) -> Period:

def _match_space_time(
observation: Observation,
raw_mod_data: Mapping[str, PointModelResult | TrackModelResult],
raw_mod_data: Mapping[
str, PointModelResult | TrackModelResult | NetworkModelResult | NodeModelResult
],
max_model_gap: float | None,
spatial_tolerance: float,
obs_no_overlap: Literal["ignore", "error", "warn"],
Expand All @@ -374,7 +401,10 @@ def _match_space_time(
observation, spatial_tolerance=spatial_tolerance
)
case PointModelResult() as pmr, PointObservation():
aligned = pmr.align(observation, max_gap=max_model_gap)
aligned = align_data(pmr.data, observation, max_gap=max_model_gap)
case NodeModelResult() as nmr, NodeObservation():
# mr is the extracted NodeModelResult
aligned = align_data(nmr.data, observation, max_gap=max_model_gap)
case _:
raise TypeError(
f"Matching not implemented for model type {type(mr)} and observation type {type(observation)}"
Expand Down
Loading