diff --git a/src/modelskill/comparison/_collection_plotter.py b/src/modelskill/comparison/_collection_plotter.py index 2122b8488..0c8d7020a 100644 --- a/src/modelskill/comparison/_collection_plotter.py +++ b/src/modelskill/comparison/_collection_plotter.py @@ -804,10 +804,9 @@ def spatial_overview( """ from ..plotting import spatial_overview - obs = [cmp._to_observation() for cmp in self.cc] # TODO how to add model domain(s) - return spatial_overview(obs, ax=ax, figsize=figsize, title=title) + return spatial_overview(self.cc, ax=ax, figsize=figsize, title=title) def temporal_coverage( self, @@ -835,11 +834,10 @@ def temporal_coverage( """ from ..plotting import temporal_coverage - obs = [cmp._to_observation() for cmp in self.cc] - mod = self.cc[0]._to_model() + mod = list(self.cc[0].raw_mod_data.values()) return temporal_coverage( - obs=obs, + obs=list(self.cc), mod=mod, limit_to_model_period=limit_to_model_period, marker=marker, diff --git a/src/modelskill/comparison/_comparison.py b/src/modelskill/comparison/_comparison.py index f8fedef75..abdf73368 100644 --- a/src/modelskill/comparison/_comparison.py +++ b/src/modelskill/comparison/_comparison.py @@ -24,7 +24,6 @@ from .. import metrics as mtr from .. import Quantity from ..types import GeometryType -from ..obs import PointObservation, TrackObservation from ..model import PointModelResult, TrackModelResult from ..timeseries._timeseries import _validate_data_var_name from ._comparer_plotter import ComparerPlotter @@ -707,37 +706,6 @@ def rename( return Comparer(matched_data=data, raw_mod_data=raw_mod_data) - def _to_observation(self) -> PointObservation | TrackObservation: - """Convert to Observation""" - if self.gtype == "point": - df = self.data.drop_vars(["x", "y", "z"])[self._obs_str].to_dataframe() - return PointObservation( - data=df, - name=self.name, - x=self.x, - y=self.y, - z=self.z, - quantity=self.quantity, - # TODO: add attrs - ) - elif self.gtype == "track": - df = self.data.drop_vars(["z"])[[self._obs_str]].to_dataframe() - return TrackObservation( - data=df, - item=0, - x_item=1, - y_item=2, - name=self.name, - quantity=self.quantity, - # TODO: add attrs - ) - else: - raise NotImplementedError(f"Unknown gtype: {self.gtype}") - - def _to_model(self) -> list[PointModelResult | TrackModelResult]: - mods = list(self.raw_mod_data.values()) - return mods - def __add__(self, other): warnings.warn( "Merging comparers using + is deprecated, use .merge instead.", diff --git a/src/modelskill/plotting/_spatial_overview.py b/src/modelskill/plotting/_spatial_overview.py index 156af04c1..754c7e339 100644 --- a/src/modelskill/plotting/_spatial_overview.py +++ b/src/modelskill/plotting/_spatial_overview.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional, Iterable, Tuple, TYPE_CHECKING +from typing import Optional, Iterable, Protocol, Tuple, TYPE_CHECKING if TYPE_CHECKING: import matplotlib.axes @@ -8,12 +8,26 @@ from ..model.point import PointModelResult from ..model.track import TrackModelResult -from ..obs import Observation, PointObservation, TrackObservation from ._misc import _get_ax +class SpatialOverviewItem(Protocol): + """Protocol for items that can be plotted in spatial_overview.""" + + @property + def gtype(self) -> str: ... + @property + def x(self): ... + @property + def y(self): ... + @property + def name(self) -> str: ... + @property + def n_points(self) -> int: ... + + def spatial_overview( - obs: Observation | Iterable[Observation], + obs: SpatialOverviewItem | Iterable[SpatialOverviewItem], mod: Optional[ DfsuModelResult | GeometryFM2D @@ -82,9 +96,9 @@ def spatial_overview( g.plot.outline(ax=ax) # type: ignore for o in obs: - if isinstance(o, PointObservation): + if o.gtype == "point": ax.scatter(x=o.x, y=o.y, marker="x") - elif isinstance(o, TrackObservation): + elif o.gtype == "track": if o.n_points < 10000: ax.scatter(x=o.x, y=o.y, marker=".") else: @@ -92,14 +106,14 @@ def spatial_overview( # TODO: group by lonlat bin or sample randomly else: raise ValueError( - f"Could not show observation {o}. Only PointObservation and TrackObservation supported." + f"Could not show {o}. Only point and track geometry types supported." ) xlim = ax.get_xlim() offset_x = 0.02 * (xlim[1] - xlim[0]) for o in obs: - if isinstance(o, PointObservation): + if o.gtype == "point": # TODO adjust xlim to accomodate text ax.annotate(o.name, (o.x + offset_x, o.y)) # type: ignore