Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/modelskill/comparison/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .. import Quantity
from ..types import GeometryType
from ..obs import PointObservation, TrackObservation, NodeObservation
from ..model import PointModelResult, TrackModelResult, NetworkModelResult
from ..model import PointModelResult, TrackModelResult
from ..timeseries._timeseries import _validate_data_var_name
from ._comparer_plotter import ComparerPlotter
from ..metrics import _parse_metric
Expand Down Expand Up @@ -461,7 +461,7 @@ def __init__(
matched_data: xr.Dataset,
raw_mod_data: dict[
str,
PointModelResult | TrackModelResult | NetworkModelResult | NodeModelResult,
PointModelResult | TrackModelResult | NodeModelResult,
]
| None = None,
) -> None:
Expand All @@ -488,7 +488,6 @@ def from_matched_data(
str,
PointModelResult
| TrackModelResult
| NetworkModelResult
| NodeModelResult,
]
] = None,
Expand Down Expand Up @@ -784,7 +783,7 @@ def _to_observation(self) -> PointObservation | TrackObservation | NodeObservati
else:
raise NotImplementedError(f"Unknown gtype: {self.gtype}")

def _to_model(self) -> list[PointModelResult | TrackModelResult]:
def _to_model(self) -> list[PointModelResult | TrackModelResult | NodeModelResult]:
mods = list(self.raw_mod_data.values())
return mods

Expand Down Expand Up @@ -1314,7 +1313,7 @@ def load(filename: Union[str, Path]) -> "Comparer":

if data.gtype == "point":
raw_mod_data: Dict[
str, PointModelResult | TrackModelResult | NetworkModelResult
str, PointModelResult | TrackModelResult | NodeModelResult
] = {}

for var in data.data_vars:
Expand Down
14 changes: 9 additions & 5 deletions src/modelskill/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Sequence,
TypeVar,
Union,
cast,
get_args,
overload,
)
Expand Down Expand Up @@ -350,17 +351,20 @@ def _match_single_obs(
if len(names) != len(set(names)):
raise ValueError(f"Duplicate model names found: {names}")

raw_mod_data = {}
raw_mod_data: dict[str, PointModelResult | TrackModelResult | NodeModelResult] = {}
for m in models:
if isinstance(m, (DfsuModelResult, GridModelResult, DummyModelResult)):
# These model types support spatial interpolation
extracted = m.extract(obs, spatial_method=spatial_method)
extracted: PointModelResult | TrackModelResult | NodeModelResult = m.extract(
cast(PointObservation | TrackObservation, obs),
spatial_method=spatial_method,
)
elif isinstance(m, NetworkModelResult):
# Network models use exact node selection (no spatial interpolation)
extracted = m.extract(obs)
extracted = m.extract(cast(NodeObservation, obs))
else:
# Other model types (e.g., already extracted TimeSeries)
extracted = m
# Other model types (already point/track timeseries - no extraction needed)
extracted = cast(PointModelResult | TrackModelResult, m)

raw_mod_data[m.name] = extracted

Expand Down
6 changes: 4 additions & 2 deletions src/modelskill/model/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from modelskill.obs import Observation
from modelskill.timeseries import TimeSeries, _parse_network_node_input

from ._base import SpatialField, _validate_overlap_in_time, SelectedItems
from ._base import _validate_overlap_in_time, SelectedItems
from ..obs import NodeObservation
from ..quantity import Quantity
from ..types import PointType
Expand Down Expand Up @@ -74,6 +74,8 @@ def __init__(
def node(self) -> int:
"""Node ID of model result"""
node_val = self.data.coords.get("node")
if node_val is None:
raise ValueError("No node coordinate found in data")
return int(node_val.item())

def interp_time(self, observation: Observation, **kwargs: Any) -> NodeModelResult:
Expand Down Expand Up @@ -155,7 +157,7 @@ def _get_valid_times(
return df[valid_idx].index


class NetworkModelResult(SpatialField):
class NetworkModelResult:
"""Model result for network data with time and node dimensions.

Construct a NetworkModelResult from an xarray.Dataset with time and node coordinates
Expand Down
4 changes: 2 additions & 2 deletions src/modelskill/obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def __init__(
super().__init__(data=data, weight=weight, attrs=attrs)

@property
def node(self) -> int:
def node(self) -> Optional[int]:
"""Node ID of observation"""
node_val = self.data.coords.get("node")
if node_val is not None:
Expand Down Expand Up @@ -509,7 +509,7 @@ def __iter__(self):
@property
def nodes(self) -> list[int]:
"""List of node IDs for all observations"""
return [obs.node for obs in self.observations]
return [obs.node for obs in self.observations if obs.node is not None]

@property
def names(self) -> list[str]:
Expand Down