From 9b3193ae42d3a0ecba273c42969b117f8ad9dc54 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Feb 2026 08:24:10 +0000 Subject: [PATCH 1/2] Initial plan From 5550ee8aaaecb75b0cd2f672ecf715ef9680a9e3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 20 Feb 2026 08:49:31 +0000 Subject: [PATCH 2/2] fix: resolve all mypy type errors in network model support Co-authored-by: jpalm3r <28826351+jpalm3r@users.noreply.github.com> --- src/modelskill/comparison/_comparison.py | 9 ++++----- src/modelskill/matching.py | 14 +++++++++----- src/modelskill/model/network.py | 6 ++++-- src/modelskill/obs.py | 4 ++-- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/modelskill/comparison/_comparison.py b/src/modelskill/comparison/_comparison.py index 22c576733..eaa027dea 100644 --- a/src/modelskill/comparison/_comparison.py +++ b/src/modelskill/comparison/_comparison.py @@ -27,7 +27,7 @@ from .. import Quantity from ..types import GeometryType from ..obs import PointObservation, TrackObservation, NodeObservation -from ..model import PointModelResult, TrackModelResult, NetworkModelResult +from ..model import PointModelResult, TrackModelResult from ..timeseries._timeseries import _validate_data_var_name from ._comparer_plotter import ComparerPlotter from ..metrics import _parse_metric @@ -461,7 +461,7 @@ def __init__( matched_data: xr.Dataset, raw_mod_data: dict[ str, - PointModelResult | TrackModelResult | NetworkModelResult | NodeModelResult, + PointModelResult | TrackModelResult | NodeModelResult, ] | None = None, ) -> None: @@ -488,7 +488,6 @@ def from_matched_data( str, PointModelResult | TrackModelResult - | NetworkModelResult | NodeModelResult, ] ] = None, @@ -784,7 +783,7 @@ def _to_observation(self) -> PointObservation | TrackObservation | NodeObservati else: raise NotImplementedError(f"Unknown gtype: {self.gtype}") - def _to_model(self) -> list[PointModelResult | TrackModelResult]: + def _to_model(self) -> list[PointModelResult | TrackModelResult | NodeModelResult]: mods = list(self.raw_mod_data.values()) return mods @@ -1314,7 +1313,7 @@ def load(filename: Union[str, Path]) -> "Comparer": if data.gtype == "point": raw_mod_data: Dict[ - str, PointModelResult | TrackModelResult | NetworkModelResult + str, PointModelResult | TrackModelResult | NodeModelResult ] = {} for var in data.data_vars: diff --git a/src/modelskill/matching.py b/src/modelskill/matching.py index 0a1adaae0..781bedfdc 100644 --- a/src/modelskill/matching.py +++ b/src/modelskill/matching.py @@ -11,6 +11,7 @@ Sequence, TypeVar, Union, + cast, get_args, overload, ) @@ -350,17 +351,20 @@ def _match_single_obs( if len(names) != len(set(names)): raise ValueError(f"Duplicate model names found: {names}") - raw_mod_data = {} + raw_mod_data: dict[str, PointModelResult | TrackModelResult | NodeModelResult] = {} for m in models: if isinstance(m, (DfsuModelResult, GridModelResult, DummyModelResult)): # These model types support spatial interpolation - extracted = m.extract(obs, spatial_method=spatial_method) + extracted: PointModelResult | TrackModelResult | NodeModelResult = m.extract( + cast(PointObservation | TrackObservation, obs), + spatial_method=spatial_method, + ) elif isinstance(m, NetworkModelResult): # Network models use exact node selection (no spatial interpolation) - extracted = m.extract(obs) + extracted = m.extract(cast(NodeObservation, obs)) else: - # Other model types (e.g., already extracted TimeSeries) - extracted = m + # Other model types (already point/track timeseries - no extraction needed) + extracted = cast(PointModelResult | TrackModelResult, m) raw_mod_data[m.name] = extracted diff --git a/src/modelskill/model/network.py b/src/modelskill/model/network.py index c4f1e8e9e..ec6fef8d1 100644 --- a/src/modelskill/model/network.py +++ b/src/modelskill/model/network.py @@ -8,7 +8,7 @@ from modelskill.obs import Observation from modelskill.timeseries import TimeSeries, _parse_network_node_input -from ._base import SpatialField, _validate_overlap_in_time, SelectedItems +from ._base import _validate_overlap_in_time, SelectedItems from ..obs import NodeObservation from ..quantity import Quantity from ..types import PointType @@ -74,6 +74,8 @@ def __init__( def node(self) -> int: """Node ID of model result""" node_val = self.data.coords.get("node") + if node_val is None: + raise ValueError("No node coordinate found in data") return int(node_val.item()) def interp_time(self, observation: Observation, **kwargs: Any) -> NodeModelResult: @@ -155,7 +157,7 @@ def _get_valid_times( return df[valid_idx].index -class NetworkModelResult(SpatialField): +class NetworkModelResult: """Model result for network data with time and node dimensions. Construct a NetworkModelResult from an xarray.Dataset with time and node coordinates diff --git a/src/modelskill/obs.py b/src/modelskill/obs.py index 73261b74c..fc1c79011 100644 --- a/src/modelskill/obs.py +++ b/src/modelskill/obs.py @@ -408,7 +408,7 @@ def __init__( super().__init__(data=data, weight=weight, attrs=attrs) @property - def node(self) -> int: + def node(self) -> Optional[int]: """Node ID of observation""" node_val = self.data.coords.get("node") if node_val is not None: @@ -509,7 +509,7 @@ def __iter__(self): @property def nodes(self) -> list[int]: """List of node IDs for all observations""" - return [obs.node for obs in self.observations] + return [obs.node for obs in self.observations if obs.node is not None] @property def names(self) -> list[str]: